Introduction to Deep Learning with PyTorch
Chapter 7: Building and Training a simple Classification Model
Divide your training dataset in batches!
Dividing our Dataset in Batches
Let’s create a function that divides our dataset (containing 60,000 images) into 1,875 batches of 32 images.
Let’s write a function divide_in_batches_32
:
def divide_in_batches_32(tensor_dataset):
"""
Divides tensor_dataset into small batches of size 32.
We assume that the number of samples in tensor_dataset (tensor_dataset.size()[0]) is a multiple of 32
Args:
tensor_dataset (torch.Tensor): Tensor containing full dataset of samples
Returns:
List[torch.Tensor] where each torch.Tensor is of size (32, ...)
"""
Try to do it yourself!
To test your function, you can get the list of batches of images, and the list of batches of labels:
list_batches_images = divide_in_batches_32(dataset_training_images)
list_batches_labels = divide_in_batches_32(dataset_training_labels)
Here is one possible way to implement it:
def divide_in_batches_32(tensor_dataset):
"""
Divides tensor_dataset into small batches of size 32.
We assume that the number of samples in tensor_dataset (tensor_dataset.size()[0]) is a multiple of 32
Args:
tensor_dataset (torch.Tensor): Tensor containing full dataset of samples
Returns:
List[torch.Tensor] where each torch.Tensor is of size (32, ...)
"""
number_samples = tensor_dataset.size()[0]
step = 32
list_batches_dataset = []
for index in range(0, number_samples, step):
new_batch = tensor_dataset[index:index+step]
list_batches_dataset.append(new_batch)
return list_batches_dataset
Training with Batches
Before we were performing one gradient step per training step.
Now for each training step, we iterate through all our dataset small batches, and perform a gradient step per batch!
def train_classifier_batches(mnist_classifier, loss, optimiser, list_batches_images, list_batch_labels, number_training_steps):
for _ in range(number_training_steps):
running_loss = 0.0
for batch_image, batch_label in zip(list_batches_images, list_batch_labels):
optimiser.zero_grad()
# Compute Loss
estimator_predictions = mnist_classifier(batch_image)
value_loss = loss.forward(input=estimator_predictions,
target=batch_label)
value_loss.backward()
optimiser.step()
running_loss += value_loss.item() # just to record the current mean squared error
running_loss = running_loss / len(list_batches_images)
print("running loss:", running_loss) # Printing the mean squared error for the full dataset.