Saliency maps for model interpretability in Python
INTRODUCTION
In this tutorial, we will look at how to use a deep learning visualisation technique saliency maps for better model interpretability using Python with the help of TensorFlow and Keras API. Please refer to my previous post to learn how to use transfer learning to train pneumonia dataset. The link is here. We will introduce the same model here and use its layer prediction to visualise and interpret the result. Let’s dive into the tutorial.
TRAINING DATASET
We will use pneumonia dataset which we obtained from Kaggle for training. The link to the dataset is here. The training data consists of almost 5000 images divided into two subcategories (normal and pneumonia). Test data and validation data are also in the dataset.
IMPORTING THE PYTHON LIBRARIES
!pip -qq install tf_keras_vis %matplotlib inline import matplotlib.pyplot as plt plt.style.use('ggplot') import pandas as pd import numpy as np import seaborn as sns import warnings import cv2 import glob import os import pickle import zipfile import tensorflow as tf from tensorflow import keras from __future__ import print_function from keras.preprocessing import image from keras.applications import imagenet_utils from tf_keras_vis.saliency import Saliency from tf_keras_vis.utils import normalize from matplotlib import cm from tf_keras_vis.gradcam import Gradcam from PIL import Image
Note that here we have used many use libraries. cv2 is a library majorly used in image processing. Seaborn is similar to matplotlib used for visualising graphs. tf.keras_vis is used for visualisation of deep learning. Please take a look at these libraries to know more about their functions using the documentation.
TRAINING THE MODEL
As I have explained in the introduction, please look at my previous tutorial on using transfer learning for training pneumonia dataset. The link is here. I have detailly explained the procedure for training the dataset. We will use the same DenseNet model for training.
data_path = '/content/drive/MyDrive/Kaggle/chest_xray' train_dir = os.path.join(data_path,'train') test_dir = os.path.join(data_path, 'test') val_dir = os.path.join(data_path, 'val')
from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.applications.densenet import preprocess_input from tensorflow.keras.applications import DenseNet201 train_datagen=ImageDataGenerator(rotation_range=20,width_shift_range=0.3,preprocessing_function=preprocess_input,validation_split=0.1) train_generator=train_datagen.flow_from_directory(train_dir,target_size=(224,224),class_mode='categorical',subset='training',shuffle=True) val_generator=train_datagen.flow_from_directory(train_dir,target_size=(224,224),class_mode='categorical',subset='validation',batch_size=3,shuffle=True)
base_model=DenseNet201(input_shape=[224,224,3],weights='imagenet',include_top=False) x=base_model.output base_model.trainable=False x=keras.layers.GlobalAveragePooling2D()(x) x=keras.layers.Dense(512,activation='relu')(x) preds=keras.layers.Dense(2,activation='softmax')(x) DenseNet=keras.models.Model(inputs=[base_model.input],outputs=[preds]) #specify the inputs and outputs DenseNet.compile(loss='BinaryCrossentropy',optimizer=keras.optimizers.Adam(lr=0.001),metrics=['accuracy']) history=DenseNet.fit(train_generator,validation_data=val_generator,steps_per_epoch=20,validation_steps=20,epochs=5)
Epoch 1/5 20/20 [==============================] - 466s 23s/step - loss: 0.6512 - accuracy: 0.7159 - val_loss: 0.3199 - val_accuracy: 0.8667 Epoch 2/5 20/20 [==============================] - 353s 18s/step - loss: 0.1986 - accuracy: 0.9253 - val_loss: 0.2079 - val_accuracy: 0.9000 Epoch 3/5 20/20 [==============================] - 285s 14s/step - loss: 0.1566 - accuracy: 0.9321 - val_loss: 0.1736 - val_accuracy: 0.9333 Epoch 4/5 20/20 [==============================] - 224s 11s/step - loss: 0.1277 - accuracy: 0.9603 - val_loss: 0.1299 - val_accuracy: 0.9667 Epoch 5/5 20/20 [==============================] - 186s 9s/step - loss: 0.1329 - accuracy: 0.9446 - val_loss: 0.2394 - val_accuracy: 0.8833
We have trained for 5 epochs and let’s use the saliency map to interpret how our model makes the predictions.
SALIENCY MAP
A Saliency map is an image that shows each image’s unique quality. While training birds’ images, how does CNN know to focus on bird-related pixels and ignore the leaves and the other background things in the image? By using the concept of Saliency Map.
Saliency maps are also called a heat map where hotness refers to those regions of the image, which significantly impact predicting the class to which the object belongs. The purpose of the saliency map is to find the regions which are prominent or noticeable at every location in the visual field and to guide the selection of attended locations, based on the spatial distribution of saliency.
First, let’s prepare and pre-process a test image to give input to the saliency graph function.
pneumonia=cv2.imread('/content/drive/MyDrive/Kaggle/chest_xray/test/PNEUMONIA/person100_bacteria_475.jpeg') r_pneumonia = cv2.resize(pneumonia,(224,224),cv2.INTER_AREA) plt.imshow(r_pneumonia)
Here, we took a sample test image and resized it (224,224) pixels to input our DenseNet model. You can take a look at the image at this link.
from tensorflow.keras.applications.densenet import preprocess_input pneumonia=preprocess_input(r_pneumonia) keys={0:'normal',1:'pneumonia'}
Next, import the DenseNet preprocessing function and use it to preprocess the sample image. Now we will build the saliency map utility function to visualise our model function.
def model_modifier(m): m.layers[-1].activation=tf.keras.activations.linear return m # Defining a function to generate saliency graphs for the top 3 predicted classes def saliency_graphs(model, img): # Create Saliency object saliency = Saliency(model, model_modifier) # input image that is pre-processed input_image = img # predict on the input image y_pred = model.predict(input_image) # return the indices in decreasing order of predicted probability class_idxs_sorted = np.argsort(y_pred.flatten())[::-1] for i, class_idx in enumerate(class_idxs_sorted[:3]): # Define loss function for the class label. # The 'output' variable refer to the output of the model loss = lambda output: tf.keras.backend.mean(output[:, class_idx]) # Generate saliency map with smoothing. Smoothing reduces noise in the Saliency map # smooth_samples is the number of calculating gradients iterations saliency_map = saliency(loss, input_image[0,...], smooth_samples=20,smooth_noise=0.20) saliency_map = normalize(saliency_map) plot_saliency_map(saliency_map,img,y_pred,i,class_idx) # Defining a function to plot saliency map def plot_saliency_map(sal_map, img, y_pred, i, class_idx): fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) ax[0].imshow(r_covid) j = ax[1].imshow(sal_map[0],cmap="jet",alpha=0.8) fig.colorbar(j) for axe in ax: axe.grid(False) axe.axis('off') plt.suptitle("Predicted(class={}) = {:5.2f}".format(keys[class_idx],y_pred[0,class_idx]))
Here, we first defined a function (model_modifier) to modify the output to calculate class probability instead of class prediction. Next, we defined a saliency map function to get the visualisation graph. We have used enumeration if multiple images are given as input; it can plot different saliency maps for all the images.
We have defined a function (plot_ssaliency_map) to combine the input image and saliency object function into one and visualise the resulting data in subplots. The saliency map code can be reused for any deep learning models which take images as training data. The resulting image links are here.
CONCLUSION
In this tutorial, we have looked at a deep learning visualisation technique known as saliency map and how it is used to interpret the model predictions. Model interpretability can be very useful if applied in deep learning projects. I hope that it will be useful to you.
Leave a Reply