Introduction to Deep Learning with PyTorch
Chapter 8: Building and Training an AutoEncoder
Training our Auto-Encoder
The training of our Auto-Encoder is very similar to the training of our previous classifier.
The only difference is the computation of the loss function.
Before, we were comparing the predictions of the model with the expected labels.
Here, with our Auto-Encoder, we do not consider any label. Instead, we intend to provide accurate reconstructions.
So the computation of the loss function is done as follows:
reconstruction = mnist_autoencoder(batch_image)
value_loss = loss.forward(input=reconstruction,
target=batch_image)
And the full training function becomes:
def train_autoencoder_batches(mnist_autoencoder, loss, optimiser, list_batches_images, number_training_steps):
for _ in range(number_training_steps):
running_loss = 0.0
for batch_image in list_batches_images:
optimiser.zero_grad()
# Compute Loss
reconstruction = mnist_autoencoder(batch_image)
value_loss = loss.forward(input=reconstruction,
target=batch_image)
value_loss.backward()
optimiser.step()
running_loss += value_loss.item()
running_loss = running_loss / len(list_batches_images)
print("running loss:", running_loss)