Learn practical skills, build real-world projects, and advance your career

Image Classification using Logistic Regression in PyTorch

Part 3 of "PyTorch: Zero to GANs"

This post is the third in a series of tutorials on building deep learning models with PyTorch, an open source neural networks library. Check out the full series:

  1. PyTorch Basics: Tensors & Gradients
  2. Linear Regression & Gradient Descent
  3. Image Classfication using Logistic Regression
  4. Training Deep Neural Networks on a GPU
  5. Image Classification using Convolutional Neural Networks
  6. Data Augmentation, Regularization and ResNets
  7. Generating Images using Generative Adverserial Networks

In this tutorial, we'll use our existing knowledge of PyTorch and linear regression to solve a very different kind of problem: image classification. We'll use the famous MNIST Handwritten Digits Database as our training dataset. It consists of 28px by 28px grayscale images of handwritten digits (0 to 9), along with labels for each image indicating which digit it represents. Here are some sample images from the dataset:


Exploring the Data

We begin by importing torch and torchvision. torchvision contains some utilities for working with image data. It also contains helper classes to automatically download and import popular datasets like MNIST.

# Imports
import torch
import torchvision
from torchvision.datasets import MNIST
# Download training dataset
dataset = MNIST(root='data/', download=False)