论文阅读笔记 Prototypical Networks for Few-shot Learning

本文介绍了原型网络,一种解决小样本学习问题的模型,尤其适用于只有少量样例的新类别。该网络通过学习将数据转换到特征空间,使样本聚集在每个类别的单一原型周围进行分类。与匹配网络相比,原型网络更简单,且在多个基准任务上表现出色。此外,作者还探讨了其在零样本学习的应用,通过元数据转换为特征向量作为类别原型进行分类。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

小样本学习的原型网络

论文原文链接:https://arxiv.org/abs/1703.05175

摘要

作者提出了一种小样本分类问题的原型网络,在这种网络中,分类器必须推广到训练集中没有的新类别,每个新类别只有少量样例。该原型网络学习一个度量空间,通过计算每个类的原型表示的距离进行分类。与最近的小样本学习方法相比,该方法反映出更简单的归纳偏好,这有益于这种有限数据的状况,因此取得了出色的效果。作者给出的分析表明,某些简单的设计决定能够产生实质的改进,超过了最近的方法,包括复杂的架构选择和元学习。作者把该原型网络进一步扩展到零样本学习,在 CU-Birds 数据集上取得了最高水准的结果。

1 引言

在小样本分类[20,16,13]任务中,分类器必须适应训练中没见过的新类别,每个类别只有少量样例。朴素的方法,例如,在新数据上再次训练模型,会严重过拟合。虽然这个问题很难,但事实表明人类能够处理,即使是单样例分类,每个类别只有一个样例,也能达到很高的准确率[16]。

在小样本学习领域,最近有两种方法取得了重大进展。Vinyals 等人[29]提出的匹配网络(matching networks),首先在带标签的数据集(支持集)上学到特征向量,然后对该特征向量使用注意力机制来预测无标签数据(查询集)的类别。匹配网络可解释为,一个加权最近邻分类器应用于一个特征空间。尤其是,这个模型在训练期间利用采样得到的小批量样例(称为片段),设计的每个片段是用来模拟小样本任务,这种片段是通过对类别和样例进行二段抽样得到的。片段的使用使得模型训练对于测试环境来说更加准确可靠,因此提高了模型的泛化能力。Ravi 等人[22]进一步发展了片段训练的思想,提出了小样本学习的元学习方法,该方法会训练一个LSTM[9]来产生分类器的更新,给定一个片段,使得该片段很好地泛化到测试集(这句话不太理解)。这里,不是在多个片段上训练一个单独的模型,而是LSTM元学习器学习为每个片段训练一个量身定制的模型。

作者通过解决过拟合的关键问题,来解决小样本学习问题。由于数据非常有限,作者假设分类器应该具有非常简单的归纳偏好。作者提出的原型网络的核心思想是,存在一个特征空间,在该特征空间中,样例簇拥在每个类的单一原型表示周围。为了做到这一点,作者使用神经网络把所有原始数据转换到一个特征空间,然后把支持集样例特征向量的均值作为类别原型,只是通过寻找最近的类别原型来对查询样例进行分类。

作者采用同样的方法来解决零样本学习问题,在零样本学习中,没有带标签的样例,每个类别只有元数据给出的粗略描述。因此作者把这些元数据转换为特征向量,作为每个类别的原型。对查询样例的分类方法和小样本场景一样,都是寻找距离查询样例最近的类别原型。
在这里插入图片描述

图1 小样本和零样本场景中的原型网络。左图:小样本原型是每个类别的支持样例的嵌入特征的均值。右图:零样本原型是每个类别的原数据的嵌入特征。不管在哪种情况下,对查询样例的分类方法都是相同的,即先计算查询样例的嵌入特征与每个类别原型的距离,然后取softmax值进行分类。

本文阐明了小样本学习的原型网络,以及零样本学习的原型网络。绘制了单样本学习中匹配网络的关系(这句话不太理解),分析了模型中使用的距离函数。特别地,当距离用布雷格曼散度(Bregman divergence)计算时,例如,平方欧式距离,为了证明使用类别均值作为原型更合理,作者把该原型网络与聚类联系起来。作者从经验上发现距离的选择至关重要,欧式距离远远好于更常用的余弦相似度。在一些基准任务上,达到了最高性能。原型网络比最近的元学习算法更简单、更有效,这使得原型网络成为吸引人的方法。

