import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
""" Multi-Head Attention """
def __init__(self, n_head, d_k_, d_v_, d_k, d_v, d_o):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.fc_q = nn.Linear(d_k_, n_head * d_k)
self.fc_k = nn.Linear(d_k_, n_head * d_k)
self.fc_v = nn.Linear(d_v_, n_head * d_v)
self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
self.fc_o
多头注意力(MultiHeadAttention)python实现
于 2024-03-14 17:57:10 首次发布