Skip to main content

Class Signature

class StratifiedKFold {
  constructor(options?: StratifiedKFoldOptions)
  split<TX>(X: TX[], y: number[]): FoldIndices[]
}

Constructor

options
StratifiedKFoldOptions
Configuration options for Stratified K-Fold cross-validation

Methods

split

Generate train/test indices to split data into k stratified folds.
split<TX>(X: TX[], y: number[]): FoldIndices[]
X
TX[]
required
Feature array to split
y
number[]
required
Target class labels. Each fold will preserve the percentage of samples for each class.
Returns: Array of FoldIndices objects, each containing:
  • trainIndices: number[] - Indices for the training set
  • testIndices: number[] - Indices for the test set

Description

StratifiedKFold is a variation of K-Fold that returns stratified folds: each fold contains approximately the same percentage of samples of each target class as the complete set. This is particularly useful for:
  • Imbalanced classification problems - ensures each fold has representative samples from all classes
  • Small datasets - maximizes the use of minority class samples
  • Fair model evaluation - prevents biased performance metrics

Example

import { StratifiedKFold } from 'bun-scikit';
import { LogisticRegression } from 'bun-scikit';

// Imbalanced dataset: 80% class 0, 20% class 1
const X = [
  [1, 2], [2, 3], [3, 4], [4, 5], [5, 6],
  [6, 7], [7, 8], [8, 9], [9, 10], [10, 11]
];
const y = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1];

// Create stratified 5-fold cross-validator
const skf = new StratifiedKFold({ nSplits: 5, shuffle: true, randomState: 42 });
const folds = skf.split(X, y);

// Each fold maintains ~80/20 class distribution
for (let i = 0; i < folds.length; i++) {
  const fold = folds[i];
  const yTest = fold.testIndices.map(idx => y[idx]);
  const class0Count = yTest.filter(label => label === 0).length;
  const class1Count = yTest.filter(label => label === 1).length;
  
  console.log(`Fold ${i + 1}: ${class0Count} class 0, ${class1Count} class 1`);
}

// Perform stratified cross-validation
const scores: number[] = [];

for (const fold of folds) {
  const XTrain = fold.trainIndices.map(i => X[i]);
  const yTrain = fold.trainIndices.map(i => y[i]);
  const XTest = fold.testIndices.map(i => X[i]);
  const yTest = fold.testIndices.map(i => y[i]);

  const model = new LogisticRegression();
  model.fit(XTrain, yTrain);
  const score = model.score(XTest, yTest);
  scores.push(score);
}

const avgScore = scores.reduce((a, b) => a + b) / scores.length;
console.log('Average accuracy:', avgScore);

Comparison with KFold

import { KFold, StratifiedKFold } from 'bun-scikit';

const X = [[1], [2], [3], [4], [5], [6], [7], [8]];
const y = [0, 0, 0, 0, 1, 1, 1, 1]; // 50/50 balanced

// Regular KFold - may create imbalanced folds
const kf = new KFold({ nSplits: 4 });
const regularFolds = kf.split(X, y);

// StratifiedKFold - guarantees balanced folds
const skf = new StratifiedKFold({ nSplits: 4 });
const stratifiedFolds = skf.split(X, y);

// Each stratified fold will have exactly 1 sample from each class
// Regular folds might have 2 from one class, 0 from another

Multi-class Example

import { StratifiedKFold } from 'bun-scikit';

// Multi-class problem with 3 classes
const X = [
  [1], [2], [3], [4], [5], [6],
  [7], [8], [9], [10], [11], [12]
];
const y = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];

const skf = new StratifiedKFold({ nSplits: 3, shuffle: true });
const folds = skf.split(X, y);

// Each fold will have equal representation from all 3 classes
folds.forEach((fold, idx) => {
  const yTest = fold.testIndices.map(i => y[i]);
  console.log(`Fold ${idx + 1} test set:`, yTest);
  // Output: approximately [0, 0, 1, 1, 2, 2] for each fold
});

Notes

  • Requires at least 2 distinct classes in y
  • nSplits cannot exceed the number of samples in the smallest class
  • The target array y must contain integer class labels
  • When shuffle=true, samples within each class are shuffled before distribution to folds
  • For regression problems or when class distribution is not important, use KFold instead
  • This is the recommended cross-validator for classification tasks

Build docs developers (and LLMs) love