目录

20250306-笔记-精读class-CVRPEnvstepself,-selected

20250306-笔记-精读class CVRPEnv:step(self, selected)


前言

class CVRPEnv:step(self, selected) 函数是强化学习代码实现中的核心。

精读该代码的目标:

  1. 熟悉每一个参数的shape。
  2. 熟悉每个参数之间的关系(剪切,扩展,等)。

一、时间步小于 4

1.1 控制时间步的递增

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1
参数Shape含义
self.time_step标量用来控制时间步数
self.selectex_count(batch, pomo)表示每个批次、每个智能体已选择的节点数量。

self.time_step=self.time_step+1

增加 self.time_step 的值,用来控制时间步数。

  • self.time_step 是一个标量(单个整数),表示当前的时间步数。

self.selected_count = self.selected_count + 1

这一行代码的目的是增加 self.selected_count 的值,表示在当前时间步中,智能体已经选择的节点数量增加了。

  • self.selected_count 是一个形状为 (batch_size, pomo_size) 的张量,表示每个批次、每个智能体已选择的节点数量。

1.2 判断是否在配送中心

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)
参数Shape含义
self.at_the_depot(batch, pomo)布尔张量,表示每个智能体是否位于配送中心
selected(batch, pomo)表示每个批次和每个智能体选择的节点编号

这行代码的目的是更新 self.at_the_depot 张量,用来表示每个智能体是否位于配送中心(通常是节点 0)。如果智能体选择的节点编号是 0(配送中心节点),则 self.at_the_depot 对应的位置为 True ,否则为 False

1.3 特定时间步的操作

            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
参数Shape含义
self.time_step标量用来控制时间步数
self.current_node(batch, pomo)示每个批次、每个智能体当前访问的节点。
self.last_current_node(batch, pomo)self.current_node.clone()
self.load(batch, pomo)表示每个智能体当前的负载状态。
self.last_load(batch, pomo)self.load.clone()
self.visited_ninf_flag(batch, pomo, problem + 1)记录每个智能体对每个节点的访问标志。
self.at_the_depot(batch, pomo)布尔张量,表示每个智能体是否位于配送中心

if self.time_step == 3:

self.last_current_node = self.current_node.clone()

self.last_load = self.load.clone()

在时间步为 3 时,保存当前节点( self.current_node )和负载( self.load )的状态。

  • self.current_nodeself.load 在时间步 3 时被保存为 self.last_current_nodeself.last_load 。它们的形状仍然是 (batch_size, pomo_size)

if self.time_step == 4:

self.last_current_node = self.current_node.clone()

self.last_load = self.load.clone()

在时间步为 4 时,再次保存当前节点和负载状态,并更新 self.visited_ninf_flag ,修改智能体在配送中心以外的访问状态。

  • self.current_nodeself.load 在时间步 4 时被保存为 self.last_current_nodeself.last_load 。它们的形状仍然是 (batch_size, pomo_size)

self.visited_ninf_flag[:, :, self.problem_size + 1][(~self.at_the_depot) & (self.last_current_node != 0)] = 0

  • self.visited_ninf_flag :这是一个形状为 (batch_size, pomo_size, problem_size + 1) 的张量,记录每个智能体对每个节点的访问标志。

  • self.visited_ninf_flag[:, :, self.problem_size + 1] :表示对 visited_ninf_flag 张量的切片操作,选取所有批次、所有智能体,并指定第 problem_size + 1 个节点的位置。

    • self.problem_size + 1 指定的是配送中心。
  • (~self.at_the_depot) & (self.last_current_node != 0) :这部分是一个布尔索引,用来筛选符合特定条件的位置。

    • self.at_the_depot 是一个布尔张量,表示每个智能体是否在配送中心( True 表示在配送中心, False 表示不在)。
    • ~self.at_the_depotself.at_the_depot 进行布尔取反,表示哪些智能体不在配送中心。
    • self.last_current_node != 0 判断哪些智能体在时间步 3 时没有选择 配送中心(节点 0) 。
    • (~self.at_the_depot) & (self.last_current_node != 0) 综合起来表示,选择那些不在配送中心并且在时间步 3 时没有选择配送中心的智能体。

1.4更新

1.4.1 更新当前节点和已选择节点列表

            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
参数Shape含义
self.current_node(batch, pomo)
self.selected_node_list(batch, pomo,0~)

注: 0~ 表示第三维度逐渐增加

self.selected_node_list 的shape:

https://i-blog.csdnimg.cn/direct/8efd0138756b49c7b9cbdbb5b0940f3c.png

self.current_node 的shape:

https://i-blog.csdnimg.cn/direct/4cda4b4a612a4e7ba2787e12fa63ead1.jpeg

self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2) ,表示先将 self.current_node 扩展为三维数据,再将 self.current_node 沿着 self.selected_node_list 的第三维度( dim=2 )进行依次剪切进去。

1.4.2 更新需求和负载

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot
参数Shape含义g
self.depot_node_demand(batch, problem + 1)表示每个批次中,每个问题(包括配送中心)对应的节点需求
demand_list(batch, pomo, problem + 1)包含每个节点需求的张量
selected(batch, pomo)表示每个批次中的每个智能体所选择的节点编号(这些节点是从节点集合中选择的)
selected_demand(batch, pomo)示每个智能体所选择节点的需求。

