Skip to main content
Kinematrix provides two powerful machine learning algorithms optimized for embedded systems: Decision Trees for interpretable models with mixed data types, and K-Nearest Neighbors for real-time classification.

ML Algorithms

Decision Tree

Information gain-based classifier with pruning and mixed data type support

K-Nearest Neighbors

Distance-based classification with multiple metrics and cross-validation

Decision Tree Classifier

Decision Trees create interpretable models by learning decision rules from training data. The Kinematrix implementation supports mixed data types, pruning, and cross-validation.

Basic Usage

#define ENABLE_MODULE_DECISION_TREE
#include "Kinematrix.h"

// DecisionTree(max_features, max_samples, max_depth, min_samples_split, min_samples_leaf)
DecisionTree tree(4, 50, 10, 2, 1);

void setup() {
  Serial.begin(115200);
  
  // Add training data for Iris classification
  // Format: (class_label, {feature1, feature2, feature3, feature4})
  
  // Iris setosa
  tree.addTrainingData("setosa", (float[]){5.1, 3.5, 1.4, 0.2});
  tree.addTrainingData("setosa", (float[]){4.9, 3.0, 1.4, 0.2});
  tree.addTrainingData("setosa", (float[]){4.7, 3.2, 1.3, 0.2});
  
  // Iris versicolor
  tree.addTrainingData("versicolor", (float[]){7.0, 3.2, 4.7, 1.4});
  tree.addTrainingData("versicolor", (float[]){6.4, 3.2, 4.5, 1.5});
  tree.addTrainingData("versicolor", (float[]){6.9, 3.1, 4.9, 1.5});
  
  // Iris virginica
  tree.addTrainingData("virginica", (float[]){6.3, 3.3, 6.0, 2.5});
  tree.addTrainingData("virginica", (float[]){5.8, 2.7, 5.1, 1.9});
  tree.addTrainingData("virginica", (float[]){7.1, 3.0, 5.9, 2.1});
  
  Serial.print("Training samples: ");
  Serial.println(tree.getSampleCount());
  
  // Train the model
  if (tree.train(GINI, COST_COMPLEXITY)) {
    Serial.println("Training successful!");
    tree.printTreeStructure();
  } else {
    Serial.print("Training failed: ");
    Serial.println(tree.getErrorMessage());
  }
}

void loop() {
  // Make predictions
  float sample[] = {5.0, 3.4, 1.5, 0.2};
  const char* prediction = tree.predictClass(sample);
  float confidence = tree.getPredictionConfidence(sample);
  
  Serial.print("Prediction: ");
  Serial.print(prediction);
  Serial.print(" (confidence: ");
  Serial.print(confidence * 100);
  Serial.println("%)");
  
  delay(5000);
}

Mixed Data Types

Decision Trees can handle numeric, categorical, ordinal, and binary features:
void setupMixedData() {
  DecisionTree tree(5, 100);
  
  // Configure feature types
  tree.setFeatureType(0, NUMERIC);       // Temperature
  tree.setFeatureType(1, CATEGORICAL);   // Weather type
  tree.setFeatureType(2, ORDINAL);       // Wind level (1-5)
  tree.setFeatureType(3, BINARY);        // Raining (yes/no)
  tree.setFeatureType(4, NUMERIC);       // Humidity
  
  // Set feature names
  tree.setFeatureName(0, "Temperature");
  tree.setFeatureName(1, "Weather");
  tree.setFeatureName(2, "Wind");
  tree.setFeatureName(3, "Rain");
  tree.setFeatureName(4, "Humidity");
  
  // Add categorical values
  tree.addCategoricalValue(1, "sunny");
  tree.addCategoricalValue(1, "cloudy");
  tree.addCategoricalValue(1, "overcast");
  
  // Add training data with mixed types
  FeatureValue features[5];
  features[0] = FeatureValue(25.5f);      // Numeric
  features[1] = FeatureValue("sunny");    // Categorical
  features[2] = FeatureValue(3, true);    // Ordinal
  features[3] = FeatureValue(false);      // Binary
  features[4] = FeatureValue(60.0f);      // Numeric
  
  tree.addTrainingData(features, "play");
}

