higher
是一个用于 元学习(Meta-Learning) 和 高阶导数(Higher-order gradients) 的 Python 库,专为 PyTorch 设计。它扩展了 PyTorch 的自动微分机制,使得在训练过程中可以动态地计算参数的梯度更新,并把这些更新过程纳入到更高阶的梯度计算中。
🧠 一、主要用途
higher
主要用于以下场景:
- 元学习(Meta-Learning)
- 比如 MAML(Model-Agnostic Meta-Learning)
- 内循环优化(Inner-loop optimization)
- 在少样本学习(Few-shot learning)、超参数优化等任务中
- 支持高阶导数(Second-order gradients)
- 例如在强化学习、神经网络架构搜索中使用二阶梯度信息
📦 二、安装方式
import higher
higher需要自己手动安装,否则报错。
No module named ‘higher’。
你可以通过 pip 安装:
pip install higher
或者从 GitHub 安装最新版本:
git clone https://github.com/facebookresearch/higher
cd higher
pip install -e .
🚀 三、核心功能与用法
✅ 1. 替换模型参数以支持梯度更新追踪(Differentiable parameter updates)
通常在 PyTorch 中,当你进行 optimizer.step()
时,参数的更新是“in-place”的,不会被自动微分系统记录。而 higher
提供了一种方法,可以在不修改原始模型的情况下,对参数更新进行跟踪,从而实现高阶导数的计算。
示例:使用 higher
进行一次可微分的参数更新
import torch
import torch.nn as nn
import torch.optim as optim
import higher
# 定义一个简单的模型
model = nn.Linear(2, 1)
opt = optim.SGD(model.parameters(), lr=0.1)
# 使用 higher 包装 optimizer,允许参数更新被记录
with higher.gradient_tracking():
tr_opt = higher.get_trainable_optim(opt, model.parameters())
# 假设我们有一个简单的损失函数
x = torch.randn(4, 2)
y = torch.randn(4, 1)
# 前向传播
pred = model(x)
loss = ((pred - y) ** 2).mean()
# 计算梯度并更新参数(这些更新将被记录用于高阶导数)
tr_opt.zero_grad()
grads = torch.autograd.grad(loss, list(model.parameters()))
tr_opt.step(grads)
# 现在你可以继续在这个更新后的模型上做新的前向/反向传播
new_pred = tr_opt.current_params() @ x.t()
final_loss = ((new_pred - y.t()) ** 2).mean()
# 反向传播最终损失会回传到初始参数和学习率上
final_loss.backward()
✅ 2. 快速构建 meta-learner(比如 MAML)
higher
特别适合用于实现像 MAML 这样的算法,因为它可以动态生成“更新后的模型”,并且这些更新是可微的。
示例:MAML 风格代码片段
for task in tasks:
fmodel = copy.deepcopy(model)
with higher.gradient_tracking():
tr_opt = higher.get_trainable_optim(optimizer, fmodel.parameters())
# Inner loop: 在当前任务上做几步梯度更新
for _ in range(inner_steps):
pred = fmodel(x_support)
loss = loss_fn(pred, y_support)
grads = torch.autograd.grad(loss, list(fmodel.parameters()))
tr_opt.step(grads)
# Outer loop: 测试在 query 数据上的表现,并反向传播
pred_query = fmodel(x_query)
outer_loss = loss_fn(pred_query, y_query)
outer_loss.backward() # 这个 loss 对原始 model.parameters() 的梯度会被正确计算
🧩 四、关键 API 函数
函数 | 描述 |
---|---|
higher.get_trainable_optim(optimizer, params) | 将普通 optimizer 转换为 trainable optimizer,支持梯度更新追踪 |
tr_opt.step(grads) | 手动执行一步带梯度的参数更新 |
tr_opt.current_params() | 获取当前更新后的参数列表 |
higher.gradient_tracking() | 上下文管理器,启用梯度追踪模式 |
🧪 五、适用领域 & 实际应用
- Few-shot Learning(少样本学习)
- 如 MAML、Reptile
- Continual Learning(持续学习)
- Hyperparameter Optimization(超参数优化)
- Neural Architecture Search(神经结构搜索)
- Reinforcement Learning(强化学习)