Wasserstein 距离
时间: 2025-06-21 14:27:18 浏览: 15
### Wasserstein距离的概念
Wasserstein距离,也称为Earth Mover's Distance (EMD),衡量两个概率分布之间的差异。具体来说,在最优传输理论框架下,该距离计算将一个分布转换成另一个所需的最小工作量。对于任意的概率测度 \( P \) 和 \( Q \),以及成本函数 \( c(x,y) \),Wasserstein 距离定义如下:
\[ W(P,Q)=\inf_{\gamma\in\Pi(P,Q)}\int_{X\times Y}c(x,y)\mathrm{d}\gamma(x,y)[^1]
这里 \( \Pi(P,Q) \) 表示所有边缘分别为 \( P \) 和 \( Q \) 的联合分布集合。
### 应用于机器学习中的场景
#### 生成对抗网络(GAN)
在生成对抗网络中,传统方法采用Jensen-Shannon散度来评估生成样本的质量;但是这种方法可能导致训练过程不稳定或梯度消失等问题。为此提出了使用Wasserstein距离作为替代方案之一,即所谓的WGAN模型[^2]。通过引入Lipschitz约束条件下的批评家(critic),可以更稳定有效地优化目标函数并改善最终效果。
#### 数据增强与迁移学习
除了图像生成领域外,Wasserstein距离还在其他方面有着广泛的应用价值。例如,在数据集之间存在较大差距的情况下实施迁移学习时,可以通过最小化源域和目标域间特征表示上的Wasserstein距离实现更好的泛化性能[^3]。
```python
import torch
from torch import nn
class Critic(nn.Module):
def __init__(self, input_dim=784, hidden_dim=512):
super(Critic, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.model(x)
def wasserstein_loss(real_scores, fake_scores):
"""Compute the Wasserstein loss."""
return -(torch.mean(real_scores) - torch.mean(fake_scores))
```
阅读全文
相关推荐



















