Sklearn Objects – fit() v/s transform() v/s fit_transform() v/s predict()

What is the difference between fit(), transform(), fit_transform() and predict()?

This question has been arising a lot in our mind and every Machine Learning beginner does struggle with these functions. And hence, we are here to explain you the major difference between these functions from the scikit-learn Python module.

Before we proceed, we should know that in a machine learning algorithm, we first begin with data preprocessing and then we train the model.

In data preprocessing, we will use the fit(), transform() and the fit_transform() functions.

Meanwhile in model training, we use fit() and predict() functions.


Data Preprocessing:

In data preprocessing, we use those three functions as mentioned above.

The fit() function is used to find the mean and the standard deviation of the dataset.

The transform() function is used after the fit() function where we mathematically transform the values using the mean and the standard deviation values derived from the fit() function.

However, there is already a function which clubs both of these functions and hence we get to see the fit_transform() function which first ‘fits’ the dataset and then it ‘transforms’ them in just a one single function.

However, we use fit_transform()for the train dataset, while transform()is used for the test dataset.

To use data preprocessing, we write the following line of code to install the library:

from sklearn.preprocessing import StandardScaler()

After installing the library, we can use these functions.


Now we will take a data set that represents the offices in the Middle East.

The data set is:

As you can see, the data set has been loaded by using pandas. Before we begin with data preprocessing, we will split the data into testing and training using the train_test_split()function to get x_train, y_train, x_test, y_test.

We can write the following code to get that:

from sklearn.model_selection import train_test_split

x_train, y_train, x_test, y_test = train_test_split(x,y)

print(f"x train dataset: {x_train}\n")
print(f"y train dataset: {y_train}\n")
print(f"x test dataset: {x_test}\n")
print(f"y test dataset: {y_test}")


Now we begin with train and test scale values of x by writing the following piece of code:

datascale = StandardScaler()

#We will apply the fit_transform function first

x_train_scaled = datascale.fit_transform(x_train)

x_test_reshape = array(x_test).reshape(3,2)
x_trans = datascale.transform(x_test_reshape)

#Here x_trans is the scaled test value of x

Model Training:

Now we will start with model training. We will use the fit() and predict() functions. The fit() function here has a different role. It here finds the best fit line that can be used for the training and the predict() function is used to predict the values.

We begin by importing the LogisticRegression library.

#Time for training the model
from sklearn.linear_model import LogisticRegression

Then we will use the fit() function:

model = LogisticRegression()

#We will fit the model

y_train = array(y_train).flatten(), y_train)

Then comes the predict() function usage:

pred_val = model.predict(x_trans)

print(f"The predicted value is: {pred_val}")

After this, we get to see this output:

The predicted value is: [5610000 5610000 5610000]

Hence, we are done with the code.



Data Preprocessing:

fit() is used for finding the mean and standard deviation, transform() is used to transform the data values using the mean and standard deviation we get. fit_transform() is a special function which used both fit() and transform() functions at one.


Model training:

Here only fit() is used to find the best line and predict() is used for predicting the outcomes.

Thank you for reading the tutorial. We will bring more such tutorials for you. Till then, happy learning!

Leave a Reply

Your email address will not be published. Required fields are marked *