学习笔记:AlphaZero中的蒙特卡洛搜索过程

导入:

最近在做一个AlphaZero下棋的毕设,在Paddle上copy了代码,对其中蒙特卡洛搜索有很多不懂的地方。在查了一些资料,结合代码,勉强是捋顺了这个过程,因此打算记录下来。


AlphaZero的论文中是这样描述蒙特卡洛搜索的:

https://zhuanlan.zhihu.com/p/34433581
a.每次模拟通过选择具有最大行动价值Q的边加上取决于所存储的先验概率P和该边的访问计数N(每次访问都被增加一次)的上限置信区间U来遍历树。
b.展开叶子节点,通过神经网络 ( P ( s , ⋅ ) , V ( s ) ) = f θ ( s ) (P(s,·),V(s))=f_{\theta}(s) (P(s,⋅),V(s))=fθ(s)来评估局面s;向量P的值存储在叶子结点扩展的边上。
c.更新行动价值Q等于在该行动下的子树中的所有评估值V的均值。
d.一旦MCTS搜索完成,返回局面s下的落子概率π,与 N 1 / τ N^{1/\tau} N1/τ成正比,其中N是从根状态每次移动的访问计数, τ是控制温度的参数。


笔记

这个“上限置信区间”在AlphaZero采用的是一种变体(Upper Confidence Bound applied to Trees, UCB1 for Trees 或简称 UCB for Trees)来选择要扩展的节点。这个公式结合了节点的胜率(或称为“平均回报”)和访问次数,以在探索和利用之间找到平衡。
其公式为: UCB ( s , a ) = W ( s , a ) N ( s , a ) + c p u c t ⋅ P ( s , a ) ⋅ ∑ b N ( s , b ) 1 + N ( s , a ) \text {UCB}(s,a)=\frac {W(s,a)}{N(s,a)}+c_{puct}·P(s,a)·\frac {\sqrt{\sum_{b}N(s,b)}}{1+N(s,a)} UCB(s,a)=N(s,a)W(s,a)+cpuctP(s,a)1+N(s,a)bN(s,b)

s s s是当前的状态(棋盘上的配置)。
a a a是从状态 s s s可以采取的一个动作(或称为“边”或“移动”)。
W ( s , a ) W(s,a) W(s,a)是从状态 s s s采取动作 a a a后,通过模拟得到的总回报。在AlphaZero中, W ( s , a ) W(s,a) W(s,a)的值代表节点在次访问过程内,每次访问 t t t时刻下模拟的最新的回报 Q ( s t , a t ) Q(s_{t},a_{t}) Q(st,at)之和,即 W ( s , a ) = ∑ t = 0 N ( s , a ) Q ( s t , a t ) W(s,a)=\sum_{t=0}^{N(s,a)}Q(s_{t},a_{t}) W(s,a)=t=0N(s,a)Q(st,at)

  • 当在状态 s t s_{t} st下选择动作 a t a_{t} at后,如果对局结束,则 Q ( s t , a t ) Q(s_{t},a_{t}) Q(st,at)根据胜负结果其值为1(胜利)、-1(失败)、0(平局)
  • 而当在 s t s_{t} st下选择动作 a t a_{t} at后,对局还没有结束,则 Q ( s t , a t ) Q(s_{t},a_{t}) Q(st,at)的值由神经网络价值头预测,其预测值在[0,1]内,符号由当前玩家决定。

N ( s , a ) N(s,a) N(s,a)是从状态s选择动作a的次数(即该边的访问次数)。
c p u c t c_{puct} cpuct是一个常数,用于控制探索和利用之间的权衡。这个常数通常是通过超参数调整来确定的。
P ( s , a ) P(s,a) P(s,a)是由神经网络策略头给出的从状态 s s s采取动作 a a a的先验概率。这个概率是基于当前状态 s s s的神经网络策略头输出。
∑ b N ( s , b ) \sum_{b}N(s,b) bN(s,b)是从状态s出发的所有合法动作的访问次数之和。


搜索时每次进行模拟的步骤:

每次从一个棋局状态开始蒙特卡洛搜索时,先令这个状态 s s s为根节点。其所有的合法动作 a a a,组成 s − a s-a sa的对,称作一个节点或者一条边。

    def _playout(self, state):
        """
        进行一次搜索,根据叶节点的评估值进行反向更新树节点的参数
        注意:state已就地修改,因此必须提供副本
        """
        node = self._root
        while True:
            if node.is_leaf():
                break
            # 贪心算法选择下一步行动
            action, node = node.select(self._c_puct)
            state.do_move(action)

        # 使用网络评估叶子节点,网络输出(动作,概率)元组p的列表以及当前玩家视角的得分[-1, 1]
        action_probs, leaf_value = self._policy(state)
        # 查看游戏是否结束
        end, winner = state.game_end()
        if not end:
            node.expand(action_probs)
        else:
            # 对于结束状态,将叶子节点的值换成1或-1
            if winner == -1:  # Tie
                leaf_value = 0.0
            else:
                leaf_value = (
                    1.0 if winner == state.get_current_player_id() else -1.0
                )
        # 在本次遍历中更新节点的值和访问次数
        # 必须添加符号,因为两个玩家共用一个搜索树
        node.update_recursive(-leaf_value)

