Data Augmentation using Keras in Python

In this blog, You will understand about:

  • What is Data Augmentation
  • Why use Data  Augmentation
  • How to do Data Augmentation using Keras in Python

So, Let’s dive into it.

What is Data Augmentation?

Data Augmentation is a technique that is used to increase the diversity of the training set by applying various transformations and it increases the size of the data present in the training set. As a result of this, A new dataset is made that contains data with the new transformations.

Note: It is only applied to the Training set and not on the Validation set or the Test set because the training set is used to train the model and validation and test set are used for the testing of the model.

Why use Data  Augmentation?

Deep Learning Algorithms are data-hungry. We need a lot of data, in order to make a good deep learning model. So, In order to increase the amount of training data, we can use Web Scrawling. In this, We collect data such as images from the internet.  But, it is not cost-effective if you making software. So, in order to get more data, we do data augmentation, which creates an artificial but realistic dataset.

In this blog, We will perform Data Augmentation on Images using the Keras ImageDataGenerator class. It generates batches of tensor image data with real-time data augmentation.

How to do Data Augmentation in Python using Keras TensorFlow API?

Firstly, We will import all the necessary Python libraries that are required for the task.

import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator

 

Now, We will read an image by either writing the name of the image or by passing the complete path of the image.

#loading the image
img = image.load_image("bird.jpg") 

# converting image to array
img = image.img_to_array(img)/255.0

 

Lets us first see how the original image looks like.

#plotting the image

plt.imshow(img)
plt.axis("off")
plt.show()
OUTPUT:

 

Finally, we will perform data augmentation and it has various transformations such as width shift, zoom, flip, and many more. We will use some of the transformations on the image.

#data augmentation
augmentation = ImageDataGenerator(rotation_range=25, width_shift_range=0.2,
                                  height_shift_range=0.2,  
                                  zoom_range=0.1, horizontal_flip=True)

#convering the image into array
aug_img = image.img_to_array(img)
#expanding the dimensions of the array
aug_img = np.expand_dims(aug_img,axis=0)

#applying the transformation on the image
transformed_image = augmentation.flow(aug_img)

 

Let us see the transformed image.

#plotting the transformations applied to the image.

plt.figure(fig_size=(15,15))
for i in range(1,10):
    plt.subplot(3,3,i)
    x = transformed_image.next()
    aug = x[0].astype('uint8')
    plt.imshow(aug)
plt.show()
OUTPUT:

 

By the output, we can observe transformation such as:

  1.  Width Shift
  2. Height Shift
  3. Rotation
  4. Horizontal Flip
  5. Zoom

Conclusion

Wrapping up!! We have learned about data augmentation, its use, and how to use it. It is a very useful technique to avoid overfitting.

You can ask your doubts or any suggestions in the comment.

Leave a Reply

Your email address will not be published. Required fields are marked *