demand_list = self.depot_node_demand[:, None,:].expand(self.batch_size, self.pomo_size, -1)

  • [:, None, :] :先在 self.depot_node_demand 的第二维(即问题维度)上增加一个新的维度,使其变为 (batch_size, 1, problem_size + 1)
  • .expand(self.batch_size, self.pomo_size, -1) :将数据self.depot_node_demand扩展为 (batch_size, pomo_size, problem_size + 1) ,表示每个批次中的每个 POMO 智能体都有一份相同的需求数据。

https://i-blog.csdnimg.cn/direct/a069d22632754fca824e946b40feb5ea.png

gathering_index = selected[:, :, None]

  • selected 进行维度扩展

    https://i-blog.csdnimg.cn/direct/5b2b1d4a3a724e02ae1f900acccc5a8d.png


selected_demand = demand_list.gather(dim=2,index=gathering_index).squeeze(dim=2)

  • demand_list 的 shape 是 (batch_size, pomo_size, problem_size + 1) ,包含了所有节点的需求数据。
  • gather(dim=2, index=gathering_index) 会按照 gathering_index (即 selected 中存储的节点编号)从 demand_list 中选择出对应的节点需求。 dim=2 表示沿着第三维(即问题维度)进行选择。
  • gather 的结果是一个 shape 为 (batch_size, pomo_size, 1) 的张量。
  • .squeeze(dim=2) 去掉了多余的第三维,最终得到 selected_demand ,其 shape 是 (batch_size, pomo_size) ,表示每个智能体所选择节点的需求。

https://i-blog.csdnimg.cn/direct/01c045d769b242fcb90c3f71c936bf10.png

1.4.3 更新访问标记

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot
参数Shape含义
self.visited_ninf_flag(batch, pomo, problem+ 1)记录了每 个智能体(POMO)在每个批次中已访问的节点的信息,标记某些节点是否已经被访问(用负无穷表示)。
self.BATCH_IDX(batch, pomo)批次索引的张量
self.POMO_IDX(batch, pomo)智能体(POMO)索引的张量
selected(batch, pomo)表示每个批次中的每个智能体所选择的节点编号(这些节点是从节点集合中选择的)
self.at_the_depot(batch, pomo)一个布尔型张量,表示每个智能体是否处于配送中心(即该智能体是否在节点 0,通常是配送中心)。

https://i-blog.csdnimg.cn/direct/4fc977549e8f4c21961ce3556316c0e4.png


self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] =float(‘-inf’)

self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] 表示从 visited_ninf_flag 张量中选择出对应批次和智能体的对应位置,并设置为 float('-inf') ,表示这些节点已经被访问过。

举例:

假设我们有以下参数:

  • batch_size = 2,即有 2 个批次。
  • pomo_size = 3,即每个批次有 3 个智能体(POMO)。
  • problem_size = 4,即有 4 个节点(包含配送中心)。
self.visited_ninf_flag = [
    [[  0.,   0.,   0.,   0.,   0.],  # 第一个批次(batch 0)
     [  0.,   0.,   0.,   0.,   0.],  # POMO 0, POMO 1, POMO 2 各自对节点的访问标志
     [  0.,   0.,   0.,   0.,   0.]],
    
    [[  0.,   0.,   0.,   0.,   0.],  # 第二个批次(batch 1)
     [  0.,   0.,   0.,   0.,   0.],
     [  0.,   0.,   0.,   0.,   0.]]
]

self.BATCH_IDX(批次索引):

self.BATCH_IDX = [
    [0, 0, 0],  # 第一个批次
    [1, 1, 1]   # 第二个批次
]

self.POMO_IDX(POMO 索引):

self.POMO_IDX = [
    [0, 1, 2],  # 每个批次中三个智能体的索引
    [0, 1, 2]
]

selected(每个智能体选择的节点):

selected = [
    [1, 2, 0],  # 第一个批次中,智能体选择的节点:POMO 0 选择节点 1,POMO 1 选择节点 2,POMO 2 选择节点 0
    [3, 1, 2]   # 第二个批次中,智能体选择的节点:POMO 0 选择节点 3,POMO 1 选择节点 1,POMO 2 选择节点 2
]

执行这一行代码 self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')

对于第一个批次( BATCH_IDX[0] ),我们有三个智能体( POMO_IDX[0] ),选择了节点 [1, 2, 0] ,分别是:

  • selected[0][0] = 1 表示 POMO 0 选择了节点 1。

  • selected[0][1] = 2 表示 POMO 1 选择了节点 2。

  • selected[0][2] = 0 表示 POMO 2 选择了节点 0。

    对于第二个批次( BATCH_IDX[1] ),我们同样有三个智能体( POMO_IDX[1] ),选择了节点 [3, 1, 2] ,分别是:

  • selected[1][0] = 3 表示 POMO 0 选择了节点 3。

  • selected[1][1] = 1 表示 POMO 1 选择了节点 1。

  • selected[1][2] = 2 表示 POMO 2 选择了节点 2。

更新 visited_ninf_flag : 根据批次索引和 POMO 索引,我们更新了对应位置的值为负无穷 -inf

  • 对于 BATCH_IDX[0]POMO_IDX[0, 1, 2] ,我们将 selected[0][0] = 1selected[0][1] = 2selected[0][2] = 0 位置标记为 -inf
  • 对于 BATCH_IDX[1]POMO_IDX[0, 1, 2] ,我们将 selected[1][0] = 3selected[1][1] = 1selected[1][2] = 2 位置标记为 -inf

self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0

self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0 ,我们将所有不在配送中心的智能体的配送中心访问标志设置为 0

[:, :, 0] 是一个切片操作,表示我们提取张量中的第一个节点(通常是配送中心节点)。

  • ~self.at_the_depot 是对 self.at_the_depot 张量的布尔取反操作,将 True 变为 False ,将 False 变为 True

1.4.4 更新负无穷掩码

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')
参数Shape含义
self.visited_ninf_flag(batch, pomo, problem+ 1)记录了每 个智能体(POMO)在每个批次中已访问的节点的信息,标记某些节点是否已经被访问(用负无穷表示)。
self.ninf_mask(batch, pomo, problem+ 1)self.visited_ninf_flag.clone()
demand_too_large(batch, pomo, problem + 1)每个智能体负载与节点需求的比较结果
_2(batch, pomo, 1)张量 _2 用于扩展 demand_too_large 张量的形状。

self.ninf_mask = self.visited_ninf_flag.clone()

复制 visited_ninf_flag 张量的内容,初始化 ninf_maskself.visited_ninf_flag 是一个形状为 (batch_size, pomo_size, problem_size + 1) 的张量,记录了每个智能体对每个节点的访问状态。

round_error_epsilon = 0.00001

demand_too_large = self.load[:, :, None]+ round_error_epsilon < demand_list

定义一个小的数值误差 round_error_epsilon ,用来避免浮点数运算时的小数误差。

检查每个智能体的负载是否小于当前节点的需求量。

  • self.load[:, :, None] 的形状为 (batch_size, pomo_size, 1) ,是 每个批次中每个智能体的负载 。通过 [:, :, None] 进行扩展,将其转换为三维张量,第三维用于后续与 demand_list 对比。
  • demand_list 是一个形状为 (batch_size, pomo_size, problem_size + 1) 的张量,表示每个智能体选择的节点的需求量。
  • round_error_epsilon 是用于避免计算中的浮动误差。
  • demand_too_large 是一个布尔张量,形状为 (batch_size, pomo_size, problem_size + 1) ,其值为 True 表示该节点的需求量大于当前负载(包括误差修正),为 False 表示需求量不大于负载。

demand_too_large 的形状继承于 demand_list

https://i-blog.csdnimg.cn/direct/22651a2965c140d2b102abf11838e917.jpeg

_2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)

demand_too_large = torch.cat((demand_too_large, _2), dim=2)

创建一个形状为 (batch_size, pomo_size, 1) 的张量 _2 ,其值为 False

  • torch.full() 创建一个所有元素都为 False 的张量,形状为 (batch_size, pomo_size, 1) ,确保将其连接到 demand_too_large 上时,可以对其进行扩展。

_2 连接到 demand_too_large 张量的最后一维。

  • demand_too_large 的形状为 (batch_size, pomo_size, problem_size + 1) ,表示每个智能体负载与节点需求的比较结果。
  • _2 的形状为 (batch_size, pomo_size, 1) ,用于将布尔值 False 填充到 demand_too_large 的最后一维。
  • 通过 torch.cat()_2 拼接到 demand_too_large 后面,得到新的形状 (batch_size, pomo_size, problem_size + 2) ,扩展了一个额外的维度。

https://i-blog.csdnimg.cn/direct/e8b2527675fa413490689eb199de20d9.png

self.ninf_mask[demand_too_large] = float(‘-inf’)

demand_too_largeTrue 的位置,更新 ninf_mask 为负无穷 -inf ,表示这些节点的需求量超过当前负载。

1.4.5 更新步骤状态,将更新后的状态同步到 self.step_state

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask
参数Shape含义
self.time_step一个整数时间步数
self.load(batch, pomo)每个批次所有POMO智能体的负载
self.current_node(batch, pomo)每个批次所有POMO智能体选择的节点
self.ninf_mask(batch, pomo, problem+ 1)记录了每 个智能体(POMO)在每个批次中已访问的节点的信息,标记某些节点是否已经被访问(用负无穷表示)。

二、时间步大于等于 4

2.1 动作模式分类 (action classification):

# 动作模式分类
action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
action2_bool_index = self.mode == 1
action3_bool_index = self.mode == 2
参数Shape含义
self.mode(batch_size, pomo_size)表示每个批次中每个智能体的当前状态模式(mode)。
selected(batch_size, pomo_size)表示每个批次中每个智能体选择的节点编号。
action0_bool_index(batch_size, pomo_size)表示哪些智能体当前处于模式 0 且选择的节点不是 self.problem_size + 1。
action1_bool_index(batch_size, pomo_size)表示哪些智能体当前处于模式 0 且选择了 self.problem_size + 1(即“后悔”模式)。
action2_bool_index(batch_size, pomo_size)表示哪些智能体当前处于模式 1。
action3_bool_index(batch_size, pomo_size)表示哪些智能体当前处于模式 2。

action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))

  • selected 是一个形状为 (batch_size, pomo_size) 的张量,表示每个批次中每个智能体选择的节点编号。
    • selected != self.problem_size + 1 会生成一个布尔张量,表示哪些智能体没有选择 self.problem_size + 1 这个特殊的节点(假设 self.problem_size + 1 是表示一个特定的节点,如“后悔”节点)。

2.2 动作索引与选择计数更新:

action1_index = torch.nonzero(action1_bool_index)
action2_index = torch.nonzero(action2_bool_index)

action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

