参考链接
attention其实就是一个加权求和
import numpy as np
import tensorflow as tf
a=np.array((list(range(3*4)))).reshape((3,4))*1.0
b=a+3.0
katten=tf.keras.layers.Attention()(
[a, b
])
print('keras attention=',katten)
从算法本质上理解attention ,就是三次矩阵运算
weight=a@b.T
weight1=tf.nn.softmax(weight)
attent=weight1@b
print('my attention=',attent)
keras attention= tf.Tensor(
[[11. 12. 13. 14.]
[11. 12. 13. 14.]
[11. 12. 13. 14.]], shape=(3, 4), dtype=float32)
my attention= tf.Tensor(
[[11. 12. 13. 14.]
[11. 12. 13. 14.]
[11. 12. 13. 14.]], shape=(3, 4), dtype=float64)