stock-lstm/m1_model.py
2024-08-04 19:05:34 +08:00

20 lines
438 B
Python

from torch import nn
from torchsummary import summary
class NeuralNetwork(nn.Module):
def __init__(self, num_feature):
super(NeuralNetwork, self).__init__()
self.lstm = nn.LSTM(num_feature,64,batch_first=True)
self.fc = nn.Linear(64,num_feature)
def forward(self, x):
output, (hidden, cell) = self.lstm(x)
x = self.fc(hidden)
return x