How to use Keras fit and fit_generator in Python

In this tutorial, we will learn and deeply understand the concepts of Keras fit and fit_generator in Python. We will learn the working of Keras .fit and .fit_generator functions, along with their differences and similarities. In this tutorial, the model implementation along with theory for the implementation is to provide a complete understanding of the concepts in depth.
The model training functions provided by the Keras deep learning library are :

  • .fit
  • .fit_generator
  • .train_on_batch

We will be discussing fit and fit_generator functions.

Keras fit() function

The syntax for using the fit function is as shown below:

fit(object, x = NULL, y = NULL, batch_size = NULL, epochs = 10,
  callbacks = NULL, , validation_split = 0, validation_data = NULL,
  shuffle = TRUE, class_weight = NULL, sample_weight = NULL,
  initial_epoch = 0, steps_per_epoch = NULL, validation_steps = NULL)

An explanation for some of the parameters are:

  • batch_size
    The number of samples per gradient.
  • callbacks
    Different callback calls.
  • verbose
    Verbosity mode.
  • epochs
    The number of times to loop over the training data.
  • validation_split
    The ratio between training and validation data, this value should be from 0.0 to 0.99.
  • validation_data
    It is the data for which the validation will be carried out.
  • shuffle
    True or false value to shuffling data.
  • class_weight
    this parameter gives more stress to the weights added to the more thereby telling model to give more attention to weights.

We can use Keras fit function for training the model as:

model.fit(X_train, y_train, batch_size=30, epochs=40)

where X_train and y_train are the training data for feature class and prediction class respectively,
batch_size represents the number of batch division used in each training epoch. Here 30 batches are trained per epoch for 40 times.
The main reasons for using fit functions are based on assumption that fit function train using the complete RAM and there is no need for any data augmentation requirement for the training dataset, i.e the input data remains as the raw data. We will not manipulate or change the data while the execution when using the fit function.

Since the data fit completely into the RAM there is no need for exchanging old batch with a new batch. While using fit function it is possible to call the function multiple times because the reinitialization does not take place but needed to be managed properly in order for the systematic working of the model. As a final note, we can say that fit function is applied the data generators are not required.

Keras .fit_generator() function

The syntax of the fit_generator function() is :

fit_generator(object, generator, steps_per_epoch, epochs = 1,
  verbose = getOption("keras.fit_verbose", default = 1),
  callbacks = NULL, view_metrics = getOption("keras.view_metrics",
  default = "auto"), validation_data = NULL, validation_steps = NULL,
  class_weight = NULL, max_queue_size = 10, workers = 1,
  initial_epoch = 0)

An explanation for some of the parameters are:

  • generator:
    The generator output is of the form :
    -> input, targets
    -> input, targets, sample_weights
    a single output makes a single batch and hence all arrays in the list have the length equal to the size of the batch.
  • steps_per_epoch:
    It gives the value of total training set data divided by total no of batch data.
  • Epochs:
    no of times we want to train the model.
  • Verbose :
    Verbosity mode.
  • callbacks:
    Different callbacks applicable to the model.
  • validation_data:
    -> an inputs and targets list
    -> a generator
  • validation_steps:
    applicable only if the validation_data is a generator then only this argument.

 

The syntax for fit_generator function in training a model is:

# initialize the number of epochs and batch size
epochs = 100
bs = 32

# construct the training image generator for data augmentation
aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,
      width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,
      horizontal_flip=True, fill_mode="nearest")


# train the model
history = model.fit_generator(aug.flow(X_train,y_train,batch_size=bs),
          validation_data=(testX, testY), steps_per_epoch=len(trainX) // bs,
          epochs=epochs)

Image Data Generation

Keras ImageDataGenerator is used for data augmentation, which is used to alter the pixel position and coloring properties such as scaling and rescaling, rotating, inverting, brightening, etc. Thus when we apply augmentation to our data we alter the static state of the dataset, which means when each batch is selected for fitting the data changes according to the parameter set provided by the ImageDataGenerator.This property of augmenting images helps to generalize the data better thus providing a simple regulariser to further reduce overfitting in data.

Thus when the data changes in each new batch, the regular fit() function does not work which urge us to choose another function fit_generator() that assures that there is an underlying function that generated the data for us even when the data changes. The generator function yield the batch size to fit_generator which backpropagates to update the weight until the number of epochs are reach desirable value thereby increasing the accuracy of the model a further step more.

As a further note, we can notice in the syntax of fit.generator() that there is a need for steps_per_epoch which using this function. This is because of the fit.generator() function loop for infinite time without any returns or exit. Thus steps_per_epoch helps to determine the starting and stopping of an epoch by dividing the length of the dataset by batch size. So when the loop hits the number, Keras will be able to differentiate between Epochs.

Summary

So we have learned that fit() and fit_generator() are part of training the neural network. The fit() function does not need data augmentation and has only one step to fit data completely in Ram, Whereas fit_generator() uses data augmentation thus creating different batches to fit in a systematic manner.

Leave a Reply

Your email address will not be published. Required fields are marked *