Transfer Learning to classify images of cats & dogs

Author

Gunardi

Published

October 22, 2023

In this Jupyter Notebook, I try to build a predictor to classify images of cats and dogs by using transfer learning from a MobilenetV2 pre-trained network (developed at Google).

The idea behind transfer learning is to use a previously trained model, which is trained on a large-scale dataset. This pretrained model will serve as a generic model, meaning by itself it could already classify many different objects (including cat and dog) up to certain accuracy. Since the goal of this notebook is to classify cats and dogs, the trained model needs to be customised/fine-tuned to achieve a higher accuracy specifically for predicting cats and dogs.

To customize a pretrained model, we can perform: 1. Add a new classifier on top of the pretrained model. This new classifier layer will be trained, while the actual pretrained model will be freezed. 2. Fine tune the pretrained model afterwards, by unfreezing few top layers of the pretrained model and jointly trained the new added classifier layers (from previous step) and the last layers of the base model.

import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

Step 1: Data preprocessing

1.1 Download Dataset

Download the zip file containing the images of cats and dogs. After extracting the zip file, create a tf.data.Dataset for training and validation dataset using tf.keras.utils.image_dataset_from_directory. More about it, check the following tensorflow tutorial.

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68606236/68606236 [==============================] - 0s 0us/step
Found 2000 files belonging to 2 classes.
print("path_to_zip:", path_to_zip)
print("Parent directory contains:", os.listdir(os.path.dirname(path_to_zip)))
print("PATH:", PATH)
print("train_dir:", train_dir)
print("validation_dir:", validation_dir)
path_to_zip: /root/.keras/datasets/cats_and_dogs.zip
Parent directory contains: ['cats_and_dogs_filtered', 'cats_and_dogs.zip']
PATH: /root/.keras/datasets/cats_and_dogs_filtered
train_dir: /root/.keras/datasets/cats_and_dogs_filtered/train
validation_dir: /root/.keras/datasets/cats_and_dogs_filtered/validation
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                                 shuffle=True,
                                                                 batch_size=BATCH_SIZE,
                                                                 image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Check the first nine images and labels from the training set:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

The original dataset doesn’t contain a test set. Therefore, we need to use a portion (i.e. 20%) of validation dataset for test purpose and the rest for validation. It can be done using tf.data.experimental.cardinality.

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

1.2 Preprocess the dataset to increase training performance

This guide describes multiple ways to preprocess the dataset for performance purpose. We will apply buffrered prefetching to load images from disk without having I/O bottleneck.

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

1.3 Use data augmentation

Next we could horizontal flip and rotate the images to artificially increase the total amount of training dataset. This helps expose the model to different aspects of training data and reduce overfitting. Check out the tensorflow tutorial about data augmentation.

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip('horizontal'),
  tf.keras.layers.RandomRotation(0.2),
])

Note: These layers are active only during training, when we call Model.fit. They are inactive when the model is used in inference mode in Model.evaluate, Model.predict, or Model.call.

Let’s repeatedly apply these layers to the same image and see the result.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

1.4 Rescale pixel values

Next we will download tf.keras.applications.MobileNetV2 pretrained model as our base model. It expects pixel values in [-1, 1]. Our image still has pixel values in [0, 255]. To rescale, we need to insert a preprocessing layer to the model.

Note: If using other tf.keras.applications, be sure to check the API doc to determine if they expect pixels in [-1, 1] or [0, 1], or use the included preprocess_input function.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

Note: Alternatively, we could rescale pixel values from [0, 255] to [-1, 1] using tf.keras.layers.Rescaling.

rescale = tf.keras.layers.Rescaling(1./127.5, offset=-1)

Step 2: Use pre-trained MobileNet V2 convnets as base model

In this step, we will use MobileNet V2 model as our base model. This model is pre-trained on the ImageNet dataset, consisting of 1.4M images and 1000 classes. It contains wide range of categories like goldfish and banana. This generic base model will help us classify cats and dogs from our specific dataset.

