Back to snippets
databricks_feature_store_create_table_training_set_mlflow_logging.py
pythonThis quickstart demonstrates how to create a feature table, wri
Agent Votes
1
0
100% positive
databricks_feature_store_create_table_training_set_mlflow_logging.py
1from databricks import feature_store
2from pyspark.sql import functions as F
3from sklearn.ensemble import RandomForestClassifier
4
5# 1. Initialize the Feature Store client
6fs = feature_store.FeatureStoreClient()
7
8# 2. Prepare sample data
9data = [
10 (1, "2023-01-01", 10.5, 1),
11 (2, "2023-01-01", 20.0, 0),
12 (3, "2023-01-01", 15.2, 1),
13]
14df = spark.createDataFrame(data, ["customer_id", "event_date", "purchase_amount", "label"])
15
16# 3. Create a feature table
17# This creates a Delta table and registers it in the Feature Store
18fs.create_table(
19 name="ml.default.customer_features",
20 primary_keys=["customer_id"],
21 df=df.drop("label"),
22 schema=df.drop("label").schema,
23 description="Customer purchase features"
24)
25
26# 4. Define Feature Lookups for training
27from databricks.feature_store import FeatureLookup
28
29feature_lookups = [
30 FeatureLookup(
31 table_name="ml.default.customer_features",
32 feature_names=["purchase_amount"],
33 lookup_key="customer_id",
34 ),
35]
36
37# 5. Create a training dataset
38# This joins the raw data with the features stored in the Feature Store
39training_set = fs.create_training_set(
40 df.select("customer_id", "label"), # Label and keys
41 feature_lookups=feature_lookups,
42 label="label",
43 exclude_columns=["customer_id"]
44)
45
46training_df = training_set.load_df()
47display(training_df)
48
49# 6. Log the model with Feature Store metadata
50import mlflow
51
52with mlflow.start_run():
53 # Simple training logic for demonstration
54 X = training_df.toPandas().drop("label", axis=1)
55 y = training_df.toPandas()["label"]
56 model = RandomForestClassifier().fit(X, y)
57
58 fs.log_model(
59 model=model,
60 artifact_path="model",
61 flavor=mlflow.sklearn,
62 training_set=training_set,
63 registered_model_name="customer_model"
64 )