短文本匹配模型-ESIM

论文链接:http://tongtianta.site/paper/11096

文本匹配是智能问答(社区问答)中的关键环节,用于判断两个句子的语义是否相似。机器智能问答FAQ中,输入新文本(语音转文本)后,和对话库内已有句子进行匹配,匹配完成后输出对应问题答案。而这里主要研究的就是两个句子如何计算它们之间语义相似度的问题。 

一、原理

Enhanced LSTM for Natural Language Inference(ESIM)是最近提出的效果比较好的一个文本相似度计算模型,该模型横扫了很多比赛,我们这里介绍该模型的基本原理和pytorch版本实现。

从ESIM的全名可以看出,这是一种转为自然语言推断而生的加强版LSTM,由原文中知这种精心设计的链式LSTM顺序推理模型可以胜过以前很多复杂的模型。

ESIM的模型架构如图1所示,它的主要组成为:输入编码,局部推理建模和推理组合。图1中垂直方向描绘了三个主要组件,水平方向,左侧代表连续语言推断模型,名为ESIM,右侧代表在树LSTM中包含语法分析信息的网络。

图1 ESIM模型图

这里用a=(a_1, ...,a_{l_a})b=(b_1, ...,b_{l_b})分别表示两个句子,其中l_a表示句子a的长度,l_b表示句子b的长度,这里在原文中a又称为前提,b为假设。

1.输入编码

采用BiLSTM(双向LSTM)构建输入编码,使用该模型对输入的两个句子分别编码,表示如下(1)(2)所示:

                                                          \bar{a}_{i}=BiLSTM(a, i),\forall i\in [1, ...,l_a]          (1)

                                                           \bar{b}_{i}=BiLSTM(b, i),\forall i\in [1, ...,l_b]          (2)

我们将输入序列a在时间i由BiLSTM生成的隐藏(输出)状态写为\bar{a}。这同样适用于b

2.局部推理建模

这里使用点积的方式计算\bar{a}\bar{b}之间的attention权重,公式如(3)所示,论文作者提到attention权重他们也尝试了其他更加复杂的方法,但是并没有取得更好的效果。

                                                                                   e_{ij}=\bar{a}_{i}^{T}\bar{b}_j                            (3)

计算两个句子之间的交互表示:

                                                        \tilde{a}_{i}=\sum _{j=1}^{l_b}\frac{exp(e_{ij})}{\sum_{k=1}^{l_b}exp(e_{ik})}\bar{b_{j}},\forall i\in [1, ...,l_a]     (4)

                                                         \tilde{b}_{j}=\sum _{i=1}^{l_a}\frac{exp(e_{ij})}{\sum_{k=1}^{l_b}exp(e_{kj})}\bar{a_{i}},\forall j\in [1, ...,l_b]    (5)

其中\tilde{a}_i\{\bar{b}_j\}_{j=1}^{l_b}的加权求和。

通过计算元组<\bar{a}, \tilde{a}>以及<\bar{b}, \tilde{b}>的差和元素乘积来增强局部推理信息。增强后的局部推理信息如下:

                                                                     m_a=[\bar{a};\tilde{a};\bar{a}-\tilde{a};\bar{a}\bigodot \tilde{a}]                  (6)

                                                                      m_b=[\bar{b};\tilde{b};\bar{b}-\tilde{b};\bar{b}\bigodot \tilde{b}]                    (7)

3.推理组成

使用BiLSTM来提取m_am_b的局部推理信息,对于BiLSTM提取到的信息,作者认为求和对序列长度比较敏感,不太健壮,因此他们采用了:最大池化和平均池化的方案,并将池化后的结果连接为固定长度的向量。如下:

                                                              v_{a,avg}=\sum_{i=1}^{l_a}\frac{v_{a, i}}{l_a},\quad v_{a,max}=max_{i=1}^{l_a}v_{a,i}     (8) 

                                                              v_{b,avg}=\sum_{j=1}^{l_b}\frac{v_{a, j}}{l_b},\quad v_{b,max}=max_{j=1}^{l_b}v_{b,j}     (9)

                                                                       v=[v_{a,ave};v_{a,max};v_{b,ave};v_{b;max}]          (10)

然后将v放入一个全连接层分类器中进行分类。

二、pytorch实现

