代码阅读-官方tf版BeamSearch

本文详细探讨了TensorFlow官方实现的BeamSearch解码过程,从BasicSeq2Seq模型入手,分析了Bridge类的ZeroBridge、PassThroughBridge和InitialStateBridge三种子类在不同情况下的应用。接着,重点讲解了BeamSearchDecoder类的step函数,包括其内部的长度惩罚因子和如何选择下一个最优beam。最后,总结了BeamSearch解码的原理和在实际编码中的应用技巧。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

官方代码

0. BasicSeq2Seq

先从入口看起,BasicSeq2Seq类继承的是Seq2SeqModel类,下面是关于解码的部分。可以看到训练和预测阶段的解码方式是不同的。

@templatemethod("decode")
  def decode(self, encoder_output, features, labels):
    decoder = self._create_decoder(encoder_output, features, labels)
    if self.use_beam_search:
      decoder = self._get_beam_search_decoder(decoder)

    bridge = self._create_bridge(
        encoder_outputs=encoder_output,
        decoder_state_size=decoder.cell.state_size)
    if self.mode == tf.contrib.learn.ModeKeys.INFER:
      return self._decode_infer(decoder, bridge, encoder_output, features,
                                labels)
    else:
      return self._decode_train(decoder, bridge, encoder_output, features,
                                labels)

了解了上面这个函数之后,我们接下来会从两方面继续介绍,一个当然是我们这篇文章要介绍的BeamSearchDecoder了,它通过_get_beam_search_decoder返回;另一个则是bridge,因为这个变量在论文中并没有体现,我们就先来研究一下他是什么吧。

1.Bridge类

这个我是在代码中看到的,论文中并没有。

bridge定义了信息在编码器、解码器之间是如何传递的,所以在编码器和解码器之间是有很多bridge链接的。

比如,encoder之后的是一个 [ b a t c h , m ] [batch, m] [batch,m]的向量 V e V_e Ve,而decoder却需要一个[batch size, n]的输入向量 V d V_d Vd m m m n n n是可以不一样的。这时就需要bridge类通过不同的逻辑,将 V e V_e Ve转化为 V d V_d Vd.

来看一下基类的实现:

@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
  """一个抽象类,定义信息如何在解码器编码器之间传输。
  
  Args:
    encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
    decoder_state_size: An integer or tuple of integers defining the
      state size of the decoder.
  """

  def __init__(self, encoder_outputs, decoder_state_size, params, mode):
    Configurable.__init__(self, params, mode)
    self.encoder_outputs = encoder_outputs
    self.decoder_state_size = decoder_state_size
    self.batch_size = tf.shape(
        nest.flatten(self.encoder_outputs.final_state)[0])[0]

  def __call__(self):
    <
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值