Training Options

// Split criteria
tree.train(GINI);          // Gini impurity (classification)
tree.train(ENTROPY);       // Information gain (classification)
tree.train(MSE);           // Mean squared error (regression)
tree.train(MAE);           // Mean absolute error (regression)

// Pruning methods
tree.train(GINI, NO_PRUNING);          // No pruning
tree.train(GINI, COST_COMPLEXITY);     // CART α-pruning
tree.train(GINI, REDUCED_ERROR);       // Reduced error pruning

Model Analysis

void analyzeModel() {
  // Print tree structure
  tree.printTreeStructure();
  
  // Print tree statistics
  tree.printTreeStatistics();
  
  // Print model summary
  tree.printModelSummary();
  
  // Get feature importance
  float* importance = tree.getFeatureImportance();
  Serial.println("Feature Importance:");
  for (int i = 0; i < 4; i++) {
    Serial.print("Feature ");
    Serial.print(i);
    Serial.print(": ");
    Serial.println(importance[i]);
  }
}

Cross-Validation

void validateModel() {
  // Perform 5-fold cross-validation
  float accuracy = tree.crossValidate(5);
  
  Serial.print("Cross-validation accuracy: ");
  Serial.print(accuracy * 100);
  Serial.println("%");
}

Model Persistence (ESP32)

#ifdef ESP32
void saveAndLoadModel() {
  // Save trained model
  if (tree.saveModel("/decision_tree.dat")) {
    Serial.println("Model saved to SPIFFS");
  }
  
  // Load model
  DecisionTree loadedTree(4, 50);
  if (loadedTree.loadModel("/decision_tree.dat")) {
    Serial.println("Model loaded successfully");
    
    // Use loaded model
    float sample[] = {5.0, 3.4, 1.5, 0.2};
    const char* pred = loadedTree.predictClass(sample);
  }
}
#endif

K-Nearest Neighbors (KNN)

KNN is a simple yet powerful classification algorithm that finds the K closest training samples and uses majority voting to classify new data.

Basic Usage

#define ENABLE_MODULE_KNN
#include "Kinematrix.h"

// KNN(k_neighbors, max_features, max_training_samples)
KNN knn(3, 4, 50);

void setup() {
  Serial.begin(115200);
  
  // Add training data
  // Iris setosa
  knn.addTrainingData("setosa", (float[]){5.1, 3.5, 1.4, 0.2});
  knn.addTrainingData("setosa", (float[]){4.9, 3.0, 1.4, 0.2});
  knn.addTrainingData("setosa", (float[]){4.7, 3.2, 1.3, 0.2});
  
  // Iris versicolor
  knn.addTrainingData("versicolor", (float[]){7.0, 3.2, 4.7, 1.4});
  knn.addTrainingData("versicolor", (float[]){6.4, 3.2, 4.5, 1.5});
  knn.addTrainingData("versicolor", (float[]){6.9, 3.1, 4.9, 1.5});
  
  // Iris virginica
  knn.addTrainingData("virginica", (float[]){6.3, 3.3, 6.0, 2.5});
  knn.addTrainingData("virginica", (float[]){5.8, 2.7, 5.1, 1.9});
  knn.addTrainingData("virginica", (float[]){7.1, 3.0, 5.9, 2.1});
  
  Serial.print("Training samples: ");
  Serial.println(knn.getDataCount());
  
  // Calculate feature ranges for normalization
  knn.calculateFeatureRanges();
}

void loop() {
  float sample[] = {5.0, 3.4, 1.5, 0.2};
  
  const char* prediction = knn.predict(sample);
  float confidence = knn.getPredictionConfidence(sample);
  
  Serial.print("Predicted: ");
  Serial.print(prediction);
  Serial.print(" (confidence: ");
  Serial.print(confidence * 100);
  Serial.println("%)");
  
  // Get nearest neighbors
  int indices[3];
  float distances[3];
  knn.getNearestNeighbors(sample, indices, distances, 3);
  
  Serial.println("Nearest neighbors:");
  for (int i = 0; i < 3; i++) {
    Serial.print("  ");
    Serial.print(i + 1);
    Serial.print(". Distance: ");
    Serial.println(distances[i]);
  }
  
  delay(5000);
}

