Fix number of quantisation buckets (#182)
Some checks failed
CI / type-check (ubuntu-latest, 3.11) (push) Has been cancelled
CI / test (ubuntu-latest, 3.11) (push) Has been cancelled

Fixes https://github.com/amazon-science/chronos-forecasting/issues/181.

Chronos' tokenizer has a vocabulary size of `n_tokens`. Among these,
there are `n_special_tokens` reserved for EOS, PAD, etc. and `n_tokens -
n_special_tokens` allocated to numerical values. However, the provided
`MeanScaleUniformBins` tokenizer creates` n_tokens - n_special_tokens +
1` different buckets, resulting in a total of `n_tokens + 1` possible
tokens. This causes training and inference errors when one of the data
points gets allocated to the largest bucket, as the model requires 0 <=
token_id < n_tokens.

This PR modifies the `MeanScaleUniformBins` tokenizer, so that it
creates one less bucket for numerical values.

---

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Lorenzo Stella <lorenzostella@gmail.com>
This commit is contained in:
Alvaro Perez-Diaz 2024-10-04 23:00:42 +02:00 committed by GitHub
parent eb7bdfc047
commit ac6ee36ace
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 3 deletions

View File

@ -1,6 +1,6 @@
[project]
name = "chronos"
version = "1.2.0"
version = "1.2.1"
requires-python = ">=3.8"
license = { file = "LICENSE" }
dependencies = [

View File

@ -5,7 +5,6 @@ import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import chronos
import torch
import torch.nn as nn
from transformers import (
@ -16,6 +15,8 @@ from transformers import (
PreTrainedModel,
)
import chronos
@dataclass
class ChronosConfig:
@ -187,6 +188,9 @@ class MeanScaleUniformBins(ChronosTokenizer):
)
+ self.config.n_special_tokens
)
token_ids.clamp_(0, self.config.n_tokens - 1)
token_ids[~attention_mask] = self.config.pad_token_id
return token_ids, attention_mask, scale

View File

@ -4,8 +4,8 @@
from pathlib import Path
from typing import Tuple
import torch
import pytest
import torch
from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
@ -244,3 +244,59 @@ def test_pipeline_embed(torch_dtype: str):
embedding, scale = pipeline.embed(context[0, ...])
validate_tensor(embedding, (1, expected_embed_length, d_model))
validate_tensor(scale, (1,))
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
def test_tokenizer_number_of_buckets(n_tokens):
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
n_tokens=n_tokens,
n_special_tokens=2,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
n_numerical_tokens = config.n_tokens - config.n_special_tokens
# The tokenizer has one bucket too many as a result of an early bug. In order to
# keep consistent with the original trained models, this is kept as it is. However,
# token ids are clipped to a maximum of `n_tokens - 1` to avoid out-of-bounds errors.
assert len(tokenizer.centers) == (n_numerical_tokens - 1)
assert len(tokenizer.boundaries) == n_numerical_tokens
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
def test_token_clipping(n_tokens):
config = ChronosConfig(
tokenizer_class="MeanScaleUniformBins",
tokenizer_kwargs={"low_limit": -15, "high_limit": 15},
n_tokens=n_tokens,
n_special_tokens=2,
pad_token_id=0,
eos_token_id=1,
use_eos_token=True,
model_type="seq2seq",
context_length=512,
prediction_length=64,
num_samples=20,
temperature=1.0,
top_k=50,
top_p=1.0,
)
tokenizer = config.create_tokenizer()
huge_value = 1e22 # this large value is assigned to the largest bucket
token_ids, _, _ = tokenizer._input_transform(
context=torch.tensor([[huge_value]]), scale=torch.tensor(([1]))
)
assert token_ids[0, 0] == config.n_tokens - 1 # and it's clipped to n_tokens - 1