Skip to main content

Command Palette

Search for a command to run...

Why we use fit_transform() for Training Data but only transform() for Testing Data

Published
3 min read
Why we use fit_transform() for Training Data but only transform() for Testing Data

When you start working with machine learning pipelines, one of the first things you’ll notice is that we often write code like this:

from sklearn.preprocessing import StandardScaler  

scaler = StandardScaler()

# Training set
X_train_scaled = scaler.fit_transform(X_train)

# Test set
X_test_scaled = scaler.transform(X_test)

At first glance, this can be confusing.
Why do we call fit_transform() on the training set, but only transform() on the test set?
Let’s break it down step by step:


1. What does fit() mean?

The fit() step is where the model or transformer learns parameters from data.
For example, in StandardScaler, it calculates:

  • The mean of each feature in X_train

  • The standard deviation of each feature in X_train

These values are stored inside the scaler object and later used to scale data.

Think of fit() as “learn from the data.”


2. What does transform() mean?

The transform() step actually applies the learned parameters to scale or modify the dataset.

  • In StandardScaler, it subtracts the mean and divides by the standard deviation (that were already learned in fit()).

Think of transform() as “apply what was learned.”


3. Why fit_transform() on Training Data?

When working with the training dataset:

  • We need to learn the parameters (mean, standard deviation, min/max and etc.)

  • And then apply them to scale the training data

Instead of calling fit() and then transform() separately, scikit-learn provides the shortcut fit_transform() to do both in one step.

So:

X_train_scaled = scaler.fit_transform(X_train)

This ensures the scaler learns the statistics of the training set only.


4. Why only transform() on Test Data?

When scaling the test dataset, we must not refit.
Here’s why:

  • The model should be evaluated on unseen data using the same transformation learned from training.

  • If we call fit() on the test set, it will recalculate mean and std from test data, which is “cheating”, it leaks information from the test set into the model training process.

  • This leakage makes performance metrics (accuracy, precision, etc.) unrealistically high and untrustworthy.

So instead, we just call:

X_test_scaled = scaler.transform(X_test)

This applies the training set parameters to the test data, keeping the process fair.


5. A Real-Life Analogy

Imagine you’re designing a suit:

  • First, you measure yourself (fit on training data).

  • Then, you tailor your clothes to those measurements (transform training data).

  • Later, when your friend borrows your suit (test data), they don’t re-measure you. They just wear the suit already made (transform test data).

If you re-measured for each friend, the “model” (the suit) would keep changing, which defeats the purpose.


6. Summary

  • fit() → Learn parameters (mean, std and etc.) from training data.

  • transform() → Apply those learned parameters to scale data.

  • fit_transform() → Convenient combo of both steps (used on training set).

  • Never fit on test data → avoids information leakage and ensures fair evaluation.

Always remember:

  • Training data → fit_transform()

  • Test data → transform()