Lab 8 - Augmentation
Let's keep working with images of Cats v Dogs.
The model is similar to the previous models that you have used, but there are now 4 convolutional layers with 32, 64, 128 and 128 convolutions respectively.
Also, this will train for 100 epochs in order to plot the graph of loss and accuracy.
A. Training without augmentation
You should be familiar with the code below from the previous lab. Run this code and follow the instructions in the lab sheet.
!wget --no-check-certificate \
https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
-O /tmp/cats_and_dogs_filtered.zip
import os
import zipfile
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator
local_zip = '/tmp/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp')
zip_ref.close()
base_dir = '/tmp/cats_and_dogs_filtered'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
# Directory with our training cat pictures
train_cats_dir = os.path.join(train_dir, 'cats')
# Directory with our training dog pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')
# Directory with our validation cat pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')
# Directory with our validation dog pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy',
optimizer=RMSprop(lr=1e-4),
metrics=['accuracy'])
# All images will be rescaled by 1./255
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(
train_dir, # This is the source directory for training images
target_size=(150, 150), # All images will be resized to 150x150
batch_size=20,
# Since we use binary_crossentropy loss, we need binary labels
class_mode='binary')
# Flow validation images in batches of 20 using test_datagen generator
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary')
history = model.fit(
train_generator,
steps_per_epoch=100, # 2000 images = batch_size * steps
epochs=100,
validation_data=validation_generator,
validation_steps=50, # 1000 images = batch_size * steps
verbose=2)
--2020-11-01 07:02:54-- https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.98.128, 74.125.195.128, 74.125.28.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.98.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68606236 (65M) [application/zip]
Saving to: ‘/tmp/cats_and_dogs_filtered.zip’
/tmp/cats_and_dogs_ 100%[===================>] 65.43M 95.6MB/s in 0.7s
2020-11-01 07:02:55 (95.6 MB/s) - ‘/tmp/cats_and_dogs_filtered.zip’ saved [68606236/68606236]
Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.
Epoch 1/100
100/100 - 9s - loss: 0.6887 - accuracy: 0.5390 - val_loss: 0.6661 - val_accuracy: 0.6140
Epoch 2/100
100/100 - 9s - loss: 0.6556 - accuracy: 0.6065 - val_loss: 0.6432 - val_accuracy: 0.6240
Epoch 3/100
100/100 - 9s - loss: 0.5994 - accuracy: 0.6670 - val_loss: 0.5902 - val_accuracy: 0.7010
Epoch 4/100
100/100 - 9s - loss: 0.5636 - accuracy: 0.7035 - val_loss: 0.5862 - val_accuracy: 0.6730
Epoch 5/100
100/100 - 9s - loss: 0.5387 - accuracy: 0.7190 - val_loss: 0.6407 - val_accuracy: 0.6620
Epoch 6/100
100/100 - 9s - loss: 0.5101 - accuracy: 0.7405 - val_loss: 0.5490 - val_accuracy: 0.7120
Epoch 7/100
100/100 - 9s - loss: 0.4768 - accuracy: 0.7755 - val_loss: 0.6093 - val_accuracy: 0.6870
Epoch 8/100
100/100 - 9s - loss: 0.4526 - accuracy: 0.7915 - val_loss: 0.5535 - val_accuracy: 0.7240
Epoch 9/100
100/100 - 9s - loss: 0.4315 - accuracy: 0.7945 - val_loss: 0.5369 - val_accuracy: 0.7230
Epoch 10/100
100/100 - 9s - loss: 0.4029 - accuracy: 0.8160 - val_loss: 0.5343 - val_accuracy: 0.7400
Epoch 11/100
100/100 - 9s - loss: 0.3760 - accuracy: 0.8370 - val_loss: 0.5351 - val_accuracy: 0.7390
Epoch 12/100
100/100 - 9s - loss: 0.3543 - accuracy: 0.8480 - val_loss: 0.5546 - val_accuracy: 0.7230
Epoch 13/100
100/100 - 9s - loss: 0.3306 - accuracy: 0.8570 - val_loss: 0.5372 - val_accuracy: 0.7440
Epoch 14/100
100/100 - 9s - loss: 0.2993 - accuracy: 0.8775 - val_loss: 0.5744 - val_accuracy: 0.7370
Epoch 15/100
100/100 - 9s - loss: 0.2783 - accuracy: 0.8830 - val_loss: 0.5523 - val_accuracy: 0.7590
Epoch 16/100
100/100 - 9s - loss: 0.2570 - accuracy: 0.8970 - val_loss: 0.5904 - val_accuracy: 0.7500
Epoch 17/100
100/100 - 9s - loss: 0.2317 - accuracy: 0.9110 - val_loss: 0.6289 - val_accuracy: 0.7270
Epoch 18/100
100/100 - 9s - loss: 0.2158 - accuracy: 0.9250 - val_loss: 0.5982 - val_accuracy: 0.7450
Epoch 19/100
100/100 - 9s - loss: 0.1924 - accuracy: 0.9295 - val_loss: 0.6375 - val_accuracy: 0.7290
Epoch 20/100
100/100 - 9s - loss: 0.1792 - accuracy: 0.9350 - val_loss: 0.6548 - val_accuracy: 0.7400
Epoch 21/100
100/100 - 9s - loss: 0.1549 - accuracy: 0.9505 - val_loss: 0.6729 - val_accuracy: 0.7390
Epoch 22/100
100/100 - 9s - loss: 0.1327 - accuracy: 0.9590 - val_loss: 0.6801 - val_accuracy: 0.7330
Epoch 23/100
100/100 - 9s - loss: 0.1189 - accuracy: 0.9620 - val_loss: 0.7183 - val_accuracy: 0.7270
Epoch 24/100
100/100 - 9s - loss: 0.1037 - accuracy: 0.9675 - val_loss: 0.7122 - val_accuracy: 0.7330
Epoch 25/100
100/100 - 9s - loss: 0.0887 - accuracy: 0.9755 - val_loss: 0.8348 - val_accuracy: 0.7260
Epoch 26/100
100/100 - 9s - loss: 0.0797 - accuracy: 0.9755 - val_loss: 0.7955 - val_accuracy: 0.7470
Epoch 27/100
100/100 - 9s - loss: 0.0629 - accuracy: 0.9825 - val_loss: 1.1187 - val_accuracy: 0.7000
Epoch 28/100
100/100 - 9s - loss: 0.0538 - accuracy: 0.9870 - val_loss: 1.0044 - val_accuracy: 0.7190
Epoch 29/100
100/100 - 9s - loss: 0.0439 - accuracy: 0.9885 - val_loss: 0.9298 - val_accuracy: 0.7220
Epoch 30/100
100/100 - 9s - loss: 0.0441 - accuracy: 0.9875 - val_loss: 0.8928 - val_accuracy: 0.7420
Epoch 31/100
100/100 - 9s - loss: 0.0330 - accuracy: 0.9915 - val_loss: 0.9735 - val_accuracy: 0.7450
Epoch 32/100
100/100 - 9s - loss: 0.0273 - accuracy: 0.9935 - val_loss: 1.0589 - val_accuracy: 0.7430
Epoch 33/100
100/100 - 9s - loss: 0.0281 - accuracy: 0.9930 - val_loss: 1.1209 - val_accuracy: 0.7310
Epoch 34/100
100/100 - 9s - loss: 0.0242 - accuracy: 0.9940 - val_loss: 1.0591 - val_accuracy: 0.7360
Epoch 35/100
100/100 - 9s - loss: 0.0178 - accuracy: 0.9960 - val_loss: 1.2023 - val_accuracy: 0.7300
Epoch 36/100
100/100 - 9s - loss: 0.0259 - accuracy: 0.9950 - val_loss: 1.2731 - val_accuracy: 0.7350
Epoch 37/100
100/100 - 9s - loss: 0.0178 - accuracy: 0.9955 - val_loss: 1.2055 - val_accuracy: 0.7290
Epoch 38/100
100/100 - 9s - loss: 0.0120 - accuracy: 0.9965 - val_loss: 1.2500 - val_accuracy: 0.7240
Epoch 39/100
100/100 - 9s - loss: 0.0127 - accuracy: 0.9970 - val_loss: 1.2449 - val_accuracy: 0.7350
Epoch 40/100
100/100 - 9s - loss: 0.0183 - accuracy: 0.9960 - val_loss: 1.2425 - val_accuracy: 0.7400
Epoch 41/100
100/100 - 9s - loss: 0.0082 - accuracy: 0.9970 - val_loss: 1.4533 - val_accuracy: 0.7180
Epoch 42/100
100/100 - 9s - loss: 0.0080 - accuracy: 0.9985 - val_loss: 1.6004 - val_accuracy: 0.7160
Epoch 43/100
100/100 - 9s - loss: 0.0120 - accuracy: 0.9965 - val_loss: 1.3962 - val_accuracy: 0.7390
Epoch 44/100
100/100 - 9s - loss: 0.0058 - accuracy: 0.9985 - val_loss: 1.5095 - val_accuracy: 0.7370
Epoch 45/100
100/100 - 9s - loss: 0.0083 - accuracy: 0.9970 - val_loss: 1.4252 - val_accuracy: 0.7360
Epoch 46/100
100/100 - 9s - loss: 0.0146 - accuracy: 0.9965 - val_loss: 1.5492 - val_accuracy: 0.7290
Epoch 47/100
100/100 - 9s - loss: 0.0059 - accuracy: 0.9985 - val_loss: 1.6130 - val_accuracy: 0.7330
Epoch 48/100
100/100 - 9s - loss: 0.0129 - accuracy: 0.9950 - val_loss: 1.5720 - val_accuracy: 0.7260
Epoch 49/100
100/100 - 9s - loss: 0.0022 - accuracy: 1.0000 - val_loss: 1.6314 - val_accuracy: 0.7350
Epoch 50/100
100/100 - 9s - loss: 0.0093 - accuracy: 0.9975 - val_loss: 1.6602 - val_accuracy: 0.7400
Epoch 51/100
100/100 - 9s - loss: 0.0038 - accuracy: 0.9990 - val_loss: 1.6663 - val_accuracy: 0.7320
Epoch 52/100
100/100 - 9s - loss: 0.0098 - accuracy: 0.9975 - val_loss: 1.7785 - val_accuracy: 0.7260
Epoch 53/100
100/100 - 9s - loss: 0.0049 - accuracy: 0.9980 - val_loss: 1.6516 - val_accuracy: 0.7360
Epoch 54/100
100/100 - 9s - loss: 0.0105 - accuracy: 0.9980 - val_loss: 1.7263 - val_accuracy: 0.7420
Epoch 55/100
100/100 - 9s - loss: 0.0025 - accuracy: 0.9995 - val_loss: 1.7574 - val_accuracy: 0.7470
Epoch 56/100
100/100 - 9s - loss: 0.0100 - accuracy: 0.9980 - val_loss: 1.7999 - val_accuracy: 0.7340
Epoch 57/100
100/100 - 9s - loss: 0.0032 - accuracy: 0.9990 - val_loss: 1.7603 - val_accuracy: 0.7370
Epoch 58/100
100/100 - 9s - loss: 0.0091 - accuracy: 0.9965 - val_loss: 1.8534 - val_accuracy: 0.7350
Epoch 59/100
100/100 - 9s - loss: 0.0049 - accuracy: 0.9985 - val_loss: 2.3367 - val_accuracy: 0.7140
Epoch 60/100
100/100 - 9s - loss: 0.0076 - accuracy: 0.9985 - val_loss: 1.9854 - val_accuracy: 0.7280
Epoch 61/100
100/100 - 9s - loss: 5.1690e-04 - accuracy: 1.0000 - val_loss: 1.8946 - val_accuracy: 0.7380
Epoch 62/100
100/100 - 9s - loss: 0.0060 - accuracy: 0.9980 - val_loss: 1.8314 - val_accuracy: 0.7410
Epoch 63/100
100/100 - 9s - loss: 0.0098 - accuracy: 0.9975 - val_loss: 1.8731 - val_accuracy: 0.7430
Epoch 64/100
100/100 - 9s - loss: 0.0022 - accuracy: 0.9995 - val_loss: 2.0208 - val_accuracy: 0.7330
Epoch 65/100
100/100 - 9s - loss: 0.0069 - accuracy: 0.9980 - val_loss: 1.9617 - val_accuracy: 0.7380
Epoch 66/100
100/100 - 9s - loss: 0.0057 - accuracy: 0.9990 - val_loss: 2.1242 - val_accuracy: 0.7430
Epoch 67/100
100/100 - 9s - loss: 1.9257e-04 - accuracy: 1.0000 - val_loss: 2.0680 - val_accuracy: 0.7380
Epoch 68/100
100/100 - 9s - loss: 0.0042 - accuracy: 0.9985 - val_loss: 2.5419 - val_accuracy: 0.7040
Epoch 69/100
100/100 - 9s - loss: 0.0048 - accuracy: 0.9985 - val_loss: 2.0973 - val_accuracy: 0.7230
Epoch 70/100
100/100 - 9s - loss: 0.0060 - accuracy: 0.9990 - val_loss: 2.0317 - val_accuracy: 0.7340
Epoch 71/100
100/100 - 9s - loss: 0.0039 - accuracy: 0.9985 - val_loss: 2.0859 - val_accuracy: 0.7380
Epoch 72/100
100/100 - 9s - loss: 0.0012 - accuracy: 0.9990 - val_loss: 2.1255 - val_accuracy: 0.7290
Epoch 73/100
100/100 - 9s - loss: 0.0028 - accuracy: 0.9985 - val_loss: 2.1020 - val_accuracy: 0.7280
Epoch 74/100
100/100 - 9s - loss: 1.7059e-04 - accuracy: 1.0000 - val_loss: 2.2865 - val_accuracy: 0.7240
Epoch 75/100
100/100 - 9s - loss: 0.0031 - accuracy: 0.9985 - val_loss: 2.1766 - val_accuracy: 0.7330
Epoch 76/100
100/100 - 9s - loss: 0.0049 - accuracy: 0.9980 - val_loss: 2.4057 - val_accuracy: 0.7280
Epoch 77/100
100/100 - 9s - loss: 8.6864e-05 - accuracy: 1.0000 - val_loss: 2.2848 - val_accuracy: 0.7360
Epoch 78/100
100/100 - 9s - loss: 0.0093 - accuracy: 0.9980 - val_loss: 2.2741 - val_accuracy: 0.7410
Epoch 79/100
100/100 - 9s - loss: 0.0086 - accuracy: 0.9975 - val_loss: 2.3698 - val_accuracy: 0.7250
Epoch 80/100
100/100 - 9s - loss: 5.1133e-05 - accuracy: 1.0000 - val_loss: 2.2561 - val_accuracy: 0.7340
Epoch 81/100
100/100 - 9s - loss: 0.0048 - accuracy: 0.9980 - val_loss: 2.2609 - val_accuracy: 0.7420
Epoch 82/100
100/100 - 9s - loss: 0.0101 - accuracy: 0.9985 - val_loss: 2.2378 - val_accuracy: 0.7430
Epoch 83/100
100/100 - 9s - loss: 0.0032 - accuracy: 0.9990 - val_loss: 2.1067 - val_accuracy: 0.7420
Epoch 84/100
100/100 - 9s - loss: 0.0068 - accuracy: 0.9980 - val_loss: 2.1805 - val_accuracy: 0.7490
Epoch 85/100
100/100 - 9s - loss: 5.0398e-05 - accuracy: 1.0000 - val_loss: 2.2126 - val_accuracy: 0.7460
Epoch 86/100
100/100 - 9s - loss: 0.0023 - accuracy: 0.9990 - val_loss: 2.4285 - val_accuracy: 0.7470
Epoch 87/100
100/100 - 9s - loss: 0.0047 - accuracy: 0.9985 - val_loss: 2.2399 - val_accuracy: 0.7470
Epoch 88/100
100/100 - 8s - loss: 0.0085 - accuracy: 0.9980 - val_loss: 2.3226 - val_accuracy: 0.7360
Epoch 89/100
100/100 - 9s - loss: 0.0036 - accuracy: 0.9990 - val_loss: 2.3412 - val_accuracy: 0.7390
Epoch 90/100
100/100 - 9s - loss: 3.2244e-05 - accuracy: 1.0000 - val_loss: 2.3509 - val_accuracy: 0.7420
Epoch 91/100
100/100 - 9s - loss: 0.0022 - accuracy: 0.9995 - val_loss: 3.0298 - val_accuracy: 0.7010
Epoch 92/100
100/100 - 9s - loss: 0.0019 - accuracy: 0.9990 - val_loss: 2.3150 - val_accuracy: 0.7440
Epoch 93/100
100/100 - 9s - loss: 0.0044 - accuracy: 0.9990 - val_loss: 2.4564 - val_accuracy: 0.7380
Epoch 94/100
100/100 - 8s - loss: 0.0017 - accuracy: 0.9990 - val_loss: 2.5354 - val_accuracy: 0.7490
Epoch 95/100
100/100 - 8s - loss: 2.4129e-05 - accuracy: 1.0000 - val_loss: 2.8400 - val_accuracy: 0.7430
Epoch 96/100
100/100 - 8s - loss: 0.0041 - accuracy: 0.9985 - val_loss: 2.5172 - val_accuracy: 0.7430
Epoch 97/100
100/100 - 8s - loss: 7.8073e-04 - accuracy: 0.9995 - val_loss: 2.4677 - val_accuracy: 0.7380
Epoch 98/100
100/100 - 8s - loss: 5.6894e-05 - accuracy: 1.0000 - val_loss: 2.5393 - val_accuracy: 0.7350
Epoch 99/100
100/100 - 9s - loss: 0.0027 - accuracy: 0.9990 - val_loss: 2.6687 - val_accuracy: 0.7380
Epoch 100/100
100/100 - 8s - loss: 0.0057 - accuracy: 0.9980 - val_loss: 2.5604 - val_accuracy: 0.7390
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
A2: After about 15 epochs, validation reaches it's peak while the model gets close to a 100% accuracy on the training data. This happens because the model fit the data not general enough -> overfitting
Let's augment the images a bit to see if it can improve the model's performance on the validation set. If you think about it, most pictures of a cat are very similar -- the ears are at the top, then the eyes, then the mouth etc. Things like the distance between the eyes and ears will always be quite similar too.
What if we tweak with the images to change this up a bit -- rotate the image, squash it, etc. That's what image augementation is all about. Here are some options for augementation as discussed in the lectures.
- rotation_range is a value in degrees (0–180), a range within which to randomly rotate pictures.
- width_shift and height_shift are ranges (as a fraction of total width or height) within which to randomly translate pictures vertically or horizontally.
- shear_range is for randomly applying shearing transformations.
- zoom_range is for randomly zooming inside pictures.
- horizontal_flip is for randomly flipping half of the images horizontally. This is relevant when there are no assumptions of horizontal assymmetry (e.g. real-world pictures).
- fill_mode is the strategy used for filling in newly created pixels, which can appear after a rotation or a width/height shift.