Introduction to Deep Learning with PyTorch
Chapter 5: Training a Linear Model with PyTorch
Using Linear operator instead of torch.nn.Parameter
Remember the way we defined our model ModelNumberQuestions
?
import torch
class ModelNumberQuestions(torch.nn.Module):
def __init__(self):
super().__init__()
initial_theta_0 = torch.Tensor([1])
initial_theta_1 = torch.Tensor([2])
self.theta_0 = torch.nn.Parameter(initial_theta_0)
self.theta_1 = torch.nn.Parameter(initial_theta_1)
def forward(self, tensor_number_tasks):
return self.theta_1 * tensor_number_tasks + self.theta_0
You can see that the model above performs a linear operation:
f_\theta(n_T) = \theta_1 n_T + \theta_0
PyTorch already provides an implementation of that operation (and the name is quite straightforward): torch.nn.Linear
Such linear operation takes two necessary arguments:
in_features
: Number of features as input to the linear operation. In our case, the input has only one feature: n_T. So we setin_features=1
out_features
: Number of features as outputted by the linear operation. In our case, the output has only one feature: \widehat{n_Q}. So we setout_features=1
In the end, our class ModelNumberQuestions
becomes way simpler!
import torch
class ModelNumberQuestions(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=1,
out_features=1)
# self.linear is automatically initialised
def forward(self, tensor_number_tasks):
return self.linear(tensor_number_tasks)
You can see all the parameters of ModelNumberQuestions
by printing the named_parameters()
:
net = ModelNumberQuestions()
print(list(net.named_parameters()))
You should get something like this:
[('linear.weight', Parameter containing:
tensor([[0.1731]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([0.7059], requires_grad=True))]
You can notice that we still have two parameters:
linear.weight
corresponding to our formertheta_1
linear.bias
corresponding to our formertheta_0
Exercise:
Try your implementation of train_parameters_linear_regression
with our new implementation of ModelNumberQuestions
.
The results should remain unchanged! :D