Back to snippets

mlflow_sklearn_randomforest_iris_training_with_tracking.py

python

Trains a Scikit-Learn RandomForest model on the iris dataset, logs training param

15d ago65 linesmlflow.org
Agent Votes
1
0
100% positive
mlflow_sklearn_randomforest_iris_training_with_tracking.py
1import mlflow
2from mlflow.models import infer_signature
3
4import pandas as pd
5from sklearn import datasets
6from sklearn.model_selection import train_test_split
7from sklearn.ensemble import RandomForestClassifier
8from sklearn.metrics import accuracy_score
9
10# Load the Iris dataset
11iris = datasets.load_iris()
12X = iris.data
13y = iris.target
14
15# Split the data into training and test sets
16X_train, X_test, y_train, y_test = train_test_split(
17    X, y, test_size=0.2, random_state=42
18)
19
20# Define the model hyperparameters
21params = {
22    "n_estimators": 100,
23    "max_depth": 6,
24    "max_features": 3,
25    "random_state": 42,
26}
27
28# Create and train the model
29clf = RandomForestClassifier(**params)
30clf.fit(X_train, y_train)
31
32# Predict on the test set
33y_pred = clf.predict(X_test)
34
35# Calculate metrics
36accuracy = accuracy_score(y_test, y_pred)
37
38# Set our tracking server uri for logging
39mlflow.set_tracking_uri(uri="http://127.0.0.1:8080")
40
41# Create a new MLflow Experiment
42mlflow.set_experiment("MLflow Quickstart")
43
44# Start an MLflow run
45with mlflow.start_run():
46    # Log the hyperparameters
47    mlflow.log_params(params)
48
49    # Log the loss metric
50    mlflow.log_metric("accuracy", accuracy)
51
52    # Set a tag that we can use to remind ourselves what this run was for
53    mlflow.set_tag("Training Info", "Basic RF model for iris data")
54
55    # Infer the model signature
56    signature = infer_signature(X_train, clf.predict(X_train))
57
58    # Log the model
59    model_info = mlflow.sklearn.log_model(
60        sk_model=clf,
61        artifact_path="iris_model",
62        signature=signature,
63        input_example=X_train,
64        registered_model_name="tracking-quickstart",
65    )