y = wx + b 扩展为包含多项式项(如 x²、x³)的形式

线性模型 y = wx + b 扩展为包含多项式项(如 )的形式。这种情况下,你需要修改模型的输入特征或前向传播逻辑。以下是两种实现方法:

方法一:手动扩展输入特征(推荐)

如果你想实现多项式回归(例如 y = w₁x + w₂x² + w₃x³ + b),可以在输入时手动扩展特征维度,然后使用线性层处理这些扩展后的特征。

python

运行

import torch
import torch.nn as nn

class PolynomialModel(nn.Module):
    def __init__(self, degree=3):
        """
        degree: 多项式的最高次数(如3表示包含x、x²、x³)
        """
        super(PolynomialModel, self).__init__()
        self.degree = degree
        # 输入维度为degree(例如3),输出维度为1
        self.linear = nn.Linear(degree, 1)
    
    def forward(self, x):
        """
        x: 输入张量,形状为 [batch_size, 1]
        """
        # 手动扩展x为[x, x², x³, ...]
        batch_size = x.shape[0]
        x_powers = torch.cat([x**i for i in range(1, self.degree+1)], dim=1)
        # x_powers 形状变为 [batch_size, degree]
        
        # 通过线性层计算输出
        return self.linear(x_powers)

# 测试模型
model = PolynomialModel(degree=3)  # 创建3次多项式模型
x = torch.tensor([[2.0]])  # 输入 x=2,形状为 [1, 1]
print(f"输入 x={x.item()}")
print(f"输出 y={model(x).item()}")

方法二:自定义参数(不使用 nn.Linear)

如果你想更明确地控制每个多项式项的权重,可以自定义参数而不使用 nn.Linear

python

运行

import torch
import torch.nn as nn

class CustomPolynomialModel(nn.Module):
    def __init__(self, degree=3):
        super(CustomPolynomialModel, self).__init__()
        self.degree = degree
        # 为每个多项式项创建独立的权重
        self.weights = nn.ParameterList([
            nn.Parameter(torch.randn(1)) for _ in range(degree)
        ])
        self.bias = nn.Parameter(torch.randn(1))
    
    def forward(self, x):
        # 计算 y = w₁x + w₂x² + w₃x³ + ... + b
        output = self.bias
        for i in range(self.degree):
            output += self.weights[i] * (x ** (i+1))
        return output

# 测试模型
model = CustomPolynomialModel(degree=3)
x = torch.tensor([[2.0]])
print(f"输入 x={x.item()}")
print(f"输出 y={model(x).item()}")

两种方法的对比

方法优点缺点适用场景
手动扩展输入利用 PyTorch 内置层,结构清晰需要手动处理特征扩展希望复用线性层优化和初始化逻辑
自定义参数直观控制每个参数需手动实现参数管理需要明确解释每个多项式项的权重

可视化:多项式次数对模型的影响

不同次数的多项式可以拟合不同复杂度的曲线:

  • 1 次多项式:y = w₁x + b(直线)
  • 2 次多项式:y = w₁x + w₂x² + b(抛物线)
  • 3 次多项式:y = w₁x + w₂x² + w₃x³ + b(S 型曲线)

通过调整 degree 参数,你可以控制模型的复杂度。

实际使用建议

  1. 数据预处理:在扩展特征时,建议对输入进行归一化(如 x = (x - mean) / std),避免因  导致数值范围差异过大。
  2. 过拟合风险:高次多项式容易过拟合,建议配合正则化(如 L2 正则化)或早停策略使用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值