Python机器学习元学习库higher

higher 是一个用于 元学习(Meta-Learning)高阶导数(Higher-order gradients) 的 Python 库,专为 PyTorch 设计。它扩展了 PyTorch 的自动微分机制,使得在训练过程中可以动态地计算参数的梯度更新,并把这些更新过程纳入到更高阶的梯度计算中。


🧠 一、主要用途

higher 主要用于以下场景:

  • 元学习(Meta-Learning)
    • 比如 MAML(Model-Agnostic Meta-Learning)
  • 内循环优化(Inner-loop optimization)
    • 在少样本学习(Few-shot learning)、超参数优化等任务中
  • 支持高阶导数(Second-order gradients)
    • 例如在强化学习、神经网络架构搜索中使用二阶梯度信息

📦 二、安装方式

import higher

higher需要自己手动安装,否则报错。

No module named ‘higher’。

你可以通过 pip 安装:

pip install higher

或者从 GitHub 安装最新版本:

git clone https://github.com/facebookresearch/higher
cd higher
pip install -e .

🚀 三、核心功能与用法

✅ 1. 替换模型参数以支持梯度更新追踪(Differentiable parameter updates)

通常在 PyTorch 中,当你进行 optimizer.step() 时,参数的更新是“in-place”的,不会被自动微分系统记录。而 higher 提供了一种方法,可以在不修改原始模型的情况下,对参数更新进行跟踪,从而实现高阶导数的计算。

示例:使用 higher 进行一次可微分的参数更新
import torch
import torch.nn as nn
import torch.optim as optim
import higher

# 定义一个简单的模型
model = nn.Linear(2, 1)
opt = optim.SGD(model.parameters(), lr=0.1)

# 使用 higher 包装 optimizer,允许参数更新被记录
with higher.gradient_tracking():
    tr_opt = higher.get_trainable_optim(opt, model.parameters())

# 假设我们有一个简单的损失函数
x = torch.randn(4, 2)
y = torch.randn(4, 1)

# 前向传播
pred = model(x)
loss = ((pred - y) ** 2).mean()

# 计算梯度并更新参数(这些更新将被记录用于高阶导数)
tr_opt.zero_grad()
grads = torch.autograd.grad(loss, list(model.parameters()))
tr_opt.step(grads)

# 现在你可以继续在这个更新后的模型上做新的前向/反向传播
new_pred = tr_opt.current_params() @ x.t()
final_loss = ((new_pred - y.t()) ** 2).mean()

# 反向传播最终损失会回传到初始参数和学习率上
final_loss.backward()

✅ 2. 快速构建 meta-learner(比如 MAML)

higher 特别适合用于实现像 MAML 这样的算法,因为它可以动态生成“更新后的模型”,并且这些更新是可微的。

示例:MAML 风格代码片段
for task in tasks:
    fmodel = copy.deepcopy(model)
    with higher.gradient_tracking():
        tr_opt = higher.get_trainable_optim(optimizer, fmodel.parameters())

    # Inner loop: 在当前任务上做几步梯度更新
    for _ in range(inner_steps):
        pred = fmodel(x_support)
        loss = loss_fn(pred, y_support)
        grads = torch.autograd.grad(loss, list(fmodel.parameters()))
        tr_opt.step(grads)

    # Outer loop: 测试在 query 数据上的表现,并反向传播
    pred_query = fmodel(x_query)
    outer_loss = loss_fn(pred_query, y_query)
    outer_loss.backward()  # 这个 loss 对原始 model.parameters() 的梯度会被正确计算

🧩 四、关键 API 函数

函数描述
higher.get_trainable_optim(optimizer, params)将普通 optimizer 转换为 trainable optimizer,支持梯度更新追踪
tr_opt.step(grads)手动执行一步带梯度的参数更新
tr_opt.current_params()获取当前更新后的参数列表
higher.gradient_tracking()上下文管理器,启用梯度追踪模式

🧪 五、适用领域 & 实际应用

  • Few-shot Learning(少样本学习)
    • 如 MAML、Reptile
  • Continual Learning(持续学习)
  • Hyperparameter Optimization(超参数优化)
  • Neural Architecture Search(神经结构搜索)
  • Reinforcement Learning(强化学习)

🔗 六、相关资源

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

音程

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值