diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..0c8cba5 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,25 @@ +name: CI + +on: [push, pull_request] + +jobs: + test: + strategy: + max-parallel: 4 + fail-fast: false + matrix: + python-version: ['3.11'] + platform: [ubuntu-latest] + + runs-on: ${{ matrix.platform }} + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: pip install ".[test]" + - name: Test with pytest + run: pytest diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8cf79d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# macOS stuff +.DS_store \ No newline at end of file diff --git a/README.md b/README.md index 847260c..cc45129 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,87 @@ -## My Project +# Chronos: Learning the Language of Time Series -TODO: Fill this README out! +Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes. -Be sure to: +For details on Chronos models, training data and procedures, and experimental results, please refer to the paper [Chronos: Learning the Language of Time Series](https://arxiv.org/abs/2403.07815). -* Change the title in this README -* Edit your repository description on GitHub +

+ +
+ + Fig. 1: High-level depiction of Chronos. (Left) The input time series is scaled and quantized to obtain a sequence of tokens. (Center) The tokens are fed into a language model which may either be an encoder-decoder or a decoder-only model. The model is trained using the cross-entropy loss. (Right) During inference, we autoregressively sample tokens from the model and map them back to numerical values. Multiple trajectories are sampled to obtain a predictive distribution. + +

+ +--- + +## Architecture + +The models in this repository are based on the [T5 architecture](https://arxiv.org/abs/1910.10683). The only difference is in the vocabulary size: Chronos-T5 models use 4096 different tokens, compared to 32128 of the original T5 models, resulting in fewer parameters. + +| Model | Parameters | Based on | +| ---------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------- | +| [**chronos-t5-tiny**](https://huggingface.co/amazon/chronos-t5-tiny) | 8M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) | +| [**chronos-t5-mini**](https://huggingface.co/amazon/chronos-t5-mini) | 20M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) | +| [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) | +| [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) | +| [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) | + +## Usage + +To perform inference with Chronos models, install this package by running: + +``` +pip install git+https://github.com/amazon-science/chronos-forecasting.git +``` + +A minimal example showing how to perform inference using Chronos models: + +```python +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from chronos import ChronosPipeline + +pipeline = ChronosPipeline.from_pretrained( + "amazon/chronos-t5-small", + device_map="cuda", + torch_dtype=torch.bfloat16, +) + +df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") + +# context must be either a 1D tensor, a list of 1D tensors, +# or a left-padded 2D tensor with batch as the first dimension +context = torch.tensor(df["#Passengers"]) +prediction_length = 12 +forecast = pipeline.predict(context, prediction_length) # shape [num_series, num_samples, prediction_length] + +# visualize the forecast +forecast_index = range(len(df), len(df) + prediction_length) +low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) + +plt.figure(figsize=(8, 4)) +plt.plot(df["#Passengers"], color="royalblue", label="historical data") +plt.plot(forecast_index, median, color="tomato", label="median forecast") +plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") +plt.legend() +plt.grid() +plt.show() +``` + +## Citation + +If you find Chronos models useful for your research, please consider citing the associated [paper](https://arxiv.org/abs/2403.07815): + +``` +@article{ansari2024chronos, + author = {Ansari, Abdul Fatir and Stella, Lorenzo and Turkmen, Caner and Zhang, Xiyuan, and Mercado, Pedro and Shen, Huibin and Shchur, Oleksandr and Rangapuram, Syama Syndar and Pineda Arango, Sebastian and Kapoor, Shubham and Zschiegner, Jasper and Maddix, Danielle C. and Mahoney, Michael W. and Torkkola, Kari and Gordon Wilson, Andrew and Bohlke-Schneider, Michael and Wang, Yuyang}, + title = {Chronos: Learning the Language of Time Series}, + journal = {arXiv preprint arXiv:2403.07815}, + year = {2024} +} +``` ## Security @@ -14,4 +90,3 @@ See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more inform ## License This project is licensed under the Apache-2.0 License. - diff --git a/figures/main-figure.png b/figures/main-figure.png new file mode 100644 index 0000000..329b890 Binary files /dev/null and b/figures/main-figure.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c268579 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "chronos" +version = "1.0.0" +requires-python = ">=3.8" +license = {file = "LICENSE"} +dependencies = [ + "torch~=2.1", # package was tested on 2.2 + "transformers~=4.31", + "accelerate" +] + +[project.optional-dependencies] +test = [ + "pytest~=8.0", + "numpy~=1.21" +] + +[tool.mypy] +ignore_missing_imports = true diff --git a/src/chronos/__init__.py b/src/chronos/__init__.py new file mode 100644 index 0000000..4474e8e --- /dev/null +++ b/src/chronos/__init__.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from .chronos import ( + ChronosConfig, + ChronosModel, + ChronosPipeline, + ChronosTokenizer, + MeanScaleUniformBins, +) + +__all__ = [ + "ChronosConfig", + "ChronosModel", + "ChronosPipeline", + "ChronosTokenizer", + "MeanScaleUniformBins", +] diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py new file mode 100644 index 0000000..efa6bbf --- /dev/null +++ b/src/chronos/chronos.py @@ -0,0 +1,424 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + GenerationConfig, + PreTrainedModel, +) + + +@dataclass +class ChronosConfig: + """ + This class holds all the configuration parameters to be used + by ``ChronosTokenizer`` and ``ChronosModel``. + """ + + tokenizer_class: str + tokenizer_kwargs: Dict[str, Any] + n_tokens: int + n_special_tokens: int + pad_token_id: int + eos_token_id: int + use_eos_token: bool + model_type: Literal["causal", "seq2seq"] + context_length: int + prediction_length: int + num_samples: int + temperature: float + top_k: int + top_p: float + + def __post_init__(self): + assert ( + self.pad_token_id < self.n_special_tokens + and self.eos_token_id < self.n_special_tokens + ), f"Special token id's must be smaller than {self.n_special_tokens=}" + + def create_tokenizer(self) -> "ChronosTokenizer": + if self.tokenizer_class == "MeanScaleUniformBins": + return MeanScaleUniformBins(**self.tokenizer_kwargs, config=self) + raise ValueError + + +class ChronosTokenizer: + """ + A ``ChronosTokenizer`` definines how time series are mapped into token IDs + and back. + + For details, see the ``input_transform`` and ``output_transform`` methods, + which concrete classes must implement. + """ + + def input_transform( + self, context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Any]: + """ + Turn a batch of time series into token IDs, attention map, and scale. + + Parameters + ---------- + context + A tensor shaped (batch_size, time_length), containing the + timeseries to forecast. Use left-padding with ``torch.nan`` + to align time series of different lengths. + + Returns + ------- + token_ids + A tensor of integers, shaped (batch_size, time_length + 1) + if ``config.use_eos_token`` and (batch_size, time_length) + otherwise, containing token IDs for the input series. + attention_mask + A boolean tensor, same shape as ``token_ids``, indicating + which input observations are not ``torch.nan`` (i.e. not + missing nor padding). + decoding_context + An object that will be passed to ``output_transform``. + Contains the relevant context to decode output samples into + real values, such as location and scale parameters. + """ + raise NotImplementedError() + + def output_transform( + self, samples: torch.Tensor, decoding_context: Any + ) -> torch.Tensor: + """ + Turn a batch of sample token IDs into real values. + + Parameters + ---------- + samples + A tensor of integers, shaped (batch_size, num_samples, time_length), + containing token IDs of sample trajectories. + decoding_context + An object returned by ``input_transform`` containing + relevant context to decode samples, such as location and scale. + The nature of this depends on the specific tokenizer. + + Returns + ------- + forecasts + A real tensor, shaped (batch_size, num_samples, time_length), + containing forecasted sample paths. + """ + raise NotImplementedError() + + +class MeanScaleUniformBins(ChronosTokenizer): + def __init__( + self, low_limit: float, high_limit: float, config: ChronosConfig + ) -> None: + self.config = config + self.centers = torch.linspace( + low_limit, + high_limit, + config.n_tokens - config.n_special_tokens - 1, + ) + self.boundaries = torch.concat( + ( + torch.tensor([-1e20], device=self.centers.device), + (self.centers[1:] + self.centers[:-1]) / 2, + torch.tensor([1e20], device=self.centers.device), + ) + ) + + def input_transform( + self, context: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch_size, length = context.shape + + if length > self.config.context_length: + context = context[..., -self.config.context_length :] + elif length < self.config.context_length: + padding_size = ( + *context.shape[:-1], + self.config.context_length - length, + ) + padding = torch.full(size=padding_size, fill_value=torch.nan) + context = torch.concat((padding, context), dim=-1) + + attention_mask = ~torch.isnan(context) + scale = torch.nansum( + torch.abs(context) * attention_mask, dim=-1 + ) / torch.nansum(attention_mask, dim=-1) + scale[~(scale > 0)] = 1.0 + scaled_context = context / scale.unsqueeze(dim=-1) + token_ids = ( + torch.bucketize( + input=scaled_context, + boundaries=self.boundaries, + # buckets are open to the right, see: + # https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize + right=True, + ) + + self.config.n_special_tokens + ) + token_ids[~attention_mask] = self.config.pad_token_id + + if self.config.use_eos_token: + eos_tokens = torch.full( + (batch_size, 1), fill_value=self.config.eos_token_id + ) + token_ids = torch.concat((token_ids, eos_tokens), dim=1) + eos_mask = torch.full((batch_size, 1), fill_value=True) + attention_mask = torch.concat((attention_mask, eos_mask), dim=1) + + return token_ids, attention_mask, scale + + def output_transform( + self, samples: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: + scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1) + indices = torch.clamp( + samples - self.config.n_special_tokens, + min=0, + max=len(self.centers) - 1, + ) + return self.centers[indices] * scale_unsqueezed + + +class ChronosModel(nn.Module): + """ + A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers`` + and uses it to predict sample paths for time series tokens. + + Parameters + ---------- + config + The configuration to use. + model + The pre-trained model to use. + """ + + def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None: + super().__init__() + self.config = config + self.model = model + self.device = model.device + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + prediction_length: Optional[int] = None, + num_samples: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> torch.Tensor: + """ + Predict future sample tokens for the given token sequences. + + Arguments ``prediction_length``, ``num_samples``, ``temperature``, + ``top_k``, ``top_p`` can be used to customize the model inference, + and default to the corresponding attributes in ``self.config`` if + not provided. + + Returns + ------- + samples + A tensor of integers, shaped (batch_size, num_samples, time_length), + containing forecasted sample paths. + """ + if prediction_length is None: + prediction_length = self.config.prediction_length + if num_samples is None: + num_samples = self.config.num_samples + if temperature is None: + temperature = self.config.temperature + if top_k is None: + top_k = self.config.top_k + if top_p is None: + top_p = self.config.top_p + + preds = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + generation_config=GenerationConfig( + min_new_tokens=prediction_length, + max_new_tokens=prediction_length, + do_sample=True, + num_return_sequences=num_samples, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ), + ) + + if self.config.model_type == "seq2seq": + preds = preds[..., 1:] # remove the decoder start token + else: + assert self.config.model_type == "causal" + assert preds.size(-1) == input_ids.size(-1) + prediction_length + preds = preds[..., -prediction_length:] + + return preds.reshape(input_ids.size(0), num_samples, -1) + + +def left_pad_and_stack_1D(tensors: List[torch.Tensor]): + max_len = max(len(c) for c in tensors) + padded = [] + for c in tensors: + assert isinstance(c, torch.Tensor) + assert c.ndim == 1 + padding = torch.full( + size=(max_len - len(c),), fill_value=torch.nan, device=c.device + ) + padded.append(torch.concat((padding, c), dim=-1)) + return torch.stack(padded) + + +class ChronosPipeline: + """ + A ``ChronosPipeline`` uses the given tokenizer and model to forecast + input time series. + + Use the ``from_pretrained`` class method to load serialized models. + Use the ``predict`` method to get forecasts. + + Parameters + ---------- + tokenizer + The tokenizer object to use. + model + The model to use. + """ + + tokenizer: ChronosTokenizer + model: ChronosModel + + def __init__(self, tokenizer, model): + self.tokenizer = tokenizer + self.model = model + + def predict( + self, + context: Union[torch.Tensor, List[torch.Tensor]], + prediction_length: Optional[int] = None, + num_samples: Optional[int] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + limit_prediction_length: bool = True, + ) -> torch.Tensor: + """ + Get forecasts for the given time series. + + Parameters + ---------- + context + Input series. This is either a 1D tensor, or a list + of 1D tensors, or a 2D tensor whose first dimension + is batch. In the latter case, use left-padding with + ``torch.nan`` to align series of different lengths. + prediction_length + Time steps to predict. Defaults to what specified + in ``self.model.config``. + num_samples + Number of sample paths to predict. Defaults to what + specified in ``self.model.config``. + temperature + Temperature to use for generating sample tokens. + Defaults to what specified in ``self.model.config``. + top_k + Top-k parameter to use for generating sample tokens. + Defaults to what specified in ``self.model.config``. + top_p + Top-p parameter to use for generating sample tokens. + Defaults to what specified in ``self.model.config``. + limit_prediction_length + Force prediction length smaller or equal than the + built-in prediction length from the model. True by + default. When true, fail loudly if longer predictions + are requested, otherwise longer predictions are allowed. + + Returns + ------- + samples + Tensor of sample forecasts, of shape + (batch_size, num_samples, prediction_length). + """ + if isinstance(context, list): + context = left_pad_and_stack_1D(context) + assert isinstance(context, torch.Tensor) + if context.ndim == 1: + context = context.unsqueeze(0) + assert context.ndim == 2 + + if prediction_length is None: + prediction_length = self.model.config.prediction_length + + if prediction_length > self.model.config.prediction_length: + msg = ( + f"We recommend keeping prediction length <= {self.model.config.prediction_length}. " + f"The quality of longer predictions may degrade since the model is not optimized for it. " + ) + if limit_prediction_length: + msg += "You can turn off this check by setting `limit_prediction_length=False`." + raise ValueError(msg) + warnings.warn(msg) + + predictions = [] + remaining = prediction_length + + while remaining > 0: + token_ids, attention_mask, scale = self.tokenizer.input_transform(context) + samples = self.model( + token_ids.to(self.model.device), + attention_mask.to(self.model.device), + min(remaining, self.model.config.prediction_length), + num_samples, + temperature, + top_k, + top_p, + ) + prediction = self.tokenizer.output_transform( + samples.to(scale.device), scale + ) + + predictions.append(prediction) + remaining -= prediction.shape[-1] + + if remaining <= 0: + break + + context = torch.cat([context, prediction.median(dim=1).values], dim=-1) + + return torch.cat(predictions, dim=-1) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + """ + Load the model, either from a local path or from the HuggingFace Hub. + Supports the same arguments as ``AutoConfig`` and ``AutoModel`` + from ``transformers``. + """ + + config = AutoConfig.from_pretrained(*args, **kwargs) + + assert hasattr(config, "chronos_config"), "Not a Chronos config file" + + chronos_config = ChronosConfig(**config.chronos_config) + + if chronos_config.model_type == "seq2seq": + inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs) + else: + assert config.model_type == "causal" + inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs) + + return cls( + tokenizer=chronos_config.create_tokenizer(), + model=ChronosModel(config=chronos_config, model=inner_model), + ) diff --git a/test/dummy-chronos-model/config.json b/test/dummy-chronos-model/config.json new file mode 100644 index 0000000..5b2e399 --- /dev/null +++ b/test/dummy-chronos-model/config.json @@ -0,0 +1,48 @@ +{ + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 32, + "d_kv": 16, + "d_model": 64, + "decoder_start_token_id": 0, + "dense_act_fn": "relu", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "initializer_factor": 0.05, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 1, + "num_heads": 1, + "num_layers": 1, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "torch_dtype": "bfloat16", + "transformers_version": "4.31.0", + "use_cache": true, + "vocab_size": 32, + "chronos_config": { + "tokenizer_class": "MeanScaleUniformBins", + "tokenizer_kwargs": { + "low_limit": -15.0, + "high_limit": 15.0 + }, + "n_tokens": 32, + "n_special_tokens": 2, + "pad_token_id": 0, + "eos_token_id": 1, + "use_eos_token": true, + "model_type": "seq2seq", + "context_length": 512, + "prediction_length": 64, + "num_samples": 20, + "temperature": 1.0, + "top_k": 50, + "top_p": 1.0 + } +} diff --git a/test/dummy-chronos-model/generation_config.json b/test/dummy-chronos-model/generation_config.json new file mode 100644 index 0000000..7528dbb --- /dev/null +++ b/test/dummy-chronos-model/generation_config.json @@ -0,0 +1,7 @@ +{ + "_from_model_config": true, + "decoder_start_token_id": 0, + "eos_token_id": 1, + "pad_token_id": 0, + "transformers_version": "4.31.0" +} diff --git a/test/dummy-chronos-model/pytorch_model.bin b/test/dummy-chronos-model/pytorch_model.bin new file mode 100644 index 0000000..42e9a79 Binary files /dev/null and b/test/dummy-chronos-model/pytorch_model.bin differ diff --git a/test/test_chronos.py b/test/test_chronos.py new file mode 100644 index 0000000..85b8669 --- /dev/null +++ b/test/test_chronos.py @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Tuple + +import torch +import pytest + +from chronos import ChronosConfig, ChronosPipeline + + +@pytest.mark.xfail +@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27]) +@pytest.mark.parametrize("n_special_tokens", [2, 5, 13]) +@pytest.mark.parametrize("use_eos_token", [False, True]) +def test_tokenizer_fixed_data( + n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool +): + n_tokens = n_numerical_tokens + n_special_tokens + context_length = 3 + + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), + n_tokens=n_tokens, + n_special_tokens=n_special_tokens, + pad_token_id=0, + eos_token_id=1, + use_eos_token=use_eos_token, + model_type="seq2seq", + context_length=512, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + + tokenizer = config.create_tokenizer() + + context = torch.tensor( + [ + [-3.7, 3.7], + [-42.0, 42.0], + ] + ) + batch_size, _ = context.shape + + token_ids, attention_mask, scale = tokenizer.input_transform(context) + + assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token) + assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size)) + assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size)) + assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size)) + + if use_eos_token: + assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size)) + + samples = tokenizer.output_transform( + torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1), + decoding_context=scale, + ) + + assert (samples[:, 0, [0, -1]] == context).all() + + +@pytest.mark.xfail +@pytest.mark.parametrize("use_eos_token", [False, True]) +def test_tokenizer_random_data(use_eos_token: bool): + context_length = 8 + n_tokens = 256 + n_special_tokens = 2 + + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), + n_tokens=n_tokens, + n_special_tokens=n_special_tokens, + pad_token_id=0, + eos_token_id=1, + use_eos_token=use_eos_token, + model_type="seq2seq", + context_length=context_length, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + + tokenizer = config.create_tokenizer() + + context = torch.tensor( + [ + [torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0], + [3.0, torch.nan, 3.9, 4.0, 4.1, 4.9], + ] + ) + + token_ids, attention_mask, scale = tokenizer.input_transform(context) + + assert token_ids.shape == ( + *context.shape[:-1], + context_length + 1 * use_eos_token, + ) + assert attention_mask.shape == ( + *context.shape[:-1], + context_length + 1 * use_eos_token, + ) + assert scale.shape == context.shape[:1] + + sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4)) + sample_ids[0, 0, 0] = n_special_tokens + sample_ids[-1, -1, -1] = n_tokens - 1 + + samples = tokenizer.output_transform(sample_ids, scale) + + assert samples.shape == (2, 10, 4) + + +def validate_samples(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None: + assert isinstance(samples, torch.Tensor) + assert samples.shape == shape + + +@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16]) +def test_pipeline(torch_dtype: str): + pipeline = ChronosPipeline.from_pretrained( + Path(__file__).parent / "dummy-chronos-model", + device_map="cpu", + torch_dtype=torch_dtype, + ) + context = 10 * torch.rand(size=(4, 16)) + 10 + + # input: tensor of shape (batch_size, context_length) + + samples = pipeline.predict(context, num_samples=12, prediction_length=3) + validate_samples(samples, (4, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(context, num_samples=7, prediction_length=65) + + samples = pipeline.predict( + context, num_samples=7, prediction_length=65, limit_prediction_length=False + ) + validate_samples(samples, (4, 7, 65)) + + # input: batch_size-long list of tensors of shape (context_length,) + + samples = pipeline.predict(list(context), num_samples=12, prediction_length=3) + validate_samples(samples, (4, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(list(context), num_samples=7, prediction_length=65) + + samples = pipeline.predict( + list(context), + num_samples=7, + prediction_length=65, + limit_prediction_length=False, + ) + validate_samples(samples, (4, 7, 65)) + + # input: tensor of shape (context_length,) + + samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3) + validate_samples(samples, (1, 12, 3)) + + with pytest.raises(ValueError): + samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65) + + samples = pipeline.predict( + context[0, ...], + num_samples=7, + prediction_length=65, + limit_prediction_length=False, + ) + validate_samples(samples, (1, 7, 65))