Upload code

This commit is contained in:
Lorenzo Stella 2024-02-29 14:39:05 +01:00
parent 2420c10232
commit 7ba945c995
11 changed files with 964 additions and 6 deletions

25
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,25 @@
name: CI
on: [push, pull_request]
jobs:
test:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ['3.11']
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]"
- name: Test with pytest
run: pytest

163
.gitignore vendored Normal file
View File

@ -0,0 +1,163 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# macOS stuff
.DS_store

View File

@ -1,11 +1,87 @@
## My Project
# Chronos: Learning the Language of Time Series
TODO: Fill this README out!
Chronos is a family of **pretrained time series forecasting models** based on language model architectures. A time series is transformed into a sequence of tokens via scaling and quantization, and a language model is trained on these tokens using the cross-entropy loss. Once trained, probabilistic forecasts are obtained by sampling multiple future trajectories given the historical context. Chronos models have been trained on a large corpus of publicly available time series data, as well as synthetic data generated using Gaussian processes.
Be sure to:
For details on Chronos models, training data and procedures, and experimental results, please refer to the paper [Chronos: Learning the Language of Time Series](https://arxiv.org/abs/2403.07815).
* Change the title in this README
* Edit your repository description on GitHub
<p align="center">
<img src="figures/main-figure.png" width="100%">
<br />
<span>
Fig. 1: High-level depiction of Chronos. (<b>Left</b>) The input time series is scaled and quantized to obtain a sequence of tokens. (<b>Center</b>) The tokens are fed into a language model which may either be an encoder-decoder or a decoder-only model. The model is trained using the cross-entropy loss. (<b>Right</b>) During inference, we autoregressively sample tokens from the model and map them back to numerical values. Multiple trajectories are sampled to obtain a predictive distribution.
</span>
</p>
---
## Architecture
The models in this repository are based on the [T5 architecture](https://arxiv.org/abs/1910.10683). The only difference is in the vocabulary size: Chronos-T5 models use 4096 different tokens, compared to 32128 of the original T5 models, resulting in fewer parameters.
| Model | Parameters | Based on |
| ---------------------------------------------------------------------- | ---------- | ---------------------------------------------------------------------- |
| [**chronos-t5-tiny**](https://huggingface.co/amazon/chronos-t5-tiny) | 8M | [t5-efficient-tiny](https://huggingface.co/google/t5-efficient-tiny) |
| [**chronos-t5-mini**](https://huggingface.co/amazon/chronos-t5-mini) | 20M | [t5-efficient-mini](https://huggingface.co/google/t5-efficient-mini) |
| [**chronos-t5-small**](https://huggingface.co/amazon/chronos-t5-small) | 46M | [t5-efficient-small](https://huggingface.co/google/t5-efficient-small) |
| [**chronos-t5-base**](https://huggingface.co/amazon/chronos-t5-base) | 200M | [t5-efficient-base](https://huggingface.co/google/t5-efficient-base) |
| [**chronos-t5-large**](https://huggingface.co/amazon/chronos-t5-large) | 710M | [t5-efficient-large](https://huggingface.co/google/t5-efficient-large) |
## Usage
To perform inference with Chronos models, install this package by running:
```
pip install git+https://github.com/amazon-science/chronos-forecasting.git
```
A minimal example showing how to perform inference using Chronos models:
```python
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv")
# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = torch.tensor(df["#Passengers"])
prediction_length = 12
forecast = pipeline.predict(context, prediction_length) # shape [num_series, num_samples, prediction_length]
# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
plt.show()
```
## Citation
If you find Chronos models useful for your research, please consider citing the associated [paper](https://arxiv.org/abs/2403.07815):
```
@article{ansari2024chronos,
author = {Ansari, Abdul Fatir and Stella, Lorenzo and Turkmen, Caner and Zhang, Xiyuan, and Mercado, Pedro and Shen, Huibin and Shchur, Oleksandr and Rangapuram, Syama Syndar and Pineda Arango, Sebastian and Kapoor, Shubham and Zschiegner, Jasper and Maddix, Danielle C. and Mahoney, Michael W. and Torkkola, Kari and Gordon Wilson, Andrew and Bohlke-Schneider, Michael and Wang, Yuyang},
title = {Chronos: Learning the Language of Time Series},
journal = {arXiv preprint arXiv:2403.07815},
year = {2024}
}
```
## Security
@ -14,4 +90,3 @@ See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more inform
## License
This project is licensed under the Apache-2.0 License.

BIN
figures/main-figure.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 227 KiB

19
pyproject.toml Normal file
View File

@ -0,0 +1,19 @@
[project]
name = "chronos"
version = "1.0.0"
requires-python = ">=3.8"
license = {file = "LICENSE"}
dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate"
]
[project.optional-dependencies]
test = [
"pytest~=8.0",
"numpy~=1.21"
]
[tool.mypy]
ignore_missing_imports = true

18
src/chronos/__init__.py Normal file
View File

@ -0,0 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from .chronos import (
ChronosConfig,
ChronosModel,
ChronosPipeline,
ChronosTokenizer,
MeanScaleUniformBins,
)
__all__ = [
"ChronosConfig",
"ChronosModel",
"ChronosPipeline",
"ChronosTokenizer",
"MeanScaleUniformBins",
]

424
src/chronos/chronos.py Normal file
View File

@ -0,0 +1,424 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
GenerationConfig,
PreTrainedModel,
)
@dataclass
class ChronosConfig:
"""
This class holds all the configuration parameters to be used
by ``ChronosTokenizer`` and ``ChronosModel``.
"""
tokenizer_class: str
tokenizer_kwargs: Dict[str, Any]
n_tokens: int
n_special_tokens: int
pad_token_id: int
eos_token_id: int
use_eos_token: bool
model_type: Literal["causal", "seq2seq"]
context_length: int
prediction_length: int
num_samples: int
temperature: float
top_k: int
top_p: float
def __post_init__(self):
assert (
self.pad_token_id < self.n_special_tokens
and self.eos_token_id < self.n_special_tokens
), f"Special token id's must be smaller than {self.n_special_tokens=}"
def create_tokenizer(self) -> "ChronosTokenizer":
if self.tokenizer_class == "MeanScaleUniformBins":
return MeanScaleUniformBins(**self.tokenizer_kwargs, config=self)
raise ValueError
class ChronosTokenizer:
"""
A ``ChronosTokenizer`` definines how time series are mapped into token IDs
and back.
For details, see the ``input_transform`` and ``output_transform`` methods,
which concrete classes must implement.
"""
def input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, Any]:
"""
Turn a batch of time series into token IDs, attention map, and scale.
Parameters
----------
context
A tensor shaped (batch_size, time_length), containing the
timeseries to forecast. Use left-padding with ``torch.nan``
to align time series of different lengths.
Returns
-------
token_ids
A tensor of integers, shaped (batch_size, time_length + 1)
if ``config.use_eos_token`` and (batch_size, time_length)
otherwise, containing token IDs for the input series.
attention_mask
A boolean tensor, same shape as ``token_ids``, indicating
which input observations are not ``torch.nan`` (i.e. not
missing nor padding).
decoding_context
An object that will be passed to ``output_transform``.
Contains the relevant context to decode output samples into
real values, such as location and scale parameters.
"""
raise NotImplementedError()
def output_transform(
self, samples: torch.Tensor, decoding_context: Any
) -> torch.Tensor:
"""
Turn a batch of sample token IDs into real values.
Parameters
----------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing token IDs of sample trajectories.
decoding_context
An object returned by ``input_transform`` containing
relevant context to decode samples, such as location and scale.
The nature of this depends on the specific tokenizer.
Returns
-------
forecasts
A real tensor, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
raise NotImplementedError()
class MeanScaleUniformBins(ChronosTokenizer):
def __init__(
self, low_limit: float, high_limit: float, config: ChronosConfig
) -> None:
self.config = config
self.centers = torch.linspace(
low_limit,
high_limit,
config.n_tokens - config.n_special_tokens - 1,
)
self.boundaries = torch.concat(
(
torch.tensor([-1e20], device=self.centers.device),
(self.centers[1:] + self.centers[:-1]) / 2,
torch.tensor([1e20], device=self.centers.device),
)
)
def input_transform(
self, context: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size, length = context.shape
if length > self.config.context_length:
context = context[..., -self.config.context_length :]
elif length < self.config.context_length:
padding_size = (
*context.shape[:-1],
self.config.context_length - length,
)
padding = torch.full(size=padding_size, fill_value=torch.nan)
context = torch.concat((padding, context), dim=-1)
attention_mask = ~torch.isnan(context)
scale = torch.nansum(
torch.abs(context) * attention_mask, dim=-1
) / torch.nansum(attention_mask, dim=-1)
scale[~(scale > 0)] = 1.0
scaled_context = context / scale.unsqueeze(dim=-1)
token_ids = (
torch.bucketize(
input=scaled_context,
boundaries=self.boundaries,
# buckets are open to the right, see:
# https://pytorch.org/docs/2.1/generated/torch.bucketize.html#torch-bucketize
right=True,
)
+ self.config.n_special_tokens
)
token_ids[~attention_mask] = self.config.pad_token_id
if self.config.use_eos_token:
eos_tokens = torch.full(
(batch_size, 1), fill_value=self.config.eos_token_id
)
token_ids = torch.concat((token_ids, eos_tokens), dim=1)
eos_mask = torch.full((batch_size, 1), fill_value=True)
attention_mask = torch.concat((attention_mask, eos_mask), dim=1)
return token_ids, attention_mask, scale
def output_transform(
self, samples: torch.Tensor, scale: torch.Tensor
) -> torch.Tensor:
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
indices = torch.clamp(
samples - self.config.n_special_tokens,
min=0,
max=len(self.centers) - 1,
)
return self.centers[indices] * scale_unsqueezed
class ChronosModel(nn.Module):
"""
A ``ChronosModel`` wraps a ``PreTrainedModel`` object from ``transformers``
and uses it to predict sample paths for time series tokens.
Parameters
----------
config
The configuration to use.
model
The pre-trained model to use.
"""
def __init__(self, config: ChronosConfig, model: PreTrainedModel) -> None:
super().__init__()
self.config = config
self.model = model
self.device = model.device
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
) -> torch.Tensor:
"""
Predict future sample tokens for the given token sequences.
Arguments ``prediction_length``, ``num_samples``, ``temperature``,
``top_k``, ``top_p`` can be used to customize the model inference,
and default to the corresponding attributes in ``self.config`` if
not provided.
Returns
-------
samples
A tensor of integers, shaped (batch_size, num_samples, time_length),
containing forecasted sample paths.
"""
if prediction_length is None:
prediction_length = self.config.prediction_length
if num_samples is None:
num_samples = self.config.num_samples
if temperature is None:
temperature = self.config.temperature
if top_k is None:
top_k = self.config.top_k
if top_p is None:
top_p = self.config.top_p
preds = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=GenerationConfig(
min_new_tokens=prediction_length,
max_new_tokens=prediction_length,
do_sample=True,
num_return_sequences=num_samples,
eos_token_id=self.config.eos_token_id,
pad_token_id=self.config.pad_token_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
)
if self.config.model_type == "seq2seq":
preds = preds[..., 1:] # remove the decoder start token
else:
assert self.config.model_type == "causal"
assert preds.size(-1) == input_ids.size(-1) + prediction_length
preds = preds[..., -prediction_length:]
return preds.reshape(input_ids.size(0), num_samples, -1)
def left_pad_and_stack_1D(tensors: List[torch.Tensor]):
max_len = max(len(c) for c in tensors)
padded = []
for c in tensors:
assert isinstance(c, torch.Tensor)
assert c.ndim == 1
padding = torch.full(
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
)
padded.append(torch.concat((padding, c), dim=-1))
return torch.stack(padded)
class ChronosPipeline:
"""
A ``ChronosPipeline`` uses the given tokenizer and model to forecast
input time series.
Use the ``from_pretrained`` class method to load serialized models.
Use the ``predict`` method to get forecasts.
Parameters
----------
tokenizer
The tokenizer object to use.
model
The model to use.
"""
tokenizer: ChronosTokenizer
model: ChronosModel
def __init__(self, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
def predict(
self,
context: Union[torch.Tensor, List[torch.Tensor]],
prediction_length: Optional[int] = None,
num_samples: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
limit_prediction_length: bool = True,
) -> torch.Tensor:
"""
Get forecasts for the given time series.
Parameters
----------
context
Input series. This is either a 1D tensor, or a list
of 1D tensors, or a 2D tensor whose first dimension
is batch. In the latter case, use left-padding with
``torch.nan`` to align series of different lengths.
prediction_length
Time steps to predict. Defaults to what specified
in ``self.model.config``.
num_samples
Number of sample paths to predict. Defaults to what
specified in ``self.model.config``.
temperature
Temperature to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_k
Top-k parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
top_p
Top-p parameter to use for generating sample tokens.
Defaults to what specified in ``self.model.config``.
limit_prediction_length
Force prediction length smaller or equal than the
built-in prediction length from the model. True by
default. When true, fail loudly if longer predictions
are requested, otherwise longer predictions are allowed.
Returns
-------
samples
Tensor of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
if isinstance(context, list):
context = left_pad_and_stack_1D(context)
assert isinstance(context, torch.Tensor)
if context.ndim == 1:
context = context.unsqueeze(0)
assert context.ndim == 2
if prediction_length is None:
prediction_length = self.model.config.prediction_length
if prediction_length > self.model.config.prediction_length:
msg = (
f"We recommend keeping prediction length <= {self.model.config.prediction_length}. "
f"The quality of longer predictions may degrade since the model is not optimized for it. "
)
if limit_prediction_length:
msg += "You can turn off this check by setting `limit_prediction_length=False`."
raise ValueError(msg)
warnings.warn(msg)
predictions = []
remaining = prediction_length
while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(context)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
min(remaining, self.model.config.prediction_length),
num_samples,
temperature,
top_k,
top_p,
)
prediction = self.tokenizer.output_transform(
samples.to(scale.device), scale
)
predictions.append(prediction)
remaining -= prediction.shape[-1]
if remaining <= 0:
break
context = torch.cat([context, prediction.median(dim=1).values], dim=-1)
return torch.cat(predictions, dim=-1)
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""
Load the model, either from a local path or from the HuggingFace Hub.
Supports the same arguments as ``AutoConfig`` and ``AutoModel``
from ``transformers``.
"""
config = AutoConfig.from_pretrained(*args, **kwargs)
assert hasattr(config, "chronos_config"), "Not a Chronos config file"
chronos_config = ChronosConfig(**config.chronos_config)
if chronos_config.model_type == "seq2seq":
inner_model = AutoModelForSeq2SeqLM.from_pretrained(*args, **kwargs)
else:
assert config.model_type == "causal"
inner_model = AutoModelForCausalLM.from_pretrained(*args, **kwargs)
return cls(
tokenizer=chronos_config.create_tokenizer(),
model=ChronosModel(config=chronos_config, model=inner_model),
)

View File

@ -0,0 +1,48 @@
{
"architectures": [
"T5ForConditionalGeneration"
],
"d_ff": 32,
"d_kv": 16,
"d_model": 64,
"decoder_start_token_id": 0,
"dense_act_fn": "relu",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "relu",
"initializer_factor": 0.05,
"is_encoder_decoder": true,
"is_gated_act": false,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 512,
"num_decoder_layers": 1,
"num_heads": 1,
"num_layers": 1,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"torch_dtype": "bfloat16",
"transformers_version": "4.31.0",
"use_cache": true,
"vocab_size": 32,
"chronos_config": {
"tokenizer_class": "MeanScaleUniformBins",
"tokenizer_kwargs": {
"low_limit": -15.0,
"high_limit": 15.0
},
"n_tokens": 32,
"n_special_tokens": 2,
"pad_token_id": 0,
"eos_token_id": 1,
"use_eos_token": true,
"model_type": "seq2seq",
"context_length": 512,
"prediction_length": 64,
"num_samples": 20,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0
}
}

View File

@ -0,0 +1,7 @@
{
"_from_model_config": true,
"decoder_start_token_id": 0,
"eos_token_id": 1,
"pad_token_id": 0,
"transformers_version": "4.31.0"
}

Binary file not shown.

179
test/test_chronos.py Normal file
View File

@ -0,0 +1,179 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Tuple
import torch
import pytest
from chronos import ChronosConfig, ChronosPipeline
@pytest.mark.xfail
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_fixed_data(
n_numerical_tokens: int, n_special_tokens: int, use_eos_token: bool
):
n_tokens = n_numerical_tokens + n_special_tokens
context_length = 3
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[-3.7, 3.7],
[-42.0, 42.0],
]
)
batch_size, _ = context.shape
token_ids, attention_mask, scale = tokenizer.input_transform(context)
assert token_ids.shape == (batch_size, context_length + 1 * use_eos_token)
assert all(token_ids[:, 0] == torch.tensor([0]).repeat(batch_size))
assert all(token_ids[:, 1] == torch.tensor([n_special_tokens]).repeat(batch_size))
assert all(token_ids[:, 2] == torch.tensor([n_tokens - 1]).repeat(batch_size))
if use_eos_token:
assert all(token_ids[:, 3] == torch.tensor([1]).repeat(batch_size))
samples = tokenizer.output_transform(
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
decoding_context=scale,
)
assert (samples[:, 0, [0, -1]] == context).all()
@pytest.mark.xfail
@pytest.mark.parametrize("use_eos_token", [False, True])
def test_tokenizer_random_data(use_eos_token: bool):
context_length = 8
n_tokens = 256
n_special_tokens = 2
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=n_special_tokens,
pad_token_id=0,
eos_token_id=1,
use_eos_token=use_eos_token,
model_type="seq2seq",
context_length=context_length,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
context = torch.tensor(
[
[torch.nan, torch.nan, 1.0, 1.1, torch.nan, 2.0],
[3.0, torch.nan, 3.9, 4.0, 4.1, 4.9],
]
)
token_ids, attention_mask, scale = tokenizer.input_transform(context)
assert token_ids.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert attention_mask.shape == (
*context.shape[:-1],
context_length + 1 * use_eos_token,
)
assert scale.shape == context.shape[:1]
sample_ids = torch.randint(low=n_special_tokens, high=n_tokens, size=(2, 10, 4))
sample_ids[0, 0, 0] = n_special_tokens
sample_ids[-1, -1, -1] = n_tokens - 1
samples = tokenizer.output_transform(sample_ids, scale)
assert samples.shape == (2, 10, 4)
def validate_samples(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
def test_pipeline(torch_dtype: str):
pipeline = ChronosPipeline.from_pretrained(
Path(__file__).parent / "dummy-chronos-model",
device_map="cpu",
torch_dtype=torch_dtype,
)
context = 10 * torch.rand(size=(4, 16)) + 10
# input: tensor of shape (batch_size, context_length)
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
validate_samples(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
samples = pipeline.predict(
context, num_samples=7, prediction_length=65, limit_prediction_length=False
)
validate_samples(samples, (4, 7, 65))
# input: batch_size-long list of tensors of shape (context_length,)
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
validate_samples(samples, (4, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
samples = pipeline.predict(
list(context),
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_samples(samples, (4, 7, 65))
# input: tensor of shape (context_length,)
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
validate_samples(samples, (1, 12, 3))
with pytest.raises(ValueError):
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
samples = pipeline.predict(
context[0, ...],
num_samples=7,
prediction_length=65,
limit_prediction_length=False,
)
validate_samples(samples, (1, 7, 65))