Distance Metrics

void setupDistanceMetrics() {
  KNN knn(3, 4, 50);
  
  // Euclidean distance (default) - L2 norm
  knn.setDistanceMetric(EUCLIDEAN);
  
  // Manhattan distance - L1 norm, good for grid-like data
  knn.setDistanceMetric(MANHATTAN);
  
  // Cosine distance - good for text/vector similarity
  knn.setDistanceMetric(COSINE);
}

Advanced Features

void setupAdvancedKNN() {
  KNN knn(5, 4, 100);
  
  // Enable weighted voting (closer neighbors have more influence)
  knn.setWeightedVoting(true);
  
  // Enable feature normalization (recommended)
  knn.enableNormalization(true);
  
  // Enable low memory mode (slower but uses less RAM)
  knn.setLowMemoryMode(true);
  
  // Enable debug output
  knn.setDebugMode(true);
  
  // Add training data...
  
  // Calculate feature ranges for normalization
  knn.calculateFeatureRanges();
}

Cross-Validation

void validateKNN() {
  // Perform 5-fold cross-validation
  float accuracy = knn.crossValidate(5);
  
  Serial.print("5-fold cross-validation accuracy: ");
  Serial.print(accuracy * 100);
  Serial.println("%");
}

Training Data Management

void manageTrainingData() {
  // Get total sample count
  int total = knn.getDataCount();
  
  // Get count by label
  int setosaCount = knn.getDataCountByLabel("setosa");
  
  // Remove specific sample
  knn.removeTrainingData(5);  // Remove sample at index 5
  
  // Clear all training data
  knn.clearTrainingData();
}

Complete Example: Gesture Recognition

#define ENABLE_MODULE_KNN
#include "Kinematrix.h"

const int ACCEL_X_PIN = 34;
const int ACCEL_Y_PIN = 35;
const int ACCEL_Z_PIN = 32;
const int BUTTON_PIN = 14;

// K=5, 3 features (x,y,z), max 30 gestures
KNN gestureClassifier(5, 3, 30);

bool trainingMode = false;
String currentGesture = "";

void setup() {
  Serial.begin(115200);
  pinMode(BUTTON_PIN, INPUT_PULLUP);
  
  // Configure KNN
  gestureClassifier.setDistanceMetric(EUCLIDEAN);
  gestureClassifier.enableNormalization(true);
  gestureClassifier.setWeightedVoting(true);
  
  // Load existing model or train new
  #ifdef ESP32
  if (!gestureClassifier.loadModel("/gestures.dat")) {
    Serial.println("No saved model, starting fresh");
  }
  #endif
  
  Serial.println("Gesture Recognition System");
  Serial.println("Commands:");
  Serial.println("  'train <gesture_name>' - Enter training mode");
  Serial.println("  'save' - Save model to SPIFFS");
  Serial.println("  'test' - Test recognition");
}

void loop() {
  // Check for serial commands
  if (Serial.available()) {
    String command = Serial.readStringUntil('\n');
    command.trim();
    
    if (command.startsWith("train ")) {
      currentGesture = command.substring(6);
      trainingMode = true;
      Serial.print("Training mode for gesture: ");
      Serial.println(currentGesture);
      Serial.println("Press button to capture samples");
    }
    else if (command == "save") {
      #ifdef ESP32
      if (gestureClassifier.saveModel("/gestures.dat")) {
        Serial.println("Model saved!");
      }
      #endif
    }
    else if (command == "test") {
      trainingMode = false;
      Serial.println("Recognition mode");
    }
  }
  
  // Button pressed - capture sample
  if (digitalRead(BUTTON_PIN) == LOW) {
    delay(50);  // Debounce
    
    float features[3];
    captureGestureSample(features);
    
    if (trainingMode) {
      // Add training sample
      gestureClassifier.addTrainingData(currentGesture.c_str(), features);
      Serial.print("Sample added. Total: ");
      Serial.println(gestureClassifier.getDataCount());
      
      // Recalculate ranges after adding data
      gestureClassifier.calculateFeatureRanges();
    }
    else {
      // Recognize gesture
      const char* prediction = gestureClassifier.predict(features);
      float confidence = gestureClassifier.getPredictionConfidence(features);
      
      Serial.print("Recognized: ");
      Serial.print(prediction);
      Serial.print(" (");
      Serial.print(confidence * 100);
      Serial.println("% confidence)");
    }
    
    // Wait for button release
    while (digitalRead(BUTTON_PIN) == LOW) delay(10);
  }
  
  delay(10);
}

