Learn practical skills, build real-world projects, and advance your career

Lab 8 - Augmentation

Let's keep working with images of Cats v Dogs.

The model is similar to the previous models that you have used, but there are now 4 convolutional layers with 32, 64, 128 and 128 convolutions respectively.

Also, this will train for 100 epochs in order to plot the graph of loss and accuracy.

A. Training without augmentation

You should be familiar with the code below from the previous lab. Run this code and follow the instructions in the lab sheet.

!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
    -O /tmp/cats_and_dogs_filtered.zip
  
import os
import zipfile
import tensorflow as tf
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.preprocessing.image import ImageDataGenerator

local_zip = '/tmp/cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/tmp')
zip_ref.close()

base_dir = '/tmp/cats_and_dogs_filtered'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

# Directory with our training cat pictures
train_cats_dir = os.path.join(train_dir, 'cats')

# Directory with our training dog pictures
train_dogs_dir = os.path.join(train_dir, 'dogs')

# Directory with our validation cat pictures
validation_cats_dir = os.path.join(validation_dir, 'cats')

# Directory with our validation dog pictures
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(150, 150, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss='binary_crossentropy',
              optimizer=RMSprop(lr=1e-4),
              metrics=['accuracy'])

# All images will be rescaled by 1./255
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Flow training images in batches of 20 using train_datagen generator
train_generator = train_datagen.flow_from_directory(
        train_dir,  # This is the source directory for training images
        target_size=(150, 150),  # All images will be resized to 150x150
        batch_size=20,
        # Since we use binary_crossentropy loss, we need binary labels
        class_mode='binary')

# Flow validation images in batches of 20 using test_datagen generator
validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')

history = model.fit(
      train_generator,
      steps_per_epoch=100,  # 2000 images = batch_size * steps
      epochs=100,
      validation_data=validation_generator,
      validation_steps=50,  # 1000 images = batch_size * steps
      verbose=2)