git上有很多版本的实现,经过实际使用发现,很多版本存在各种问题,其中比较明显的问题是有些使用BiLSTM时没有对批进行pad_pack操作,导致BiLSTM提取的特征是错误的,之前做NER的时候这个错误困扰了我很久,各位也要注意pytorch中BiLSTM的用法。详细的ESIM的代码实现见下面,有问题请留言交流。


 
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn as nn
  4. from utils.utils import sort_by_seq_lens, masked_softmax, weighted_sum, get_mask, replace_masked
  5. import numpy as np
  6. # Class widely inspired from:
  7. # https://github.com/allenai/allennlp/blob/master/allennlp/modules/input_variational_dropout.py
  8. class RNNDropout(nn.Dropout):
  9. """
  10. Dropout layer for the inputs of RNNs.
  11. Apply the same dropout mask to all the elements of the same sequence in
  12. a batch of sequences of size (batch, sequences_length, embedding_dim).
  13. """
  14. def forward(self, sequences_batch):
  15. """
  16. Apply dropout to the input batch of sequences.
  17. Args:
  18. sequences_batch: A batch of sequences of vectors that will serve
  19. as input to an RNN.
  20. Tensor of size (batch, sequences_length, emebdding_dim).
  21. Returns:
  22. A new tensor on which dropout has been applied.
  23. """
  24. ones = sequences_batch.data.new_ones(sequences_batch.shape[ 0],
  25. sequences_batch.shape[ -1])
  26. dropout_mask = nn.functional.dropout(ones, self.p, self.training,
  27. inplace= False)
  28. return dropout_mask.unsqueeze( 1) * sequences_batch
  29. class Seq2SeqEncoder(nn.Module):
  30. """
  31. RNN taking variable length padded sequences of vectors as input and
  32. encoding them into padded sequences of vectors of the same length.
  33. This module is useful to handle batches of padded sequences of vectors
  34. that have different lengths and that need to be passed through a RNN.
  35. The sequences are sorted in descending order of their lengths, packed,
  36. passed through the RNN, and the resulting sequences are then padded and
  37. permuted back to the original order of the input sequences.
  38. """
  39. def __init__(self,
  40. rnn_type,
  41. input_size,
  42. hidden_size,
  43. num_layers=1,
  44. bias=True,
  45. dropout=0.0,
  46. bidirectional=False):
  47. """
  48. Args:
  49. rnn_type: The type of RNN to use as encoder in the module.
  50. Must be a class inheriting from torch.nn.RNNBase
  51. (such as torch.nn.LSTM for example).
  52. input_size: The number of expected features in the input of the
  53. module.
  54. hidden_size: The number of features in the hidden state of the RNN
  55. used as encoder by the module.
  56. num_layers: The number of recurrent layers in the encoder of the
  57. module. Defaults to 1.
  58. bias: If False, the encoder does not use bias weights b_ih and
  59. b_hh. Defaults to True.
  60. dropout: If non-zero, introduces a dropout layer on the outputs
  61. of each layer of the encoder except the last one, with dropout
  62. probability equal to 'dropout'. Defaults to 0.0.
  63. bidirectional: If True, the encoder of the module is bidirectional.
  64. Defaults to False.
  65. """
  66. assert issubclass(rnn_type, nn.RNNBase), \
  67. "rnn_type must be a class inheriting from torch.nn.RNNBase"
  68. super(Seq2SeqEncoder, self).__init__()
  69. self.rnn_type = rnn_type
  70. self.input_size = input_size
  71. self.hidden_size = hidden_size
  72. self.num_layers = num_layers
  73. self.bias = bias
  74. self.dropout = dropout
  75. self.bidirectional = bidirectional
  76. self._encoder = rnn_type(input_size,
  77. hidden_size,
  78. num_layers=num_layers,
  79. bias=bias,
  80. batch_first= True,
  81. dropout=dropout,
  82. bidirectional=bidirectional)
  83. def forward(self, sequences_batch, sequences_lengths):
  84. """
  85. Args:
  86. sequences_batch: A batch of variable length sequences of vectors.
  87. The batch is assumed to be of size
  88. (batch, sequence, vector_dim).
  89. sequences_lengths: A 1D tensor containing the sizes of the
  90. sequences in the input batch.
  91. Returns:
  92. reordered_outputs: The outputs (hidden states) of the encoder for
  93. the sequences in the input batch, in the same order.
  94. """
  95. sorted_batch, sorted_lengths, _, restoration_idx = \
  96. sort_by_seq_lens(sequences_batch, sequences_lengths)
  97. packed_batch = nn.utils.rnn.pack_padded_sequence(sorted_batch,
  98. sorted_lengths,
  99. batch_first= True)
  100. outputs, _ = self._encoder(packed_batch, None)
  101. outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs,
  102. batch_first= True)
  103. reordered_outputs = outputs.index_select( 0, restoration_idx)
  104. return reordered_outputs
  105. class SoftmaxAttention(nn.Module):
  106. """
  107. Attention layer taking premises and hypotheses encoded by an RNN as input
  108. and computing the soft attention between their elements.
  109. The dot product of the encoded vectors in the premises and hypotheses is
  110. first computed. The softmax of the result is then used in a weighted sum
  111. of the vectors of the premises for each element of the hypotheses, and
  112. conversely for the elements of the premises.
  113. """
  114. def forward(self,
  115. premise_batch,
  116. premise_mask,
  117. hypothesis_batch,
  118. hypothesis_mask):
  119. """
  120. Args:
  121. premise_batch: A batch of sequences of vectors representing the
  122. premises in some NLI task. The batch is assumed to have the
  123. size (batch, sequences, vector_dim).
  124. premise_mask: A mask for the sequences in the premise batch, to
  125. ignore padding data in the sequences during the computation of
  126. the attention.
  127. hypothesis_batch: A batch of sequences of vectors representing the
  128. hypotheses in some NLI task. The batch is assumed to have the
  129. size (batch, sequences, vector_dim).
  130. hypothesis_mask: A mask for the sequences in the hypotheses batch,
  131. to ignore padding data in the sequences during the computation
  132. of the attention.
  133. Returns:
  134. attended_premises: The sequences of attention vectors for the
  135. premises in the input batch.
  136. attended_hypotheses: The sequences of attention vectors for the
  137. hypotheses in the input batch.
  138. """
  139. # Dot product between premises and hypotheses in each sequence of
  140. # the batch.
  141. similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose( 2, 1)
  142. .contiguous())
  143. # Softmax attention weights.
  144. prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
  145. hyp_prem_attn = masked_softmax(similarity_matrix.transpose( 1, 2)
  146. .contiguous(),
  147. premise_mask)
  148. # Weighted sums of the hypotheses for the the premises attention,
  149. # and vice-versa for the attention of the hypotheses.
  150. attended_premises = weighted_sum(hypothesis_batch,
  151. prem_hyp_attn,
  152. premise_mask)
  153. attended_hypotheses = weighted_sum(premise_batch,
  154. hyp_prem_attn,
  155. hypothesis_mask)
  156. return attended_premises, attended_hypotheses
  157. class ESIM(nn.Module):
  158. """
  159. Implementation of the ESIM model presented in the paper "Enhanced LSTM for
  160. Natural Language Inference" by Chen et al.
  161. """
  162. def __init__(self,
  163. vocab_size,
  164. embedding_dim,
  165. hidden_size,
  166. embeddings=None,
  167. padding_idx=0,
  168. dropout=0.1,
  169. num_classes=2):
  170. """
  171. Args:
  172. vocab_size: The size of the vocabulary of embeddings in the model.
  173. embedding_dim: The dimension of the word embeddings.
  174. hidden_size: The size of all the hidden layers in the network.
  175. embeddings: A tensor of size (vocab_size, embedding_dim) containing
  176. pretrained word embeddings. If None, word embeddings are
  177. initialised randomly. Defaults to None.
  178. padding_idx: The index of the padding token in the premises and
  179. hypotheses passed as input to the model. Defaults to 0.
  180. dropout: The dropout rate to use between the layers of the network.
  181. A dropout rate of 0 corresponds to using no dropout at all.
  182. Defaults to 0.5.
  183. num_classes: The number of classes in the output of the network.
  184. Defaults to 3.
  185. device: The name of the device on which the model is being
  186. executed. Defaults to 'cpu'.
  187. """
  188. super(ESIM, self).__init__()
  189. self.vocab_size = vocab_size
  190. self.embedding_dim = embedding_dim
  191. self.hidden_size = hidden_size
  192. self.num_classes = num_classes
  193. self.dropout = dropout
  194. self._word_embedding = nn.Embedding(self.vocab_size,
  195. self.embedding_dim,
  196. padding_idx=padding_idx)
  197. if embeddings:
  198. embvecs, embwords = embeddings
  199. self._word_embedding.weight.data.copy_(torch.from_numpy(np.asarray(embvecs)))
  200. self._word_embedding.weight.requires_grad = False
  201. if self.dropout:
  202. self._rnn_dropout = RNNDropout(p=self.dropout)
  203. # self._rnn_dropout = nn.Dropout(p=self.dropout)
  204. self._encoding = Seq2SeqEncoder(nn.LSTM,
  205. self.embedding_dim,
  206. self.hidden_size,
  207. bidirectional= True)
  208. self._attention = SoftmaxAttention()
  209. self._projection = nn.Sequential(nn.Linear( 4 * 2 * self.hidden_size,
  210. self.hidden_size),
  211. nn.ReLU())
  212. self._composition = Seq2SeqEncoder(nn.LSTM,
  213. self.hidden_size,
  214. self.hidden_size,
  215. bidirectional= True)
  216. self._classification = nn.Sequential(nn.Dropout(p=self.dropout),
  217. nn.Linear( 2 * 4 * self.hidden_size,
  218. self.hidden_size),
  219. nn.Tanh(),
  220. nn.Dropout(p=self.dropout),
  221. nn.Linear(self.hidden_size,
  222. self.num_classes))
  223. # Initialize all weights and biases in the model.
  224. self.apply(_init_esim_weights)
  225. def forward(self,
  226. text_a_ids,
  227. text_b_ids,
  228. q1_char_inputs=None,
  229. q2_char_inputs=None,
  230. q1_lens=None,
  231. q2_lens=None,
  232. device='cpu'
  233. ):
  234. """
  235. Args:
  236. premises: A batch of varaible length sequences of word indices
  237. representing premises. The batch is assumed to be of size
  238. (batch, premises_length).
  239. premises_lengths: A 1D tensor containing the lengths of the
  240. premises in 'premises'.
  241. hypothesis: A batch of varaible length sequences of word indices
  242. representing hypotheses. The batch is assumed to be of size
  243. (batch, hypotheses_length).
  244. hypotheses_lengths: A 1D tensor containing the lengths of the
  245. hypotheses in 'hypotheses'.
  246. Returns:
  247. logits: A tensor of size (batch, num_classes) containing the
  248. logits for each output class of the model.
  249. probabilities: A tensor of size (batch, num_classes) containing
  250. the probabilities of each output class in the model.
  251. """
  252. premises = text_a_ids
  253. premises_lengths = q1_lens
  254. hypotheses = text_b_ids
  255. hypotheses_lengths = q2_lens
  256. premises_mask = get_mask(premises, premises_lengths).to(device)
  257. hypotheses_mask = get_mask(hypotheses, hypotheses_lengths).to(device)
  258. embedded_premises = self._word_embedding(premises)
  259. embedded_hypotheses = self._word_embedding(hypotheses)
  260. if self.dropout:
  261. embedded_premises = self._rnn_dropout(embedded_premises)
  262. embedded_hypotheses = self._rnn_dropout(embedded_hypotheses)
  263. encoded_premises = self._encoding(embedded_premises,
  264. premises_lengths)
  265. encoded_hypotheses = self._encoding(embedded_hypotheses,
  266. hypotheses_lengths)
  267. attended_premises, attended_hypotheses = self._attention(encoded_premises, premises_mask,
  268. encoded_hypotheses, hypotheses_mask)
  269. enhanced_premises = torch.cat([encoded_premises,
  270. attended_premises,
  271. encoded_premises - attended_premises,
  272. encoded_premises * attended_premises],
  273. dim= -1)
  274. enhanced_hypotheses = torch.cat([encoded_hypotheses,
  275. attended_hypotheses,
  276. encoded_hypotheses -
  277. attended_hypotheses,
  278. encoded_hypotheses *
  279. attended_hypotheses],
  280. dim= -1)
  281. projected_premises = self._projection(enhanced_premises)
  282. projected_hypotheses = self._projection(enhanced_hypotheses)
  283. if self.dropout:
  284. projected_premises = self._rnn_dropout(projected_premises)
  285. projected_hypotheses = self._rnn_dropout(projected_hypotheses)
  286. v_ai = self._composition(projected_premises, premises_lengths)
  287. v_bj = self._composition(projected_hypotheses, hypotheses_lengths)
  288. v_a_avg = torch.sum(v_ai * premises_mask.unsqueeze( 1)
  289. .transpose( 2, 1), dim= 1) \
  290. / torch.sum(premises_mask, dim= 1, keepdim= True)
  291. v_b_avg = torch.sum(v_bj * hypotheses_mask.unsqueeze( 1)
  292. .transpose( 2, 1), dim= 1) \
  293. / torch.sum(hypotheses_mask, dim= 1, keepdim= True)
  294. v_a_max, _ = replace_masked(v_ai, premises_mask, -1e7).max(dim= 1)
  295. v_b_max, _ = replace_masked(v_bj, hypotheses_mask, -1e7).max(dim= 1)
  296. v = torch.cat([v_a_avg, v_a_max, v_b_avg, v_b_max], dim= 1)
  297. logits = self._classification(v)
  298. # probabilities = nn.functional.softmax(logits, dim=-1)
  299. return logits
  300. def _init_esim_weights(module):
  301. """
  302. Initialise the weights of the ESIM model.
  303. """
  304. if isinstance(module, nn.Linear):
  305. nn.init.kaiming_normal_(module.weight.data)
  306. nn.init.constant_(module.bias.data, 0.0)
  307. elif isinstance(module, nn.LSTM):
  308. nn.init.kaiming_normal_(module.weight_ih_l0.data)
  309. nn.init.orthogonal_(module.weight_hh_l0.data)
  310. nn.init.constant_(module.bias_ih_l0.data, 0.0)
  311. nn.init.constant_(module.bias_hh_l0.data, 0.0)
  312. hidden_size = module.bias_hh_l0.data.shape[ 0] // 4
  313. module.bias_hh_l0.data[hidden_size:( 2 * hidden_size)] = 1.0
  314. if (module.bidirectional):
  315. nn.init.kaiming_normal_(module.weight_ih_l0_reverse.data)
  316. nn.init.orthogonal_(module.weight_hh_l0_reverse.data)
  317. nn.init.constant_(module.bias_ih_l0_reverse.data, 0.0)
  318. nn.init.constant_(module.bias_hh_l0_reverse.data, 0.0)
  319. module.bias_hh_l0_reverse.data[hidden_size:( 2 * hidden_size)] = 1.0

