Introduction to Deep Learning with PyTorch
Chapter 7: Building and Training a simple Classification Model
Divide your training dataset in batches!
Dividing our Dataset in Batches
Let’s create a function that divides our dataset (containing 60,000 images) into 1,875 batches of 32 images.
Let’s write a function divide_in_batches_32
def divide_in_batches_32(tensor_dataset):
Divides tensor_dataset into small batches of size 32.
We assume that the number of samples in tensor_dataset (tensor_dataset.size()[0]) is a multiple of 32
tensor_dataset (torch.Tensor): Tensor containing full dataset of samples
List[torch.Tensor] where each torch.Tensor is of size (32, ...)
Try to do it yourself!
Here is one possible way to implement it:
def divide_in_batches_32(tensor_dataset):
Divides tensor_dataset into small batches of size 32.
We assume that the number of samples in tensor_dataset (tensor_dataset.size()[0]) is a multiple of 32
tensor_dataset (torch.Tensor): Tensor containing full dataset of samples
List[torch.Tensor] where each torch.Tensor is of size (32, ...)
number_samples = tensor_dataset.size()[0]
step = 32
list_batches_dataset = []
for index in range(0, number_samples, step):
new_batch = tensor_dataset[index:index+step]
return list_batches_dataset
And we can get the list of batches of images, and the list of batches of labels:
list_batches_images = divide_in_batches_32(dataset_training_images)
list_batches_labels = divide_in_batches_32(dataset_training_labels)
Training with Batches
Before we were performing one gradient step per training step.
Now for each training step, we iterate through all our dataset small batches, and perform a gradient step per batch!
def train_classifier_batches(mnist_classifier, loss, optimiser, list_batches_images, list_batch_labels, number_training_steps):
for _ in range(number_training_steps):
running_loss = 0.0
for batch_image, batch_label in zip(list_batches_images, list_batch_labels):
# Compute Loss
estimator_predictions = mnist_classifier(batch_image)
value_loss = loss.forward(input=estimator_predictions,
running_loss += value_loss.item() # just to record the current mean squared error
running_loss = running_loss / len(list_batches_images)
print("running loss:", running_loss) # Printing the mean squared error for the full dataset.