Introduction to Deep Learning with PyTorch
Chapter 3: PyTorch Tensors
Extracting sub-tensors
All PyTorch tensors support slicing.
Let’s first define a tensor t
to illustrate this:
import torch
t = torch.randn(size=(3, 4))
Here is the value of t
:
tensor([[-1.2328, -0.2615, -0.6309, 1.0880],
[-1.0982, 0.1157, 0.0263, 0.7285],
[-0.5299, -0.7179, -0.9029, 0.8168]])
Extracting a single line or column
Suppose we just want to extract the second line of the tensor, then we can simply do:
index_row = 1 # remember that indexes start at 0!
row = t[index_row, :]
Here is the value of row
:
tensor([-1.0982, 0.1157, 0.0263, 0.7285])
Similarly, if we want to get the 3rd column:
Here is the value of `col`:
index_col = 2 # remember that indexes start at 0!
col = t[:, index_col]
tensor([-0.6309, 0.0263, -0.9029])
Extracting several lines and/or columns:
For example, if we want to get the 2nd and 3rd column:
col_lower = 1 # remember that indexes start at 0!
col_upper = 3
block = t[:, col_lower:col_upper] # notice that the slicing does never include col_upper
block
equals:
tensor([[-0.2615, -0.6309],
[ 0.1157, 0.0263],
[-0.7179, -0.9029]])
And if we want the same tensor as before, but only with the 2 first lines:
producing:
col_lower = 1 # remember that indexes start at 0!
col_upper = 3
row_lower = 0
row_upper = 2
block = t[row_lower:row_upper, col_lower:col_upper] # notice that slicing never includes col_upper
tensor([[-0.2615, -0.6309],
[ 0.1157, 0.0263]])