# 更新选择计数
self.selected_count = self.selected_count + 1
# 后悔模式
self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2
参数Shape含义
action1_bool_index(N, 2) ,其中 N 是符合条件的元素个数, 2 表示 [batch_idx, pomo_idx] 两个维度所有满足“后悔模式”条件的智能体的批次和智能体索引
action1_bool_index(batch_size, pomo_size)表示每个批次、每个智能体是否符合 action1 条件
action2_index(N, 2) ,其中 N 是符合条件的元素个数, 2 表示 [batch_idx, pomo_idx] 两个维度所有处于模式 1 的智能体的批次和智能体索引
action2_bool_index(batch_size, pomo_size)表示每个批次、每个智能体是否符合 action2 条件
action4_index(N, 2) ,其中 N 是符合条件的元素个数, 2 表示 [batch_idx, pomo_idx] 两个维度所有处于模式 2 且当前节点不为配送中心的智能体的批次和智能体索引
action3_bool_index(batch_size, pomo_size)表示每个批次、每个智能体是否符合 action3 条件
self.selected_count(batch_size, pomo_size)表示每个批次、每个智能体选择的节点数量。

action1_index = torch.nonzero(action1_bool_index)

这行代码通过 torch.nonzero 获取所有满足 action1_bool_index 条件的位置索引,表示那些处于模式 0 且选择了 self.problem_size + 1 (可能是“后悔模式”)的智能体。

  • torch.nonzero(action1_bool_index) 会返回一个张量,包含所有为 True 的位置的索引。返回的索引张量的形状为 (N, 2) ,其中 NTrue 的数量,第一列是批次索引,第二列是智能体索引。

action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

  • action3_bool_index 是一个布尔张量,形状为 (batch_size, pomo_size) ,表示每个批次、每个智能体是否符合 action3 条件(即处于模式 2)。
  • self.current_node != 0 生成一个布尔张量,表示当前选择的节点不为配送中心(节点编号 0)。

self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

这行代码针对处于后悔模式 (action1_bool_index) 的智能体,将它们的选择计数减去 2。

  • action1_bool_index 是一个形状为 (batch_size, pomo_size) 的布尔张量,表示哪些智能体处于后悔模式。
  • self.selected_count[action1_bool_index] 会提取所有处于后悔模式的智能体的选择计数。
  • 将这些智能体的选择计数减去 2,表示它们在当前步骤的选择无效,可能需要补偿或调整。

2.3 节点更新:

# 节点更新
self.last_is_depot = (self.last_current_node == 0)

_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
self.last_current_node = self.current_node.clone()
self.current_node = selected.clone()
self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

# 更新已选择节点列表
self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
参数Shape含义
self.last_current_node(batch_size, pomo_size)表示每个批次中每个智能体在上一个时间步选择的节点。
self.last_is_depot(batch_size, pomo_size)表示每个智能体是否选择了节点 0(配送中心)。
action1_index(N, 2)表示处于“后悔模式”(mode == 0 且选择了特殊节点 self.problem_size + 1)的智能体的索引。
_(N,) ,其中 N 是处于“后悔模式”下的智能体数量每个位置的值是上一个时间步智能体选择的节点编号。
temp_last_current_node_action2(N,) ,其中 N 是处于模式 1 下的智能体数量每个位置的值是上一个时间步智能体选择的节点编号。是一个 一维向量 ,其具体元素的个数是 不定 的,取决于符合 动作 2 条件的智能体数量。
selected(batch_size, pomo_size)表示当前时间步每个智能体选择的节点。
self.current_node(batch_size, pomo_size)表示当前时间步每个智能体选择的节点。
self.selected_node_list(batch_size, pomo_size, num_selected_nodes)表示每个批次中每个智能体已经选择的节点列表。

_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()

  • self.last_current_node[action1_index[:, 0], action1_index[:, 1]] 提取了那些处于“后悔模式”下的智能体在上一个时间步选择的节点。

  • _ 的形状是 (N,) ,其中 N 是处于“后悔模式”下的智能体数量。每个位置的值是上一个时间步智能体选择的节点编号。

    https://i-blog.csdnimg.cn/direct/d5099075484f4ab1a2eb49e864962027.png

temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()

  • self.last_current_node[action2_index[:, 0], action2_index[:, 1]] 提取了那些处于模式 1 下的智能体在上一个时间步选择的节点。

self.current_node = selected.clone()

self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

  • self.current_node[action1_index[:, 0], action1_index[:, 1]] 获取这些智能体在当前时间步的节点位置。
  • _.clone() 将之前保存的智能体在上一个时间步选择的节点(即 _ )赋值给这些智能体在当前时间步的节点。

self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

  • 更新 self.selected_node_list ,将当前时间步的节点选择添加到已选择的节点列表中。
  • torch.cat((self.selected_node_list, selected[:, :, None]), dim=2) 将当前时间步选择的节点添加到 self.selected_node_list 中,使得 selected_node_list 更新为包含当前时间步节点的列表。

2.3.1节点更新(后悔操作)

https://i-blog.csdnimg.cn/direct/a56a4b3b3a124f65a51a15fba4d6be5a.png

https://i-blog.csdnimg.cn/direct/ea37c4a331ac437f8b5d6136c889c69a.png

上图是 后悔模式 的节点更新的示意图,需要注意代码中三个变量在传递信息中的顺序问题。具体代码如下:

_ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()

self.last_current_node = self.current_node.clone()

self.current_node = selected.clone()

self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

self.last_current_nodeself.current_nodeselected 三者是信息传递关系。 selected 是最新的信息 self.current_node 次之 self.last_current_node 最后。

若为 后悔操作的节点更新self.current_nodeself.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone() ,若为 正常节点更新self.current_nodeselected.clone()


2.4 负载更新:

