Introduction to Deep Learning with PyTorch
Chapter 7: Building and Training a simple Classification Model
Predict Labels of Images
Predicting Label of an Image
Now that we have trained our MNISTClassifier
, we would like to use it to predict the label of an image.
Let’s take an image in our training dataset (the first image in our dataset with index_image_dataset=0
):
index_image_dataset = 0
image = dataset_training_images[index_image_dataset].view(1, 1, 28, 28)
label = dataset_training_labels[index_image_dataset]
and let’s output the result returned by our neural network:
prediction = mnist_classifier(image)
print("prediction:", prediction)
print("label:", label.item())
prediction: tensor([[ 0.4664, -11.1482, 1.9592, -2.9771, 3.6853, 3.2510, 10.4968, -4.8413, -2.6138, -2.6875]])
label: 6
If we have a look at the argmax of the prediction
tensor, we get the expected label!
argmax_prediction = torch.argmax(prediction).item()
print("argmax_prediction:", argmax_prediction)
Avoid Computing Gradients if possible!
As we saw in the previous chapters, PyTorch records all operations performed on Tensors to compute gradients.
But in the previous section, we did not need to compute any gradient.
If we want to tell PyTorch that we would like to disable gradient calculation, we can use the torch.no_grad()
context manager.
When using torch.no_grad()
, the code provided above becomes:
with torch.no_grad():
index_image_dataset = 0
image = dataset_training_images[index_image_dataset].view(1, 1, 28, 28)
label = dataset_training_labels[index_image_dataset]
prediction = mnist_classifier(image)
print("prediction:", prediction)
print("label:", label.item())
argmax_prediction = torch.argmax(prediction).item()
print("argmax_prediction:", argmax_prediction)