目录

基于动手学强化学习的知识点五第-18-章-离线强化学习gym版本-0.26

基于“动手学强化学习”的知识点(五):第 18 章 离线强化学习(gym版本 >= 0.26)

摘要

本系列知识点讲解基于 中的内容进行详细的疑难点分析!具体内容请阅读 !


对应


SAC 算法部分

代码讲解对应

为了生成数据集,在倒立摆环境中从零开始训练一个在线 SAC 智能体,直到算法达到收敛效果,把训练过程中智能体采集的所有轨迹保存下来作为数据集。这样,数据集中既包含训练初期较差策略的采样,又包含训练后期较好策略的采样,是一个混合数据集。下面给出生成数据集的代码,SAC 部分直接使用 14.5 节中的代码,因此不再详细解释。——


import numpy as np
import gym
from tqdm import tqdm
import random
import rl_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt


# 传统的SAC算法
class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        normal_sample = dist.rsample()  # rsample()是重参数化采样
        log_prob = dist.log_prob(normal_sample)
        action = torch.tanh(normal_sample)
        # 计算tanh_normal分布的对数概率密度
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound
        return action, log_prob


class QValueNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        x = F.relu(self.fc2(x))
        return self.fc_out(x)


class SACContinuous:
    ''' 处理连续动作的SAC算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim,
                                         action_bound).to(device)  # 策略网络
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第一个Q网络
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 第二个Q网络
        self.target_critic_1 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第一个目标Q网络
        self.target_critic_2 = QValueNetContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 第二个目标Q网络
        # 令目标Q网络的初始参数和Q网络一样
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),
                                                   lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),
                                                   lr=critic_lr)
        # 使用alpha的log值,可以使训练结果比较稳定
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  #对alpha求梯度
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目标熵的大小
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        action = self.actor(state)[0]
        return [action.item()]

    def calc_target(self, rewards, next_states, dones):  # 计算目标Q值
        next_actions, log_prob = self.actor(next_states)
        entropy = -log_prob
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        next_value = torch.min(q1_value,
                               q2_value) + self.log_alpha.exp() * entropy
        td_target = rewards + self.gamma * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        rewards = (rewards + 8.0) / 8.0  # 对倒立摆环境的奖励进行重塑

        # 更新两个Q网络
        td_target = self.calc_target(rewards, next_states, dones)
        critic_1_loss = torch.mean(
            F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(
            F.mse_loss(self.critic_2(states, actions), td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # 更新策略网络
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy -
                                torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        alpha_loss = torch.mean(
            (entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 动作最大值
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)

actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # 软更新参数
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,
                      actor_lr, critic_lr, alpha_lr, target_entropy, tau,
                      gamma, device)

return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,
                                              replay_buffer, minimal_size,
                                              batch_size)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('SAC on {}'.format(env_name))
plt.show()

CQL 算法

CQL 总结与大函数意义

CQL(Conservative Q-Learning) 类的整体设计在 SAC 算法基础上增加了保守性正则项,用于减少 Q 函数过估计,改善离线 RL 表现。

  • 构造函数

    • 意义 :初始化 SAC 中的 actor、critic(两个及对应目标网络)、以及温度参数和优化器,并传入 CQL 特有的超参数 beta(正则系数)和 num_random(动作采样数)。
    • 输入 :状态、动作维度、各个学习率、目标熵、tau、gamma、设备、beta、num_random。
    • 输出 :构造好的 CQL 实例,准备进行更新。
  • take_action

    • 意义 :给定状态,利用 actor 网络采样动作,用于与环境交互。
    • 输入 :单个状态(例如 [0.1, 0.2, -0.1])。
    • 输出 :对应动作(例如 [0.8])。
  • soft_update

    • 意义 :平滑更新目标 Q 网络参数,保证训练稳定性。
    • 输入 :当前 Q 网络和对应目标网络。
    • 输出 :目标网络参数更新为旧值与当前值的线性组合。
  • update

    • 意义 :使用一批真实环境转移数据更新 actor、critic 网络和温度参数 (\alpha),同时在 SAC 基础上加入 CQL 正则项

      L CQL

      L critic + β ( log ⁡ ∑ exp ⁡ ( Q ( s , a ′ ) − log ⁡ π ref ( a ′ ) ) − E ( s , a ) ∼ D [ Q ( s , a ) ] ) L_{\text{CQL}} = L_{\text{critic}} + \beta \Big(\log\sum \exp(Q(s,a’) - \log \pi_{\text{ref}}(a’)) - \mathbb{E}_{(s,a) \sim D}[Q(s,a)]\Big)

      L

      CQL

      =

      L

      critic

β


(



lo
g



∑



exp

(

Q

(

s

,




a










′

)



−





lo
g




π











ref

​


(


a










′

))



−






E










(

s

,

a

)

∼

D

​


[

Q

(

s

,



a

)]


)
  
其中通过对随机动作、策略动作和下一动作进行采样,计算 logsumexp 的项。
  • 步骤
    1. 计算 TD 目标

      y

      r + γ ( min ⁡ ( Q 1 ′ , Q 2 ′ ) + α ⋅ entropy ) y = r + \gamma ( \min(Q_1’, Q_2’) + \alpha \cdot \text{entropy})

      y

      =

      r

   γ

   (

   min

   (


   Q









   1






   ′

   ​


   ,




   Q









   2






   ′

   ​


   )



   +





   α



   ⋅






   entropy

   )
   。
2. 分别计算 critic_1_loss 与 critic_2_loss(均方误差损失)。
3. 额外采样一批随机动作(均匀分布),以及策略生成的当前和下一动作,对各个 Q 网络的输出进行 logsumexp 操作,形成 CQL 正则项;
4. 总的 critic 损失 = SAC critic loss +

   β
   \beta





   β
   ×(CQL 正则项差值)。
5. 分别更新 critic_1 与 critic_2;
6. 更新 actor,使其最大化

   min
   ⁡
   (
   Q
   1
   ,
   Q
   2
   )
   −
   α
   log
   ⁡
   π
   (
   a
   ∣
   s
   )
   \min(Q_1, Q_2) - \alpha \log\pi(a|s)





   min

   (


   Q









   1

   ​


   ,




   Q









   2

   ​


   )



   −





   α



   lo
   g



   π

   (

   a

   ∣

   s

   )
   ;
7. 更新

   α
   \alpha





   α
   使策略熵接近目标熵;
8. 最后对目标网络进行软更新。
  • 输入
    • transition_dict 包含 ‘states’, ‘actions’, ‘rewards’, ‘next_states’, ‘dones’。
  • 输出
    • 模型参数更新后,策略和 Q 网络得到改进。

CQL 总结

CQL 类在 SAC 框架下增加了保守性正则项,以降低 Q 函数对未见动作过高的估计,从而实现更为保守的 Q 学习。主要流程包括:

  • 网络初始化 :构造策略网络、两个 Q 网络、对应目标网络,温度参数 (\alpha) 及其优化器,以及 SAC 固有参数(gamma、tau、目标熵)和 CQL 超参数(beta、num_random)。
  • 动作采样 (take_action):给定状态通过 actor 网络采样动作。
  • 软更新 (soft_update):平滑更新目标网络参数。
  • 更新过程 (update):
    1. 从 transition_dict 中读取数据并转换为张量;
    2. 计算 SAC 的 TD 目标和 critic 损失;
    3. 额外采样随机动作及策略动作,并计算 CQL 正则项(通过 logsumexp);
    4. 总的 critic 损失 = SAC critic loss + beta*(正则项差值);
    5. 更新 critic 网络、策略网络及温度参数 (\alpha);
    6. 最后软更新目标网络参数。

CQL 类详细分析

class CQL:
    ''' CQL算法 '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device, beta, num_random):
        """
        定义 CQL 类构造函数,接收状态、隐藏、动作维度、动作范围,
        以及各优化器学习率、目标熵、tau、gamma、
        设备、CQL正则系数 beta 和随机采样数 num_random。
        """
        '''创建策略网络(actor),用于输出动作分布并采样连续动作。'''
        self.actor = PolicyNetContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)
        '''创建两个 Q 网络,分别用于评估 (state,action) 对的价值,帮助降低过估计偏差。'''
        self.critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''创建目标 Q 网络,用于计算 TD 目标,使得训练更稳定。'''
        self.target_critic_1 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.target_critic_2 = QValueNetContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''复制 critic 网络参数到目标网络,保证初始时一致。'''
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        '''为策略网络分配 Adam 优化器,学习率 actor_lr。'''
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        '''分别为两个 Q 网络创建 Adam 优化器。'''
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        '''初始化温度参数的对数,log_alpha = log(0.01) ≈ -4.6052。'''
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        '''使 log_alpha 可求梯度,从而在训练过程中自动调整 alpha=exp({log_alpha})。'''
        self.log_alpha.requires_grad = True  #对alpha求梯度
        '''为 log_alpha 创建 Adam 优化器,学习率 alpha_lr。'''
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
        '''保存目标熵,用于温度调整。'''
        self.target_entropy = target_entropy  # 目标熵的大小
        '''保存折扣因子 gamma。'''
        self.gamma = gamma
        '''保存软更新系数 tau,用于目标网络更新。'''
        self.tau = tau
        '''保存 CQL 损失函数中的系数 beta,用于权衡 Q 网络额外正则项。'''
        self.beta = beta  # CQL损失函数中的系数
        '''保存 CQL 中在计算正则项时所采样的随机动作数,用于估计动作分布上界。'''
        self.num_random = num_random  # CQL中的动作采样数

    def take_action(self, state):
        """定义策略执行接口,给定单个状态,输出对应动作。"""
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(device)
        # action = self.actor(state)[0]
        # log_prob = self.actor(state)[1]
        action, log_prob = self.actor(state)
        return [action.item()]

    def soft_update(self, net, target_net):
        """对目标网络进行软更新,即用当前网络参数更新目标网络参数。"""
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)

    def update(self, transition_dict):
        """
        定义策略与 Q 网络的更新过程,使用从环境采集的经验数据更新所有网络参数,同时计算 CQL 的额外正则项。
        """
        '''从transition_dict中提取数据'''
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(device)
        actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(device)
        '''对奖励进行归一化处理,这里将倒立摆环境的奖励平移与缩放,使得奖励范围更加稳定。'''
        rewards = (rewards + 8.0) / 8.0  # 对倒立摆环境的奖励进行重塑
        '''对所有下一状态通过 actor 得到下一动作及其对数概率。'''
        next_actions, log_prob = self.actor(next_states)
        '''计算熵项,即熵 = - log_prob。'''
        entropy = -log_prob
        '''用目标 Q 网络计算下一状态下的 Q 值估计。'''
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        '''计算下一时刻的价值估计,选择较小的 Q 值(双 Q 机制),并加入熵正则项 α⋅entropy。'''
        next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy
        '''
        计算 TD 目标,若 done 为 1(终止状态)则不加折扣。
        td_target = 当前的即时奖励 + gamma*下一阶段的奖励
        '''
        td_target = rewards + self.gamma * next_value * (1 - dones)
        '''
        分别计算两个 Q 网络的均方误差(MSE)损失,相对于 td_target(detach防止梯度流向目标)。
        '''
        critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))

        # 以上与SAC相同,以下Q网络更新是CQL的额外部分
        '''取当前批次样本数量。'''
        batch_size = states.shape[0]
        '''
        作用:
        - 生成一组随机均匀分布的动作,形状为 (batch_size*num_random, action_dim),范围在 [-1,1]。
        - 用于 CQL 正则项计算,作为额外的动作样本。
        数值例子:
        - 若 batch_size=64, num_random=10, action_dim=1,则生成形状为 (640,1) 的随机动作,
          如 [[0.23], [-0.45], …]。
        '''
        random_unif_actions = torch.rand([batch_size * self.num_random, actions.shape[-1]], dtype=torch.float).uniform_(-1, 1).to(device)
        '''
        作用:
        计算均匀分布的对数概率密度。
        - 由于在连续区间 [−1,1] 中,均匀密度为 1/2 对于每个动作维度,因此对数概率为 log(0.5);
        - 对于 action_dim 个维度,总和为 log(0.5^action_dim)。
        '''
        random_unif_log_pi = np.log(0.5**next_actions.shape[-1])
        # random_unif_log_pi = np.log(0.5) * next_actions.shape[-1]
        '''
        作用:增加数据集
        - 将 states 增加一个新维度,重复 num_random 次,使得每个样本重复 num_random 次,
          最后 reshape 成 (batch_size*num_random, state_dim)。
        数值例子:
        - 若 states shape=(64,3), 
          unsqueeze后变为 (64,1,3), 
          repeat 后变为 (64,10,3), view 后为 (640,3).
        '''
        tmp_states = states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, states.shape[-1])
        tmp_next_states = next_states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, next_states.shape[-1])
        '''用策略网络对 tmp_states(重复后的真实状态)采样动作,得到随机采样的当前动作及其对数概率。'''
        random_curr_actions, random_curr_log_pi = self.actor(tmp_states)
        '''同理,对 tmp_next_states 采样下一步动作和对应对数概率。'''
        random_next_actions, random_next_log_pi = self.actor(tmp_next_states)
        '''用 critic 网络计算对于 tmp_states 与随机均匀动作 random_unif_actions 的 Q 值,之后 reshape 成 (batch_size, num_random, 1)。'''
        q1_unif = self.critic_1(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        q2_unif = self.critic_2(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        '''用 critic 网络计算对于 tmp_states 与随机当前动作的 Q 值,并 reshape。'''
        q1_curr = self.critic_1(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        q2_curr = self.critic_2(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        '''用 critic 网络计算对于 tmp_states 与随机下一动作的 Q 值,并 reshape。'''
        q1_next = self.critic_1(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        q2_next = self.critic_2(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        '''
        作用:
        - 将三部分 Q 值进行拼接:
          1. 对于随机均匀采样动作:减去其对数概率(固定值);
          2. 对于策略采样的当前动作:减去对应对数概率(detach 后避免梯度传递);
          3. 对于策略采样的下一动作:同理。
        - 拼接维度为第 1 维(即动作样本维度)。
        数值例子:
        - 假设每个部分形状为 (64,10,1),拼接后 q1_cat 形状为 (64,30,1)。
        在 Conservative Q-Learning (CQL) 中,为了防止 Q 函数在离线数据之外的动作上过高估计,
        会在 critic 损失中增加一个正则项。正则项的思想是“保守地”估计 Q 值,使得对于未见过的动作,
        Q 值不至于过高。具体来说,CQL 的正则项大致形式为:
                            Penalty=logE_{a∼μ}[exp(Q(s,a)−logμ(a))]−E_{(s,a)∼D}[Q(s,a)]
        '''
        q1_cat = torch.cat([
            q1_unif - random_unif_log_pi,
            q1_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q1_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        q2_cat = torch.cat([
            q2_unif - random_unif_log_pi,
            q2_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q2_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        '''
        作用:
        - 对 q1_cat 和 q2_cat 在动作维度上做 logsumexp 操作,再取平均,得到一个标量,
          表示对所有随机动作样本的软最大值。
        - logsumexp 是一种平滑的最大值函数,计算公式:
                            log∑_iexp(xi)
        '''
        qf1_loss_1 = torch.logsumexp(q1_cat, dim=1).mean()
        qf2_loss_1 = torch.logsumexp(q2_cat, dim=1).mean()
        '''分别计算 critic 网络对于当前真实 (states, actions) 的 Q 值的均值。'''
        qf1_loss_2 = self.critic_1(states, actions).mean()
        qf2_loss_2 = self.critic_2(states, actions).mean()
        '''将原本 SAC 中的 critic 损失(均方误差)与 CQL 正则项相加,构成最终 critic 损失。'''
        qf1_loss = critic_1_loss + self.beta * (qf1_loss_1 - qf1_loss_2)
        qf2_loss = critic_2_loss + self.beta * (qf2_loss_1 - qf2_loss_2)
        '''
        对 critic_1 的损失进行反向传播更新。
        对 critic_2 的损失进行反向传播更新。
        '''
        self.critic_1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        qf2_loss.backward(retain_graph=True)
        self.critic_2_optimizer.step()

        # 更新策略网络
        '''使用当前策略对真实状态采样动作,并获得对应对数概率。
        self.actor未更新'''
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        '''
        评估当前策略生成的动作的 Q 值,使用 critic_1 和 critic_2 分别计算,并取最小值以降低过估计。
        self.critic_1和self.critic_2是刚更新过的
        '''
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        '''
        作用:
        - 计算策略损失,目标是最大化 min(𝑄1,𝑄2)−𝛼log𝜋;这里取负号后作为损失。
        '''
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新alpha值
        '''
        作用:
        - 计算温度参数 α 的损失,目标是使策略熵接近目标熵。
        - detach() 表示不对 entropy 反向传播梯度,只更新 alpha。
        SAC原始论文中:
                            J(α)=E_{a∼π}[−α(logπ(a∣s)+Htarget)],其中H=−logπ(a∣s)
        '''
        alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


random.seed(0)
np.random.seed(0)

if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)

beta = 5.0
num_random = 5
num_epochs = 100
num_trains_per_epoch = 500

agent = CQL(state_dim, hidden_dim, action_dim, action_bound, actor_lr,
            critic_lr, alpha_lr, target_entropy, tau, gamma, device, beta,
            num_random)

return_list = []
for i in range(10):
    with tqdm(total=int(num_epochs / 10), desc='Iteration %d' % i) as pbar:
        for i_epoch in range(int(num_epochs / 10)):
            # 此处与环境交互只是为了评估策略,最后作图用,不会用于训练
            epoch_return = 0
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                result = env.step(action)
                if len(result) == 5:
                    next_state, reward, done, truncated, info = result
                    done = done or truncated  # 可合并 terminated 和 truncated 标志
                else:
                    next_state, reward, done, info = result
                # next_state, reward, done, _ = env.step(action)
                state = next_state
                epoch_return += reward
            return_list.append(epoch_return)

            for _ in range(num_trains_per_epoch):
                b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                transition_dict = {
                    'states': b_s,
                    'actions': b_a,
                    'next_states': b_ns,
                    'rewards': b_r,
                    'dones': b_d
                }
                agent.update(transition_dict)

            if (i_epoch + 1) % 10 == 0:
                pbar.set_postfix({
                    'epoch':
                    '%d' % (num_epochs / 10 * i + i_epoch + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)


epochs_list = list(range(len(return_list)))
plt.plot(epochs_list, return_list)
plt.xlabel('Epochs')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('CQL on {}'.format(env_name))
plt.show()