This is an archived version of the course. Please find the latest version of the course on the main webpage.

Chapter 3: PyTorch Tensors

Reshaping tensors

face Luca Grillotti

One of the tensor methods you will be using the most is view(*shape). This method reshapes the tensor into the requested shape.

import torch

tensor_shape_3_2 = torch.Tensor([[1, 2],
                                 [3, 4],
                                 [5, 6]])  # shape = (3, 2)

tensor_shape_2_3 = tensor_shape_3_2.view(2, 3)  # shape = (2, 3)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

tensor_shape_1_6_1 = tensor_shape_3_2.view(1, 6, 1)
# tensor([[[1.],
#          [2.],
#          [3.],
#          [4.],
#          [5.],
#          [6.]]])

tensor_shape_6 = tensor_shape_3_2.view(6) # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])

If you know the sizes of all dimensions except one, you can simply put -1.

For example, with the code provided above, if you do tensor_shape_3_2.view(-1, 3), torch will automatically guess that the -1 should be replaced with 2.

import torch

tensor_shape_3_2 = torch.Tensor([[1, 2],
                                 [3, 4],
                                 [5, 6]])  # shape = (3, 2)

tensor_shape_2_3 = tensor_shape_3_2.view(-1, 3)  # shape = (-1, 3) = (2, 3) here
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

tensor_shape_6 = tensor_shape_3_2.view(-1)  # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])

Note that view(...) produces a shallow copy of the tensor (reshaped). If you change any of the elements in the view, you will also update the value in the original tensor.

import torch

tensor_shape_3_2 = torch.Tensor([[1, 2],
                                 [3, 4],
                                 [5, 6]])  # shape = (3, 2)
# tensor([[1., 2.],
#         [3., 4.],
#         [5., 6.]])

tensor_shape_2_3 = tensor_shape_3_2.view(2, 3)  # shape = (2, 3)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

tensor_shape_2_3[0, 0] = 9

print(tensor_shape_2_3)
# tensor([[9., 2., 3.],
#         [4., 5., 6.]])

print(tensor_shape_3_2)
# tensor([[9., 2.],
#         [3., 4.],
#         [5., 6.]])