Stratified Split

In the context of machine learning, stratified split is a method used to divide a dataset into training and test sets. This method is often used when the data has a categorical target variable and the categories are imbalanced. The stratified split method divides the data in such a way that the proportion of each category remains the same in both the training and test sets as they are in the original dataset.

Definition

Stratified split is a specific type of data splitting method, most commonly used in machine learning. In a stratified split, the data is split in a way that maintains the same proportions of classes in both the training and test sets as are present in the original dataset. This is particularly useful when the data is imbalanced, i.e., one class has many more samples than the other.

Pros

  1. Maintains class distribution: This is the main advantage of stratified split. It ensures that the train and test sets have the same class distribution. This is particularly important when dealing with imbalanced datasets.

  2. Better performance metrics: Since the class distribution is the same in the train and test sets, the performance metrics will be more reliable.

Cons

  1. Not suitable for all datasets: Stratified split is not useful when dealing with regression problems or multilabel classification problems. It’s also not suitable for datasets with large numbers of classes or with classes that have very few samples.

  2. Requires careful handling of data: If the data is not properly shuffled before splitting, stratified split can lead to biased results.

Comparison to Other Methods

Stratified split is one of many methods to split a dataset into training and test sets. Other methods include random split, time-based split, and scaffold-based split. The best method depends on the specific task and data at hand. Compared to these methods, stratified split is specifically designed for classification problems where the class distribution is important.

Example in Python

In this example, we will use the train_test_split function from the sklearn.model_selection module with the stratify parameter. This function splits arrays or matrices into random train and test subsets. By specifying the stratify parameter, we can ensure that the class distribution in the train and test sets is the same as in the original data.

First, we need to install the necessary library. If you haven’t installed scikit-learn, you can do so by running the following command:

!pip install scikit-learn

Please note that this example assumes you have a dataset with a categorical target variable.

[1]:
# Import necessary libraries
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

# Let's assume we have a dataset with 1000 samples and a binary target variable with a 9:1 imbalance
X = np.random.rand(1000, 10)
y = np.concatenate([np.ones(100), np.zeros(900)])

# Perform a stratified split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Check the class distribution in the train and test sets
train_counts = np.bincount(y_train.astype('int'))
test_counts = np.bincount(y_test.astype('int'))

# Plot the class distribution in the train and test sets
plt.figure(figsize=(8, 6))
plt.bar([0, 1, 2, 3], np.concatenate([train_counts, test_counts]), tick_label=['Train Class 0', 'Train Class 1', 'Test Class 0', 'Test Class 1'])
plt.title('Class distribution in the train and test sets')
plt.ylabel('Count')
plt.show()
../../_images/ipynbs_Notes_stratified_split_2_0.png