加权软注意力机制姐扫
时间: 2025-05-08 16:46:03 浏览: 11
### 加权软注意力机制概述
加权软注意力机制是一种用于增强神经网络模型性能的技术,在自然语言处理(NLP)和其他领域广泛应用。该机制允许模型聚焦于输入序列的不同部分,而无需严格限定关注的具体位置[^1]。
在具体实现上,加权软注意力通常涉及以下几个方面:
#### 计算权重分布
为了计算不同输入元素的重要性程度,即所谓的“注意力分数”,一般采用某种形式的距离度量函数\(f(q, k)\),其中\(q\)代表查询向量,\(k\)表示键向量集合中的成员。这个过程可以通过多种方式完成,比如点乘、拼接后再线性变换等方法[^3]。
```python
import torch
import torch.nn.functional as F
def compute_attention_scores(query, keys):
"""
Compute the attention scores between a query and multiple keys.
Args:
query (Tensor): Query tensor of shape [batch_size, hidden_dim].
keys (Tensor): Keys tensor of shape [batch_size, seq_len, hidden_dim].
Returns:
Tensor: Attention score matrix with shape [batch_size, seq_len].
"""
# Calculate compatibility function value using dot product here
scores = torch.bmm(keys, query.unsqueeze(-1)).squeeze(-1)
return F.softmax(scores, dim=-1)
```
#### 形成上下文向量
一旦获得了各元素对应的权重值之后,就可以据此形成一个新的综合表示——上下文向量(context vector)。这一步骤通常是通过对原始输入按权重求和得到的结果。
```python
def apply_weighted_sum(values, weights):
"""
Apply weighted sum over values according to given weights.
Args:
values (Tensor): Values tensor of shape [batch_size, seq_len, hidden_dim].
weights (Tensor): Weights tensor of shape [batch_size, seq_len].
Returns:
Tensor: Context vectors after applying weighted sum, shaped [batch_size, hidden_dim].
"""
batch_size, _, hidden_dim = values.size()
expanded_weights = weights.unsqueeze(-1).expand_as(values)
context_vector = (expanded_weights * values).sum(dim=1)
return context_vector
```
#### 应用实例:机器翻译任务
在一个典型的机器翻译场景下,源语言句子会被编码器转换为一系列隐藏状态;解码阶段则利用这些隐藏状态作为背景信息指导目标语言词句的选择。此时引入加权软注意力能够帮助更好地捕捉源语境与目的表达之间的关联关系,进而提升整体翻译质量[^2]。
```python
class TranslatorWithSoftAttention(nn.Module):
def __init__(self, encoder_hidden_size, decoder_hidden_size, vocab_size):
super().__init__()
self.encoder = EncoderRNN(input_size=vocab_size,
hidden_size=encoder_hidden_size)
self.decoder = DecoderRNN(output_size=vocab_size,
hidden_size=decoder_hidden_size)
def forward(self, source_sentence, target_sentence=None):
encoded_states = self.encoder(source_sentence)
if not training_mode:
translated_words = []
current_word_idx = SOS_TOKEN # Start Of Sentence token index
while True:
next_word_probabilities = self.decoder(
previous_output=current_word_idx,
context_vectors=self.apply_soft_attention(encoded_states))
predicted_next_word_idx = choose_best_prediction(next_word_probabilities)
translated_words.append(predicted_next_word_idx.item())
if predicted_next_word_idx == EOS_TOKEN or len(translated_words)>MAX_LENGTH:
break
current_word_idx = predicted_next_word_idx
return translated_words
else:
teacher_forcing_ratio = 0.5
use_teacher_forcing = random.random() < teacher_forcing_ratio
outputs = []
for t in range(target_sentence.shape[1]):
output_t = self.decoder(
previous_output=target_sentence[:,t-1].unsqueeze(1),
context_vectors=self.apply_soft_attention(encoded_states))
outputs.append(output_t.squeeze())
if not use_teacher_forcing:
topv, topi = output_t.topk(1)
input_token = topi.squeeze().detach()
return torch.stack(outputs), None
def apply_soft_attention(self, encoder_outputs):
last_decoder_state = ... # Get from decoder's LSTM/GRU cell state
attn_weights = compute_attention_scores(last_decoder_state, encoder_outputs)
context_vecs = apply_weighted_sum(encoder_outputs, attn_weights)
return context_vecs
```
阅读全文
相关推荐


