①选择:贪心算法计算所有合法叶子节点的 UCB ( s , a ) \text {UCB}(s,a) UCB(s,a),选择最大的节点访问。

    def select(self, c_puct):
        """
        在子节点中选择能够提供最大的Q+U的节点
        return: (action, next_node)的二元组
        """
        return max(self._children.items(),
                   key=lambda act_node: act_node[1].get_value(c_puct))
    def get_value(self, c_puct):
        """
        计算并返回此节点的值,它是节点评估Q和此节点的先验的组合
        c_puct: 控制相对影响(0, inf)
        """
        self._u = (c_puct * self._P *
                   np.sqrt(self._parent._n_visits) / (1 + self._n_visits))
        return self._Q + self._u

此处的选择的公式即为上文的UCB1公式。

②扩展:在选择了这个叶子节点到达新的状态 s ′ s' s下,预测后续的合法动作 a ′ a' a的概率,得到延伸的新的叶子节点 s ′ − a ′ s'-a' sa

    def expand(self, action_priors):  # 这里把不合法的动作概率全部设置为0
        """通过创建新子节点来展开树"""
        for action, prob in action_priors:
            if action not in self._children:
                self._children[action] = TreeNode(self, prob)

③检查游戏状态:检查当前访问这个节点下棋局状态是否结束(胜利、失败、平局)。

④求值:根据游戏状态,得到当前节点的 Q ( s , a ) Q(s,a) Q(s,a),详见上文 Q ( s , a ) Q(s,a) Q(s,a)的定义。

⑤回溯:逐层根据 Q ( s , a ) Q(s,a) Q(s,a)更新父节点,直到根节点。

    def update_recursive(self, leaf_value):
        """就像调用update()一样,但是对所有直系节点进行更新"""
        # 如果它不是根节点,则应首先更新此节点的父节点
        if self._parent:
            self._parent.update_recursive(-leaf_value)
        self.update(leaf_value)
    def update(self, leaf_value):
        """
        从叶节点评估中更新节点值
        leaf_value: 这个子节点的评估值来自当前玩家的视角
        """
        # 统计访问次数
        self._n_visits += 1
        # 更新Q值,取决于所有访问次数的平均树,使用增量式更新方式
        self._Q += 1.0 * (leaf_value - self._Q) / self._n_visits

这里的更新公式的推导如下:
https://zhuanlan.zhihu.com/p/53843114
因为蒙特卡洛树这里的节点价值,相当于通过求所有观测的回报值(价值)的平均值来估计值函数。
因此其一般形式可以表示为: Q n = R 1 + R 2 + R 3 + . . . + R n − 1 n − 1 Q_{n}=\frac{R_{1}+R_{2}+R_{3}+...+R_{n-1}}{n-1} Qn=n1R1+R2+R3+...+Rn1

则下一次 Q n + 1 Q_{n+1} Qn+1可以表示为
Q n + 1 = 1 n ∑ i = 1 n R i = 1 n ( ∑ i = 1 n − 1 R i + R n ) = 1 n ( R n + ( n − 1 ) 1 n − 1 ∑ i = 1 n − 1 R i ) = 1 n ( R n + ( n − 1 ) Q n ) = 1 n ( R n + n Q n − Q n ) = Q n + 1 n [ R n − Q n ] \begin{split} Q_{n+1} &=\frac{1}{n} \sum_{i=1}^{n}R_{i}\\ &=\frac{1}{n} (\sum_{i=1}^{n-1}R_{i}+R_{n})\\ &=\frac{1}{n} (R_{n}+(n-1)\frac{1}{n-1}\sum_{i=1}^{n-1}R_{i})\\ &=\frac{1}{n}(R_{n}+(n-1)Q_{n})\\ &=\frac{1}{n}(R_{n}+nQ_{n}-Q_{n})\\ &=Q_{n}+\frac{1}{n}[R_{n}-Q_{n}] \end{split} Qn+1=n1i=1nRi=n1(i=1n1Ri+Rn)=n1(Rn+(n1)n11i=1n1Ri)=n1(Rn+(n1)Qn)=n1(Rn+nQnQn)=Qn+n1[RnQn]
AlphaZero对每步棋都进行成千上万次这样的模拟,最后根据节点的访问次数 N ( s , a ) N(s,a) N(s,a)计算概率,来决定执行的棋步。


参考代码和文章:
时间女神 https://aistudio.baidu.com/projectdetail/4215743
程世东 https://zhuanlan.zhihu.com/p/34433581
Alvin https://zhuanlan.zhihu.com/p/53843114
UQI-LIUWJ https://blog.csdn.net/qq_40206371/article/details/125131928

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值