Learn practical skills, build real-world projects, and advance your career
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

matplotlib.rcParams['figure.facecolor'] = '#ffffff'
project_name = 'cifar100-cnn'
from torchvision.datasets.utils import download_url


# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./cifar100.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')
    
# Look into the data directory
data_dir = './data/cifar100'
print(os.listdir(data_dir))
super_classes = os.listdir(data_dir + "/train")
print(super_classes)

num_classes = 0
for images in super_classes:
  classes = os.listdir(data_dir + "/train/" + images)
  print(classes)
  num_classes += len(classes)

num_classes
Using downloaded and verified file: ./cifar100.tgz ['train', 'test'] ['fish', 'vehicles_2', 'aquatic_mammals', 'reptiles', 'trees', 'large_omnivores_and_herbivores', 'large_natural_outdoor_scenes', 'large_man-made_outdoor_things', 'medium_mammals', 'large_carnivores', 'small_mammals', 'household_furniture', 'insects', 'vehicles_1', 'food_containers', 'household_electrical_devices', 'fruit_and_vegetables', 'non-insect_invertebrates', 'people', 'flowers'] ['trout', 'flatfish', 'aquarium_fish', 'ray', 'shark'] ['lawn_mower', 'streetcar', 'rocket', 'tractor', 'tank'] ['dolphin', 'whale', 'seal', 'otter', 'beaver'] ['lizard', 'dinosaur', 'turtle', 'snake', 'crocodile'] ['maple_tree', 'pine_tree', 'willow_tree', 'oak_tree', 'palm_tree'] ['kangaroo', 'camel', 'cattle', 'chimpanzee', 'elephant'] ['mountain', 'sea', 'plain', 'cloud', 'forest'] ['road', 'bridge', 'skyscraper', 'castle', 'house'] ['skunk', 'fox', 'raccoon', 'porcupine', 'possum'] ['bear', 'lion', 'wolf', 'leopard', 'tiger'] ['rabbit', 'squirrel', 'mouse', 'shrew', 'hamster'] ['bed', 'couch', 'chair', 'wardrobe', 'table'] ['caterpillar', 'beetle', 'bee', 'butterfly', 'cockroach'] ['pickup_truck', 'train', 'bus', 'motorcycle', 'bicycle'] ['cup', 'bowl', 'can', 'plate', 'bottle'] ['telephone', 'lamp', 'clock', 'television', 'keyboard'] ['apple', 'pear', 'mushroom', 'sweet_pepper', 'orange'] ['spider', 'crab', 'lobster', 'worm', 'snail'] ['woman', 'baby', 'man', 'girl', 'boy'] ['poppy', 'tulip', 'sunflower', 'orchid', 'rose']
100
def get_mean_sd(loader):
  channel_sum, channel_squared_sum, num_batches = 0, 0, 0
  for image, _ in loader:
    channel_sum += torch.mean(image, dim=[0,2,3])
    channel_squared_sum += torch.mean(image**2, dim=[0,2,3])
    num_batches += 1
  
  mean = channel_sum / num_batches
  sd = (channel_squared_sum / num_batches - mean**2)**0.5

  return mean, sd
mean, sd = get_mean_sd(train_loader)
print(mean)
print(sd)
tensor([-0.0028, -0.0044, -0.0051]) tensor([0.9977, 0.9971, 0.9965])