Handling Imbalanced Datasets with scikit-learn

Handling Imbalanced Datasets with scikit-learn

Imbalanced datasets are a common challenge in the field of machine learning and data science. They occur when the number of instances in one class significantly outnumbers the instances in another class, leading to a disproportionate ratio between the classes. This imbalance can have a negative impact on the performance of machine learning models, as they may become biased towards the majority class and fail to adequately capture the patterns of the minority class.

For example, in a binary classification problem, if we have 95% of samples belonging to class A and only 5% belonging to class B, our dataset is highly imbalanced. Such scenarios are typical in real-world applications like fraud detection, medical diagnosis, and churn prediction where the event of interest is rare compared to the normal occurrences.

Why is this a problem? The issue is that most machine learning algorithms are designed to maximize overall accuracy. However, with imbalanced datasets, a model can achieve high accuracy by simply predicting the majority class for all instances. This might yield a high accuracy rate, but it’s not useful as it fails to detect the minority class, which is often the class of interest.

Take, for instance, a dataset for fraud detection where the majority of transactions are non-fraudulent. A naive model might predict that all transactions are non-fraudulent to achieve high accuracy, but such a model would be useless in a real-world scenario because it fails to identify any fraudulent transactions.

It is important to address dataset imbalance to prevent model bias and to ensure that the predictive performance of a model is genuine and reliable, especially for the minority class which often represents critical outcomes in a problem domain.

Understanding the nature and impact of imbalanced datasets is the first step towards handling them effectively. In the following sections, we will explore various techniques and strategies to manage imbalanced datasets and improve model performance using scikit-learn, a powerful machine learning library for Python.

Techniques for Handling Imbalanced Datasets

There are several techniques to handle imbalanced datasets effectively. These techniques can be broadly categorized into two groups: data-level methods and algorithm-level methods.

Data-level methods involve modifying the dataset to balance the class distribution before feeding it to the machine learning model. This can be achieved through resampling techniques, which include oversampling the minority class, undersampling the majority class, or a combination of both. Another data-level method is generating synthetic samples for the minority class using algorithms like SMOTE (Synthetic Minority Over-sampling Technique).

Algorithm-level methods involve modifying existing learning algorithms to make them more sensitive to the minority class. This could be through cost-sensitive learning where higher misclassification costs are assigned to the minority class, or by using ensemble methods like bagging and boosting which can improve model performance on imbalanced datasets.

To illustrate how these techniques can be implemented with scikit-learn, let’s take a look at some code examples:

Oversampling the minority class:

from imblearn.over_sampling import RandomOverSampler
from collections import Counter
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Create a synthetic imbalanced classification dataset
X, y = make_classification(n_classes=2, class_sep=2,
weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)

print('Original dataset shape %s' % Counter(y))

# Oversample the minority class
ros = RandomOverSampler(random_state=42)
X_res, y_res = ros.fit_resample(X, y)

print('Resampled dataset shape %s' % Counter(y_res))

Using SMOTE to generate synthetic samples:

from imblearn.over_sampling import SMOTE

# Apply SMOTE
sm = SMOTE(random_state=42)
X_res, y_res = sm.fit_resample(X, y)

print('Resampled dataset shape %s' % Counter(y_res))

Cost-sensitive learning:

from sklearn.svm import SVC

# Train a SVM with class_weight='balanced'
svm = SVC(kernel='linear', class_weight='balanced', C=1.0, random_state=42)
svm.fit(X_train, y_train)

Ensemble methods such as Random Forest with class_weight adjustment:

from sklearn.ensemble import RandomForestClassifier

# Train a RandomForestClassifier with class_weight='balanced'
rf = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42)
rf.fit(X_train, y_train)

By applying these techniques, we can mitigate the effects of imbalanced datasets and build models that are more robust and fair, giving due consideration to both majority and minority classes.

Implementing Resampling Methods with scikit-learn

