This is an archived version of the course. Please find the latest version of the course on the main webpage.

Chapter 3: PyTorch Tensors

Manipulating Tensors

face Luca Grillotti

Reshaping tensors with view(...)

One of the functions you will be using the most in torch is the view(*shape) method. That method reshapes the tensor into the requested shape.

import torch

tensor_shape_3_2 = torch.Tensor([[1, 2],
                                 [3, 4],
                                 [5, 6]])  # shape = (3, 2)

tensor_shape_2_3 = tensor_shape_3_2.view(2, 3)  # shape = (2, 3)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

tensor_shape_1_6_1 = tensor_shape_3_2.view(1, 6, 1)
# tensor([[[1.],
#          [2.],
#          [3.],
#          [4.],
#          [5.],
#          [6.]]])

tensor_shape_6 = tensor_shape_3_2.view(6) # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])

If you know the sizes of all dimensions except one, you can simply put -1.

For example, with the code provided above, if you do tensor_shape_3_2.view(-1, 3), torch will automatically guess that the -1 should be replaced with 2.

import torch

tensor_shape_3_2 = torch.Tensor([[1, 2],
                                 [3, 4],
                                 [5, 6]])  # shape = (3, 2)

tensor_shape_2_3 = tensor_shape_3_2.view(-1, 3)  # shape = (-1, 3) = (2, 3) here
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

tensor_shape_6 = tensor_shape_3_2.view(-1)  # shape = (6,)
# tensor([1., 2., 3., 4., 5., 6.])

Concatenating tensors with torch.cat(...)

torch.cat concatenates tensors along the specified dimension.

Examples:

Concatenating tensors along dimension 0

import torch
t1 = torch.randn(size=(1, 4))
t2 = torch.randn(size=(2, 4))
concatenation = torch.cat(tensors=(t1, t2), dim=0)

print(concatenation)
print(concatenation.size())

produces this kind of output:

tensor([[-1.2413,  0.1362,  0.9370,  2.1812],
        [ 0.5601,  0.0252,  0.4164, -0.6447],
        [-0.4758, -0.2737, -0.0152,  1.5531]])
torch.Size([3, 4])

Note: the code above runs as all tensors have the same size on dimensions other than 0. If t1 was of size (1, 3) instead of (1, 4), the code would not run.

Concatenating tensors along dimension 1

import torch
t1 = torch.randn(size=(3, 1))
t2 = torch.randn(size=(3, 5))
concatenation = torch.cat(tensors=(t1, t2), dim=1)

print(concatenation)
print(concatenation.size())

gives an output similar to this:

tensor([[-0.1497,  0.0853, -0.6608, -1.1509,  0.3870,  0.2287],
        [ 0.3432,  0.6032,  0.0454, -0.3627, -0.6101,  1.1735],
        [ 0.3677, -1.5225, -0.0834,  0.6458,  0.9340,  0.0303]])
torch.Size([3, 6])

Extracting sub-tensors

All PyTorch tensors support slicing:

Let’s first define a tensor t to illustrate that:

import torch
t = torch.randn(size=(3, 4))
Here is the value of t:
tensor([[-1.2328, -0.2615, -0.6309,  1.0880],
        [-1.0982,  0.1157,  0.0263,  0.7285],
        [-0.5299, -0.7179, -0.9029,  0.8168]])

Extracting single line or column

Suppose we just want to extract the second line of the tensor, then we can simply do:

index_row = 1  # remember that indexes start at 0!
row = t[index_row, :]
Here is the value of row:
tensor([-1.0982,  0.1157,  0.0263,  0.7285])

Similarly, if we want to get the 3rd column:

index_col = 2  # remember that indexes start at 0!
col = t[:, index_col]
Here is the value of col:
tensor([-0.6309,  0.0263, -0.9029])

Extracting several lines and/or columns:

For example, if we want to get the 2nd and 3rd column:

col_lower = 1  # remember that indexes start at 0!
col_upper = 3 
block = t[:, col_lower:col_upper]  # notice that the slicing does never include col_upper
Then block equals:
tensor([[-0.2615, -0.6309],
        [ 0.1157,  0.0263],
        [-0.7179, -0.9029]])

And if we want the same tensor as before, but only with the 2 first lines:

col_lower = 1  # remember that indexes start at 0!
col_upper = 3 
row_lower = 0
row_upper = 2
block = t[row_lower:row_upper, col_lower:col_upper]  # notice that the slicing does never include col_upper
producing:
tensor([[-0.2615, -0.6309],
        [ 0.1157,  0.0263]])