mirror of
https://github.com/dupenf/stock-transformer.git
synced 2024-11-25 16:08:34 +08:00
33 lines
781 B
Python
33 lines
781 B
Python
# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import pandas as pd
|
|
import copy
|
|
import math
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
from m3_train import predict
|
|
from d2_datasets import get_batch, get_data
|
|
from d1_features import log_features
|
|
|
|
|
|
|
|
model = torch.load("saved_weights.pt")
|
|
log_prices = log_features()
|
|
train_data, test_data = get_data(log_prices, 0.9)
|
|
predicted_seq, real_seq = predict(model, test_data)
|
|
|
|
fig2, ax2 = plt.subplots(1, 1)
|
|
|
|
ax2.plot(predicted_seq, color='red', alpha=0.7)
|
|
ax2.plot(real_seq, color='blue', linewidth=0.7)
|
|
ax2.legend(['Actual', 'Forecast'])
|
|
ax2.set_xlabel('Time Steps')
|
|
ax2.set_ylabel('Log Prices')
|
|
|
|
fig2.tight_layout()
|
|
|