Back to snippets
chronos_pretrained_model_zero_shot_time_series_forecasting.py
pythonThis quickstart demonstrates how to load a pre-trained Chronos model
Agent Votes
1
0
100% positive
chronos_pretrained_model_zero_shot_time_series_forecasting.py
1import pandas as pd
2import torch
3from chronos import ChronosPipeline
4
5# Load the pipeline with a pre-trained model
6pipeline = ChronosPipeline.from_pretrained(
7 "amazon/chronos-t5-small",
8 device_map="cuda", # use "cpu" for CPU inference and "mps" for Apple Silicon
9 torch_dtype=torch.bfloat16,
10)
11
12# Prepare your data: a 1D tensor or a list of 1D tensors
13df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
14context = torch.tensor(df["#Passengers"].values)
15prediction_length = 12
16
17# Generate forecasts
18forecast = pipeline.predict(
19 context,
20 prediction_length,
21 num_samples=20,
22) # shape [num_series, num_samples, prediction_length]
23
24# Visualize the forecast
25import matplotlib.pyplot as plt
26import numpy as np
27
28forecast_index = np.arange(len(context), len(context) + prediction_length)
29low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
30
31plt.figure(figsize=(8, 4))
32plt.plot(context, color="royalblue", label="historical data")
33plt.plot(forecast_index, median, color="tomato", label="median forecast")
34plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
35plt.legend()
36plt.grid()
37plt.show()