Suppose you have a tensor t_images
containing several images of size: 28x28 pixels.
t_images
is thus of size (N_batch, 1, 28, 28)
Write a function flatten_images
that convert any tensor of size (N_batch, 1, 28, 28)
into a tensor containing flattened images (i.e. a tensor of size: (N_batch, 1 * 28 * 28)
).
def flatten_images(tensor_images):
"""
Args:
tensor_images (Tensor): assumed to be of size (N_batch, 1, 28, 28) where: N_batch is unknown
Returns:
same tensor as given as input, but of size (N_batch, 1 * 28 * 28)
"""
Sample solution:
Here is a possible implementation:
def flatten_images(tensor_images):
"""
Args:
tensor_images (Tensor): assumed to be of size (N_batch, 1, 28, 28) where: N_batch is unknown
Returns:
same tensor as given as input, but of size (N_batch, 1 * 28 * 28)
"""
return tensor_images.view(-1, 1 * 28 * 28)