From ac6ee36acee1e47446cd66f72f540c87f1f1fbe4 Mon Sep 17 00:00:00 2001 From: Alvaro Perez-Diaz Date: Fri, 4 Oct 2024 23:00:42 +0200 Subject: [PATCH] Fix number of quantisation buckets (#182) Fixes https://github.com/amazon-science/chronos-forecasting/issues/181. Chronos' tokenizer has a vocabulary size of `n_tokens`. Among these, there are `n_special_tokens` reserved for EOS, PAD, etc. and `n_tokens - n_special_tokens` allocated to numerical values. However, the provided `MeanScaleUniformBins` tokenizer creates` n_tokens - n_special_tokens + 1` different buckets, resulting in a total of `n_tokens + 1` possible tokens. This causes training and inference errors when one of the data points gets allocated to the largest bucket, as the model requires 0 <= token_id < n_tokens. This PR modifies the `MeanScaleUniformBins` tokenizer, so that it creates one less bucket for numerical values. --- By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Lorenzo Stella --- pyproject.toml | 2 +- src/chronos/chronos.py | 6 ++++- test/test_chronos.py | 58 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7c3af18..55dd210 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "chronos" -version = "1.2.0" +version = "1.2.1" requires-python = ">=3.8" license = { file = "LICENSE" } dependencies = [ diff --git a/src/chronos/chronos.py b/src/chronos/chronos.py index 3b17502..c8ba344 100644 --- a/src/chronos/chronos.py +++ b/src/chronos/chronos.py @@ -5,7 +5,6 @@ import warnings from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union -import chronos import torch import torch.nn as nn from transformers import ( @@ -16,6 +15,8 @@ from transformers import ( PreTrainedModel, ) +import chronos + @dataclass class ChronosConfig: @@ -187,6 +188,9 @@ class MeanScaleUniformBins(ChronosTokenizer): ) + self.config.n_special_tokens ) + + token_ids.clamp_(0, self.config.n_tokens - 1) + token_ids[~attention_mask] = self.config.pad_token_id return token_ids, attention_mask, scale diff --git a/test/test_chronos.py b/test/test_chronos.py index 9cd039c..a84c57d 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -4,8 +4,8 @@ from pathlib import Path from typing import Tuple -import torch import pytest +import torch from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins @@ -244,3 +244,59 @@ def test_pipeline_embed(torch_dtype: str): embedding, scale = pipeline.embed(context[0, ...]) validate_tensor(embedding, (1, expected_embed_length, d_model)) validate_tensor(scale, (1,)) + + +@pytest.mark.parametrize("n_tokens", [10, 1000, 10000]) +def test_tokenizer_number_of_buckets(n_tokens): + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0), + n_tokens=n_tokens, + n_special_tokens=2, + pad_token_id=0, + eos_token_id=1, + use_eos_token=True, + model_type="seq2seq", + context_length=512, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + tokenizer = config.create_tokenizer() + + n_numerical_tokens = config.n_tokens - config.n_special_tokens + + # The tokenizer has one bucket too many as a result of an early bug. In order to + # keep consistent with the original trained models, this is kept as it is. However, + # token ids are clipped to a maximum of `n_tokens - 1` to avoid out-of-bounds errors. + assert len(tokenizer.centers) == (n_numerical_tokens - 1) + assert len(tokenizer.boundaries) == n_numerical_tokens + + +@pytest.mark.parametrize("n_tokens", [10, 1000, 10000]) +def test_token_clipping(n_tokens): + config = ChronosConfig( + tokenizer_class="MeanScaleUniformBins", + tokenizer_kwargs={"low_limit": -15, "high_limit": 15}, + n_tokens=n_tokens, + n_special_tokens=2, + pad_token_id=0, + eos_token_id=1, + use_eos_token=True, + model_type="seq2seq", + context_length=512, + prediction_length=64, + num_samples=20, + temperature=1.0, + top_k=50, + top_p=1.0, + ) + tokenizer = config.create_tokenizer() + + huge_value = 1e22 # this large value is assigned to the largest bucket + token_ids, _, _ = tokenizer._input_transform( + context=torch.tensor([[huge_value]]), scale=torch.tensor(([1])) + ) + assert token_ids[0, 0] == config.n_tokens - 1 # and it's clipped to n_tokens - 1