Skip to main content
train_model_teamwins() reads the team_game_stats table, fits a RandomForestClassifier, prints test accuracy, and saves the model to disk. You must run this before making any predictions.
The team_game_stats table must exist before training. Run generate_features_teamwins() first to populate it.

How to train

1

Ensure the feature table is ready

Run the data pipeline to fetch games and generate rolling-window features:
from Data.fetch_games import fetch_games_teamwins
from Data.generate_features import generate_features_teamwins

fetch_games_teamwins("2025-26")
generate_features_teamwins()
2

Call train_model_teamwins()

Import and call the training function:
from prediction_ai import train_model_teamwins

train_model_teamwins()
# Test Accuracy: 0.XXX  (value depends on season data)
The function prints test accuracy to stdout and writes the model file to models/nba_model.pkl.

Model architecture

The classifier is a RandomForestClassifier from scikit-learn:
model = RandomForestClassifier(n_estimators=200, random_state=0)
ParameterValueDescription
n_estimators200Number of decision trees in the ensemble
random_state0Seed for reproducibility

Train/test split

The dataset is split before fitting:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)
ParameterValueDescription
test_size0.330 % of rows held out for evaluation
random_state42Seed for reproducibility

Input features

The model uses six numeric columns from team_game_stats:
FeatureTypeDescription
points_difffloatRolling average points differential (team minus opponent)
team_reb_rollfloatTeam’s rolling average rebounds
opponent_reb_rollfloatOpponent’s rolling average rebounds
team_ast_rollfloatTeam’s rolling average assists
opponent_ast_rollfloatOpponent’s rolling average assists
homeint1 if the team is at home, 0 if away

Target variable

The target column is win — a binary integer (1 = team won, 0 = team lost).
X = df[['points_diff', 'team_reb_roll', 'opponent_reb_roll',
        'team_ast_roll', 'opponent_ast_roll', 'home']]
y = df['win']

Model persistence

After fitting, the model is serialized with pickle to models/nba_model.pkl relative to prediction_ai.py:
os.makedirs(os.path.dirname(MODELPATH), exist_ok=True)
pickle.dump(model, open(MODELPATH, "wb"))
The models/ directory is created automatically if it does not exist.
Calling train_model_teamwins() again overwrites the existing models/nba_model.pkl file. There is no automatic backup.

Build docs developers (and LLMs) love