Now that we have seen various techniques to handle imbalanced datasets, let’s delve deeper into implementing resampling methods with scikit-learn. Resampling methods are one of the most simpler data-level methods to balance class distribution. It either involves adding more copies of instances from the minority class or removing instances from the majority class. Scikit-learn in combination with imbalanced-learn library provides simple and effective tools for implementing these strategies.

Let’s start with implementing undersampling of the majority class. The idea is to reduce the number of instances from the majority class to make the dataset more balanced. We can achieve this using the RandomUnderSampler class from imbalanced-learn. Here’s how you can do it:

from imblearn.under_sampling import RandomUnderSampler

# Initialize the RandomUnderSampler
rus = RandomUnderSampler(random_state=42)

# Resample the dataset
X_rus, y_rus = rus.fit_resample(X, y)

print('Resampled dataset shape (undersampling) %s' % Counter(y_rus))

After applying undersampling, we can see that the number of instances in both classes is now equal, which means that the dataset is balanced.

Similarly, we can combine both oversampling and undersampling to balance the dataset. This technique is known as hybrid sampling. Here’s an example of how to implement this using SMOTEENN which is a combination of SMOTE and Edited Nearest Neighbors (ENN) technique:

from imblearn.combine import SMOTEENN

sme = SMOTEENN(random_state=42)
X_sme, y_sme = sme.fit_resample(X, y)

print('Resampled dataset shape (SMOTEENN) %s' % Counter(y_sme))

By applying hybrid sampling, we can not only add synthetic samples to the minority class but also clean the dataset by removing samples from the majority class that are close to the decision boundary.

It’s important to note that when implementing these resampling techniques, they should only be applied to the training set and not to the test set. The test set should reflect the true distribution of the classes in the real-world scenario. Therefore, you should split your dataset into a training set and a test set before applying any resampling method:

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Apply resampling only to the training data
X_resampled, y_resampled = ros.fit_resample(X_train, y_train)

By using scikit-learn and imbalanced-learn, you can easily implement resampling methods to create a more balanced dataset which should lead to better performance of your machine learning models on imbalanced datasets.

Evaluating Model Performance on Imbalanced Datasets

When it comes to evaluating model performance on imbalanced datasets, traditional metrics like accuracy may not be the best indicators of a model’s effectiveness. Instead, metrics that provide more insight into the classification of the minority class, such as precision, recall, and the F1 score, are more appropriate. The confusion matrix is also a valuable tool for assessing model performance in the context of imbalanced classes.

Let’s take a closer look at how these metrics can be calculated using scikit-learn:

from sklearn.metrics import confusion_matrix, classification_report

# Assume we have a trained model and a test set
y_pred = model.predict(X_test)

# Generate the confusion matrix
conf_mat = confusion_matrix(y_test, y_pred)

# Generate a classification report
class_report = classification_report(y_test, y_pred)

The confusion matrix provides a breakdown of the true positives, false positives, true negatives, and false negatives. This allows us to see how well the model is identifying the minority class, as opposed to just labeling everything as the majority class.

The classification report, on the other hand, gives us precision, recall, and F1 score for each class. Precision measures the model’s accuracy in predicting positive labels, recall measures the model’s ability to find all positive samples, and the F1 score is the harmonic mean of precision and recall.

Another useful metric is the Area Under the Receiver Operating Characteristic Curve (AUC-ROC). The ROC curve plots the true positive rate against the false positive rate at various threshold settings, and the AUC-ROC represents the probability that the model ranks a random positive instance higher than a random negative instance.

from sklearn.metrics import roc_auc_score

# Calculate AUC-ROC
roc_auc = roc_auc_score(y_test, y_pred)
print('AUC-ROC:', roc_auc)

It is important to remember that no single metric can fully capture the performance of a model on an imbalanced dataset. It’s always advisable to look at a range of metrics to get a complete picture of how the model is performing.

When evaluating models on imbalanced datasets, it’s important to move beyond accuracy and ponder metrics that focus on the model’s ability to predict the minority class. Tools like confusion matrices, classification reports, and AUC-ROC score in scikit-learn make it possible to get a comprehensive understanding of model performance in the presence of class imbalance.


No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *