Chapter 7: Building and Training a simple Classification Model

Predict Labels of Images

face Luca Grillotti

Predicting the label of an Image

Now that we have trained our MNISTClassifier, we would like to use it to predict the label of an image.

Let’s take an image in our training dataset (the first image in our dataset with index_image_dataset=0):

index_image_dataset = 0

image = dataset_training_images[index_image_dataset].view(1, 1, 28, 28)
label = dataset_training_labels[index_image_dataset]

and let’s output the result returned by our neural network:

prediction = mnist_classifier(image)
print("prediction:", prediction)
print("label:", label.item())
prediction: tensor([[  0.4664, -11.1482,   1.9592,  -2.9771,   3.6853,   3.2510,  10.4968, -4.8413,  -2.6138,  -2.6875]])
label: 6
(the dataset is randomly initialised, so my label and prediction may be different from yours)

If we have a look at the argmax of the prediction tensor, we get the expected label!

argmax_prediction = torch.argmax(prediction).item()
print("argmax_prediction:", argmax_prediction)

Avoid Computing Gradients if possible!

As we saw in the previous chapters, PyTorch records all operations performed on Tensors to compute gradients.

But in the previous section, we did not need to compute any gradient.

If we want to tell PyTorch that we would like to disable gradient calculation, we can use the torch.no_grad() context manager.

When using torch.no_grad(), the code provided above becomes:

with torch.no_grad():
    index_image_dataset = 0

    image = dataset_training_images[index_image_dataset].view(1, 1, 28, 28)
    label = dataset_training_labels[index_image_dataset]
    prediction = mnist_classifier(image)
    print("prediction:", prediction)
    print("label:", label.item())
    argmax_prediction = torch.argmax(prediction).item()
    print("argmax_prediction:", argmax_prediction)