Learn how to classify CIFAR100 images using ResNets, Regularization, and Data Augmentation in PyTorch. This tutorial covers techniques like data normalization, residual connections, batch normalization, and more. Plus, learn how to use a GPU for faster training.
Classifying CIFAR100 images using ResNets, Regularization and Data Augmentation in PyTorch
In this notebook, we'll use the following techniques to train a state-of-the-art model in classifying images from the CIFAR100 dataset:
- Data normalization
- Data augmentation
- Residual connections
- Batch normalization
- Learning rate scheduling
- Weight Decay
- Gradient clipping
- Adam optimizer
Using a GPU for faster training
You can use a Graphics Processing Unit (GPU) to train your models faster if your execution platform is connected to a GPU manufactured by NVIDIA. Follow these instructions to use a GPU on the platform of your choice:
- Google Colab: Use the menu option "Runtime > Change Runtime Type" and select "GPU" from the "Hardware Accelerator" dropdown.
- Kaggle: In the "Settings" section of the sidebar, select "GPU" from the "Accelerator" dropdown. Use the button on the top-right to open the sidebar.
- Binder: Notebooks running on Binder cannot use a GPU, as the machines powering Binder aren't connected to any GPUs.
- Linux: If your laptop/desktop has an NVIDIA GPU (graphics card), make sure you have installed the NVIDIA CUDA drivers.
- Windows: If your laptop/desktop has an NVIDIA GPU (graphics card), make sure you have installed the NVIDIA CUDA drivers.
- macOS: macOS is not compatible with NVIDIA GPUs
If you do not have access to a GPU or aren't sure what it is, don't worry, you can execute all the code in this tutorial just fine without a GPU.
Let's begin by importing the required libraries.
import os import torch import torchvision import tarfile import torch.nn as nn import numpy as np import torch.nn.functional as F from torchvision.datasets.utils import download_url from torchvision.datasets import ImageFolder, CIFAR100 from torch.utils.data import DataLoader import torchvision.transforms as tt from torch.utils.data import random_split from torchvision.utils import make_grid import matplotlib import matplotlib.pyplot as plt %matplotlib inline matplotlib.rcParams['figure.facecolor'] = '#ffffff'