stock-transformer/d2_datasets.py
2024-08-05 14:15:06 +08:00

56 lines
1.7 KiB
Python

# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import copy
import math
import time
import matplotlib.pyplot as plt
# from torchinfo import summary
from torchsummary import summary
from a0_config import device, output_window,input_window,batch_size,USE_CUDA
def create_inout_sequences(input_data, input_window):
inout_seq = []
L = len(input_data)
for i in range(L - input_window):
train_seq = input_data[i:i + input_window]
train_label = input_data[i + output_window: i + input_window + output_window]
inout_seq.append((train_seq, train_label))
return torch.FloatTensor(np.array(inout_seq))
def get_data(data_raw, split):
split = round(split * len(data_raw))
train_data = data_raw[:split]
test_data = data_raw[split:]
train_data = train_data.cumsum()
train_data = 2 * train_data # Training data scaling
test_data = test_data.cumsum()
train_sequence = create_inout_sequences(train_data, input_window)
train_sequence = train_sequence[:-output_window]
test_sequence = create_inout_sequences(test_data, input_window)
test_sequence = test_sequence[:-output_window]
return train_sequence.to(device), test_sequence.to(device)
def get_batch(source, i, batch_size):
seq_len = min(batch_size, len(source) - 1 - i)
data = source[i:i+seq_len]
data_in = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window, 1))
target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window, 1))
return data_in, target