10 KiB
Supervised Learning in scikit-learn
Overview
Supervised learning algorithms learn patterns from labeled training data to make predictions on new data. Scikit-learn organizes supervised learning into 17 major categories.
Linear Models
Regression
- LinearRegression: Ordinary least squares regression
- Ridge: L2-regularized regression, good for multicollinearity
- Lasso: L1-regularized regression, performs feature selection
- ElasticNet: Combined L1/L2 regularization
- LassoLars: Lasso using Least Angle Regression algorithm
- BayesianRidge: Bayesian approach with automatic relevance determination
Classification
- LogisticRegression: Binary and multiclass classification
- RidgeClassifier: Ridge regression for classification
- SGDClassifier: Linear classifiers with SGD training
Use cases: Baseline models, interpretable predictions, high-dimensional data, when linear relationships are expected
Key parameters:
alpha: Regularization strength (higher = more regularization)fit_intercept: Whether to calculate interceptsolver: Optimization algorithm ('lbfgs', 'saga', 'liblinear')
Support Vector Machines (SVM)
- SVC: Support Vector Classification
- SVR: Support Vector Regression
- LinearSVC: Linear SVM using liblinear (faster for large datasets)
- OneClassSVM: Unsupervised outlier detection
Use cases: Complex non-linear decision boundaries, high-dimensional spaces, when clear margin of separation exists
Key parameters:
kernel: 'linear', 'poly', 'rbf', 'sigmoid'C: Regularization parameter (lower = more regularization)gamma: Kernel coefficient ('scale', 'auto', or float)degree: Polynomial degree (for poly kernel)
Performance tip: SVMs don't scale well beyond tens of thousands of samples. Use LinearSVC for large datasets with linear kernel.
Decision Trees
- DecisionTreeClassifier: Classification tree
- DecisionTreeRegressor: Regression tree
- ExtraTreeClassifier/Regressor: Extremely randomized tree
Use cases: Non-linear relationships, feature importance analysis, interpretable rules, handling mixed data types
Key parameters:
max_depth: Maximum tree depth (controls overfitting)min_samples_split: Minimum samples to split a nodemin_samples_leaf: Minimum samples in leaf nodemax_features: Number of features to consider for splitscriterion: 'gini', 'entropy' (classification); 'squared_error', 'absolute_error' (regression)
Overfitting prevention: Limit max_depth, increase min_samples_split/leaf, use pruning with ccp_alpha
Ensemble Methods
Random Forests
- RandomForestClassifier: Ensemble of decision trees
- RandomForestRegressor: Regression variant
Use cases: Robust general-purpose algorithm, reduces overfitting vs single trees, handles non-linear relationships
Key parameters:
n_estimators: Number of trees (higher = better but slower)max_depth: Maximum tree depthmax_features: Features per split ('sqrt', 'log2', int, float)bootstrap: Whether to use bootstrap samplesn_jobs: Parallel processing (-1 uses all cores)
Gradient Boosting
- HistGradientBoostingClassifier/Regressor: Histogram-based, fast for large datasets (>10k samples)
- GradientBoostingClassifier/Regressor: Traditional implementation, better for small datasets
Use cases: High-performance predictions, winning Kaggle competitions, structured/tabular data
Key parameters:
n_estimators: Number of boosting stageslearning_rate: Shrinks contribution of each treemax_depth: Maximum tree depth (typically 3-8)subsample: Fraction of samples per tree (enables stochastic gradient boosting)early_stopping: Stop when validation score stops improving
Performance tip: HistGradientBoosting is orders of magnitude faster for large datasets
AdaBoost
- AdaBoostClassifier/Regressor: Adaptive boosting
Use cases: Boosting weak learners, less prone to overfitting than other methods
Key parameters:
estimator: Base estimator (default: DecisionTreeClassifier with max_depth=1)n_estimators: Number of boosting iterationslearning_rate: Weight applied to each classifier
Bagging
- BaggingClassifier/Regressor: Bootstrap aggregating with any base estimator
Use cases: Reducing variance of unstable models, parallel ensemble creation
Key parameters:
estimator: Base estimator to fitn_estimators: Number of estimatorsmax_samples: Samples to draw per estimatorbootstrap: Whether to use replacement
Voting & Stacking
- VotingClassifier/Regressor: Combines different model types
- StackingClassifier/Regressor: Meta-learner trained on base predictions
Use cases: Combining diverse models, leveraging different model strengths
Neural Networks
- MLPClassifier: Multi-layer perceptron classifier
- MLPRegressor: Multi-layer perceptron regressor
Use cases: Complex non-linear patterns, when gradient boosting is too slow, deep feature learning
Key parameters:
hidden_layer_sizes: Tuple of hidden layer sizes (e.g., (100, 50))activation: 'relu', 'tanh', 'logistic'solver: 'adam', 'lbfgs', 'sgd'alpha: L2 regularization termlearning_rate: Learning rate scheduleearly_stopping: Stop when validation score stops improving
Important: Feature scaling is critical for neural networks. Always use StandardScaler or similar.
Nearest Neighbors
- KNeighborsClassifier/Regressor: K-nearest neighbors
- RadiusNeighborsClassifier/Regressor: Radius-based neighbors
- NearestCentroid: Classification using class centroids
Use cases: Simple baseline, irregular decision boundaries, when interpretability isn't critical
Key parameters:
n_neighbors: Number of neighbors (typically 3-11)weights: 'uniform' or 'distance' (distance-weighted voting)metric: Distance metric ('euclidean', 'manhattan', 'minkowski')algorithm: 'auto', 'ball_tree', 'kd_tree', 'brute'
Naive Bayes
- GaussianNB: Assumes Gaussian distribution of features
- MultinomialNB: For discrete counts (text classification)
- BernoulliNB: For binary/boolean features
- CategoricalNB: For categorical features
- ComplementNB: Adapted for imbalanced datasets
Use cases: Text classification, fast baseline, when features are independent, small training sets
Key parameters:
alpha: Smoothing parameter (Laplace/Lidstone smoothing)fit_prior: Whether to learn class prior probabilities
Linear/Quadratic Discriminant Analysis
- LinearDiscriminantAnalysis: Linear decision boundary with dimensionality reduction
- QuadraticDiscriminantAnalysis: Quadratic decision boundary
Use cases: When classes have Gaussian distributions, dimensionality reduction, when covariance assumptions hold
Gaussian Processes
- GaussianProcessClassifier: Probabilistic classification
- GaussianProcessRegressor: Probabilistic regression with uncertainty estimates
Use cases: When uncertainty quantification is important, small datasets, smooth function approximation
Key parameters:
kernel: Covariance function (RBF, Matern, RationalQuadratic, etc.)alpha: Noise level
Limitation: Doesn't scale well to large datasets (O(n³) complexity)
Stochastic Gradient Descent
- SGDClassifier: Linear classifiers with SGD
- SGDRegressor: Linear regressors with SGD
Use cases: Very large datasets (>100k samples), online learning, when data doesn't fit in memory
Key parameters:
loss: Loss function ('hinge', 'log_loss', 'squared_error', etc.)penalty: Regularization ('l2', 'l1', 'elasticnet')alpha: Regularization strengthlearning_rate: Learning rate schedule
Semi-Supervised Learning
- SelfTrainingClassifier: Self-training with any base classifier
- LabelPropagation: Label propagation through graph
- LabelSpreading: Label spreading (modified label propagation)
Use cases: When labeled data is scarce but unlabeled data is abundant
Feature Selection
- VarianceThreshold: Remove low-variance features
- SelectKBest: Select K highest scoring features
- SelectPercentile: Select top percentile of features
- RFE: Recursive feature elimination
- RFECV: RFE with cross-validation
- SelectFromModel: Select features based on importance
- SequentialFeatureSelector: Forward/backward feature selection
Use cases: Reducing dimensionality, removing irrelevant features, improving interpretability, reducing overfitting
Probability Calibration
- CalibratedClassifierCV: Calibrate classifier probabilities
Use cases: When probability estimates are important (not just class predictions), especially with SVM and Naive Bayes
Methods:
sigmoid: Platt scalingisotonic: Isotonic regression (more flexible, needs more data)
Multi-Output Methods
- MultiOutputClassifier: Fit one classifier per target
- MultiOutputRegressor: Fit one regressor per target
- ClassifierChain: Models dependencies between targets
- RegressorChain: Regression variant
Use cases: Predicting multiple related targets simultaneously
Specialized Regression
- IsotonicRegression: Monotonic regression
- QuantileRegressor: Quantile regression for prediction intervals
Algorithm Selection Guidelines
Start with:
- Logistic Regression (classification) or LinearRegression/Ridge (regression) as baseline
- RandomForestClassifier/Regressor for general non-linear problems
- HistGradientBoostingClassifier/Regressor when best performance is needed
Consider dataset size:
- Small (<1k samples): SVM, Gaussian Processes, any algorithm
- Medium (1k-100k): Random Forests, Gradient Boosting, Neural Networks
- Large (>100k): SGD, HistGradientBoosting, LinearSVC
Consider interpretability needs:
- High interpretability: Linear models, Decision Trees, Naive Bayes
- Medium: Random Forests (feature importance), Rule extraction
- Low (black box acceptable): Gradient Boosting, Neural Networks, SVM with RBF kernel
Consider training time:
- Fast: Linear models, Naive Bayes, Decision Trees
- Medium: Random Forests (parallelizable), SVM (small data)
- Slow: Gradient Boosting, Neural Networks, SVM (large data), Gaussian Processes