附Pytorch实践 一图解密AlphaZero( 二 )


那我们来看一下蒙特卡罗搜索树在这里面时如何实现的 。首先是其中的节点:
class Node:def __init__(self, parent=None, proba=None, move=None):self.p = probaself.n = 0self.w = 0self.q = 0self.children = []self.parent = parentself.move = move
其中主要为之前所说的4个属性以及父子节点的指针 。而最后一个move指出了在当前状态下的合法下棋步骤 。在训练的过程中 , 这些值都会被更新 , 那么在更新之后如何通过他们来进行动作的选择呢?
def select(nodes, c_puct=C_PUCT):" Optimized version of the selection based of the PUCT formula "total_count = 0for i in range(nodes.shape[0]):total_count += nodes[i][1]action_scores = np.zeros(nodes.shape[0])for i in range(nodes.shape[0]):action_scores[i] = nodes[i][0] + c_puct * nodes[i][2] * \(np.sqrt(total_count) / (1 + nodes[i][1])) equals = np.where(action_scores == np.max(action_scores))[0]if equals.shape[0] > 0:return np.random.choice(equals)return equals[0]
这里表示的是对于任何一个节点 , 从其所有的子节点当中 , 通过PUCT算法找出最大得分的那个节点 。在这个得分[i]的计算过程中 , 网络预测的概率和该节点被访问的次数都有被考虑 。对于被访问到的非叶子节点继续进行扩展;而如果是叶子节点则进行最终的评估 。至于其中的残差网络模块 , 价值网络 , 策略网络就不再一一叙述了 。详细参考:
?
【附Pytorch实践一图解密AlphaZero】: