Introduction to Deep Learning with PyTorch
Chapter 8: Building and Training an AutoEncoder
Refactoring our Auto-Encoder
Our implementation of MNISTAutoEncoder is a bit too long.
Let’s try to make a more flexible implementation: MNISTAutoEncoderRefactored.
To do that, we will separate our model implementation into 3 different modules:
MNISTEncoderwhich compresses an image into a latent encodingMNISTDecoderwhich tries to reconstruct the original image from that latent encodingMNISTAutoEncoderRefactoredwhich relies onMNISTEncoderandMNISTDecoderto encode and reconstruct an image
Encoder
The implementation of MNISTEncoder just performs the first operations of our MNISTAutoEncoder:
class MNISTEncoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_encoder_1 = torch.nn.Linear(in_features=1 * 28 * 28, out_features=64)
self.linear_encoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_encoder_final = torch.nn.Linear(in_features=64, out_features=feature_space_dimensionality)
def forward(self, tensor_images):
# Flattening images
x = tensor_images.view(-1, 1 * 28 * 28)
# Encoder --------------------
# Encoder - Layer 1
x = self.linear_encoder_1(x)
x = torch.relu(x)
# Encoder - Layer 2
x = self.linear_encoder_2(x)
x = torch.relu(x)
# Encoder - Final Layer
x = self.linear_encoder_final(x)
# Note that there is no activation function here
return x
Decoder
The implementation of MNISTDecoder just performs the last operations of our MNISTAutoEncoder:
class MNISTDecoder(torch.nn.Module):
def __init__(self, feature_space_dimensionality):
super().__init__()
self.linear_decoder_1 = torch.nn.Linear(in_features=feature_space_dimensionality, out_features=64)
self.linear_decoder_2 = torch.nn.Linear(in_features=64, out_features=64)
self.linear_decoder_final = torch.nn.Linear(in_features=64, out_features=1 * 28 * 28)
def forward(self, tensor_encoding):
# Decoder - Layer 1
x = self.linear_decoder_1(tensor_encoding)
x = torch.relu(x)
# Decoder - Layer 2
x = self.linear_decoder_2(x)
x = torch.relu(x)
# Decoder - Final Layer
reconstruction = self.linear_decoder_final(x)
# Putting reconstruction on the right shape (as original images)
reconstruction = reconstruction.view(-1, 1, 28, 28)
return reconstruction
MNISTAutoEncoderRefactored
__init__
We only need to declare our MNISTEncoder and MNISTDecoder modules in the __init__:
def __init__(self, feature_space_dimensionality):
super().__init__()
self.encoder = MNISTEncoder(feature_space_dimensionality)
self.decoder = MNISTDecoder(feature_space_dimensionality)
forward and get_encoding
Then our implementations of forward and get_encoding only need to use self.encoder and self.decoder:
def forward(self, tensor_images):
encoding = self.encoder(tensor_images)
reconstruction = self.decoder(encoding)
return reconstruction
def get_encoding(self, tensor_images):
encoding = self.encoder(tensor_images)
return encoding
Way shorter and cleaner, isn’t it? ^^