文章目录
0. BasicSeq2Seq
先从入口看起,BasicSeq2Seq
类继承的是Seq2SeqModel
类,下面是关于解码的部分。可以看到训练和预测阶段的解码方式是不同的。
@templatemethod("decode")
def decode(self, encoder_output, features, labels):
decoder = self._create_decoder(encoder_output, features, labels)
if self.use_beam_search:
decoder = self._get_beam_search_decoder(decoder)
bridge = self._create_bridge(
encoder_outputs=encoder_output,
decoder_state_size=decoder.cell.state_size)
if self.mode == tf.contrib.learn.ModeKeys.INFER:
return self._decode_infer(decoder, bridge, encoder_output, features,
labels)
else:
return self._decode_train(decoder, bridge, encoder_output, features,
labels)
了解了上面这个函数之后,我们接下来会从两方面继续介绍,一个当然是我们这篇文章要介绍的BeamSearchDecoder
了,它通过_get_beam_search_decoder
返回;另一个则是bridge
,因为这个变量在论文中并没有体现,我们就先来研究一下他是什么吧。
1.Bridge类
这个我是在代码中看到的,论文中并没有。
bridge
定义了信息在编码器、解码器之间是如何传递的,所以在编码器和解码器之间是有很多bridge
链接的。
比如,encoder
之后的是一个 [ b a t c h , m ] [batch, m] [batch,m]的向量 V e V_e Ve,而decoder
却需要一个[batch size, n]
的输入向量 V d V_d Vd, m m m和 n n n是可以不一样的。这时就需要bridge
类通过不同的逻辑,将 V e V_e Ve转化为 V d V_d Vd.
来看一下基类的实现:
@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
"""一个抽象类,定义信息如何在解码器编码器之间传输。
Args:
encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
decoder_state_size: An integer or tuple of integers defining the
state size of the decoder.
"""
def __init__(self, encoder_outputs, decoder_state_size, params, mode):
Configurable.__init__(self, params, mode)
self.encoder_outputs = encoder_outputs
self.decoder_state_size = decoder_state_size
self.batch_size = tf.shape(
nest.flatten(self.encoder_outputs.final_state)[0])[0]
def __call__(self):
<