# Reference: https://github.com/ctxj/Time-Series-Transformer-Pytorch/tree/main import torch import torch.nn as nn import numpy as np import pandas as pd import copy import math import time import matplotlib.pyplot as plt from torchinfo import summary from torch.utils.data import Dataset, DataLoader from torch.nn import TransformerEncoder, TransformerEncoderLayer from m0_position import PositionalEncoding class Transformer(nn.Module): def __init__(self, feature_size=200, num_layers=2, dropout=0.1): # feautre_size equals to embedding dimension (d_model) super().__init__() self.model_type = 'Transformer' self.src_mask = None self.pos_encoder = PositionalEncoding(feature_size) # Apply nhead multi-head attention # d_key, d_query, d_value = d_model // n_head self.encoder_layer = TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout) # Use num_layers encoders self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=num_layers) # For simple time-series prediction, decoder just uses FC layer self.decoder = nn.Linear(feature_size, 1) self._init_weights() def _init_weights(self): init_range = 0.1 self.decoder.bias.data.zero_() self.decoder.weight.data.uniform_(-init_range, init_range) def forward(self, src): if self.src_mask is None or self.src_mask.size(0) != len(src): device = src.device mask = self._generate_square_subsequent_mask(len(src)).to(device) self.src_mask = mask src = self.pos_encoder(src) output = self.transformer_encoder(src, self.src_mask) output = self.decoder(output) return output def _generate_square_subsequent_mask(self, size): mask = torch.tril(torch.ones(size, size) == 1) # Lower Triangular matrix mask = mask.float() mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0 return mask