To achieve that, we need to perform: - Loading the MobileNet V2 to be our base model - When loading, exclude the last classification layer (also called “top” layers, as most diagrams of machine learning models go from bottom to top). This can be done with include_top=False.

These top layers are not useful for our task, therefore we depend only on the very last layer before the flatten operation. This layer is also called “bottleneck layer”, whose features retain more generality as compared to the final top layers.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step

This base model (feature extractor) converts each 160x160x3 image into a 5x5x1280 block of features:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Step 3: Freeze the convolutional base

In this step, the imported base model should be freezed. This is important, so that the pretrained weights from MobileNet V2 are not modified/updated during the training.

This base model will be used as a feature extractor.

base_model.trainable = False

Important note about BatchNormalization layers

Many models contain tf.keras.layers.BatchNormalization layers. This layer is a special case and precautions should be taken in the context of fine-tuning, as shown later in this notebook.

When we set layer.trainable = False, the BatchNormalization layer will run in inference mode, and will not update its mean and variance statistics.

When we unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, we should keep the BatchNormalization layers in inference mode by passing training = False when calling the base model. Otherwise, the updates applied to the non-trainable weights will destroy what the model has learned.

For more details, see the Transfer learning guide.

# Let's take a look at the base model architecture
base_model.summary()

# The output of this cell is very long and thus not shown in HTML file.

The output shape from the last layer: (None, 5, 5, 1280)

Step 4: Add a classification layer

In this step, we will add new classification layers on top of our base model. We will later train these new layers.

These new classification layers will generate predictions using the input from the base model.

It is done first by applying tf.keras.layers.GlobalAveragePooling2D layer to convert the features (by averaging over the 5X5 spatial locations) into a single 1280-element vector per image.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Now we have output with shape (32, 1280).i.e. batch size: 32 and every image is represented in 1280 parameters.

This output 1280 parameters doesn’t say if the image is cat or dog. Therefore, we need to apply tf.keras.layers.Dense layer to convert these features into a single prediction per image. We don’t need an activation function here because this prediction will be treated as a logit, or a raw prediction value. Positive numbers predict class 1, negative numbers predict class 0.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Build a model by chaining together the data augmentation, rescaling, base_model and feature extractor layers using the Keras Functional API. As previously mentioned, use training=False as our model contains a BatchNormalization layer.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv (TFOpLambd  (None, 160, 160, 3)       0         
 a)                                                              
                                                                 
 tf.math.subtract (TFOpLamb  (None, 160, 160, 3)       0         
 da)                                                             
                                                                 
 mobilenetv2_1.00_160 (Func  (None, 5, 5, 1280)        2257984   
 tional)                                                         
                                                                 
 global_average_pooling2d (  (None, 1280)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 1)                 1281      
                                                                 
=================================================================
Total params: 2259265 (8.62 MB)
Trainable params: 1281 (5.00 KB)
Non-trainable params: 2257984 (8.61 MB)
_________________________________________________________________

The 2+ million parameters in MobileNet are frozen, but there are 1.2 thousand trainable parameters in the Dense layer. These are divided between two tf.Variable objects, the weights and biases.

len(model.trainable_variables)
2
tf.keras.utils.plot_model(model, show_shapes=True)

Step 5: Compile and train the model

5.1 Compile

Compile the model before training it. Since there are two classes, use the tf.keras.losses.BinaryCrossentropy loss with from_logits=True since the model provides a linear output.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])

