Saliency maps for model interpretability in Python

INTRODUCTION

In this tutorial, we will look at how to use a deep learning visualisation technique saliency maps for better model interpretability using Python with the help of TensorFlow and Keras API. Please refer to my previous post to learn how to use transfer learning to train pneumonia dataset. The link is here. We will introduce the same model here and use its layer prediction to visualise and interpret the result. Let’s dive into the tutorial.

TRAINING DATASET

We will use pneumonia dataset which we obtained from Kaggle for training. The link to the dataset is here. The training data consists of almost 5000 images divided into two subcategories (normal and pneumonia). Test data and validation data are also in the dataset.

IMPORTING THE PYTHON LIBRARIES

!pip -qq install tf_keras_vis 
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import pandas as pd
import numpy as np
import seaborn as sns
import warnings
import cv2 
import glob
import os
import pickle
import zipfile
import tensorflow as tf
from tensorflow import keras
from __future__ import print_function
from keras.preprocessing import image
from keras.applications import imagenet_utils
from tf_keras_vis.saliency import Saliency
from tf_keras_vis.utils import normalize
from matplotlib import cm
from tf_keras_vis.gradcam import Gradcam
from PIL import Image

Note that here we have used many use libraries. cv2 is a library majorly used in image processing. Seaborn is similar to matplotlib used for visualising graphs. tf.keras_vis is used for visualisation of deep learning. Please take a look at these libraries to know more about their functions using the documentation.

TRAINING THE MODEL

As I have explained in the introduction, please look at my previous tutorial on using transfer learning for training pneumonia dataset. The link is here. I have detailly explained the procedure for training the dataset. We will use the same DenseNet model for training.

data_path = '/content/drive/MyDrive/Kaggle/chest_xray'
train_dir = os.path.join(data_path,'train')
test_dir = os.path.join(data_path, 'test')
val_dir = os.path.join(data_path, 'val')
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.applications import DenseNet201

train_datagen=ImageDataGenerator(rotation_range=20,width_shift_range=0.3,preprocessing_function=preprocess_input,validation_split=0.1)
train_generator=train_datagen.flow_from_directory(train_dir,target_size=(224,224),class_mode='categorical',subset='training',shuffle=True)
val_generator=train_datagen.flow_from_directory(train_dir,target_size=(224,224),class_mode='categorical',subset='validation',batch_size=3,shuffle=True)
base_model=DenseNet201(input_shape=[224,224,3],weights='imagenet',include_top=False) 
x=base_model.output
base_model.trainable=False
x=keras.layers.GlobalAveragePooling2D()(x)
x=keras.layers.Dense(512,activation='relu')(x)
preds=keras.layers.Dense(2,activation='softmax')(x) 
DenseNet=keras.models.Model(inputs=[base_model.input],outputs=[preds]) #specify the inputs and outputs

DenseNet.compile(loss='BinaryCrossentropy',optimizer=keras.optimizers.Adam(lr=0.001),metrics=['accuracy'])

history=DenseNet.fit(train_generator,validation_data=val_generator,steps_per_epoch=20,validation_steps=20,epochs=5)
Epoch 1/5 20/20 [==============================] - 466s 23s/step - loss: 0.6512 - accuracy: 0.7159 - val_loss: 0.3199 - val_accuracy: 0.8667 
Epoch 2/5 20/20 [==============================] - 353s 18s/step - loss: 0.1986 - accuracy: 0.9253 - val_loss: 0.2079 - val_accuracy: 0.9000 
Epoch 3/5 20/20 [==============================] - 285s 14s/step - loss: 0.1566 - accuracy: 0.9321 - val_loss: 0.1736 - val_accuracy: 0.9333 
Epoch 4/5 20/20 [==============================] - 224s 11s/step - loss: 0.1277 - accuracy: 0.9603 - val_loss: 0.1299 - val_accuracy: 0.9667 
Epoch 5/5 20/20 [==============================] - 186s 9s/step - loss: 0.1329 - accuracy: 0.9446 - val_loss: 0.2394 - val_accuracy: 0.8833

