mirror of
https://github.com/dupenf/stock-lstm.git
synced 2024-11-25 16:22:36 +08:00
20 lines
438 B
Python
20 lines
438 B
Python
from torch import nn
|
|
from torchsummary import summary
|
|
|
|
|
|
class NeuralNetwork(nn.Module):
|
|
def __init__(self, num_feature):
|
|
super(NeuralNetwork, self).__init__()
|
|
self.lstm = nn.LSTM(num_feature,64,batch_first=True)
|
|
self.fc = nn.Linear(64,num_feature)
|
|
|
|
def forward(self, x):
|
|
output, (hidden, cell) = self.lstm(x)
|
|
x = self.fc(hidden)
|
|
return x
|
|
|
|
|
|
|
|
|
|
|