Introduction to Deep Learning with PyTorch
Chapter 8: Building and Training an AutoEncoder
Refactoring our Auto-Encoder
Our implementation of MNISTAutoEncoder
is a bit too long.
Let’s try to make a more flexible implementation: MNISTAutoEncoderRefactored
.
To do that, we will separate our model implementation into 3 different modules:
MNISTEncoder
which compresses an image into a latent encodingMNISTDecoder
which tries to reconstruct the original image from that latent encodingMNISTAutoEncoderRefactored
which relies onMNISTEncoder
andMNISTDecoder
to encode and reconstruct an image
Encoder
The implementation of MNISTEncoder
just performs the first operations of our MNISTAutoEncoder
:
class MNISTEncoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_encoder_1 = torch.nn.Linear(in_features=1 * 28 * 28, out_features=64)
self.linear_encoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_encoder_final = torch.nn.Linear(in_features=64, out_features=feature_space_dimensionality)
def forward(self, tensor_images):
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
x = self.linear_encoder_final(x)
# Note that there is no activation function here
return x
Decoder
The implementation of MNISTDecoder
just performs the last operations of our MNISTAutoEncoder
:
class MNISTDecoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_decoder_1 = torch.nn.Linear(in_features=feature_space_dimensionality, out_features=64)
self.linear_decoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_decoder_final = torch.nn.Linear(in_features=64, out_features=1 * 28 * 28)
def forward(self, tensor_encoding):
# Decoder - Layer 1
x = self.linear_decoder_1(tensor_encoding)
x = torch.relu(x)
# Decoder - Layer 2
x = self.linear_decoder_2(x)
x = torch.relu(x)
# Decoder - Final Layer
reconstruction = self.linear_decoder_final(x)
# Putting reconstruction on the right shape (as original images)
reconstruction = reconstruction.view(-1, 1, 28, 28)
return reconstruction
MNISTAutoEncoderRefactored
__init__
We only need to declare our MNISTEncoder
and MNISTDecoder
modules in the __init__
:
def __init__(self, feature_space_dimensionality):
super().__init__()
self.encoder = MNISTEncoder(feature_space_dimensionality)
self.decoder = MNISTDecoder(feature_space_dimensionality)
forward
and get_encoding
Then our implementations of forward
and get_encoding
only need to use self.encoder
and self.decoder
:
def forward(self, tensor_images):
encoding = self.encoder(tensor_images)
reconstruction = self.decoder(encoding)
return reconstruction
def get_encoding(self, tensor_images):
encoding = self.encoder(tensor_images)
return encoding
Way shorter and cleaner, isn’t it? ^^