from torch import nn from torchsummary import summary class NeuralNetwork(nn.Module): def __init__(self, num_feature): super(NeuralNetwork, self).__init__() self.lstm = nn.LSTM(num_feature,64,batch_first=True) self.fc = nn.Linear(64,num_feature) def forward(self, x): output, (hidden, cell) = self.lstm(x) x = self.fc(hidden) return x