5.2 Train the model

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 4s 65ms/step - loss: 0.8981 - accuracy: 0.3663
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 0.90
initial accuracy: 0.37
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 8s 64ms/step - loss: 0.8094 - accuracy: 0.5110 - val_loss: 0.6164 - val_accuracy: 0.6473
Epoch 2/10
63/63 [==============================] - 5s 73ms/step - loss: 0.6036 - accuracy: 0.6750 - val_loss: 0.4551 - val_accuracy: 0.8094
Epoch 3/10
63/63 [==============================] - 4s 56ms/step - loss: 0.4784 - accuracy: 0.7815 - val_loss: 0.3550 - val_accuracy: 0.8787
Epoch 4/10
63/63 [==============================] - 4s 57ms/step - loss: 0.4047 - accuracy: 0.8260 - val_loss: 0.2901 - val_accuracy: 0.9146
Epoch 5/10
63/63 [==============================] - 4s 62ms/step - loss: 0.3406 - accuracy: 0.8655 - val_loss: 0.2445 - val_accuracy: 0.9257
Epoch 6/10
63/63 [==============================] - 4s 56ms/step - loss: 0.3110 - accuracy: 0.8710 - val_loss: 0.2103 - val_accuracy: 0.9418
Epoch 7/10
63/63 [==============================] - 4s 60ms/step - loss: 0.2888 - accuracy: 0.8825 - val_loss: 0.1891 - val_accuracy: 0.9480
Epoch 8/10
63/63 [==============================] - 4s 63ms/step - loss: 0.2602 - accuracy: 0.8975 - val_loss: 0.1753 - val_accuracy: 0.9517
Epoch 9/10
63/63 [==============================] - 4s 57ms/step - loss: 0.2523 - accuracy: 0.9080 - val_loss: 0.1559 - val_accuracy: 0.9592
Epoch 10/10
63/63 [==============================] - 5s 79ms/step - loss: 0.2327 - accuracy: 0.9085 - val_loss: 0.1477 - val_accuracy: 0.9616

After training for 10 epochs, we could see ~96% accuracy on the validation set.

5.3 Visualise learning curves

Let’s take a look at the learning curves of the training and validation accuracy/loss when using the MobileNetV2 base model as a fixed feature extractor.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Note: The reason why the validation metrics are clearly better than the training metrics, the main factor is because layers like tf.keras.layers.BatchNormalization and tf.keras.layers.Dropout affect accuracy during training. They are turned off when calculating validation loss.

To a lesser extent, it is also because training metrics report the average for an epoch, while validation metrics are evaluated after the epoch, so validation metrics see a model that has trained slightly longer.

Step 6: Fine tuning

In the previous training, we only trained the new classifier layers, which were added above the base model. The weights of MobileNetV2 base model were not updated during the training.

In this step, we want to increase performance even further by training (fine-tuning) the top layers of the pretrained model alongside the training of the new added classifier layers. This training process will force the weights to be tuned from generic feature maps to features associated specifically with the dataset.

Note: This should only be attempted after we have trained the top-level classifier with the pre-trained model set to non-trainable. If we add a randomly initialized classifier on top of a pre-trained model and attempt to train all layers jointly, the magnitude of the gradient updates will be too large (due to the random weights from the classifier) and our pre-trained model will forget what it has learned.

Also, we should try to fine-tune a small number of top layers rather than the whole MobileNet model. In most convolutional networks, the higher up a layer is, the more specialized it is. The first few layers learn very simple and generic features that generalize to almost all types of images. As we go higher up, the features are increasingly more specific to the dataset on which the model was trained. The goal of fine-tuning is to adapt these specialized features to work with the new dataset, rather than overwrite the generic learning.

6.1 Un-freeze the top layers of the base model

All we need to do is unfreeze the base_modeland set the bottom layers to be un-trainable.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable = False
Number of layers in the base model:  154

6.2 Compile the model

Afterwards, we need to recompile the model (necessary for these changes to take effect) and resume training.

We need to use a lower learning rate at this stage, because we are training a much larger model and want to readapt the pretrained weights. Otherwise, the model will overfit very quickly.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
              metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0, name='accuracy')])
model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_2 (InputLayer)        [(None, 160, 160, 3)]     0         
                                                                 
 sequential (Sequential)     (None, 160, 160, 3)       0         
                                                                 
 tf.math.truediv (TFOpLambd  (None, 160, 160, 3)       0         
 a)                                                              
                                                                 
 tf.math.subtract (TFOpLamb  (None, 160, 160, 3)       0         
 da)                                                             
                                                                 
 mobilenetv2_1.00_160 (Func  (None, 5, 5, 1280)        2257984   
 tional)                                                         
                                                                 
 global_average_pooling2d (  (None, 1280)              0         
 GlobalAveragePooling2D)                                         
                                                                 
 dropout (Dropout)           (None, 1280)              0         
                                                                 
 dense (Dense)               (None, 1)                 1281      
                                                                 
