Control Training using Callbacks in TensorFlow Keras

In this tutorial, let us find out how to stop training when you reach a particular point in TensorFlow Keras. The training loop supports callbacks. So in every epoch you can callback to a code function having fulfilled the required metrics. Let us see how to control training using callbacks with a simple example.

Introduction:

A callback is a function passed to local solver in Python. The local solver object triggers the event and the type of the callback. It is possible to use the same callback method for multiple events and instances. The method can be a static function on a class. Let us see an example below on how to control training using callbacks.

A simple example:

We will see a simple example of predicting housing prices and to stop the training when we reach a particular value for loss.

1. Importing necessary libraries.

First let us import the libraries required for this example.

import tensorflow as tf
import numpy as np
from tensorflow import keras

We have imported TensorFlow, NumPy and Keras from the TensorFlow

2. The callback function.

Let us define a callback function. Here the training stops when we reach a loss less than 280 therefore further epochs need not be executed. The corresponding housing price will be printed.

class myCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs={}):
            if(logs.get('loss')>280):
              print("\nReached loss less than 280 so cancelling training!")
              self.model.stop_training = True

3. Training the data.

Now we pass in some values to train the function. Here ‘x’ represents the number of bedrooms whereas ‘y’ represents the cost of each house respectively.
Here there is only a single neuron for this function. The loss function measures the how good the guess is and then passes it to the optimizer.
Optimizer makes sure that the next guess is better than the one before.
Here the loss is ‘mean squared error‘ and optimizer is ‘stochastic gradient descent‘ , the TensorFlow documentation can be checked for more details.
Here we try to stop predicting the value after we reach a particular loss.
Notice we have an additional parameter ‘callbacks’ in the fit command.

The training of the data takes place in fit command.It tells us of fitting the values of ‘y’ for ‘x’.

callbacks=myCallback()
def house_model(y_new):
    x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 9.0, 10.0,11.0, 12.0, 13.0], dtype=float)
    y = np.array([100.0, 150.0, 200.0, 250.0, 300.0, 350.0, 450.0, 500.0, 550.0,600.0, 650.0,700.0], dtype=float)
   
    model = tf.keras.Sequential([keras.layers.Dense(units=1,input_shape=[1])])
    model.compile(optimizer='sgd',loss='mean_squared_error')
    model.fit(x,y,epochs=100, callbacks=[callbacks])
    return (model.predict(y_new)[0]+1) //100

4.Predicting the value.

Here we are trying to predict the value for a 7  bedroom house., it is evident from the data that it costs 400., but we can see that after the loss goes below 280 the training stops and the respective price is printed.

prediction = house_model([7.0])
print(prediction)

Output:

12/12 [==============================] - 0s 11ms/sample - loss: 209838.7500
Epoch 2/100
12/12 [==============================] - 0s 117us/sample - loss: 19126.5352
Epoch 3/100
12/12 [==============================] - 0s 97us/sample - loss: 2111.4070
Epoch 4/100
12/12 [==============================] - 0s 94us/sample - loss: 589.9004
Epoch 5/100
12/12 [==============================] - 0s 96us/sample - loss: 450.4402
Epoch 6/100
12/12 [==============================] - 0s 95us/sample - loss: 434.2926
Epoch 7/100
12/12 [==============================] - 0s 86us/sample - loss: 429.1788
Epoch 8/100
12/12 [==============================] - 0s 87us/sample - loss: 425.0830
Epoch 9/100
12/12 [==============================] - 0s 89us/sample - loss: 421.1125
Epoch 10/100
12/12 [==============================] - 0s 108us/sample - loss: 417.1864
Epoch 11/100
12/12 [==============================] - 0s 118us/sample - loss: 413.2977
Epoch 12/100
12/12 [==============================] - 0s 96us/sample - loss: 409.4454
Epoch 13/100
12/12 [==============================] - 0s 95us/sample - loss: 405.6286
Epoch 14/100
12/12 [==============================] - 0s 95us/sample - loss: 401.8479
Epoch 15/100
12/12 [==============================] - 0s 96us/sample - loss: 398.1024
Epoch 16/100
12/12 [==============================] - 0s 96us/sample - loss: 394.3917
Epoch 17/100
12/12 [==============================] - 0s 95us/sample - loss: 390.7153
Epoch 18/100
12/12 [==============================] - 0s 97us/sample - loss: 387.0735
Epoch 19/100
12/12 [==============================] - 0s 123us/sample - loss: 383.4654
Epoch 20/100
12/12 [==============================] - 0s 103us/sample - loss: 379.8912
Epoch 21/100
12/12 [==============================] - 0s 96us/sample - loss: 376.3503
Epoch 22/100
12/12 [==============================] - 0s 96us/sample - loss: 372.8422
Epoch 23/100
12/12 [==============================] - 0s 83us/sample - loss: 369.3668
Epoch 24/100
12/12 [==============================] - 0s 90us/sample - loss: 365.9237
Epoch 25/100
12/12 [==============================] - 0s 86us/sample - loss: 362.5131
Epoch 26/100
12/12 [==============================] - 0s 83us/sample - loss: 359.1340
Epoch 27/100
12/12 [==============================] - 0s 78us/sample - loss: 355.7863
Epoch 28/100
12/12 [==============================] - 0s 91us/sample - loss: 352.4701
Epoch 29/100
12/12 [==============================] - 0s 83us/sample - loss: 349.1844
Epoch 30/100
12/12 [==============================] - 0s 93us/sample - loss: 345.9299
Epoch 31/100
12/12 [==============================] - 0s 94us/sample - loss: 342.7055
Epoch 32/100
12/12 [==============================] - 0s 94us/sample - loss: 339.5108
Epoch 33/100
12/12 [==============================] - 0s 86us/sample - loss: 336.3465
Epoch 34/100
12/12 [==============================] - 0s 91us/sample - loss: 333.2111
Epoch 35/100
12/12 [==============================] - 0s 84us/sample - loss: 330.1054
Epoch 36/100
12/12 [==============================] - 0s 90us/sample - loss: 327.0284
Epoch 37/100
12/12 [==============================] - 0s 87us/sample - loss: 323.9803
Epoch 38/100
12/12 [==============================] - 0s 82us/sample - loss: 320.9605
Epoch 39/100
12/12 [==============================] - 0s 92us/sample - loss: 317.9686
Epoch 40/100
12/12 [==============================] - 0s 91us/sample - loss: 315.0049
Epoch 41/100
12/12 [==============================] - 0s 96us/sample - loss: 312.0686
Epoch 42/100
12/12 [==============================] - 0s 115us/sample - loss: 309.1598
Epoch 43/100
12/12 [==============================] - 0s 91us/sample - loss: 306.2781
Epoch 44/100
12/12 [==============================] - 0s 94us/sample - loss: 303.4231
Epoch 45/100
12/12 [==============================] - 0s 83us/sample - loss: 300.5948
Epoch 46/100
12/12 [==============================] - 0s 83us/sample - loss: 297.7930
Epoch 47/100
12/12 [==============================] - 0s 80us/sample - loss: 295.0171
Epoch 48/100
12/12 [==============================] - 0s 78us/sample - loss: 292.2673
Epoch 49/100
12/12 [==============================] - 0s 83us/sample - loss: 289.5433
Epoch 50/100
12/12 [==============================] - 0s 81us/sample - loss: 286.8441
Epoch 51/100
12/12 [==============================] - 0s 111us/sample - loss: 284.1707
Epoch 52/100
12/12 [==============================] - 0s 91us/sample - loss: 281.5216
Epoch 53/100
12/12 [==============================] - 0s 91us/sample - loss: 278.8977
Reached loss less than 280 so cancelling training!
[[373.512112]]

 

You may also see:

Leave a Reply

Your email address will not be published. Required fields are marked *