We have trained for 5 epochs and let’s use the saliency map to interpret how our model makes the predictions.

SALIENCY MAP

A Saliency map is an image that shows each image’s unique quality. While training birds’ images, how does CNN know to focus on bird-related pixels and ignore the leaves and the other background things in the image? By using the concept of Saliency Map.

Saliency maps are also called a heat map where hotness refers to those regions of the image, which significantly impact predicting the class to which the object belongs. The purpose of the saliency map is to find the regions which are prominent or noticeable at every location in the visual field and to guide the selection of attended locations, based on the spatial distribution of saliency.

First, let’s prepare and pre-process a test image to give input to the saliency graph function.

pneumonia=cv2.imread('/content/drive/MyDrive/Kaggle/chest_xray/test/PNEUMONIA/person100_bacteria_475.jpeg')

r_pneumonia = cv2.resize(pneumonia,(224,224),cv2.INTER_AREA)
plt.imshow(r_pneumonia)

Here, we took a sample test image and resized it (224,224) pixels to input our DenseNet model. You can take a look at the image at this link.

from tensorflow.keras.applications.densenet import preprocess_input
pneumonia=preprocess_input(r_pneumonia)
keys={0:'normal',1:'pneumonia'}

Next, import the DenseNet preprocessing function and use it to preprocess the sample image. Now we will build the saliency map utility function to visualise our model function.

def model_modifier(m):
  m.layers[-1].activation=tf.keras.activations.linear
  return m
  
# Defining a function to generate saliency graphs for the top 3 predicted classes
def saliency_graphs(model, img):

  # Create Saliency object
  saliency = Saliency(model, model_modifier)

  # input image that is pre-processed
  input_image = img

  # predict on the input image
  y_pred = model.predict(input_image)

  # return the indices in decreasing order of predicted probability
  class_idxs_sorted = np.argsort(y_pred.flatten())[::-1]

  for i, class_idx in enumerate(class_idxs_sorted[:3]):

    # Define loss function for the class label.
    # The 'output' variable refer to the output of the model
    loss = lambda output: tf.keras.backend.mean(output[:, class_idx])   

    # Generate saliency map with smoothing. Smoothing reduces noise in the Saliency map
    # smooth_samples is the number of calculating gradients iterations
    saliency_map = saliency(loss, input_image[0,...], smooth_samples=20,smooth_noise=0.20)    
    saliency_map = normalize(saliency_map)
    plot_saliency_map(saliency_map,img,y_pred,i,class_idx)


# Defining a function to plot saliency map
def plot_saliency_map(sal_map, img, y_pred, i, class_idx):

  fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
  ax[0].imshow(r_covid)
  j = ax[1].imshow(sal_map[0],cmap="jet",alpha=0.8)
  fig.colorbar(j)
  for axe in ax:
    axe.grid(False)
    axe.axis('off')
  plt.suptitle("Predicted(class={}) = {:5.2f}".format(keys[class_idx],y_pred[0,class_idx]))

Here, we first defined a function (model_modifier) to modify the output to calculate class probability instead of class prediction. Next, we defined a saliency map function to get the visualisation graph. We have used enumeration if multiple images are given as input; it can plot different saliency maps for all the images.

We have defined a function (plot_ssaliency_map) to combine the input image and saliency object function into one and visualise the resulting data in subplots. The saliency map code can be reused for any deep learning models which take images as training data. The resulting image links are here.

CONCLUSION

In this tutorial, we have looked at a deep learning visualisation technique known as saliency map and how it is used to interpret the model predictions. Model interpretability can be very useful if applied in deep learning projects. I hope that it will be useful to you.

 

 

Leave a Reply

Your email address will not be published. Required fields are marked *