《A Self-Attentive model for Knowledge Tracing》技术核心模块详解
一、模型输入与输出
(一)输入序列
模型的输入为学生过去的互动序列
(
X
=
(
x
1
,
x
2
,
…
,
x
t
)
)
(X = (x_1, x_2, \dots, x_t))
(X=(x1,x2,…,xt)),其中每个互动元素
(
x
t
=
(
e
t
,
r
t
)
)
(x_t = (e_t, r_t))
(xt=(et,rt)),
(
e
t
)
(e_t)
(et)表示学生在时间戳 (t) 尝试的练习,
(
r
t
)
(r_t)
(rt) 表示学生回答的正确性。
为了将问题转化为序列建模问题,将输入序列处理为
(
x
1
,
x
2
,
…
,
x
t
−
1
)
(x_1, x_2, \dots, x_{t-1})
(x1,x2,…,xt−1),对应的练习序列则为
(
e
2
,
e
3
,
…
,
e
t
)
(e_2, e_3, \dots, e_t)
(e2,e3,…,et),输出为对这些练习正确性的预测
(
r
2
,
r
3
,
…
,
r
t
)
(r_2, r_3, \dots, r_t)
(r2,r3,…,rt)。这种处理方式确保了模型在预测第 (t+1) 次练习的正确性时,仅依据前 (t) 次的互动信息。
嵌入层: 嵌入学生正在尝试的当前练习和他过去的交互。在每次标记 t + 1时,使用练习嵌入将当前问题
e
t
+
1
e_{t+1}
et+1 嵌入到查询空间中,使用交互嵌入将过去交互的元素
x
t
x_t
xt 嵌入到键和值空间中。
模型目的: 根据学生1到 t 时刻 的习题作答情况,(即交互序列 X = x 1 , x 2 , . . . , x t X = x_1, x_2, ..., x_t X=x1,x2,...,xt) 预测在t+1时刻,习题 e t + 1 e_{t+1} et+1 的回答情况(即预测出真实情况,正确的概率)。
交互元组:
x
t
=
(
e
t
,
r
t
)
x_t = (e_t, r_t)
xt=(et,rt) : t 时刻习题
e
t
e_t
et 的作答情况
r
t
r_t
rt 构成的。
x
t
x_t
xt 编号化时,用两者来表示,:
y
t
=
e
t
+
r
t
×
E
y_t = e_t + r_t × E
yt=et+rt×E,E是题目数量,可以看出交互编号,回答错误 时和题目编号同
y
t
=
e
t
y_t = e_t
yt=et,回答正确时,编号加上题目总数
y
t
=
e
t
+
E
y_t = e_t + E
yt=et+E。
eg:设 E=100 即一共有 100 道题。那么 99 表示第 99 道题做错了,199 则表示第 99 道题做对了
(二)输出序列
模型输出为学生在各个练习上的正确性预测概率,通过全连接网络结合 Sigmoid 激活函数得到。对于每个练习 ( e i ) (e_i) (ei),预测其正确性的概率 ( p i ) (p_i) (pi),即:
[ p i = Sigmoid ( F i w + b ) ] [p_i = \text{Sigmoid}(\mathbf{F}_i \mathbf{w} + \mathbf{b})] [pi=Sigmoid(Fiw+b)]
其中, ( F i ) (F_i) (Fi) 是经过模型多层处理后的特征向量的第 (i) 行。
二、嵌入层(Embedding Layer)
嵌入层负责将离散的练习和互动信息转换为连续的向量表示,便于模型进行后续的复杂运算和特征学习。
交互序列需要划分处理,保证所以的交互序列的长度一致,多则截断,短则填充。
因此交互序列由
y
=
(
y
1
,
y
2
,
.
.
.
,
y
t
)
y = (y_1, y_2, ...,y_t)
y=(y1,y2,...,yt)变为
s
=
(
s
1
,
s
2
,
.
.
.
,
s
n
)
s = (s_1,s_2,...,s_n)
s=(s1,s2,...,sn)。
训练一个交互嵌入矩阵 :
M
∈
R
2
E
×
d
M \in \mathbb{R}^{2E \times d}
M∈R2E×d,其中 d 是潜在维度,用于获取交互嵌入。
s
i
s_i
si的嵌入表示为
M
s
i
M_{s_i}
Msi
训练一个练习嵌入矩阵:
(
E
∈
R
E
×
d
)
(E \in \mathbb{R}^{E \times d})
(E∈RE×d), 用户获取练习嵌入。
e
i
e_i
ei的嵌入表示为
E
e
i
E_{e_i}
Eei
(一)交互嵌入
训练一个交互嵌入矩阵 ( M ∈ R 2 E × d ) (M \in \mathbb{R}^{2E \times d}) (M∈R2E×d),其中 (E) 为练习的总数,(d) 为潜在向量的维度。对于输入序列中的每个元素 ( s i ) (s_i) (si)(经过处理后的互动信息),通过查找交互嵌入矩阵 (M) 得到其嵌入表示 ( M s i ) (M_{s_i}) (Msi)。
(二)练习嵌入
训练一个练习嵌入矩阵 ( E ∈ R E × d ) (E \in \mathbb{R}^{E \times d}) (E∈RE×d),用于将每个练习嵌入到 (d) 维空间中。对于练习序列中的每个练习 ( e i ) (e_i) (ei),通过查找练习嵌入矩阵 (E) 得到其嵌入表示 ( E e i ) (E_{e_i}) (Eei)。
公式
[ M ^ = [ M s 1 + P 1 M s 2 + P 2 … M s n + P n ] , E ^ = [ E s 1 E s 2 … E s n ] ] [ \hat{\mathbf{M}} = \begin{bmatrix} \mathbf{M}_{s_1} + \mathbf{P}_1 \\ \mathbf{M}_{s_2} + \mathbf{P}_2 \\ \ldots \\ \mathbf{M}_{s_n} + \mathbf{P}_n \end{bmatrix}, \quad \hat{\mathbf{E}} = \begin{bmatrix} \mathbf{E}_{s_1} \\ \mathbf{E}_{s_2} \\ \ldots \\ \mathbf{E}_{s_n} \end{bmatrix} ] [M^= Ms1+P1Ms2+P2…Msn+Pn ,E^= Es1Es2…Esn ]
参数说明
- M \mathbf{M} M:交互嵌入矩阵(Interaction Embedding Matrix),表示学生历史交互的嵌入(如练习和答案对)。
- P \mathbf{P} P:位置嵌入矩阵(Positional Embedding Matrix),捕获序列中元素的顺序信息。
- E \mathbf{E} E:练习嵌入矩阵(Exercise Embedding Matrix),表示每个练习的潜在向量。
- s i s_i si:序列中第 i i i个时间步的索引, n n n为最大序列长度。
功能
- M ^ \hat{\mathbf{M}} M^:将每个历史交互的嵌入( M s i \mathbf{M}_{s_i} Msi)与对应的位置嵌入( P i \mathbf{P}_i Pi)相加,形成包含时序信息的交互表示。
- E ^ \hat{\mathbf{E}} E^:将当前待预测的练习(Query)嵌入到向量空间,用于后续注意力计算。
代码
self.M = Embedding(self.num_q * 2, self.d)
self.E = Embedding(self.num_q, d)
self.P = Parameter(torch.Tensor(self.n, self.d))
kaiming_normal_(self.P)
self.M
对应论文中的交互嵌入矩阵 ( M ∈ R 2 E × d ) (M \in \mathbb{R}^{2E \times d}) (M∈R2E×d),其中(E)是练习总数, d d d是潜在向量维度。它用于将交互序列中的每个元素嵌入到(d)维空间。self.E
对应论文中的练习嵌入矩阵 ( E ∈ R E × d ) (E \in \mathbb{R}^{E \times d}) (E∈RE×d),用于将练习嵌入到相应的行。self.P
对应论文中的位置嵌入参数 ( P ∈ R n × d ) (P \in \mathbb{R}^{n \times d}) (P∈Rn×d),其中 n n n是模型能够处理的序列最大长度,用于编码序列中元素的位置信息。
三、位置编码层(Position Encoding)
位置编码层旨在为模型提供序列中元素的顺序信息,因为在知识追踪任务中,学生知识状态的演变具有时间顺序性。
引入一个位置嵌入矩阵 ( P ∈ R n × d ) (P \in \mathbb{R}^{n \times d}) (P∈Rn×d),其中 (n) 为模型能够处理的序列的最大长度。在训练过程中,将位置嵌入矩阵的第 (i) 行 ( P i ) (P_i) (Pi) 加到交互序列中第 (i) 个元素的交互嵌入向量上,从而为每个元素赋予位置信息。
四、自注意力层(Self-attention Layer)
自注意力层是 SAKT 模型的核心组件,用于捕捉序列中不同元素之间的相关性,从而识别出与当前预测相关的过去的练习。
(一) 查询(Query)、键(Key)、值(Value)投影
通过线性变换将练习嵌入和交互嵌入投影到不同的空间,生成查询(Query)、键(Key)和值(Value)向量:
[ Q = E ^ W Q , K = M ^ W K , V = M ^ W V ] [ \mathbf{Q} = \hat{\mathbf{E}}\mathbf{W}^Q, \quad \mathbf{K} = \hat{\mathbf{M}}\mathbf{W}^K, \quad \mathbf{V} = \hat{\mathbf{M}}\mathbf{W}^V ] [Q=E^WQ,K=M^WK,V=M^WV]
Q: 习题嵌入
K:作答交互嵌入
V :作答交互嵌入
- W Q , W K , W V ∈ R d × d \mathbf{W}^Q, \mathbf{W}^K, \mathbf{W}^V \in \mathbb{R}^{d \times d} WQ,WK,WV∈Rd×d:可学习的投影矩阵,将输入映射到不同的子空间。
- E ^ \hat{\mathbf{E}} E^ 和 M ^ \hat{\mathbf{M}} M^: 分别为经过位置编码后的练习嵌入和交互嵌入。
- d d d:潜在向量的维度。
(二)缩放点积注意力(Scaled Dot-Product Attention)
利用缩放点积注意力机制计算查询与键之间的相关性权重:
[ Attention ( Q , K , V ) = softmax ( Q K T d ) V ] [\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d}}\right)\mathbf{V}] [Attention(Q,K,V)=softmax(dQKT)V]
- d \sqrt{d} d:缩放因子,防止点积结果过大导致梯度消失。
- Softmax:归一化注意力权重,表示历史交互对当前预测的重要性。
该公式首先计算查询向量 Q \mathbf{Q} Q 和键向量 K \mathbf{K} K 的转置的点积,然后除以 ( d ) (\sqrt{d}) (d) 进行缩放,接着通过 Softmax 函数将结果转换为概率分布形式的注意力权重,最后将这些权重与值向量 V \mathbf{V} V 相乘,得到加权后的值向量,表示每个历史互动对当前预测的重要性。
五、多头注意力(Multi-head Attention)
为了同时关注不同表示子空间中的信息,模型采用了多头注意力机制。具体来说,将查询、键和值通过不同的投影矩阵分别线性投影 (h) 次,形成多个注意力头:
[
MultiHead
(
M
^
,
E
^
)
=
Concat
(
head
1
,
…
,
head
h
)
W
O
]
[ \text{MultiHead}(\hat{\mathbf{M}}, \hat{\mathbf{E}}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O ]
[MultiHead(M^,E^)=Concat(head1,…,headh)WO]
- head i = Attention ( E ^ W i Q , M ^ W i K , M ^ W i V ) \text{head}_i = \text{Attention}(\hat{\mathbf{E}}\mathbf{W}_i^Q, \hat{\mathbf{M}}\mathbf{W}_i^K, \hat{\mathbf{M}}\mathbf{W}_i^V) headi=Attention(E^WiQ,M^WiK,M^WiV):每个头独立计算注意力。
- h h h:头的数量,默认设为5。
- W O ∈ R h d × d \mathbf{W}^O \in \mathbb{R}^{hd \times d} WO∈Rhd×d:将拼接后的多头输出映射回原始维度。
功能
- 捕捉长程依赖:通过注意力权重动态选择相关历史交互。
- 并行计算:多头机制允许模型在不同子空间中学习交互模式。
代码
self.attn = MultiheadAttention(
self.d, self.num_attn_heads, dropout=self.dropout
)
self.attn
对应论文中的多头自注意力机制,用于计算查询、键和值之间的相关性权重。
六、因果性处理(Causality)
在预测第 (t+1) 次练习的结果时,模型只能参考前 (t) 次的互动信息。因此,通过因果层对注意力权重进行掩码处理,确保查询向量 Q i Q_i Qi 只与键向量 K j K_j Kj ,( j ≤ i j \leq i j≤i)计算权重,忽略未来时间步的键向量。
实现方式
- 掩码机制:在计算注意力权重时,屏蔽未来时间步的信息(即 j > i j > i j>i的位置),确保预测仅依赖历史数据。
- 公式中通过将未来位置的注意力权重设为 − ∞ -\infty −∞,使Softmax后趋近于0。
代码
causal_mask = torch.triu(
torch.ones([E.shape[0], M.shape[0]]), diagonal=1
).bool()
- 通过
causal_mask
确保在预测第(t+1)个练习时,只考虑前(t)个交互
七、前馈层(Feed Forward Layer)
前馈层用于在自注意力层的线性组合输出基础上引入非线性,增强模型的表达能力:
公式
[ F = FFN ( S ) = ReLU ( S W ( 1 ) + b ( 1 ) ) W ( 2 ) + b ( 2 ) ] [ \mathbf{F} = \text{FFN}(\mathbf{S}) = \text{ReLU}(\mathbf{S}\mathbf{W}^{(1)} + \mathbf{b}^{(1)}) \mathbf{W}^{(2)} + \mathbf{b}^{(2)} ] [F=FFN(S)=ReLU(SW(1)+b(1))W(2)+b(2)]
- W ( 1 ) , W ( 2 ) ∈ R d × d \mathbf{W}^{(1)}, \mathbf{W}^{(2)} \in \mathbb{R}^{d \times d} W(1),W(2)∈Rd×d:可学习的权重矩阵。
- b ( 1 ) , b ( 2 ) b^{(1)}, b^{(2)} b(1),b(2):为偏置项。
- ReLU:引入非线性激活函数。
功能
- 增强非线性表达能力:弥补自注意力层线性组合的不足。
代码
self.FFN = Sequential(
Linear(self.d, self.d),
ReLU(),
Dropout(self.dropout),
Linear(self.d, self.d),
Dropout(self.dropout),
)
self.FFN
对应论文中的前馈神经网络,用于引入非线性并捕捉不同潜在维度之间的交互
八、残差连接与层归一化(Residual Connections and Layer Normalization)
(一)残差连接
残差连接将较低层的特征直接传播到较高层,有助于模型在训练过程中保持梯度的稳定流动,避免梯度消失问题。在自注意力层和前馈层之后分别应用残差连接。
- 残差连接(Residual Connection):将子层(如自注意力、FFN)的输入与输出相加,保留底层特征。
[ S out = S in + SubLayer ( S in ) ] [ \mathbf{S}_{\text{out}} = \mathbf{S}_{\text{in}} + \text{SubLayer}(\mathbf{S}_{\text{in}})] [Sout=Sin+SubLayer(Sin)]
(二)层归一化
层归一化对输入进行归一化处理,稳定模型的训练过程并加速收敛。在自注意力层和前馈层之后均应用层归一化。
- 层归一化(Layer Normalization):对每个样本的特征维度进行归一化,加速训练收敛。
代码
S = self.attn_layer_norm(S + M + E)
F = self.FFN_layer_norm(F + S)
- 残差连接通过将输入加到输出上实现,层归一化则通过对每一层的输出进行归一化来稳定训练过程。
九、预测层(Prediction Layer)
预测层通过全连接网络结合 Sigmoid 激活函数,将经过多层处理后的特征向量映射到学生正确作答练习的概率:
[ p i = Sigmoid ( F i w + b ) ] [p_i = \text{Sigmoid}(\mathbf{F}_i \mathbf{w} + \mathbf{b})] [pi=Sigmoid(Fiw+b)]
- F i \mathbf{F}_i Fi:前馈网络输出的第 i i i行向量。
- w , b \mathbf{w}, \mathbf{b} w,b:可学习的权重和偏置。
- Sigmoid:将输出映射到[0,1],表示学生回答正确的概率。
- P i \mathbf{P}_i Pi:表示学生正确回答练习 e i e_i ei 的概率。
代码
self.pred = Linear(self.d, 1)
self.pred
对应论文中的全连接层,用于将模型的输出映射到单个值,表示学生对练习的正确响应概率。
十、模型训练(Network Training)
模型的训练目标是最小化学生响应序列在模型下的负对数似然,即最小化预测概率与实际结果之间的交叉熵损失:
[ L = − ∑ t ( r t log ( p t ) + ( 1 − r t ) log ( 1 − p t ) ) ] [\mathcal{L} = -\sum_{t} \left( r_t \log(p_t) + (1 - r_t) \log(1 - p_t) \right)] [L=−∑t(rtlog(pt)+(1−rt)log(1−pt))]
- r t ∈ { 0 , 1 } r_t \in \{0, 1\} rt∈{0,1}:学生实际回答是否正确。
- 交叉熵损失:最小化预测概率与真实标签的差异。
- P t \mathbf{P}_t Pt:模型预测的正确性概率。
代码
loss = binary_cross_entropy(p, t)
- 使用二元交叉熵损失函数计算预测值
p
和真实标签t
之间的差异。
以上公式和组件共同构成了 SAKT 模型的技术核心,使其能够有效地从学生的历史互动中识别出相关的知识概念,并准确预测其在下一个练习中的表现。
其公式设计紧密围绕稀疏数据场景下的泛化需求,并通过可视化注意力权重增强了模型的可解释性。