stock-transformer/m1_transformer.py
2024-08-05 14:15:06 +08:00

58 lines
2.1 KiB
Python

# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import copy
import math
import time
import matplotlib.pyplot as plt
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from m0_position import PositionalEncoding
class Transformer(nn.Module):
def __init__(self, feature_size=200, num_layers=2, dropout=0.1):
# feautre_size equals to embedding dimension (d_model)
super().__init__()
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(feature_size)
# Apply nhead multi-head attention
# d_key, d_query, d_value = d_model // n_head
self.encoder_layer = TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
# Use num_layers encoders
self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers)
# For simple time-series prediction, decoder just uses FC layer
self.decoder = nn.Linear(feature_size, 1)
self._init_weights()
def _init_weights(self):
init_range = 0.1
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-init_range, init_range)
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return output
def _generate_square_subsequent_mask(self, size):
mask = torch.tril(torch.ones(size, size) == 1) # Lower Triangular matrix
mask = mask.float()
mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
return mask