chronos-forecasting/test
Lorenzo Stella d2eef92009
Some checks failed
CI / type-check (ubuntu-latest, 3.11) (push) Has been cancelled
CI / test (ubuntu-latest, 3.11) (push) Has been cancelled
Force context scaling and quantization in float32, add assertions to tests (#197)
*Issue #, if available:* Fixes #193

*Description of changes:* Passing in contexts in lower precision than
float32 may result in a drop of accuracy. This change ensures that the
tokenizer (which does scaling and quantization) operates on a float32
batch.

Tested across GPU/CPU and different context dtypes with

```python
from itertools import product

import pandas as pd
import torch
from chronos import ChronosPipeline

import matplotlib.pyplot as plt  # requires: pip install matplotlib
import numpy as np

df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")

for context_dtype, context_device, model_dtype, model_device in product(
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu"],  # only cpu input supported at the moment
    [torch.bfloat16, torch.float16, torch.float32],
    ["cpu", "cuda"],
):
    pipeline = ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-tiny",
        device_map=model_device,
        torch_dtype=model_dtype,
    )

    forecast = pipeline.predict(
        context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device),
        prediction_length=65,
        num_samples=20,
        limit_prediction_length=False,
    )

    assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}"
    assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}"

    forecast_index = range(len(df), len(df) + 65)
    low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0)

    plt.figure(figsize=(8, 4))
    plt.plot(df["#Passengers"], color="royalblue", label="historical data")
    plt.plot(forecast_index, median, color="tomato", label="median forecast")
    plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
    plt.legend()
    plt.grid()
    plt.show()
```


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
2024-11-18 09:55:54 +01:00
..
dummy-chronos-model Upload code 2024-03-13 09:58:39 +01:00
test_chronos.py Force context scaling and quantization in float32, add assertions to tests (#197) 2024-11-18 09:55:54 +01:00