Updated 4 years ago
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
matplotlib.rcParams['figure.facecolor'] = '#ffffff'
project_name = 'cifar100-cnn'
from torchvision.datasets.utils import download_url
# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar100.tgz"
download_url(dataset_url, '.')
# Extract from archive
with tarfile.open('./cifar100.tgz', 'r:gz') as tar:
tar.extractall(path='./data')
# Look into the data directory
data_dir = './data/cifar100'
print(os.listdir(data_dir))
super_classes = os.listdir(data_dir + "/train")
print(super_classes)
num_classes = 0
for images in super_classes:
classes = os.listdir(data_dir + "/train/" + images)
print(classes)
num_classes += len(classes)
num_classes
Using downloaded and verified file: ./cifar100.tgz
['train', 'test']
['fish', 'vehicles_2', 'aquatic_mammals', 'reptiles', 'trees', 'large_omnivores_and_herbivores', 'large_natural_outdoor_scenes', 'large_man-made_outdoor_things', 'medium_mammals', 'large_carnivores', 'small_mammals', 'household_furniture', 'insects', 'vehicles_1', 'food_containers', 'household_electrical_devices', 'fruit_and_vegetables', 'non-insect_invertebrates', 'people', 'flowers']
['trout', 'flatfish', 'aquarium_fish', 'ray', 'shark']
['lawn_mower', 'streetcar', 'rocket', 'tractor', 'tank']
['dolphin', 'whale', 'seal', 'otter', 'beaver']
['lizard', 'dinosaur', 'turtle', 'snake', 'crocodile']
['maple_tree', 'pine_tree', 'willow_tree', 'oak_tree', 'palm_tree']
['kangaroo', 'camel', 'cattle', 'chimpanzee', 'elephant']
['mountain', 'sea', 'plain', 'cloud', 'forest']
['road', 'bridge', 'skyscraper', 'castle', 'house']
['skunk', 'fox', 'raccoon', 'porcupine', 'possum']
['bear', 'lion', 'wolf', 'leopard', 'tiger']
['rabbit', 'squirrel', 'mouse', 'shrew', 'hamster']
['bed', 'couch', 'chair', 'wardrobe', 'table']
['caterpillar', 'beetle', 'bee', 'butterfly', 'cockroach']
['pickup_truck', 'train', 'bus', 'motorcycle', 'bicycle']
['cup', 'bowl', 'can', 'plate', 'bottle']
['telephone', 'lamp', 'clock', 'television', 'keyboard']
['apple', 'pear', 'mushroom', 'sweet_pepper', 'orange']
['spider', 'crab', 'lobster', 'worm', 'snail']
['woman', 'baby', 'man', 'girl', 'boy']
['poppy', 'tulip', 'sunflower', 'orchid', 'rose']
100
def get_mean_sd(loader):
channel_sum, channel_squared_sum, num_batches = 0, 0, 0
for image, _ in loader:
channel_sum += torch.mean(image, dim=[0,2,3])
channel_squared_sum += torch.mean(image**2, dim=[0,2,3])
num_batches += 1
mean = channel_sum / num_batches
sd = (channel_squared_sum / num_batches - mean**2)**0.5
return mean, sd
mean, sd = get_mean_sd(train_loader)
print(mean)
print(sd)
tensor([-0.0028, -0.0044, -0.0051])
tensor([0.9977, 0.9971, 0.9965])