Introduction to Deep Learning with PyTorch
Chapter 7: Building and Training a simple Classification Model
Getting Dataset
Loading Dataset
Before doing anything, we need to get the MNIST dataset, and the corresponding labels:
import torch
import torchvision
import torchvision.transforms as transforms
def load_training_dataset():
"""
Returns:
dataset_images (torch.Tensor): single tensor with all images in training dataset: of size (N_dataset, 1, 28, 28)
dataset_labels (torch.Tensor): single tensor with all labels for training dataset: of size (N_dataset,)
"""
# ------------------------------------------------------------------------
# Initialise the Data Loader:
# YOU DO NOT NEED TO UNDERSTAND THIS FUNCTION
# ------------------------------------------------------------------------
# Define transformation to apply to the images. In our case:
# - transforms.ToTensor() -> Convert a PIL Image or numpy.ndarray to tensor
# - transforms.Normalize((0.5), (0.5)) -> normalise the data to get a distribution of mean 0.5 and standard deviation 0.5
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))])
# Load training set
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create loader for manipulating the training set easily (more information about that part in the following section)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset), shuffle=True, num_workers=2)
# Printing different classes (in our case: the digits we are trying to recognise)
classes = [str(i) for i in range(10)]
print("Here are all the different classes of the problem (each image belongs to one of these categories): \n" + str(classes))
dataset_training_images, dataset_training_labels = next(iter(trainloader))
return dataset_training_images, dataset_training_labels
You don’t need to understand the code above. The only thing you need to know is that it creates a single tensor containing all our dataset with the labels.
There are cleaner ways to process the data (using DataLoader
for instance), but our goal here is to make the code as easy as possible.
The dataset_training_images
and dataset_training_labels
can be obtained by calling the function above:
dataset_training_images, dataset_training_labels = load_training_dataset()
Analysing the Data shape
dataset_training_images
Let’s have a look at the shape of our dataset_training_images
:
print(dataset_training_images.size())
torch.Size([60000, 1, 28, 28])
The 1st dimension corresponds to the number of elements in the entire batch: 60000.
The 2nd dimension corresponds to the number of data streams. Here, it corresponds to the number of color channels. As all images are black and white, there is only one data stream.
The 3rd dimension corresponds to the number of width pixels: 28.
The 4th dimension corresponds to the number of height pixels: 28.
dataset_training_labels
dataset_training_labels
contains the ground-truth label (0
, 1
, …, 9
) corresponding to each image in dataset_training_images
.
Hence, the size of dataset_training_labels
is:
torch.Size([60000])