Introduction to Deep Learning with PyTorch
Chapter 3: PyTorch Tensors
Reshaping tensors
One of the tensor methods you will be using the most is view(*shape)
. This method reshapes the tensor into the requested shape.
import torch
tensor_shape_3_2 = torch.Tensor([[1, 2],
[3, 4],
[5, 6]]) # shape = (3, 2)
tensor_shape_2_3 = tensor_shape_3_2.view(2, 3) # shape = (2, 3)
# tensor([[1., 2., 3.],
# [4., 5., 6.]])
tensor_shape_1_6_1 = tensor_shape_3_2.view(1, 6, 1)
# tensor([[[1.],
# [2.],
# [3.],
# [4.],
# [5.],
# [6.]]])
tensor_shape_6 = tensor_shape_3_2.view(6) # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])
If you know the sizes of all dimensions except one, you can simply put -1
.
For example, with the code provided above, if you do tensor_shape_3_2.view(-1, 3)
, torch will automatically guess that
the -1
should be replaced with 2
.
import torch
tensor_shape_3_2 = torch.Tensor([[1, 2],
[3, 4],
[5, 6]]) # shape = (3, 2)
tensor_shape_2_3 = tensor_shape_3_2.view(-1, 3) # shape = (-1, 3) = (2, 3) here
# tensor([[1., 2., 3.],
# [4., 5., 6.]])
tensor_shape_6 = tensor_shape_3_2.view(-1) # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])
Note that view(...)
produces a shallow copy of the tensor (reshaped). If you change any of the elements in the view, you will also update the value in the original tensor.
import torch
tensor_shape_3_2 = torch.Tensor([[1, 2],
[3, 4],
[5, 6]]) # shape = (3, 2)
# tensor([[1., 2.],
# [3., 4.],
# [5., 6.]])
tensor_shape_2_3 = tensor_shape_3_2.view(2, 3) # shape = (2, 3)
# tensor([[1., 2., 3.],
# [4., 5., 6.]])
tensor_shape_2_3[0, 0] = 9
print(tensor_shape_2_3)
# tensor([[9., 2., 3.],
# [4., 5., 6.]])
print(tensor_shape_3_2)
# tensor([[9., 2.],
# [3., 4.],
# [5., 6.]])