# 更新负载
self.at_the_depot = (selected == 0)
demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
demand_list = torch.cat((demand_list, _3), dim=2)
gathering_index = selected[:, :, None]
selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
self.last_load = self.load.clone()
self.load -= selected_demand
self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
self.load[self.at_the_depot] = 1  # refill loaded at the depot
参数Shape含义
self.at_the_depot(batch, pomo)标记每个智能体是否在配送中心。
depot_node_demand(batch, problem + 1)表示每个节点的需求,包括配送中心。
demand_list(batch, pomo_size, problem+1)每个智能体所对应的所有节点的需求。
selected(batch, pomo)表示当前批次和 POMO(多智能体)中选择的节点。
_3(batch, pomo_size, 1)一个临时张量,用于后续扩展 demand_list。
gathering_index(batch, pomo_size, 1)用于指定要从 demand_list 中收集哪些需求。
selected_demand(batch, pomo_size)表示每个智能体选择的节点的需求。
action1_index(N, 2)表示处于“后悔模式”(mode == 0 且选择了特殊节点 self.problem_size + 1)的智能体的索引。
self.load(batch, pomo_size)每个智能体(POMO agent)当前的负载状态。
self.last_load(batch, pomo)存储上一次决策时,每个智能体的载重情况。

self.at_the_depot = (selected == 0)

  • 该行代码根据 selected 来判断是否选择了配送中心。
    • 如果选择的节点是配送中心(节点编号为 0),则 self.at_the_depotTrue ,否则为 False

demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)

  • 这行代码扩展了 self.depot_node_demand ,使其能够适应批量和多个智能体的需求。
    • self.depot_node_demand[:, None, :]self.depot_node_demand 的维度从 (batch, problem+1) 扩展到 (batch, 1, problem+1) ,在第二维上添加一个新维度。
    • .expand(self.batch_size, self.pomo_size, -1) 将该 tensor 扩展到 (batch, pomo_size, problem+1) ,复制需求数据,使每个智能体都能访问这些数据。
      • -1expand() 方法中的作用是 保持该维度的大小不变,让 PyTorch 自动推导最后一个维度的大小。

https://i-blog.csdnimg.cn/direct/a069d22632754fca824e946b40feb5ea.png

_3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)

  • 创建一个全为 0 的张量 _3 ,它的形状与 demand_list 相同,但在最后一维有一个额外的维度。

demand_list = torch.cat((demand_list, _3), dim=2)

  • 这行代码将 _3 张量连接到 demand_list 的最后一维( dim=2 )。
    • demand_list 的原始形状是 (batch, pomo_size, problem+1) ,它表示每个智能体的需求。
    • _3 的形状是 (batch, pomo_size, 1) ,它会被附加到 demand_list 的最后一维,使其形状变为 (batch, pomo_size, problem+2)

gathering_index = selected[:, :, None]

  • gathering_index 用于指定要从 demand_list 中收集哪些需求。
    • selected 的形状为 (batch, pomo_size) ,表示每个智能体选择的节点。
    • [:, :, None] 会将 selected 的形状从 (batch, pomo_size) 转换为 (batch, pomo_size, 1) ,增加一个新的维度,使其可以用作 gather 的索引。

selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)

  • 该行代码根据 gathering_indexdemand_list 中提取所选节点的需求,并移除不必要的维度。
    • demand_list.gather(dim=2, index=gathering_index) 会根据 gathering_index 提取每个智能体选择节点的需求。 dim=2 表示从最后一维(需求)中选取。
    • .squeeze(dim=2) 去除 gather 后产生的单一维度,使 selected_demand 的形状从 (batch, pomo_size, 1) 转变为 (batch, pomo_size)

https://i-blog.csdnimg.cn/direct/12adb3c35fb148f48148e99e654b63a4.png

_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()

self.last_load = self.load.clone()

self.load -= selected_demand

  • 该段代码先将 _1 更新为之前的 self.last_load ,在此之后 self.last_load 已更新为当前的 self.load
    • self.last_load 中,指定位置的负载克隆到 _1
      • action1_index 是一个包含批次和 POMO 索引的 tensor。
      • self.last_load 是一个形状为 (batch, pomo_size) 的 tensor,表示每个智能体的负载。
      • action1_index 提供了批次和智能体的索引,所以通过这些索引来获取 self.last_load 中的值。
    • 将当前负载 self.load 的值克隆到 self.last_load 中,以便后续使用。
    • 根据 selected_demand 更新负载。
      • selected_demand 是每个智能体选择的节点需求,形状为 (batch, pomo_size)

self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()

  • 这行代码将 action1_index 索引位置的负载值恢复为 _1 (即之前的负载值)。
    • action1_index 是一个 tensor,表示在某些情况下需要恢复负载的智能体的索引。
    • 通过该索引, 将 self.load 中的负载恢复为之前保存的 _1

self.load[self.at_the_depot] = 1

  • 这行代码将位于配送中心的智能体的负载设置为 1,表示它们已被重新加载。
    • self.at_the_depot 是一个形状为 (batch, pomo_size) 的布尔值 tensor,表示哪些智能体在配送中心。
    • self.load[self.at_the_depot] = 1 将这些智能体的负载恢复为 1 ,表示它们重新装载。

2.4.1 后悔操作的负载信息传递:

selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)

_1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()

self.last_load= self.load.clone() # shape: (batch, pomo)

self.load -= selected_demand

self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()

https://i-blog.csdnimg.cn/direct/1631d62ac5664ef39a386b1c84d2beda.png

若是后悔操作则 self.load_1.clone() ,若为正常负载信息更新则为数据 self.load -= selected_demand


2.5 访问标记更新:

# 更新访问标记
self.visited_ninf_flag[:, :, self.problem_size + 1][self.last_is_depot] = 0
self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
self.visited_ninf_flag[:, :, self.problem_size + 1][self.at_the_depot] = float('-inf')
self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0
参数Shape含义
self.visited_ninf_flag(batch, pomo_size, problem+2)用于存储每个节点的访问状态。
self.last_is_depot(batch, pomo_size)表示上一步的节点是否是配送中心(True 表示在配送中心,False 表示不在)。
self.BATCH_IDX(batch, pomo_size)批次的索引。
self.POMO_IDX(batch, pomo_size)POMO智能体的索引。
selected(batch, pomo_size)表示当前每个智能体选择的节点编号。
action2_index(N, 2)其中 N 是符合 mode=1 条件的智能体数量,每一行 (batch_idx, pomo_idx) 指定了具体的智能体索引。
temp_last_current_node_action2(N,)表示 action2_index 对应的智能体在上一步访问的节点编号。是一个 一维向量 ,其具体元素的个数是 不定 的,取决于符合 动作 2 条件的智能体数量。
action4_index(M, 2)其中 M 是满足 mode=2 (表示特定的选择模式)并且 current_node ≠ 0 的智能体数量,每一行 (batch_idx, pomo_idx) 指定了具体的智能体索引。
self.at_the_depot(batch, pomo_size)表示哪些智能体位于配送中心。

self.visited_ninf_flag[:, :, self.problem_size + 1][self.last_is_depot] = 0

  • 重置 self.problem_size + 1 位置的 visited_ninf_flag ,以允许节点( problem_size+1 )重新被访问。
    • self.visited_ninf_flag倒数第二个索引 self.problem_size + 1 位置的值设为 0 ,但仅限于 上一步在配送中心的智能体 (self.last_is_depot)
    • self.problem_size + 1 这个索引通常用于表示一个特殊状态(例如 “后悔” 或 “未选择” 状态)。

self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float(‘-inf’)

  • 当前选择的节点 的访问标记设为 -inf ,表示这些节点已经被访问,防止它们被重复选择。

self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)

  • 恢复某些节点的访问权限 ,即:对于 mode=1 (表示后悔操作)的智能体,重新允许访问它们 上次的节点
    • action2_index 的形状是 (N, 2) ,其中 N 是符合 mode=1 条件的智能体数量,每一行 (batch_idx, pomo_idx) 指定了具体的智能体索引。
    • temp_last_current_node_action2 的形状是 (N,) ,表示 action2_index 对应的智能体在上一步访问的节点编号。是一个 一维向量 ,其具体元素的个数是 不定 的,取决于符合 动作 2 条件的智能体数量。
    • self.visited_ninf_flag 的形状是 (batch, pomo_size, problem+2)

https://i-blog.csdnimg.cn/direct/ed40d3a5e625452498d33cba3cad9536.png

self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)

  • 针对 action3_bool_index & (self.current_node != 0) 的情况,重新启用 self.problem_size + 1 位置的访问权限。
    • action4_index 的形状是 (M, 2) ,其中 M 是满足 mode=2 (表示特定的选择模式)并且 current_node ≠ 0 的智能体数量,每一行 (batch_idx, pomo_idx) 指定了具体的智能体索引。
    • self.problem_size + 1 表示 后悔操作

self.visited_ninf_flag[:, :, self.problem_size + 1][self.at_the_depot] = float(‘-inf’)

  • 在配送中心的智能体,不允许选择 problem_size + 1 这个特殊状态。
    • self.problem_size + 1 表示 后悔操作
    • self.visited_ninf_flag[:, :, self.problem_size + 1] 选取 visited_ninf_flagproblem_size + 1 位置,得到形状 (batch, pomo_size)
    • [self.at_the_depot] 选择所有位于配送中心的智能体,并把它们的 problem_size+1 位置设置为 -inf ,防止它们选择这个状态。

self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0

  • 如果不在配送中心,则允许访问配送中心(节点 0)。
    • self.visited_ninf_flag[:, :, 0] 选取 visited_ninf_flag 的第 0 个索引(即配送中心的访问状态),形状为 (batch, pomo_size)
    • [~self.at_the_depot] 选取所有不在配送中心的智能体,并将它们的 visited_ninf_flag 设为 0 ,允许它们 重新访问配送中心

2.6 负无穷掩码更新:

# 更新负无穷掩码
self.ninf_mask = self.visited_ninf_flag.clone()
round_error_epsilon = 0.00001
demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
self.ninf_mask[demand_too_large] = float('-inf')
参数Shape含义
self.ninf_mask(batch, pomo_size, problem+2)visited_ninf_flag 包含已访问节点的屏蔽信息和需求超过当前负载的节点的屏蔽信息。
self.visited_ninf_flag(batch, pomo_size, problem+2)存储每个智能体对每个节点的访问状态。
demand_too_large(batch, pomo_size, problem+2)标记哪些节点的需求大于当前负载。
demand_list(batch, pomo_size, problem+2)表示每个智能体对每个节点的需求。
self.load(batch, pomo_size)表示当前每个智能体的负载。

self.ninf_mask = self.visited_ninf_flag.clone()

  • self.visited_ninf_flag 的形状是 (batch, pomo_size, problem+2) ,用于存储每个智能体对每个节点的访问状态。

round_error_epsilon = 0.00001

demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list

  • 定义一个极小的正数 round_error_epsilon ,用于浮点数计算中的舍入误差处理。避免 load 与 demand_list 直接比较时因为精度问题导致错误。
  • 计算一个布尔掩码 demand_too_large ,标记哪些节点的需求大于当前负载,这些节点应该被屏蔽。
    • self.load[:, :, None] 通过 None扩展维度,这样就可以与 demand_list 进行逐元素比较。
    • demand_list.shape == (batch, pomo_size, problem+2) ,表示每个智能体对每个节点的需求。
    • 生成 demand_too_large ,形状为 (batch, pomo_size, problem+2)
      • True 表示该节点的需求大于当前负载(不能被选择)。
      • False 表示该节点的需求在当前负载允许范围内。

self.ninf_mask[demand_too_large] = float(‘-inf’)

屏蔽负载不足的节点 ,确保它们不会被智能体选择。

  • self.ninf_mask[demand_too_large] = float('-inf')
    • 找到 demand_too_largeTrue 的位置(即负载不足的节点)。
    • self.ninf_mask 中,将这些位置的值设为 -inf ,防止它们被选中。

2.7 完成状态更新:

# 更新完成状态
newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
self.finished = self.finished + newly_finished
参数Shape含义
newly_finished(batch, pomo_size)标记哪些智能体已经访问了所有节点,即它们的路径已经完成。
self.visited_ninf_flag(batch, pomo_size, problem+2)记录每个智能体对各个节点的访问状态 (-inf 代表已访问)
self.finished(batch, pomo_size)表示智能体是否完成任务

newly_finished = (self.visited_ninf_flag == float(‘-inf’))[:,:,:self.problem_size+1].all(dim=2)

计算一个布尔掩码 newly_finished ,标记哪些智能体已经访问了所有节点,即它们的路径已经完成。

  • newly_finished 形状为 (batch, pomo_size) ,其中:
    • True 表示该智能体已经访问了所有任务点,旅行完成。

      False 表示该智能体仍有未访问的节点。

逻辑

  • self.visited_ninf_flag == float('-inf')
    • 生成一个布尔张量,表示每个节点是否已访问:
      • True (已访问)
      • False (未访问)
  • [:, :, :self.problem_size+1]
    • 仅保留任务节点部分,不包括 problem+2 的特殊状态。
    • 形状变为 (batch, pomo_size, problem+1)
  • .all(dim=2)
    • dim=2 (节点维度)上执行 all():
      • 若智能体已访问所有 problem+1 个节点,则返回 True
      • 若有未访问的节点,则返回 False
    • 形状变为 (batch, pomo_size) ,每个智能体对应一个布尔值。

self.finished = self.finished + newly_finished

更新 self.finished ,标记哪些智能体已经完成任务。

2.8 模式更新与掩码调整:

# 更新模式
self.mode[action1_bool_index] = 1
self.mode[action2_bool_index] = 2
self.mode[action3_bool_index] = 0
self.mode[self.finished] = 4

# 更新完成后的掩码调整
self.ninf_mask[:, :, 0][self.finished] = 0
self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float('-inf')
参数Shape含义
self.mode(batch, pomo_size)智能体的行为模式
action1_bool_index(batch, pomo_size)选中的智能体执行模式 1
action2_bool_index(batch, pomo_size)选中的智能体执行模式 2
action3_bool_index(batch, pomo_size)选中的智能体执行模式 0
self.finished(batch, pomo_size)选中的智能体进入模式 4
self.ninf_mask(batch, pomo_size, problem+2)用于屏蔽不可选择的节点

self.mode[action1_bool_index] = 1

self.mode[action2_bool_index] = 2

self.mode[action3_bool_index] = 0

self.mode[self.finished] = 4

  • 更新 self.mode ,决定智能体在下一步的行为模式。
  • self.mode 形状为 (batch, pomo_size) ,每个智能体都有自己的模式。
  • 模式的作用:
    • 0 :正常选择下一步节点
    • 1 :执行“后悔”操作(回溯上一步)
    • 2 :某种特殊选择状态(如重新选择)
    • 4 :表示智能体已完成任务,不再进行选择

self.ninf_mask[:, :, 0][self.finished] = 0

self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float(‘-inf’)

  • 调整 self.ninf_mask ,确保已完成任务的智能体不会继续选择新节点。
  • self.ninf_mask 用于屏蔽不可选择的节点,形状为 (batch, pomo_size, problem+2)
    • -inf :表示该节点不能被选择。
    • 0 :表示该节点可以被选择。

逻辑

  • self.ninf_mask[:, :, 0][self.finished] = 0
    • 允许已完成的智能体访问配送中心(节点 0)。
    • self.finishedTrue 的智能体,其 ninf_mask 对应的 0 号索引会被设为 0 ,表示它们可以回到配送中心。
  • self.ninf_mask[:, :, self.problem_size + 1][self.finished] = float('-inf')
    • 禁止已完成的智能体选择 problem_size+1 (特殊状态)。
    • self.finishedTrue 的智能体,其 ninf_mask 对应的 problem_size+1 号索引会被设为 -inf ,表示它们不能选择该状态。

2.9 步骤状态更新:

