mirror of
https://github.com/amazon-science/chronos-forecasting.git
synced 2024-11-25 16:51:05 +08:00
Force context scaling and quantization in float32, add assertions to tests (#197)
*Issue #, if available:* Fixes #193 *Description of changes:* Passing in contexts in lower precision than float32 may result in a drop of accuracy. This change ensures that the tokenizer (which does scaling and quantization) operates on a float32 batch. Tested across GPU/CPU and different context dtypes with ```python from itertools import product import pandas as pd import torch from chronos import ChronosPipeline import matplotlib.pyplot as plt # requires: pip install matplotlib import numpy as np df = pd.read_csv("https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv") for context_dtype, context_device, model_dtype, model_device in product( [torch.bfloat16, torch.float16, torch.float32], ["cpu"], # only cpu input supported at the moment [torch.bfloat16, torch.float16, torch.float32], ["cpu", "cuda"], ): pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-tiny", device_map=model_device, torch_dtype=model_dtype, ) forecast = pipeline.predict( context=torch.tensor(df["#Passengers"]).to(dtype=context_dtype, device=context_device), prediction_length=65, num_samples=20, limit_prediction_length=False, ) assert forecast.dtype == context_dtype, f"{forecast.dtype=} but {context_dtype=}" assert str(forecast.device) == context_device, f"{forecast.device=} but {context_device=}" forecast_index = range(len(df), len(df) + 65) low, median, high = np.quantile(forecast[0].to(device="cpu", dtype=torch.float32).numpy(), [0.1, 0.5, 0.9], axis=0) plt.figure(figsize=(8, 4)) plt.plot(df["#Passengers"], color="royalblue", label="historical data") plt.plot(forecast_index, median, color="tomato", label="median forecast") plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") plt.legend() plt.grid() plt.show() ``` By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
This commit is contained in:
parent
ac6ee36ace
commit
d2eef92009
@ -169,6 +169,7 @@ class MeanScaleUniformBins(ChronosTokenizer):
|
||||
def _input_transform(
|
||||
self, context: torch.Tensor, scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
context = context.to(dtype=torch.float32)
|
||||
attention_mask = ~torch.isnan(context)
|
||||
|
||||
if scale is None:
|
||||
@ -373,7 +374,7 @@ def left_pad_and_stack_1D(tensors: List[torch.Tensor]) -> torch.Tensor:
|
||||
size=(max_len - len(c),), fill_value=torch.nan, device=c.device
|
||||
)
|
||||
padded.append(torch.concat((padding, c), dim=-1))
|
||||
return torch.stack(padded)
|
||||
return torch.stack(padded).to(tensors[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -506,6 +507,9 @@ class ChronosPipeline:
|
||||
raise ValueError(msg)
|
||||
warnings.warn(msg)
|
||||
|
||||
input_dtype = context_tensor.dtype
|
||||
input_device = context_tensor.device
|
||||
|
||||
predictions = []
|
||||
remaining = prediction_length
|
||||
|
||||
@ -536,7 +540,7 @@ class ChronosPipeline:
|
||||
[context_tensor, prediction.median(dim=1).values], dim=-1
|
||||
)
|
||||
|
||||
return torch.cat(predictions, dim=-1)
|
||||
return torch.cat(predictions, dim=-1).to(dtype=input_dtype, device=input_device)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
|
@ -157,24 +157,26 @@ def test_tokenizer_random_data(use_eos_token: bool):
|
||||
assert samples.shape == (2, 10, 4)
|
||||
|
||||
|
||||
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
|
||||
assert isinstance(samples, torch.Tensor)
|
||||
assert samples.shape == shape
|
||||
def validate_tensor(a: torch.Tensor, shape: Tuple[int, ...], dtype) -> None:
|
||||
assert isinstance(a, torch.Tensor)
|
||||
assert a.shape == shape
|
||||
assert a.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_predict(torch_dtype: str):
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_predict(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
samples = pipeline.predict(context, num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (4, 12, 3))
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context, num_samples=7, prediction_length=65)
|
||||
@ -182,12 +184,12 @@ def test_pipeline_predict(torch_dtype: str):
|
||||
samples = pipeline.predict(
|
||||
context, num_samples=7, prediction_length=65, limit_prediction_length=False
|
||||
)
|
||||
validate_tensor(samples, (4, 7, 65))
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(list(context), num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (4, 12, 3))
|
||||
validate_tensor(samples, shape=(4, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(list(context), num_samples=7, prediction_length=65)
|
||||
@ -198,12 +200,12 @@ def test_pipeline_predict(torch_dtype: str):
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, (4, 7, 65))
|
||||
validate_tensor(samples, shape=(4, 7, 65), dtype=input_dtype)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
|
||||
samples = pipeline.predict(context[0, ...], num_samples=12, prediction_length=3)
|
||||
validate_tensor(samples, (1, 12, 3))
|
||||
validate_tensor(samples, shape=(1, 12, 3), dtype=input_dtype)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
samples = pipeline.predict(context[0, ...], num_samples=7, prediction_length=65)
|
||||
@ -214,36 +216,43 @@ def test_pipeline_predict(torch_dtype: str):
|
||||
prediction_length=65,
|
||||
limit_prediction_length=False,
|
||||
)
|
||||
validate_tensor(samples, (1, 7, 65))
|
||||
validate_tensor(samples, shape=(1, 7, 65), dtype=input_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_embed(torch_dtype: str):
|
||||
@pytest.mark.parametrize("model_dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.bfloat16])
|
||||
def test_pipeline_embed(model_dtype: torch.dtype, input_dtype: torch.dtype):
|
||||
pipeline = ChronosPipeline.from_pretrained(
|
||||
Path(__file__).parent / "dummy-chronos-model",
|
||||
device_map="cpu",
|
||||
torch_dtype=torch_dtype,
|
||||
torch_dtype=model_dtype,
|
||||
)
|
||||
d_model = pipeline.model.model.config.d_model
|
||||
context = 10 * torch.rand(size=(4, 16)) + 10
|
||||
context = 10 * torch.rand(size=(4, 16), dtype=input_dtype) + 10
|
||||
expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0)
|
||||
|
||||
# input: tensor of shape (batch_size, context_length)
|
||||
|
||||
embedding, scale = pipeline.embed(context)
|
||||
validate_tensor(embedding, (4, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (4,))
|
||||
validate_tensor(
|
||||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(scale, shape=(4,), dtype=torch.float32)
|
||||
|
||||
# input: batch_size-long list of tensors of shape (context_length,)
|
||||
|
||||
embedding, scale = pipeline.embed(list(context))
|
||||
validate_tensor(embedding, (4, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (4,))
|
||||
validate_tensor(
|
||||
embedding, shape=(4, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(scale, shape=(4,), dtype=torch.float32)
|
||||
|
||||
# input: tensor of shape (context_length,)
|
||||
embedding, scale = pipeline.embed(context[0, ...])
|
||||
validate_tensor(embedding, (1, expected_embed_length, d_model))
|
||||
validate_tensor(scale, (1,))
|
||||
validate_tensor(
|
||||
embedding, shape=(1, expected_embed_length, d_model), dtype=model_dtype
|
||||
)
|
||||
validate_tensor(scale, shape=(1,), dtype=torch.float32)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_tokens", [10, 1000, 10000])
|
||||
|
Loading…
Reference in New Issue
Block a user