Control Training using Callbacks in TensorFlow Keras
In this tutorial, let us find out how to stop training when you reach a particular point in TensorFlow Keras. The training loop supports callbacks. So in every epoch you can callback to a code function having fulfilled the required metrics. Let us see how to control training using callbacks with a simple example.
Introduction:
A callback is a function passed to local solver in Python. The local solver object triggers the event and the type of the callback. It is possible to use the same callback method for multiple events and instances. The method can be a static function on a class. Let us see an example below on how to control training using callbacks.
A simple example:
We will see a simple example of predicting housing prices and to stop the training when we reach a particular value for loss.
1. Importing necessary libraries.
First let us import the libraries required for this example.
import tensorflow as tf import numpy as np from tensorflow import keras
We have imported TensorFlow, NumPy and Keras from the TensorFlow
2. The callback function.
Let us define a callback function. Here the training stops when we reach a loss less than 280 therefore further epochs need not be executed. The corresponding housing price will be printed.
class myCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, logs={}): if(logs.get('loss')>280): print("\nReached loss less than 280 so cancelling training!") self.model.stop_training = True
3. Training the data.
Now we pass in some values to train the function. Here ‘x’ represents the number of bedrooms whereas ‘y’ represents the cost of each house respectively.
Here there is only a single neuron for this function. The loss function measures the how good the guess is and then passes it to the optimizer.
Optimizer makes sure that the next guess is better than the one before.
Here the loss is ‘mean squared error‘ and optimizer is ‘stochastic gradient descent‘ , the TensorFlow documentation can be checked for more details.
Here we try to stop predicting the value after we reach a particular loss.
Notice we have an additional parameter ‘callbacks’ in the fit command.
The training of the data takes place in fit command.It tells us of fitting the values of ‘y’ for ‘x’.
callbacks=myCallback() def house_model(y_new): x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 8.0, 9.0, 10.0,11.0, 12.0, 13.0], dtype=float) y = np.array([100.0, 150.0, 200.0, 250.0, 300.0, 350.0, 450.0, 500.0, 550.0,600.0, 650.0,700.0], dtype=float) model = tf.keras.Sequential([keras.layers.Dense(units=1,input_shape=[1])]) model.compile(optimizer='sgd',loss='mean_squared_error') model.fit(x,y,epochs=100, callbacks=[callbacks]) return (model.predict(y_new)[0]+1) //100
4.Predicting the value.
Here we are trying to predict the value for a 7 bedroom house., it is evident from the data that it costs 400., but we can see that after the loss goes below 280 the training stops and the respective price is printed.
prediction = house_model([7.0]) print(prediction)
Output:
12/12 [==============================] - 0s 11ms/sample - loss: 209838.7500 Epoch 2/100 12/12 [==============================] - 0s 117us/sample - loss: 19126.5352 Epoch 3/100 12/12 [==============================] - 0s 97us/sample - loss: 2111.4070 Epoch 4/100 12/12 [==============================] - 0s 94us/sample - loss: 589.9004 Epoch 5/100 12/12 [==============================] - 0s 96us/sample - loss: 450.4402 Epoch 6/100 12/12 [==============================] - 0s 95us/sample - loss: 434.2926 Epoch 7/100 12/12 [==============================] - 0s 86us/sample - loss: 429.1788 Epoch 8/100 12/12 [==============================] - 0s 87us/sample - loss: 425.0830 Epoch 9/100 12/12 [==============================] - 0s 89us/sample - loss: 421.1125 Epoch 10/100 12/12 [==============================] - 0s 108us/sample - loss: 417.1864 Epoch 11/100 12/12 [==============================] - 0s 118us/sample - loss: 413.2977 Epoch 12/100 12/12 [==============================] - 0s 96us/sample - loss: 409.4454 Epoch 13/100 12/12 [==============================] - 0s 95us/sample - loss: 405.6286 Epoch 14/100 12/12 [==============================] - 0s 95us/sample - loss: 401.8479 Epoch 15/100 12/12 [==============================] - 0s 96us/sample - loss: 398.1024 Epoch 16/100 12/12 [==============================] - 0s 96us/sample - loss: 394.3917 Epoch 17/100 12/12 [==============================] - 0s 95us/sample - loss: 390.7153 Epoch 18/100 12/12 [==============================] - 0s 97us/sample - loss: 387.0735 Epoch 19/100 12/12 [==============================] - 0s 123us/sample - loss: 383.4654 Epoch 20/100 12/12 [==============================] - 0s 103us/sample - loss: 379.8912 Epoch 21/100 12/12 [==============================] - 0s 96us/sample - loss: 376.3503 Epoch 22/100 12/12 [==============================] - 0s 96us/sample - loss: 372.8422 Epoch 23/100 12/12 [==============================] - 0s 83us/sample - loss: 369.3668 Epoch 24/100 12/12 [==============================] - 0s 90us/sample - loss: 365.9237 Epoch 25/100 12/12 [==============================] - 0s 86us/sample - loss: 362.5131 Epoch 26/100 12/12 [==============================] - 0s 83us/sample - loss: 359.1340 Epoch 27/100 12/12 [==============================] - 0s 78us/sample - loss: 355.7863 Epoch 28/100 12/12 [==============================] - 0s 91us/sample - loss: 352.4701 Epoch 29/100 12/12 [==============================] - 0s 83us/sample - loss: 349.1844 Epoch 30/100 12/12 [==============================] - 0s 93us/sample - loss: 345.9299 Epoch 31/100 12/12 [==============================] - 0s 94us/sample - loss: 342.7055 Epoch 32/100 12/12 [==============================] - 0s 94us/sample - loss: 339.5108 Epoch 33/100 12/12 [==============================] - 0s 86us/sample - loss: 336.3465 Epoch 34/100 12/12 [==============================] - 0s 91us/sample - loss: 333.2111 Epoch 35/100 12/12 [==============================] - 0s 84us/sample - loss: 330.1054 Epoch 36/100 12/12 [==============================] - 0s 90us/sample - loss: 327.0284 Epoch 37/100 12/12 [==============================] - 0s 87us/sample - loss: 323.9803 Epoch 38/100 12/12 [==============================] - 0s 82us/sample - loss: 320.9605 Epoch 39/100 12/12 [==============================] - 0s 92us/sample - loss: 317.9686 Epoch 40/100 12/12 [==============================] - 0s 91us/sample - loss: 315.0049 Epoch 41/100 12/12 [==============================] - 0s 96us/sample - loss: 312.0686 Epoch 42/100 12/12 [==============================] - 0s 115us/sample - loss: 309.1598 Epoch 43/100 12/12 [==============================] - 0s 91us/sample - loss: 306.2781 Epoch 44/100 12/12 [==============================] - 0s 94us/sample - loss: 303.4231 Epoch 45/100 12/12 [==============================] - 0s 83us/sample - loss: 300.5948 Epoch 46/100 12/12 [==============================] - 0s 83us/sample - loss: 297.7930 Epoch 47/100 12/12 [==============================] - 0s 80us/sample - loss: 295.0171 Epoch 48/100 12/12 [==============================] - 0s 78us/sample - loss: 292.2673 Epoch 49/100 12/12 [==============================] - 0s 83us/sample - loss: 289.5433 Epoch 50/100 12/12 [==============================] - 0s 81us/sample - loss: 286.8441 Epoch 51/100 12/12 [==============================] - 0s 111us/sample - loss: 284.1707 Epoch 52/100 12/12 [==============================] - 0s 91us/sample - loss: 281.5216 Epoch 53/100 12/12 [==============================] - 0s 91us/sample - loss: 278.8977 Reached loss less than 280 so cancelling training! [[373.512112]]
Leave a Reply