utils代码:


 
  1. import torch
  2. import torch.nn as nn
  3. # Code widely inspired from:
  4. # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
  5. def sort_by_seq_lens(batch, sequences_lengths, descending=True):
  6. """
  7. Sort a batch of padded variable length sequences by length.
  8. Args:
  9. batch: A batch of padded variable length sequences. The batch should
  10. have the dimensions (batch_size x max_sequence_length x *).
  11. sequences_lengths: A tensor containing the lengths of the sequences in the
  12. input batch. The tensor should be of size (batch_size).
  13. descending: A boolean value indicating whether to sort the sequences
  14. by their lengths in descending order. Defaults to True.
  15. Returns:
  16. sorted_batch: A tensor containing the input batch reordered by
  17. sequences lengths.
  18. sorted_seq_lens: A tensor containing the sorted lengths of the
  19. sequences in the input batch.
  20. sorting_idx: A tensor containing the indices used to permute the input
  21. batch in order to get 'sorted_batch'.
  22. restoration_idx: A tensor containing the indices that can be used to
  23. restore the order of the sequences in 'sorted_batch' so that it
  24. matches the input batch.
  25. """
  26. sorted_seq_lens, sorting_index =\
  27. sequences_lengths.sort( 0, descending=descending)
  28. sorted_batch = batch.index_select( 0, sorting_index)
  29. idx_range =\
  30. sequences_lengths.new_tensor(torch.arange( 0, len(sequences_lengths)))
  31. _, reverse_mapping = sorting_index.sort( 0, descending= False)
  32. restoration_index = idx_range.index_select( 0, reverse_mapping)
  33. return sorted_batch, sorted_seq_lens, sorting_index, restoration_index
  34. def get_mask(sequences_batch, sequences_lengths):
  35. """
  36. Get the mask for a batch of padded variable length sequences.
  37. Args:
  38. sequences_batch: A batch of padded variable length sequences
  39. containing word indices. Must be a 2-dimensional tensor of size
  40. (batch, sequence).
  41. sequences_lengths: A tensor containing the lengths of the sequences in
  42. 'sequences_batch'. Must be of size (batch).
  43. Returns:
  44. A mask of size (batch, max_sequence_length), where max_sequence_length
  45. is the length of the longest sequence in the batch.
  46. """
  47. batch_size = sequences_batch.size()[ 0]
  48. max_length = torch.max(sequences_lengths)
  49. mask = torch.ones(batch_size, max_length, dtype=torch.float)
  50. mask[sequences_batch[:, :max_length] == 0] = 0.0
  51. return mask
  52. def get_atten_mask(sequences_batch, sequences_lengths):
  53. """
  54. Get the mask for atten.
  55. Args:
  56. sequences_batch: A batch of padded variable length sequences
  57. containing word indices. Must be a 2-dimensional tensor of size
  58. (batch, sequence).
  59. sequences_lengths: A tensor containing the lengths of the sequences in
  60. 'sequences_batch'. Must be of size (batch).
  61. Returns:
  62. A mask of size (batch, max_sequence_length), where max_sequence_length
  63. is the length of the longest sequence in the batch.
  64. """
  65. batch_size = sequences_batch.size()[ 0]
  66. max_length = torch.max(sequences_lengths)
  67. mask = torch.zeros(batch_size, max_length, dtype=torch.float)
  68. mask[sequences_batch[:, :max_length] == 0] = 1.0
  69. return mask
  70. # Code widely inspired from:
  71. # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
  72. def masked_softmax(tensor, mask):
  73. """
  74. Apply a masked softmax on the last dimension of a tensor.
  75. The input tensor and mask should be of size (batch, *, sequence_length).
  76. Args:
  77. tensor: The tensor on which the softmax function must be applied along
  78. the last dimension.
  79. mask: A mask of the same size as the tensor with 0s in the positions of
  80. the values that must be masked and 1s everywhere else.
  81. Returns:
  82. A tensor of the same size as the inputs containing the result of the
  83. softmax.
  84. """
  85. tensor_shape = tensor.size()
  86. reshaped_tensor = tensor.view( -1, tensor_shape[ -1])
  87. # Reshape the mask so it matches the size of the input tensor.
  88. while mask.dim() < tensor.dim():
  89. mask = mask.unsqueeze( 1)
  90. mask = mask.expand_as(tensor).contiguous().float()
  91. reshaped_mask = mask.view( -1, mask.size()[ -1])
  92. result = nn.functional.softmax(reshaped_tensor * reshaped_mask, dim= -1)
  93. result = result * reshaped_mask
  94. # 1e-13 is added to avoid divisions by zero.
  95. result = result / (result.sum(dim= -1, keepdim= True) + 1e-13)
  96. return result.view(*tensor_shape)
  97. # Code widely inspired from:
  98. # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
  99. def weighted_sum(tensor, weights, mask):
  100. """
  101. Apply a weighted sum on the vectors along the last dimension of 'tensor',
  102. and mask the vectors in the result with 'mask'.
  103. Args:
  104. tensor: A tensor of vectors on which a weighted sum must be applied.
  105. weights: The weights to use in the weighted sum.
  106. mask: A mask to apply on the result of the weighted sum.
  107. Returns:
  108. A new tensor containing the result of the weighted sum after the mask
  109. has been applied on it.
  110. """
  111. weighted_sum = weights.bmm(tensor)
  112. while mask.dim() < weighted_sum.dim():
  113. mask = mask.unsqueeze( 1)
  114. mask = mask.transpose( -1, -2)
  115. mask = mask.expand_as(weighted_sum).contiguous().float()
  116. return weighted_sum * mask
  117. # Code inspired from:
  118. # https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py.
  119. def replace_masked(tensor, mask, value):
  120. """
  121. Replace the all the values of vectors in 'tensor' that are masked in
  122. 'masked' by 'value'.
  123. Args:
  124. tensor: The tensor in which the masked vectors must have their values
  125. replaced.
  126. mask: A mask indicating the vectors which must have their values
  127. replaced.
  128. value: The value to place in the masked vectors of 'tensor'.
  129. Returns:
  130. A new tensor of the same size as 'tensor' where the values of the
  131. vectors masked in 'mask' were replaced by 'value'.
  132. """
  133. mask = mask.unsqueeze( 1).transpose( 2, 1)
  134. reverse_mask = 1.0 - mask
  135. values_to_add = value * reverse_mask
  136. return tensor * mask + values_to_add

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值