mirror of
https://github.com/dupenf/stock-transformer.git
synced 2024-11-25 16:08:34 +08:00
58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import pandas as pd
|
|
import copy
|
|
import math
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
|
|
from torchinfo import summary
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
|
from m0_position import PositionalEncoding
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self, feature_size=200, num_layers=2, dropout=0.1):
|
|
# feautre_size equals to embedding dimension (d_model)
|
|
super().__init__()
|
|
self.model_type = 'Transformer'
|
|
|
|
self.src_mask = None
|
|
self.pos_encoder = PositionalEncoding(feature_size)
|
|
|
|
# Apply nhead multi-head attention
|
|
# d_key, d_query, d_value = d_model // n_head
|
|
self.encoder_layer = TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
|
|
|
|
# Use num_layers encoders
|
|
self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
|
|
|
# For simple time-series prediction, decoder just uses FC layer
|
|
self.decoder = nn.Linear(feature_size, 1)
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
init_range = 0.1
|
|
self.decoder.bias.data.zero_()
|
|
self.decoder.weight.data.uniform_(-init_range, init_range)
|
|
|
|
def forward(self, src):
|
|
if self.src_mask is None or self.src_mask.size(0) != len(src):
|
|
device = src.device
|
|
mask = self._generate_square_subsequent_mask(len(src)).to(device)
|
|
self.src_mask = mask
|
|
|
|
src = self.pos_encoder(src)
|
|
output = self.transformer_encoder(src, self.src_mask)
|
|
output = self.decoder(output)
|
|
return output
|
|
|
|
def _generate_square_subsequent_mask(self, size):
|
|
mask = torch.tril(torch.ones(size, size) == 1) # Lower Triangular matrix
|
|
mask = mask.float()
|
|
mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
|
|
mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
|
|
return mask |