--2020-11-01 07:02:54-- https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.98.128, 74.125.195.128, 74.125.28.128, ... Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.98.128|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 68606236 (65M) [application/zip] Saving to: ‘/tmp/cats_and_dogs_filtered.zip’ /tmp/cats_and_dogs_ 100%[===================>] 65.43M 95.6MB/s in 0.7s 2020-11-01 07:02:55 (95.6 MB/s) - ‘/tmp/cats_and_dogs_filtered.zip’ saved [68606236/68606236] Found 2000 images belonging to 2 classes. Found 1000 images belonging to 2 classes. Epoch 1/100 100/100 - 9s - loss: 0.6887 - accuracy: 0.5390 - val_loss: 0.6661 - val_accuracy: 0.6140 Epoch 2/100 100/100 - 9s - loss: 0.6556 - accuracy: 0.6065 - val_loss: 0.6432 - val_accuracy: 0.6240 Epoch 3/100 100/100 - 9s - loss: 0.5994 - accuracy: 0.6670 - val_loss: 0.5902 - val_accuracy: 0.7010 Epoch 4/100 100/100 - 9s - loss: 0.5636 - accuracy: 0.7035 - val_loss: 0.5862 - val_accuracy: 0.6730 Epoch 5/100 100/100 - 9s - loss: 0.5387 - accuracy: 0.7190 - val_loss: 0.6407 - val_accuracy: 0.6620 Epoch 6/100 100/100 - 9s - loss: 0.5101 - accuracy: 0.7405 - val_loss: 0.5490 - val_accuracy: 0.7120 Epoch 7/100 100/100 - 9s - loss: 0.4768 - accuracy: 0.7755 - val_loss: 0.6093 - val_accuracy: 0.6870 Epoch 8/100 100/100 - 9s - loss: 0.4526 - accuracy: 0.7915 - val_loss: 0.5535 - val_accuracy: 0.7240 Epoch 9/100 100/100 - 9s - loss: 0.4315 - accuracy: 0.7945 - val_loss: 0.5369 - val_accuracy: 0.7230 Epoch 10/100 100/100 - 9s - loss: 0.4029 - accuracy: 0.8160 - val_loss: 0.5343 - val_accuracy: 0.7400 Epoch 11/100 100/100 - 9s - loss: 0.3760 - accuracy: 0.8370 - val_loss: 0.5351 - val_accuracy: 0.7390 Epoch 12/100 100/100 - 9s - loss: 0.3543 - accuracy: 0.8480 - val_loss: 0.5546 - val_accuracy: 0.7230 Epoch 13/100 100/100 - 9s - loss: 0.3306 - accuracy: 0.8570 - val_loss: 0.5372 - val_accuracy: 0.7440 Epoch 14/100 100/100 - 9s - loss: 0.2993 - accuracy: 0.8775 - val_loss: 0.5744 - val_accuracy: 0.7370 Epoch 15/100 100/100 - 9s - loss: 0.2783 - accuracy: 0.8830 - val_loss: 0.5523 - val_accuracy: 0.7590 Epoch 16/100 100/100 - 9s - loss: 0.2570 - accuracy: 0.8970 - val_loss: 0.5904 - val_accuracy: 0.7500 Epoch 17/100 100/100 - 9s - loss: 0.2317 - accuracy: 0.9110 - val_loss: 0.6289 - val_accuracy: 0.7270 Epoch 18/100 100/100 - 9s - loss: 0.2158 - accuracy: 0.9250 - val_loss: 0.5982 - val_accuracy: 0.7450 Epoch 19/100 100/100 - 9s - loss: 0.1924 - accuracy: 0.9295 - val_loss: 0.6375 - val_accuracy: 0.7290 Epoch 20/100 100/100 - 9s - loss: 0.1792 - accuracy: 0.9350 - val_loss: 0.6548 - val_accuracy: 0.7400 Epoch 21/100 100/100 - 9s - loss: 0.1549 - accuracy: 0.9505 - val_loss: 0.6729 - val_accuracy: 0.7390 Epoch 22/100 100/100 - 9s - loss: 0.1327 - accuracy: 0.9590 - val_loss: 0.6801 - val_accuracy: 0.7330 Epoch 23/100 100/100 - 9s - loss: 0.1189 - accuracy: 0.9620 - val_loss: 0.7183 - val_accuracy: 0.7270 Epoch 24/100 100/100 - 9s - loss: 0.1037 - accuracy: 0.9675 - val_loss: 0.7122 - val_accuracy: 0.7330 Epoch 25/100 100/100 - 9s - loss: 0.0887 - accuracy: 0.9755 - val_loss: 0.8348 - val_accuracy: 0.7260 Epoch 26/100 100/100 - 9s - loss: 0.0797 - accuracy: 0.9755 - val_loss: 0.7955 - val_accuracy: 0.7470 Epoch 27/100 100/100 - 9s - loss: 0.0629 - accuracy: 0.9825 - val_loss: 1.1187 - val_accuracy: 0.7000 Epoch 28/100 100/100 - 9s - loss: 0.0538 - accuracy: 0.9870 - val_loss: 1.0044 - val_accuracy: 0.7190 Epoch 29/100 100/100 - 9s - loss: 0.0439 - accuracy: 0.9885 - val_loss: 0.9298 - val_accuracy: 0.7220 Epoch 30/100 100/100 - 9s - loss: 0.0441 - accuracy: 0.9875 - val_loss: 0.8928 - val_accuracy: 0.7420 Epoch 31/100 100/100 - 9s - loss: 0.0330 - accuracy: 0.9915 - val_loss: 0.9735 - val_accuracy: 0.7450 Epoch 32/100 100/100 - 9s - loss: 0.0273 - accuracy: 0.9935 - val_loss: 1.0589 - val_accuracy: 0.7430 Epoch 33/100 100/100 - 9s - loss: 0.0281 - accuracy: 0.9930 - val_loss: 1.1209 - val_accuracy: 0.7310 Epoch 34/100 100/100 - 9s - loss: 0.0242 - accuracy: 0.9940 - val_loss: 1.0591 - val_accuracy: 0.7360 Epoch 35/100 100/100 - 9s - loss: 0.0178 - accuracy: 0.9960 - val_loss: 1.2023 - val_accuracy: 0.7300 Epoch 36/100 100/100 - 9s - loss: 0.0259 - accuracy: 0.9950 - val_loss: 1.2731 - val_accuracy: 0.7350 Epoch 37/100 100/100 - 9s - loss: 0.0178 - accuracy: 0.9955 - val_loss: 1.2055 - val_accuracy: 0.7290 Epoch 38/100 100/100 - 9s - loss: 0.0120 - accuracy: 0.9965 - val_loss: 1.2500 - val_accuracy: 0.7240 Epoch 39/100 100/100 - 9s - loss: 0.0127 - accuracy: 0.9970 - val_loss: 1.2449 - val_accuracy: 0.7350 Epoch 40/100 100/100 - 9s - loss: 0.0183 - accuracy: 0.9960 - val_loss: 1.2425 - val_accuracy: 0.7400 Epoch 41/100 100/100 - 9s - loss: 0.0082 - accuracy: 0.9970 - val_loss: 1.4533 - val_accuracy: 0.7180 Epoch 42/100 100/100 - 9s - loss: 0.0080 - accuracy: 0.9985 - val_loss: 1.6004 - val_accuracy: 0.7160 Epoch 43/100 100/100 - 9s - loss: 0.0120 - accuracy: 0.9965 - val_loss: 1.3962 - val_accuracy: 0.7390 Epoch 44/100 100/100 - 9s - loss: 0.0058 - accuracy: 0.9985 - val_loss: 1.5095 - val_accuracy: 0.7370 Epoch 45/100 100/100 - 9s - loss: 0.0083 - accuracy: 0.9970 - val_loss: 1.4252 - val_accuracy: 0.7360 Epoch 46/100 100/100 - 9s - loss: 0.0146 - accuracy: 0.9965 - val_loss: 1.5492 - val_accuracy: 0.7290 Epoch 47/100 100/100 - 9s - loss: 0.0059 - accuracy: 0.9985 - val_loss: 1.6130 - val_accuracy: 0.7330 Epoch 48/100 100/100 - 9s - loss: 0.0129 - accuracy: 0.9950 - val_loss: 1.5720 - val_accuracy: 0.7260 Epoch 49/100 100/100 - 9s - loss: 0.0022 - accuracy: 1.0000 - val_loss: 1.6314 - val_accuracy: 0.7350 Epoch 50/100 100/100 - 9s - loss: 0.0093 - accuracy: 0.9975 - val_loss: 1.6602 - val_accuracy: 0.7400 Epoch 51/100 100/100 - 9s - loss: 0.0038 - accuracy: 0.9990 - val_loss: 1.6663 - val_accuracy: 0.7320 Epoch 52/100 100/100 - 9s - loss: 0.0098 - accuracy: 0.9975 - val_loss: 1.7785 - val_accuracy: 0.7260 Epoch 53/100 100/100 - 9s - loss: 0.0049 - accuracy: 0.9980 - val_loss: 1.6516 - val_accuracy: 0.7360 Epoch 54/100 100/100 - 9s - loss: 0.0105 - accuracy: 0.9980 - val_loss: 1.7263 - val_accuracy: 0.7420 Epoch 55/100 100/100 - 9s - loss: 0.0025 - accuracy: 0.9995 - val_loss: 1.7574 - val_accuracy: 0.7470 Epoch 56/100 100/100 - 9s - loss: 0.0100 - accuracy: 0.9980 - val_loss: 1.7999 - val_accuracy: 0.7340 Epoch 57/100 100/100 - 9s - loss: 0.0032 - accuracy: 0.9990 - val_loss: 1.7603 - val_accuracy: 0.7370 Epoch 58/100 100/100 - 9s - loss: 0.0091 - accuracy: 0.9965 - val_loss: 1.8534 - val_accuracy: 0.7350 Epoch 59/100 100/100 - 9s - loss: 0.0049 - accuracy: 0.9985 - val_loss: 2.3367 - val_accuracy: 0.7140 Epoch 60/100 100/100 - 9s - loss: 0.0076 - accuracy: 0.9985 - val_loss: 1.9854 - val_accuracy: 0.7280 Epoch 61/100 100/100 - 9s - loss: 5.1690e-04 - accuracy: 1.0000 - val_loss: 1.8946 - val_accuracy: 0.7380 Epoch 62/100 100/100 - 9s - loss: 0.0060 - accuracy: 0.9980 - val_loss: 1.8314 - val_accuracy: 0.7410 Epoch 63/100 100/100 - 9s - loss: 0.0098 - accuracy: 0.9975 - val_loss: 1.8731 - val_accuracy: 0.7430 Epoch 64/100 100/100 - 9s - loss: 0.0022 - accuracy: 0.9995 - val_loss: 2.0208 - val_accuracy: 0.7330 Epoch 65/100 100/100 - 9s - loss: 0.0069 - accuracy: 0.9980 - val_loss: 1.9617 - val_accuracy: 0.7380 Epoch 66/100 100/100 - 9s - loss: 0.0057 - accuracy: 0.9990 - val_loss: 2.1242 - val_accuracy: 0.7430 Epoch 67/100 100/100 - 9s - loss: 1.9257e-04 - accuracy: 1.0000 - val_loss: 2.0680 - val_accuracy: 0.7380 Epoch 68/100 100/100 - 9s - loss: 0.0042 - accuracy: 0.9985 - val_loss: 2.5419 - val_accuracy: 0.7040 Epoch 69/100 100/100 - 9s - loss: 0.0048 - accuracy: 0.9985 - val_loss: 2.0973 - val_accuracy: 0.7230 Epoch 70/100 100/100 - 9s - loss: 0.0060 - accuracy: 0.9990 - val_loss: 2.0317 - val_accuracy: 0.7340 Epoch 71/100 100/100 - 9s - loss: 0.0039 - accuracy: 0.9985 - val_loss: 2.0859 - val_accuracy: 0.7380 Epoch 72/100 100/100 - 9s - loss: 0.0012 - accuracy: 0.9990 - val_loss: 2.1255 - val_accuracy: 0.7290 Epoch 73/100 100/100 - 9s - loss: 0.0028 - accuracy: 0.9985 - val_loss: 2.1020 - val_accuracy: 0.7280 Epoch 74/100 100/100 - 9s - loss: 1.7059e-04 - accuracy: 1.0000 - val_loss: 2.2865 - val_accuracy: 0.7240 Epoch 75/100 100/100 - 9s - loss: 0.0031 - accuracy: 0.9985 - val_loss: 2.1766 - val_accuracy: 0.7330 Epoch 76/100 100/100 - 9s - loss: 0.0049 - accuracy: 0.9980 - val_loss: 2.4057 - val_accuracy: 0.7280 Epoch 77/100 100/100 - 9s - loss: 8.6864e-05 - accuracy: 1.0000 - val_loss: 2.2848 - val_accuracy: 0.7360 Epoch 78/100 100/100 - 9s - loss: 0.0093 - accuracy: 0.9980 - val_loss: 2.2741 - val_accuracy: 0.7410 Epoch 79/100 100/100 - 9s - loss: 0.0086 - accuracy: 0.9975 - val_loss: 2.3698 - val_accuracy: 0.7250 Epoch 80/100 100/100 - 9s - loss: 5.1133e-05 - accuracy: 1.0000 - val_loss: 2.2561 - val_accuracy: 0.7340 Epoch 81/100 100/100 - 9s - loss: 0.0048 - accuracy: 0.9980 - val_loss: 2.2609 - val_accuracy: 0.7420 Epoch 82/100 100/100 - 9s - loss: 0.0101 - accuracy: 0.9985 - val_loss: 2.2378 - val_accuracy: 0.7430 Epoch 83/100 100/100 - 9s - loss: 0.0032 - accuracy: 0.9990 - val_loss: 2.1067 - val_accuracy: 0.7420 Epoch 84/100 100/100 - 9s - loss: 0.0068 - accuracy: 0.9980 - val_loss: 2.1805 - val_accuracy: 0.7490 Epoch 85/100 100/100 - 9s - loss: 5.0398e-05 - accuracy: 1.0000 - val_loss: 2.2126 - val_accuracy: 0.7460 Epoch 86/100 100/100 - 9s - loss: 0.0023 - accuracy: 0.9990 - val_loss: 2.4285 - val_accuracy: 0.7470 Epoch 87/100 100/100 - 9s - loss: 0.0047 - accuracy: 0.9985 - val_loss: 2.2399 - val_accuracy: 0.7470 Epoch 88/100 100/100 - 8s - loss: 0.0085 - accuracy: 0.9980 - val_loss: 2.3226 - val_accuracy: 0.7360 Epoch 89/100 100/100 - 9s - loss: 0.0036 - accuracy: 0.9990 - val_loss: 2.3412 - val_accuracy: 0.7390 Epoch 90/100 100/100 - 9s - loss: 3.2244e-05 - accuracy: 1.0000 - val_loss: 2.3509 - val_accuracy: 0.7420 Epoch 91/100 100/100 - 9s - loss: 0.0022 - accuracy: 0.9995 - val_loss: 3.0298 - val_accuracy: 0.7010 Epoch 92/100 100/100 - 9s - loss: 0.0019 - accuracy: 0.9990 - val_loss: 2.3150 - val_accuracy: 0.7440 Epoch 93/100 100/100 - 9s - loss: 0.0044 - accuracy: 0.9990 - val_loss: 2.4564 - val_accuracy: 0.7380 Epoch 94/100 100/100 - 8s - loss: 0.0017 - accuracy: 0.9990 - val_loss: 2.5354 - val_accuracy: 0.7490 Epoch 95/100 100/100 - 8s - loss: 2.4129e-05 - accuracy: 1.0000 - val_loss: 2.8400 - val_accuracy: 0.7430 Epoch 96/100 100/100 - 8s - loss: 0.0041 - accuracy: 0.9985 - val_loss: 2.5172 - val_accuracy: 0.7430 Epoch 97/100 100/100 - 8s - loss: 7.8073e-04 - accuracy: 0.9995 - val_loss: 2.4677 - val_accuracy: 0.7380 Epoch 98/100 100/100 - 8s - loss: 5.6894e-05 - accuracy: 1.0000 - val_loss: 2.5393 - val_accuracy: 0.7350 Epoch 99/100 100/100 - 9s - loss: 0.0027 - accuracy: 0.9990 - val_loss: 2.6687 - val_accuracy: 0.7380 Epoch 100/100 100/100 - 8s - loss: 0.0057 - accuracy: 0.9980 - val_loss: 2.5604 - val_accuracy: 0.7390
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')

