mmd loss
时间: 2025-05-10 13:31:23 浏览: 13
### MMD Loss 的定义
Maximum Mean Discrepancy (MMD) 是一种用于衡量两个分布之间差异的度量方法。它通过计算特征空间中样本均值向量之间的距离来评估两组数据是否来自相同的分布[^3]。
具体来说,MMD 定义如下:
给定两个概率分布 \( P \) 和 \( Q \),以及一个再生核希尔伯特空间(Reproducing Kernel Hilbert Space, RKHS),其对应的正定核函数为 \( k(x, y) \),则 MMD 可表示为:
\[
\text{MMD}^2(P, Q) = \| \mu_P - \mu_Q \|^2_{\mathcal{H}}
\]
其中,\( \mu_P \) 和 \( \mu_Q \) 分别是分布在 RKHS 中的均值嵌入。
展开上述表达式可以得到更具体的公式形式:
\[
\text{MMD}^2(P, Q) = \mathbb{E}_{x,x' \sim P}[k(x, x')] + \mathbb{E}_{y,y' \sim Q}[k(y, y')] - 2\mathbb{E}_{x \sim P, y \sim Q}[k(x, y)]
\]
对于有限样本集 \( X=\{x_1,\dots,x_m\} \) 和 \( Y=\{y_1,\dots,y_n\} \),可以通过经验估计近似该值:
\[
\widehat{\text{MMD}}^2(X,Y) = \frac{1}{m(m-1)}\sum_{i=1}^{m}\sum_{j\neq i}^{m}k(x_i, x_j) + \frac{1}{n(n-1)}\sum_{i=1}^{n}\sum_{j\neq i}^{n}k(y_i, y_j) - \frac{2}{mn}\sum_{i=1}^{m}\sum_{j=1}^{n}k(x_i, y_j)
\][^4]
---
### 计算方法
在实际应用中,通常采用高斯径向基函数(Gaussian RBF kernel)作为核函数 \( k(x, y) \),即:
\[
k(x, y) = \exp(-\|x-y\|^2 / (2\sigma^2))
\]
这里 \( \sigma \) 表示带宽参数,需根据具体情况调整。
以下是 Python 实现的一个简单例子:
```python
import numpy as np
def rbf_kernel(x, y, sigma=1.0):
""" 高斯径向基函数 """
pairwise_distances_sq = np.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1)
return np.exp(-pairwise_distances_sq / (2 * sigma ** 2))
def compute_mmd(x, y, kernel_func=rbf_kernel, sigma=1.0):
""" 计算 MMD 距离 """
m, n = len(x), len(y)
K_xx = kernel_func(x, x, sigma=sigma)
K_yy = kernel_func(y, y, sigma=sigma)
K_xy = kernel_func(x, y, sigma=sigma)
# 去除对角线项的影响
K_xx -= np.diag(np.diagonal(K_xx))
K_yy -= np.diag(np.diagonal(K_yy))
mmd_squared = (
(np.sum(K_xx) / (m * (m - 1)) +
np.sum(K_yy) / (n * (n - 1)) -
2 * np.mean(K_xy)))
)
return max(0, mmd_squared)[^5]
```
---
### 应用场景
#### 1. 数据域适应(Domain Adaptation)
当源域和目标域的数据分布不一致时,MMD 可以用来最小化两者间的差距,从而提高模型泛化能力。例如,在迁移学习领域,常用的方法之一就是优化以下损失函数:
\[
L(\theta) = L_\text{task}(X_S, Y_S; \theta) + \lambda \cdot \text{MMD}^2(f(X_S;\theta), f(X_T;\theta))
\]
此处 \( L_\text{task} \) 表示任务特定的损失函数;\( X_S, Y_S \) 和 \( X_T \) 分别代表源域输入及其标签、目标域输入[^6]。
#### 2. 对抗生成网络中的判别器替代方案
在 GANs 中,传统判别器可能面临模式崩溃等问题。而利用 MMD 来代替判别器能够缓解这些问题并提升训练稳定性[^7]。
#### 3. 测试分布一致性
除了机器学习之外,MMD 还广泛应用于统计学领域,比如检验两批独立采样是否来自于同一未知分布等情况下的假设验证问题。
---
阅读全文
相关推荐