2 原型网络

2.1 符号定义

支持集 S = { ( x 1 , y 1 ) , . . . , ( x n , y n ) ∣ x i ∈ R D , y i ∈ { 1 , . . . , k } } S=\{(x_1, y_1), ..., (x_n, y_n) | x_i\in R^D, y_i\in\{1, ..., k\}\} S={ (x1,y1),...,(xn,yn)xiRD,yi{ 1,...,k}} 有 n 个样例,k 个类别,其中 x i x_i xi 是样例的 D 维特征向量, y i y_i yi 是对应的类别标签。

S k S_k Sk 表示类别 k 的带标签样例集。

2.2 模型

原型网络通过嵌入特征函数 f ϕ : R D → R M f_\phi :R^D\rightarrow R^M fϕ:RDRM ϕ \phi ϕ 是可学习的参数)为每个类计算一个 M 维的特征向量 c k ∈ R M c_k\in R^M ckRM 称为原型。每个类的原型是该类别的支持样例的嵌入特征的均值。

c k = 1 ∣ S k ∣ ∑ ( x i , y i ) ∈ S k f ϕ ( x i ) c_k=\frac {1}{|S_k|}\sum_{(x_i,y_i)\in S_k}f_\phi(x_i) ck=Sk1(xi,yi)Skf

### 关于Prototypical Networks for Few-Shot Learning论文 Prototypical Networks 是一种针对小样本学习(Few-Shot Learning)设计的方法,旨在通过构建类别原型来完成新类别的快速识别[^1]。该方法的核心思想是在嵌入空间中为每个类别定义一个原型向量,并利用这些原型向量计算输入样本与各个类别的相似度。 #### 论文下载方式 Prototypical Networks 的原始论文《Prototypical Networks for Few-shot Learning》可以在以下链接找到并下载: - 原始论文地址:[https://arxiv.org/abs/1703.05175](https://arxiv.org/abs/1703.05175) 如果无法直接访问上述链接,可以尝试通过其他学术资源网站获取PDF版本,例如Google Scholar 或者 ResearchGate。此外,一些开源项目也提供了对该论文的解读和实现细节说明[^2][^3]。 #### 可视化的实现 对于可视化部分,作者展示了在嵌入空间中的 t-SNE 图像,其中类别原型用黑色表示,而错误分类的字符则用红色高亮显示[^4]。要实现类似的可视化效果,可以通过以下步骤: 1. **提取嵌入特征** 使用训练好的模型对测试数据进行前向传播,得到每张图像对应的嵌入向量。 2. **应用t-SNE降维** 利用 `sklearn.manifold.TSNE` 将高维嵌入向量投影到二维或三维空间以便绘制图形。 3. **绘图展示** 使用 Matplotlib 或 Seaborn 绘制散点图,在图中标记出不同类别的原型位置以及误分类的数据点。 以下是基于PyTorch实现的一个简单代码示例: ```python import torch from sklearn.manifold import TSNE import matplotlib.pyplot as plt def visualize_embeddings(embeddings, labels, prototypes): """ embeddings: 测试样本的嵌入向量 (NxD) labels: 对应的真实标签 (Nx1) prototypes: 类别原型向量 (KxD) """ all_data = torch.cat([embeddings, prototypes], dim=0).cpu().numpy() tsne = TSNE(n_components=2, random_state=42) embedded_data = tsne.fit_transform(all_data) # 提取样本点和原型点的位置 sample_points = embedded_data[:len(embeddings)] prototype_points = embedded_data[len(embeddings):] # 绘制散点图 plt.figure(figsize=(8, 6)) unique_labels = set(labels.cpu().numpy()) colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels))) for i, color in zip(unique_labels, colors): idx = np.where(labels == i)[0] plt.scatter(sample_points[idx, 0], sample_points[idx, 1], c=color, label=f'Class {i}') # 添加原型点标记 plt.scatter(prototype_points[:, 0], prototype_points[:, 1], marker='*', s=200, c='black', label='Prototypes') plt.legend() plt.title('t-SNE Visualization of Embedding Space') plt.show() ``` 此函数接受嵌入向量、真实标签以及类别原型作为输入参数,并生成相应的t-SNE可视化图表。 --- ###
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值