stock-transformer/m3_train.py
2024-08-05 14:15:06 +08:00

75 lines
2.5 KiB
Python

# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import copy
import math
import time
import matplotlib.pyplot as plt
from d2_datasets import get_batch
# def train(model,train_data, optimizer,scheduler, batch_size,):
# model.train() # Turn on the evaluation mode
# total_loss = 0.
# start_time = time.time()
# for batch, i in enumerate(range(0, len(train_data) - 1, batch_size)):
# data, targets = get_batch(train_data, i, batch_size)
# optimizer.zero_grad()
# output = model(data)
# loss = criterion(output, targets)
# loss.backward()
# nn.utils.clip_grad_norm_(model.parameters(), 0.7)
# optimizer.step()
# total_loss = total_loss + loss.item()
# log_interval = int(len(train_data) / batch_size / 5)
# if batch % log_interval == 0 and batch > 0:
# cur_loss = total_loss / log_interval
# elapsed = time.time() - start_time
# print('| epoch {:3d} | {:5d}/{:5d} batches | '
# 'lr {:02.10f} | {:5.2f} ms | '
# 'loss {:5.7f}'.format(
# epoch, batch, len(train_data) // batch_size, scheduler.get_lr()[0],
# elapsed * 1000 / log_interval,
# cur_loss))
# total_loss = 0
# start_time = time.time()
def evaluate(model, data_source,criterion):
model.eval() # Turn on the evaluation mode
total_loss = 0.
eval_batch_size = 1000
with torch.no_grad():
for i in range(0, len(data_source) - 1, eval_batch_size):
data, targets = get_batch(data_source, i, eval_batch_size)
output = model(data)
total_loss = total_loss + len(data[0]) * criterion(output, targets).cpu().item()
return total_loss / len(data_source)
def predict(model, sequences):
start_timer = time.time()
model.eval()
predicted_seq = torch.Tensor(0)
real_seq = torch.Tensor(0)
with torch.no_grad():
for i in range(0, len(sequences) - 1):
data, target = get_batch(sequences, i, 1)
output = model(data)
predicted_seq = torch.cat((predicted_seq, output[-1].view(-1).cpu()), 0)
real_seq = torch.cat((real_seq, target[-1].view(-1).cpu()), 0)
timed = time.time() - start_timer
print(f"{timed} sec")
return predicted_seq, real_seq