股价预测pytorch
时间: 2025-05-31 13:32:27 浏览: 14
### 使用 PyTorch 实现股价预测模型
为了实现一个基于 PyTorch 的 LSTM 股价预测模型,以下是详细的说明和代码示例:
#### 数据预处理
在构建模型之前,需要对时间序列数据进行标准化和分割。这一步骤对于提升模型性能至关重要。
```python
import numpy as np
from sklearn.preprocessing import MinMaxScaler
def preprocess_data(data, seq_length):
scaler = MinMaxScaler(feature_range=(-1, 1))
data_scaled = scaler.fit_transform(data.reshape(-1, 1))
X, y = [], []
for i in range(len(data_scaled) - seq_length):
X.append(data_scaled[i:i + seq_length])
y.append(data_scaled[i + seq_length])
X = np.array(X).reshape(-1, seq_length, 1)
y = np.array(y)
return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32), scaler
```
此函数将输入的时间序列数据转换为适合 LSTM 输入的形式,并对其进行缩放[^1]。
---
#### 定义 LSTM 模型结构
下面是一个简单的 LSTM 模型定义,该模型接受历史价格作为输入并输出未来的预测值。
```python
import torch.nn as nn
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_layer_size=50, output_size=1):
super(LSTMModel, self).__init__()
self.hidden_layer_size = hidden_layer_size
self.lstm = nn.LSTM(input_size, hidden_layer_size)
self.linear = nn.Linear(hidden_layer_size, output_size)
def forward(self, inputs):
lstm_out, _ = self.lstm(inputs.view(len(inputs), 1, -1))
predictions = self.linear(lstm_out[-1].view(1, -1))
return predictions
```
这段代码定义了一个具有单层 LSTM 和线性回归输出的神经网络架构[^1]。
---
#### 训练过程
训练阶段涉及损失计算、反向传播以及参数更新。
```python
model = LSTMModel()
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100): # 假设迭代次数为100次
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = loss_function(outputs, y_train)
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
```
以上代码展示了如何通过梯度下降法最小化均方误差来优化模型权重[^1]。
---
#### 测试与可视化
完成训练后,在测试集上验证模型表现,并绘制实际值与预测值对比图。
```python
with torch.no_grad():
model.eval()
test_predictions = model(X_test)
fig = plt.figure(figsize=(10, 6), dpi=100)
plt.plot(test_predictions.numpy().flatten(), label="Predicted", color="red")
plt.plot(y_test.numpy().flatten(), label="Actual", color="blue")
plt.title("Stock Price Prediction vs Actual")
plt.xlabel("Time Step")
plt.ylabel("Price")
plt.legend()
plt.show()
```
这部分代码用于展示模型在未见过的数据上的泛化能力[^3]。
---
#### 总结
上述流程涵盖了从数据准备到模型部署的关键环节。值得注意的是,虽然当前模型能够捕捉部分趋势特征,但在实际应用中仍需考虑更多因素如市场情绪、宏观经济指标等以进一步增强其鲁棒性和准确性[^2]。
阅读全文
相关推荐