=================================================================
Total params: 2259265 (8.62 MB)
Trainable params: 1862721 (7.11 MB)
Non-trainable params: 396544 (1.51 MB)
_________________________________________________________________
len(model.trainable_variables)
56

6.3 Retrain the model

Let’s retrain the same model from earlier. With this step, the accuracy should improve a little bit.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 19s 95ms/step - loss: 0.1670 - accuracy: 0.9345 - val_loss: 0.0741 - val_accuracy: 0.9728
Epoch 11/20
63/63 [==============================] - 4s 60ms/step - loss: 0.1286 - accuracy: 0.9515 - val_loss: 0.0553 - val_accuracy: 0.9790
Epoch 12/20
63/63 [==============================] - 4s 61ms/step - loss: 0.1187 - accuracy: 0.9540 - val_loss: 0.0562 - val_accuracy: 0.9839
Epoch 13/20
63/63 [==============================] - 5s 79ms/step - loss: 0.1065 - accuracy: 0.9565 - val_loss: 0.0567 - val_accuracy: 0.9740
Epoch 14/20
63/63 [==============================] - 4s 61ms/step - loss: 0.1006 - accuracy: 0.9620 - val_loss: 0.0719 - val_accuracy: 0.9691
Epoch 15/20
63/63 [==============================] - 4s 68ms/step - loss: 0.0903 - accuracy: 0.9630 - val_loss: 0.0392 - val_accuracy: 0.9876
Epoch 16/20
63/63 [==============================] - 4s 61ms/step - loss: 0.0779 - accuracy: 0.9690 - val_loss: 0.0445 - val_accuracy: 0.9839
Epoch 17/20
63/63 [==============================] - 4s 61ms/step - loss: 0.0680 - accuracy: 0.9725 - val_loss: 0.0433 - val_accuracy: 0.9814
Epoch 18/20
63/63 [==============================] - 5s 79ms/step - loss: 0.0716 - accuracy: 0.9730 - val_loss: 0.0415 - val_accuracy: 0.9790
Epoch 19/20
63/63 [==============================] - 9s 141ms/step - loss: 0.0505 - accuracy: 0.9770 - val_loss: 0.0349 - val_accuracy: 0.9814
Epoch 20/20
63/63 [==============================] - 7s 106ms/step - loss: 0.0638 - accuracy: 0.9770 - val_loss: 0.0461 - val_accuracy: 0.9802

As can be seen, the validation loss is much higher than training loss, it might be due to overfitting.

The overfitting could be caused by the relatively small and similar to the original MobileNet V2 datasets.

After fine tuning the model nearly reaches 98% accuracy on the validation set.

6.4 Visualise learning curves

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Step 7: Predict

Finally we can verify the performance by using the unseen test dataset.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 40ms/step - loss: 0.0274 - accuracy: 0.9792
Test accuracy : 0.9791666865348816

And now we are all set to use this model to predict if the pet is a cat or dog.

# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1 1 1 1 1 1 1 0 0 1 0 1 1]
Labels:
 [0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 1 1 0 1 1 1 1 0 1 1 1 0 0 1 0 1 1]

Summary

In this notebook we have applied the following techniques to improve the predictor’s performance:

  • Using a pretrained model for feature extraction: The best practice when working with a small dataset is to take advantage of pretrained model, which is trained on a larger dataset. This is done by adding classifier layers on top of the pretrained model. During the training, the pretrained model is frozen and only the weights of the classifier get updated.

  • Fine-tuning the pre-trained model: To further improve the performance, we could unfreeze the top layers of the pretrained model and freeze other layers. This way, the top layers of pretrained model are learning the high level features specific to our dataset. This fine tuning should only be done, when the dataset is very similar with the original dataset used to train the pretrained model.

Further learning resources for transfer learning: link

Clean Up

Run the following cell to terminate the kernel and free memory resources:

from google.colab import runtime
runtime.unassign()