plt.figure()

plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()
Notebook Image
Notebook Image

A2: After about 15 epochs, validation reaches it's peak while the model gets close to a 100% accuracy on the training data. This happens because the model fit the data not general enough -> overfitting

Let's augment the images a bit to see if it can improve the model's performance on the validation set. If you think about it, most pictures of a cat are very similar -- the ears are at the top, then the eyes, then the mouth etc. Things like the distance between the eyes and ears will always be quite similar too.

What if we tweak with the images to change this up a bit -- rotate the image, squash it, etc. That's what image augementation is all about. Here are some options for augementation as discussed in the lectures.

  • rotation_range is a value in degrees (0–180), a range within which to randomly rotate pictures.
  • width_shift and height_shift are ranges (as a fraction of total width or height) within which to randomly translate pictures vertically or horizontally.
  • shear_range is for randomly applying shearing transformations.
  • zoom_range is for randomly zooming inside pictures.
  • horizontal_flip is for randomly flipping half of the images horizontally. This is relevant when there are no assumptions of horizontal assymmetry (e.g. real-world pictures).
  • fill_mode is the strategy used for filling in newly created pixels, which can appear after a rotation or a width/height shift.

B. Copy and paste the code in the first code cell into this code cell and MODIFY it such that the TRAINING IMAGES ARE AUGMENTED as they are loaded into the generator.