Tensorflow 的NCE-Loss的实现和word2vec
这两天因为实现mxnet的nce-loss,因此研究了一下tensorflow的nce-loss的实现。所以总结一下。
先看看tensorflow的nce-loss的API:
def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
num_true=1,
sampled_values=None,
remove_accidental_hits=False,
partition_strategy="mod",
name="nce_loss")
假设nce_loss之前的输入数据是K维的,一共有N个类,那么
- weight.shape = (N, K)
- bias.shape = (N)
- inputs.shape = (batch_size, K)
- labels.shape = (batch_size, num_true)
- num_true : 实际的正样本个数
- num_sampled: 采样出多少个负样本
- num_classes = N
- sampled_values: 采样出的负样本,如果是None,就会用不同的sampl