void captureGestureSample(float* features) {
  // Capture accelerometer data
  // Average multiple readings for stability
  float xSum = 0, ySum = 0, zSum = 0;
  const int samples = 10;
  
  for (int i = 0; i < samples; i++) {
    xSum += analogRead(ACCEL_X_PIN);
    ySum += analogRead(ACCEL_Y_PIN);
    zSum += analogRead(ACCEL_Z_PIN);
    delay(10);
  }
  
  features[0] = xSum / samples;
  features[1] = ySum / samples;
  features[2] = zSum / samples;
  
  Serial.print("Captured: [");
  Serial.print(features[0]);
  Serial.print(", ");
  Serial.print(features[1]);
  Serial.print(", ");
  Serial.print(features[2]);
  Serial.println("]");
}

API Reference

Decision Tree

Constructor

DecisionTree(int maxFeatures, int maxSamples, int maxDepth = 10,
             int minSamplesSplit = 2, int minSamplesLeaf = 1)

Training

MethodDescription
addTrainingData(label, features)Add classification sample
addTrainingData(value, features)Add regression sample
train(criterion, pruning)Train the model
clearTrainingData()Remove all samples
getSampleCount()Get total sample count

Prediction

MethodDescription
predictClass(features)Predict class label
predictRegression(features)Predict numeric value
getPredictionConfidence(features)Get confidence score
getClassProbabilities(features)Get probability distribution

Analysis

MethodDescription
getFeatureImportance()Get importance scores
printTreeStructure()Visualize tree
printTreeStatistics()Print tree metrics
printModelSummary()Print model overview

Validation

MethodDescription
crossValidate(folds)K-fold cross-validation
evaluateAccuracy(testSamples, count)Test set accuracy

K-Nearest Neighbors

Constructor

KNN(int k, int maxFeatures, int maxData)

Configuration

MethodDescription
setDistanceMetric(metric)Set EUCLIDEAN, MANHATTAN, or COSINE
setWeightedVoting(weighted)Enable distance-weighted voting
enableNormalization(enable)Enable feature normalization
setLowMemoryMode(enable)Trade speed for memory
setDebugMode(enable)Enable debug output

Training

MethodDescription
addTrainingData(label, features)Add training sample
calculateFeatureRanges()Compute normalization ranges
clearTrainingData()Remove all samples
removeTrainingData(index)Remove specific sample
getDataCount()Get total sample count
getDataCountByLabel(label)Get count for specific class

Prediction

MethodDescription
predict(features)Predict class label
getPredictionConfidence(features)Get confidence score
getNearestNeighbors(features, indices, distances, count)Get K nearest

Validation

MethodDescription
crossValidate(folds)K-fold cross-validation
evaluateAccuracy(testFeatures, testLabels, count)Test accuracy
Choosing K for KNN:
  • Small K (1-3): More sensitive to noise, better for complex boundaries
  • Medium K (5-7): Good balance for most applications
  • Large K (10+): Smoother boundaries, more robust to noise
Rule of thumb: K ≈ √(number of samples)
Memory Considerations:
  • Decision Trees grow with training data and tree depth
  • KNN stores all training samples in RAM
  • Enable low memory mode on resource-constrained devices
  • Use model persistence on ESP32 to avoid retraining

Build docs developers (and LLMs) love