导入:
最近在做一个AlphaZero下棋的毕设,在Paddle上copy了代码,对其中蒙特卡洛搜索有很多不懂的地方。在查了一些资料,结合代码,勉强是捋顺了这个过程,因此打算记录下来。
AlphaZero的论文中是这样描述蒙特卡洛搜索的:
https://zhuanlan.zhihu.com/p/34433581
a.每次模拟通过选择具有最大行动价值Q的边加上取决于所存储的先验概率P和该边的访问计数N(每次访问都被增加一次)的上限置信区间U来遍历树。
b.展开叶子节点,通过神经网络
(
P
(
s
,
⋅
)
,
V
(
s
)
)
=
f
θ
(
s
)
(P(s,·),V(s))=f_{\theta}(s)
(P(s,⋅),V(s))=fθ(s)来评估局面s;向量P的值存储在叶子结点扩展的边上。
c.更新行动价值Q等于在该行动下的子树中的所有评估值V的均值。
d.一旦MCTS搜索完成,返回局面s下的落子概率π,与
N
1
/
τ
N^{1/\tau}
N1/τ成正比,其中N是从根状态每次移动的访问计数, τ是控制温度的参数。
笔记
这个“上限置信区间”在AlphaZero采用的是一种变体(Upper Confidence Bound applied to Trees, UCB1 for Trees 或简称 UCB for Trees)来选择要扩展的节点。这个公式结合了节点的胜率(或称为“平均回报”)和访问次数,以在探索和利用之间找到平衡。
其公式为:
UCB
(
s
,
a
)
=
W
(
s
,
a
)
N
(
s
,
a
)
+
c
p
u
c
t
⋅
P
(
s
,
a
)
⋅
∑
b
N
(
s
,
b
)
1
+
N
(
s
,
a
)
\text {UCB}(s,a)=\frac {W(s,a)}{N(s,a)}+c_{puct}·P(s,a)·\frac {\sqrt{\sum_{b}N(s,b)}}{1+N(s,a)}
UCB(s,a)=N(s,a)W(s,a)+cpuct⋅P(s,a)⋅1+N(s,a)∑bN(s,b)
s
s
s是当前的状态(棋盘上的配置)。
a
a
a是从状态
s
s
s可以采取的一个动作(或称为“边”或“移动”)。
W
(
s
,
a
)
W(s,a)
W(s,a)是从状态
s
s
s采取动作
a
a
a后,通过模拟得到的总回报。在AlphaZero中,
W
(
s
,
a
)
W(s,a)
W(s,a)的值代表节点在次访问过程内,每次访问
t
t
t时刻下模拟的最新的回报
Q
(
s
t
,
a
t
)
Q(s_{t},a_{t})
Q(st,at)之和,即
W
(
s
,
a
)
=
∑
t
=
0
N
(
s
,
a
)
Q
(
s
t
,
a
t
)
W(s,a)=\sum_{t=0}^{N(s,a)}Q(s_{t},a_{t})
W(s,a)=∑t=0N(s,a)Q(st,at)。
- 当在状态 s t s_{t} st下选择动作 a t a_{t} at后,如果对局结束,则 Q ( s t , a t ) Q(s_{t},a_{t}) Q(st,at)根据胜负结果其值为1(胜利)、-1(失败)、0(平局)
- 而当在 s t s_{t} st下选择动作 a t a_{t} at后,对局还没有结束,则 Q ( s t , a t ) Q(s_{t},a_{t}) Q(st,at)的值由神经网络价值头预测,其预测值在[0,1]内,符号由当前玩家决定。
N
(
s
,
a
)
N(s,a)
N(s,a)是从状态s选择动作a的次数(即该边的访问次数)。
c
p
u
c
t
c_{puct}
cpuct是一个常数,用于控制探索和利用之间的权衡。这个常数通常是通过超参数调整来确定的。
P
(
s
,
a
)
P(s,a)
P(s,a)是由神经网络策略头给出的从状态
s
s
s采取动作
a
a
a的先验概率。这个概率是基于当前状态
s
s
s的神经网络策略头输出。
∑
b
N
(
s
,
b
)
\sum_{b}N(s,b)
∑bN(s,b)是从状态s出发的所有合法动作的访问次数之和。
搜索时每次进行模拟的步骤:
每次从一个棋局状态开始蒙特卡洛搜索时,先令这个状态 s s s为根节点。其所有的合法动作 a a a,组成 s − a s-a s−a的对,称作一个节点或者一条边。
def _playout(self, state):
"""
进行一次搜索,根据叶节点的评估值进行反向更新树节点的参数
注意:state已就地修改,因此必须提供副本
"""
node = self._root
while True:
if node.is_leaf():
break
# 贪心算法选择下一步行动
action, node = node.select(self._c_puct)
state.do_move(action)
# 使用网络评估叶子节点,网络输出(动作,概率)元组p的列表以及当前玩家视角的得分[-1, 1]
action_probs, leaf_value = self._policy(state)
# 查看游戏是否结束
end, winner = state.game_end()
if not end:
node.expand(action_probs)
else:
# 对于结束状态,将叶子节点的值换成1或-1
if winner == -1: # Tie
leaf_value = 0.0
else:
leaf_value = (
1.0 if winner == state.get_current_player_id() else -1.0
)
# 在本次遍历中更新节点的值和访问次数
# 必须添加符号,因为两个玩家共用一个搜索树
node.update_recursive(-leaf_value)
①选择:贪心算法计算所有合法叶子节点的 UCB ( s , a ) \text {UCB}(s,a) UCB(s,a),选择最大的节点访问。
def select(self, c_puct):
"""
在子节点中选择能够提供最大的Q+U的节点
return: (action, next_node)的二元组
"""
return max(self._children.items(),
key=lambda act_node: act_node[1].get_value(c_puct))
def get_value(self, c_puct):
"""
计算并返回此节点的值,它是节点评估Q和此节点的先验的组合
c_puct: 控制相对影响(0, inf)
"""
self._u = (c_puct * self._P *
np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
return self._Q + self._u
此处的选择的公式即为上文的UCB1公式。
②扩展:在选择了这个叶子节点到达新的状态 s ′ s' s′下,预测后续的合法动作 a ′ a' a′的概率,得到延伸的新的叶子节点 s ′ − a ′ s'-a' s′−a′。
def expand(self, action_priors): # 这里把不合法的动作概率全部设置为0
"""通过创建新子节点来展开树"""
for action, prob in action_priors:
if action not in self._children:
self._children[action] = TreeNode(self, prob)
③检查游戏状态:检查当前访问这个节点下棋局状态是否结束(胜利、失败、平局)。
④求值:根据游戏状态,得到当前节点的 Q ( s , a ) Q(s,a) Q(s,a),详见上文 Q ( s , a ) Q(s,a) Q(s,a)的定义。
⑤回溯:逐层根据 Q ( s , a ) Q(s,a) Q(s,a)更新父节点,直到根节点。
def update_recursive(self, leaf_value):
"""就像调用update()一样,但是对所有直系节点进行更新"""
# 如果它不是根节点,则应首先更新此节点的父节点
if self._parent:
self._parent.update_recursive(-leaf_value)
self.update(leaf_value)
def update(self, leaf_value):
"""
从叶节点评估中更新节点值
leaf_value: 这个子节点的评估值来自当前玩家的视角
"""
# 统计访问次数
self._n_visits += 1
# 更新Q值,取决于所有访问次数的平均树,使用增量式更新方式
self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits
这里的更新公式的推导如下:
https://zhuanlan.zhihu.com/p/53843114
因为蒙特卡洛树这里的节点价值,相当于通过求所有观测的回报值(价值)的平均值来估计值函数。
因此其一般形式可以表示为:
Q
n
=
R
1
+
R
2
+
R
3
+
.
.
.
+
R
n
−
1
n
−
1
Q_{n}=\frac{R_{1}+R_{2}+R_{3}+...+R_{n-1}}{n-1}
Qn=n−1R1+R2+R3+...+Rn−1
则下一次
Q
n
+
1
Q_{n+1}
Qn+1可以表示为
Q
n
+
1
=
1
n
∑
i
=
1
n
R
i
=
1
n
(
∑
i
=
1
n
−
1
R
i
+
R
n
)
=
1
n
(
R
n
+
(
n
−
1
)
1
n
−
1
∑
i
=
1
n
−
1
R
i
)
=
1
n
(
R
n
+
(
n
−
1
)
Q
n
)
=
1
n
(
R
n
+
n
Q
n
−
Q
n
)
=
Q
n
+
1
n
[
R
n
−
Q
n
]
\begin{split} Q_{n+1} &=\frac{1}{n} \sum_{i=1}^{n}R_{i}\\ &=\frac{1}{n} (\sum_{i=1}^{n-1}R_{i}+R_{n})\\ &=\frac{1}{n} (R_{n}+(n-1)\frac{1}{n-1}\sum_{i=1}^{n-1}R_{i})\\ &=\frac{1}{n}(R_{n}+(n-1)Q_{n})\\ &=\frac{1}{n}(R_{n}+nQ_{n}-Q_{n})\\ &=Q_{n}+\frac{1}{n}[R_{n}-Q_{n}] \end{split}
Qn+1=n1i=1∑nRi=n1(i=1∑n−1Ri+Rn)=n1(Rn+(n−1)n−11i=1∑n−1Ri)=n1(Rn+(n−1)Qn)=n1(Rn+nQn−Qn)=Qn+n1[Rn−Qn]
AlphaZero对每步棋都进行成千上万次这样的模拟,最后根据节点的访问次数
N
(
s
,
a
)
N(s,a)
N(s,a)计算概率,来决定执行的棋步。
参考代码和文章:
时间女神 https://aistudio.baidu.com/projectdetail/4215743
程世东 https://zhuanlan.zhihu.com/p/34433581
Alvin https://zhuanlan.zhihu.com/p/53843114
UQI-LIUWJ https://blog.csdn.net/qq_40206371/article/details/125131928