#define ENABLE_MODULE_DECISION_TREE
#include "Kinematrix.h"
// DecisionTree(max_features, max_samples, max_depth, min_samples_split, min_samples_leaf)
DecisionTree tree(4, 50, 10, 2, 1);
void setup() {
Serial.begin(115200);
// Add training data for Iris classification
// Format: (class_label, {feature1, feature2, feature3, feature4})
// Iris setosa
tree.addTrainingData("setosa", (float[]){5.1, 3.5, 1.4, 0.2});
tree.addTrainingData("setosa", (float[]){4.9, 3.0, 1.4, 0.2});
tree.addTrainingData("setosa", (float[]){4.7, 3.2, 1.3, 0.2});
// Iris versicolor
tree.addTrainingData("versicolor", (float[]){7.0, 3.2, 4.7, 1.4});
tree.addTrainingData("versicolor", (float[]){6.4, 3.2, 4.5, 1.5});
tree.addTrainingData("versicolor", (float[]){6.9, 3.1, 4.9, 1.5});
// Iris virginica
tree.addTrainingData("virginica", (float[]){6.3, 3.3, 6.0, 2.5});
tree.addTrainingData("virginica", (float[]){5.8, 2.7, 5.1, 1.9});
tree.addTrainingData("virginica", (float[]){7.1, 3.0, 5.9, 2.1});
Serial.print("Training samples: ");
Serial.println(tree.getSampleCount());
// Train the model
if (tree.train(GINI, COST_COMPLEXITY)) {
Serial.println("Training successful!");
tree.printTreeStructure();
} else {
Serial.print("Training failed: ");
Serial.println(tree.getErrorMessage());
}
}
void loop() {
// Make predictions
float sample[] = {5.0, 3.4, 1.5, 0.2};
const char* prediction = tree.predictClass(sample);
float confidence = tree.getPredictionConfidence(sample);
Serial.print("Prediction: ");
Serial.print(prediction);
Serial.print(" (confidence: ");
Serial.print(confidence * 100);
Serial.println("%)");
delay(5000);
}