python怎么把线性回归训练模型可视化
时间: 2025-04-18 22:32:49 浏览: 22
### Python 线性回归模型可视化的方法
为了更好地理解线性回归的结果,可以利用 `matplotlib` 和其他辅助库来绘制图形。下面是一个完整的例子,展示了一元线性回归以及多元线性回归的可视化过程。
#### 一元线性回归可视化示例
假设有一个简单的数据集,其中包含两个变量:广告投入量 (`X`) 和产品销售量 (`Y`)。可以通过如下方式加载并预处理这些数据:
```python
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
# 加载数据
data = {'Advertising': [1, 2, 3, 4, 5], 'Sales': [1, 1.2, 1.8, 2.5, 3]}
df = pd.DataFrame(data)
# 划分特征和标签
X = df[['Advertising']]
y = df['Sales']
# 创建训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练线性回归模型
model = LinearRegression()
model.fit(X_train, y_train)
# 预测
predictions = model.predict(X)
# 绘制原始数据点及拟合直线
plt.scatter(df['Advertising'], df['Sales'], color='blue', label="Data points") # 散点图显示实际观测值
plt.plot(df['Advertising'], predictions, color='red', linewidth=2, label="Fitted line") # 连接预测值得到最佳拟合线
plt.xlabel('Advertising')
plt.ylabel('Sales')
plt.title('Linear Regression of Advertising vs Sales')
plt.legend()
plt.show() # 显示图像
```
这段代码展示了如何读取数据、划分训练集与测试集、建立线性回归模型,并最终通过散点图加拟合曲线的形式展现出来[^1]。
#### 多元线性回归可视化示例
对于多维的数据来说,虽然无法像一维那样直观地画出一条直线,但是仍然能够借助三维空间中的曲面或者其他形式来进行表示。这里给出一个多因素影响下的房价预测案例作为说明:
```python
from mpl_toolkits.mplot3d import Axes3D
# 构建更复杂的数据结构,比如考虑多个自变量的影响
data_multivariate = {
"Size": [852, 976, 1080, 1200],
"Bedrooms": [2, 3, 3, 4],
"Price": [160000, 220000, 240000, 300000]
}
df_multi = pd.DataFrame(data_multivariate)
# 定义输入输出向量
X_multi = df_multi[["Size", "Bedrooms"]]
y_multi = df_multi["Price"]
# 使用相同的流程训练模型...
multi_model = LinearRegression().fit(X_multi, y_multi)
predicted_prices = multi_model.predict(X_multi)
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(
xs=df_multi['Size'],
ys=df_multi['Bedrooms'],
zs=y_multi,
c='b',
marker='o'
)
trisurf = ax.plot_trisurf(
X_multi['Size'].values.flatten(),
X_multi['Bedrooms'].values.flatten(),
predicted_prices.flatten(),
alpha=.5,
cmap=plt.cm.viridis
)
ax.set_xlabel('House Size (sq ft)')
ax.set_ylabel('Number of Bedrooms')
ax.set_zlabel('Home Price ($)')
plt.show()
```
此部分代码实现了对房屋大小(`Size`)和卧室数量(`Bedrooms`)这两个属性共同作用下价格变化趋势的模拟,并将其以三维立体图的方式呈现给读者。
阅读全文
相关推荐


















