Data Augmentation using Keras in Python
In this blog, You will understand about:
- What is Data Augmentation
- Why use Data Augmentation
- How to do Data Augmentation using Keras in Python
So, Let’s dive into it.
What is Data Augmentation?
Data Augmentation is a technique that is used to increase the diversity of the training set by applying various transformations and it increases the size of the data present in the training set. As a result of this, A new dataset is made that contains data with the new transformations.
Note: It is only applied to the Training set and not on the Validation set or the Test set because the training set is used to train the model and validation and test set are used for the testing of the model.
Why use Data Augmentation?
Deep Learning Algorithms are data-hungry. We need a lot of data, in order to make a good deep learning model. So, In order to increase the amount of training data, we can use Web Scrawling. In this, We collect data such as images from the internet. But, it is not cost-effective if you making software. So, in order to get more data, we do data augmentation, which creates an artificial but realistic dataset.
In this blog, We will perform Data Augmentation on Images using the Keras ImageDataGenerator class. It generates batches of tensor image data with real-time data augmentation.
How to do Data Augmentation in Python using Keras TensorFlow API?
Firstly, We will import all the necessary Python libraries that are required for the task.
import numpy as np import matplotlib.pyplot as plt from keras.preprocessing import image from keras.preprocessing.image import ImageDataGenerator
Now, We will read an image by either writing the name of the image or by passing the complete path of the image.
#loading the image img = image.load_image("bird.jpg") # converting image to array img = image.img_to_array(img)/255.0
Lets us first see how the original image looks like.
#plotting the image plt.imshow(img) plt.axis("off") plt.show()
Finally, we will perform data augmentation and it has various transformations such as width shift, zoom, flip, and many more. We will use some of the transformations on the image.
#data augmentation augmentation = ImageDataGenerator(rotation_range=25, width_shift_range=0.2, height_shift_range=0.2, zoom_range=0.1, horizontal_flip=True) #convering the image into array aug_img = image.img_to_array(img) #expanding the dimensions of the array aug_img = np.expand_dims(aug_img,axis=0) #applying the transformation on the image transformed_image = augmentation.flow(aug_img)
Let us see the transformed image.
#plotting the transformations applied to the image. plt.figure(fig_size=(15,15)) for i in range(1,10): plt.subplot(3,3,i) x = transformed_image.next() aug = x.astype('uint8') plt.imshow(aug) plt.show()
By the output, we can observe transformation such as:
- Width Shift
- Height Shift
- Horizontal Flip
Wrapping up!! We have learned about data augmentation, its use, and how to use it. It is a very useful technique to avoid overfitting.
You can ask your doubts or any suggestions in the comment.