Softmax Regression with Keras

In this tutorial, we will learn about softmax regression which is a general form of logistic regression but in the case where there are multiple classes.

Deep learning can be performed using many frameworks like TensorFlow, Caffe, Theano, but here we will use the Keras API of the popular Python TensorFlow framework to show how softmax regression works.

Let us first deep dive into the concepts of softmax regression before implementing it. So as we know in Logistic Regression we have two classes whose probability are given as 0 or 1 on the basis of if they are the probable answer or not. In softmax regression, if we have 4 classes that represent that there is a dog, a cat, a cow, or nothing in the picture. The addition of all the probabilities of above will be equal to 1 here.


First, let’s import all the libraries to implement softmax regression.

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from keras.models import Model

Here we have imported the NumPy which helps us in dealing with vectors and matrices. We have imported Keras with backend as TensorFlow. From keras.models we have imported the Model functions that will help us to fit and compile models and Sequential model to import Sequential model.

Now we will load the data set we will use for softmax regression i.e the mnist digit dataset. This has photos of handwritten digits from 0-9. Our model has to predict the digit from the picture given. So there will be 10 classes here for each digit. That is the reason why we are using softmax regression.

mnist = tf.keras.datasets.mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()

Now that we have uploaded our dataset. We will do a basic normalization that can be done on an image. Note that you can normalize the picture with more techniques to get better accuracy.

training_images = training_images / 255.0
test_images = test_images / 255.0

Let us start our model. This model will be sequential which means all the layers are connected. There are two dense layers with different nodes. A dense layer is a layer of neurons. Then we have flatten() which converts the dense layer to a single neuron. The activation function used here is relu and the algorithms used is softmax regression.

model=keras.models.Sequential([tf.keras.layers.Flatten(),  tf.keras.layers.Dense(128, activation = tf.nn.relu),  
                                    tf.keras.layers.Dense(10, activation = tf.nn.softmax)])

Now we will define a loss function and an optimisation method i.e adam. Note that we cannot use cross-entropy loss here as we have multiple classes. Let us train the model on a dataset.

model.compile(optimizer = tf.optimizers.Adam(), loss = 'sparse_categorical_crossentropy', metrics =['accuracy']), training_labels, epochs = 5)
Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2601 - accuracy: 0.9254
Epoch 2/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.1136 - accuracy: 0.9663
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0780 - accuracy: 0.9764
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0580 - accuracy: 0.9825
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0443 - accuracy: 0.9865

<tensorflow.python.keras.callbacks.History at 0x7f4796f97a90>

We can see at the end the accuracy is equal to 0.98. This network has 98% accuracy which is good with only 5 epochs. Now let’s evaluate the model.

model.evaluate(test_images, test_labels)
313/313 [==============================] - 0s 1ms/step - loss: 0.0780 - accuracy: 0.9756

[0.07798175513744354, 0.975600004196167]

Here the first value is loss and the second is the accuracy.

We have implemented the prediction of digit from the picture using softmax regression. The class with the max probability in the vector will be the output. We can see that a code of 100 lines was reduced to a few lines using Keras. That is what makes Keras such a great framework to work with.

Leave a Reply

Your email address will not be published. Required fields are marked *