# 更新步骤状态
self.step_state.selected_count = self.time_step
self.step_state.load = self.load
self.step_state.current_node = self.current_node
self.step_state.ninf_mask = self.ninf_mask
参数Shape含义
self.time_step标量(int)记录当前时间步(迭代次数)
self.step_state.selected_count标量(int)记录当前的时间步数
self.load(batch, pomo_size)当前智能体的负载
self.step_state.load(batch, pomo_size)存储当前负载状态
self.current_node(batch, pomo_size)记录当前每个智能体所处的节点
self.step_state.current_node(batch, pomo_size)存储当前智能体的位置信息
self.ninf_mask(batch, pomo_size, problem+2)记录哪些节点不能被选择
self.step_state.ninf_mask(batch, pomo_size, problem+2)存储当前的掩码信息

附录

代码:

    def step(self, selected):
        # selected.shape: (batch, pomo)

        #时间步数控制
        if self.time_step<4:

            # 控制时间步的递增
            self.time_step=self.time_step+1
            self.selectex_count = self.selected_count+1

            #判断是否在配送中心
            self.at_the_depot = (selected == 0)

            #特定时间步的操作
            if self.time_step==3:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
            if self.time_step == 4:
                self.last_current_node = self.current_node.clone()
                self.last_load = self.load.clone()
                self.visited_ninf_flag[:, :, self.problem_size+1][(~self.at_the_depot)&(self.last_current_node!=0)] = 0
            
            #更新当前节点和已选择节点列表
            self.current_node = selected
            self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)

            #更新需求和负载
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            gathering_index = selected[:, :, None]
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            self.load -= selected_demand
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记(防止重复选择已访问的节点)
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0  # depot is considered unvisited, unless you are AT the depot

            #更新负无穷掩码(屏蔽需求量超过当前负载的节点)
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            _2=torch.full((demand_too_large.shape[0],demand_too_large.shape[1],1),False)
            demand_too_large = torch.cat((demand_too_large, _2), dim=2)
            self.ninf_mask[demand_too_large] = float('-inf')

            #更新步骤状态,将更新后的状态同步到 self.step_state
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask


        #时间步大于等于 4 的复杂操作
        else:
            #动作模式分类
            action0_bool_index = ((self.mode == 0) & (selected != self.problem_size + 1))
            action1_bool_index = ((self.mode == 0) & (selected == self.problem_size + 1))  # regret
            action2_bool_index = self.mode == 1
            action3_bool_index = self.mode == 2
            
            action1_index = torch.nonzero(action1_bool_index)
            action2_index = torch.nonzero(action2_bool_index)

            action4_index = torch.nonzero((action3_bool_index & (self.current_node != 0)))

            #更新选择计数
            self.selected_count = self.selected_count+1
            #后悔模式
            self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] - 2

            #节点更新
            self.last_is_depot = (self.last_current_node == 0)

            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            self.last_current_node = self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()

            #更新已选择节点列表
            self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)

            #更新负载
            self.at_the_depot = (selected == 0)
            demand_list = self.depot_node_demand[:, None, :].expand(self.batch_size, self.pomo_size, -1)
            # shape: (batch, pomo, problem+1)
            _3 = torch.full((demand_list.shape[0], demand_list.shape[1], 1), 0)
            #扩展需求列表 demand_list 
            demand_list = torch.cat((demand_list, _3), dim=2)
            gathering_index = selected[:, :, None]
            # shape: (batch, pomo, 1)
            selected_demand = demand_list.gather(dim=2, index=gathering_index).squeeze(dim=2)
            _1 = self.last_load[action1_index[:, 0], action1_index[:, 1]].clone()
            self.last_load= self.load.clone()
            # shape: (batch, pomo)
            self.load -= selected_demand
            self.load[action1_index[:, 0], action1_index[:, 1]] = _1.clone()
            self.load[self.at_the_depot] = 1  # refill loaded at the depot

            #更新访问标记
            self.visited_ninf_flag[:, :, self.problem_size+1][self.last_is_depot] = 0
            self.visited_ninf_flag[self.BATCH_IDX, self.POMO_IDX, selected] = float('-inf')
            self.visited_ninf_flag[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.visited_ninf_flag[action4_index[:, 0], action4_index[:, 1], self.problem_size + 1] = float(0)
            self.visited_ninf_flag[:, :, self.problem_size+1][self.at_the_depot] = float('-inf')
            self.visited_ninf_flag[:, :, 0][~self.at_the_depot] = 0


            # 更新负无穷掩码
            self.ninf_mask = self.visited_ninf_flag.clone()
            round_error_epsilon = 0.00001
            demand_too_large = self.load[:, :, None] + round_error_epsilon < demand_list
            # shape: (batch, pomo, problem+1)
            self.ninf_mask[demand_too_large] = float('-inf')

            # 更新完成状态
            # 检查哪些智能体已经完成所有节点的访问。
            # 更新完成标记 self.finished。
            newly_finished = (self.visited_ninf_flag == float('-inf'))[:,:,:self.problem_size+1].all(dim=2)
            # shape: (batch, pomo)
            self.finished = self.finished + newly_finished
            # shape: (batch, pomo)

            #更新模式
            self.mode[action1_bool_index] = 1
            self.mode[action2_bool_index] = 2
            self.mode[action3_bool_index] = 0
            self.mode[self.finished] = 4

            # 更新完成后的掩码调整
            self.ninf_mask[:, :, 0][self.finished] = 0
            self.ninf_mask[:, :, self.problem_size+1][self.finished] = float('-inf')

            # 更新步骤状态
            self.step_state.selected_count = self.time_step
            self.step_state.load = self.load
            self.step_state.current_node = self.current_node
            self.step_state.ninf_mask = self.ninf_mask



        # returning valuesa
        done = self.finished.all()
        if done:
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done