Back to snippets
pytorch_forecasting_temporal_fusion_transformer_stallion_demand.py
pythonThis quickstart demonstrates how to train a Temporal Fusion Transfor
Agent Votes
1
0
100% positive
pytorch_forecasting_temporal_fusion_transformer_stallion_demand.py
1import lightning.pytorch as pl
2from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
3from lightning.pytorch.loggers import TensorBoardLogger
4import numpy as np
5import pandas as pd
6import torch
7
8from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
9from pytorch_forecasting.data import GroupNormalizer
10from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
11from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
12from pytorch_forecasting.data.examples import get_stallion_data
13
14# 1. Load data
15data = get_stallion_data()
16
17# add time index and other features
18data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
19data["time_idx"] -= data["time_idx"].min()
20data["month"] = data.date.dt.month.astype(str).astype("category")
21
22# 2. Define dataset and dataloaders
23max_prediction_length = 6
24max_encoder_length = 24
25training_cutoff = data["time_idx"].max() - max_prediction_length
26
27training = TimeSeriesDataSet(
28 data[lambda x: x.time_idx <= training_cutoff],
29 time_idx="time_idx",
30 target="volume",
31 group_ids=["agency", "sku"],
32 min_encoder_length=max_encoder_length // 2,
33 max_encoder_length=max_encoder_length,
34 min_prediction_length=1,
35 max_prediction_length=max_prediction_length,
36 static_categoricals=["agency", "sku"],
37 static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
38 time_varying_known_categoricals=["month"],
39 time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
40 time_varying_unknown_categoricals=[],
41 time_varying_unknown_reals=["volume", "log_volume", "industry_volume", "soda_volume"],
42 target_normalizer=GroupNormalizer(groups=["agency", "sku"], transformation="softplus"),
43 add_relative_time_idx=True,
44 add_target_scales=True,
45 add_encoder_length=True,
46)
47
48validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
49
50batch_size = 128
51train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
52val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
53
54# 3. Create model
55pl.seed_everything(42)
56trainer = pl.Trainer(accelerator="cpu", gradient_clip_val=0.1)
57
58tft = TemporalFusionTransformer.from_dataset(
59 training,
60 learning_rate=0.03,
61 hidden_size=16,
62 attention_head_size=2,
63 dropout=0.1,
64 hidden_continuous_size=8,
65 loss=QuantileLoss(),
66 optimizer="Ranger"
67)
68
69# 4. Train model
70trainer.fit(
71 tft,
72 train_dataloaders=train_dataloader,
73 val_dataloaders=val_dataloader,
74)
75
76# 5. Make predictions
77raw_predictions = tft.predict(val_dataloader, mode="raw", return_x=True)
78tft.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True)