Introduction to Deep Learning with PyTorch
Chapter 8: Building and Training an AutoEncoder
Implementing Simple Auto-Encoder
Before Implementing anything…
There are a few questions we need to address before starting implementing our MNISTAutoEncoder
module.
What is the input to our model?
Our model should take as input a batch of images of size: torch.Size([N_batch, 1, 28, 28])
, where:
N_batch
represents the number of images in the batch.
What is the output of our model?
Output Shape
As explained before, for each image, we expect our model to output an accurate reconstruction of that image.
So, when we give as input to the model N_batch
images, we expect our model to return N_batch
images.
Thus, if the input is of shape torch.Size([N_batch, 1, 28, 28])
, the output should be of shape torch.Size([N_batch, 1, 28, 28])
.
Which loss function should we consider?
Ideally, if our model was perfect, given an image as input, our model should reconstruct the same identical image.
So, given an image I, we intend to minimise the distance between I and its reconstruction I' from the auto-encoder.
In PyTorch
, the loss torch.nn.MSELoss
calculates the mean squared error between two tensors (in this case: between two batches images).
This is the loss we will use to compare original images with their reconstructions.
In the end, by minimising the mean squared error loss, we intend to provide more accurate reconstructions.
And in PyTorch?
We will use the same structure as before with a slight difference: we would like to have a flexible number of dimensions for our encoding/feature space.
So we add a parameter feature_space_dimensionality
to our __init__
.
Also, in this example, we may be interested in the values of the encodings, so we add a method get_encoding
import torch
class MNISTAutoEncoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
...
def forward(self, tensor_images):
"""
Args:
tensor_images: tensor of shape (N_batch, 1, 28, 28)
"""
...
def get_encoding(self, tensor_images):
"""
Args:
tensor_images: tensor of shape (N_batch, 1, 28, 28)
"""
...
We now have all the tools to fill this module, and adapt it to our problem!
The __init__
method
As before, here we define all our layers.
We consider that both the encoder and the decoder have 3 Linear
operations (with hidden layers of size 64).
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_encoder_1 = torch.nn.Linear(in_features=1 * 28 * 28, out_features=64)
self.linear_encoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_encoder_final = torch.nn.Linear(in_features=64, out_features=feature_space_dimensionality)
self.linear_decoder_1 = torch.nn.Linear(in_features=feature_space_dimensionality, out_features=64)
self.linear_decoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_decoder_final = torch.nn.Linear(in_features=64, out_features=1 * 28 * 28)
The forward
method
There is one subtlety here.
We said that our model should output reconstructions of the original images.
As a consequence, the output should be of shape (N_batch, 1, 28, 28)
.
def forward(self, tensor_images):
"""
Args:
tensor_images: tensor of shape (N_batch, 1, 28, 28)
"""
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
encoding = self.linear_encoder_final(x)
# Note that there is no activation function here
# Decoder --------------------
# Decoder - Layer 1
x = self.linear_decoder_1(encoding)
x = torch.relu(x)
# Decoder - Layer 2
x = self.linear_decoder_2(x)
x = torch.relu(x)
# Decoder - Final Layer
reconstruction = self.linear_decoder_final(x)
# Putting reconstruction on the right shape (as original images)
reconstruction = reconstruction.view(-1, 1, 28, 28)
return reconstruction
get_encoding
get_encoding
simply corresponds to the encoder operations of the forward
method.
def get_encoding(self, tensor_images):
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
encoding = self.linear_encoder_final(x)
# Note that there is no activation function here
return encoding
In Summary:
class MNISTAutoEncoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_encoder_1 = torch.nn.Linear(in_features=1 * 28 * 28, out_features=64)
self.linear_encoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_encoder_final = torch.nn.Linear(in_features=64, out_features=feature_space_dimensionality)
self.linear_decoder_1 = torch.nn.Linear(in_features=feature_space_dimensionality, out_features=64)
self.linear_decoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_decoder_final = torch.nn.Linear(in_features=64, out_features=1 * 28 * 28)
def forward(self, tensor_images):
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
encoding = self.linear_encoder_final(x)
# Note that there is no activation function here
# Decoder --------------------
# Decoder - Layer 1
x = self.linear_decoder_1(encoding)
x = torch.relu(x)
# Decoder - Layer 2
x = self.linear_decoder_2(x)
x = torch.relu(x)
# Decoder - Final Layer
reconstruction = self.linear_decoder_final(x)
# Putting reconstruction on the right shape (as original images)
reconstruction = reconstruction.view(-1, 1, 28, 28)
return reconstruction
def get_encoding(self, tensor_images):
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
encoding = self.linear_encoder_final(x)
# Note that there is no activation function here
return encoding
Let’s test!
It is always important to test our model by providing random tensors as input. This way, we ensure that we did not do any mistake.
import torch
mnist_auto_encoder = MNISTAutoEncoder()
n_batch = 42
random_tensor = torch.randn(size=(n_batch, 1, 28, 28))
print(mnist_auto_encoder(random_tensor).size())
and we get:
torch.Size([42, 1, 28, 28])
so, the result is of size (n_batch, 1, 28, 28)
, as expected! \o/