详解 DDPG 模型及其 Pytorch 完整代码
DDPG(Deep Deterministic Policy Gradient)是一种用于解决连续动作空间的深度强化学习算法。它结合了Actor-Critic框架、策略梯度和深度学习技术。DDPG适用于模型自由环境,其中传统的Q-learning方法难以应用。
DDPG由以下四个关键角色组成:
1. Actor 网络:负责选择动作。输入状态,输出特定策略下选择的动作。
2. Critic 网络:对Actor的动作进行评价。输入状态和动作,输出Q值。
3. 目标 Actor-Critic 网络:用于稳定学习过程,它们是Actor和Critic网络的延迟复制品。
4. 经验回放(Replay Buffer):用于存储经验,以打破数据相关性并提高学习效果。
以下是DDPG的关键步骤:
- 初始化Actor和Critic网络以及它们的对应目标网络。
- 初始化Replay Buffer。
- 在每个时间步:
- 使用Actor网络选择动作,并在环境中执行。
- 将转移(状态、动作、奖励、新状态)存入Replay Buffer。
- 从Replay Buffer中随机采样一个小批量转移。
- 使用Critic网络最小化关于Q值的损失。
- 使用策略梯度方法最小化关于Actor的策略损失。
- 软更新目标网络参数。
下面是使用PyTorch实现DDPG的一个简化的代码示例。为了代码清晰,许多实际应用中的细节被简化或省略:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque
import random
# Actor Network
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super(Actor, self).__init__()
self.layer1 = nn.Linear(state_dim, 400)
self.layer2 = nn.Linear(400, 300)
self.layer3 = nn.Linear(300, action_dim)
self.max_action = max_action
def forward(self, state):
a = torch.relu(self.layer1(state))
a = torch.relu(self.layer2(a))
return self.max_action * torch.tanh(self.layer3(a))
# Critic Network
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super(Critic, self).__init__()
self.layer1 = nn.Linear(state_dim + action_dim, 400)
self.layer2 = nn.Linear(400, 300)
self.layer3 = nn.Linear(300, 1)
def forward(self, state, action):
q = torch.relu(self.layer1(torch.cat([state, action], 1)))
q = torch.relu(self.layer2(q))
return self.layer3(q)
# Replay Buffer
class ReplayBuffer:
def __init__(self, max_size=1000000):
self.buffer = deque(maxlen=max_size)
def add(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)
def size(self):
return len(self.buffer)
# DDPG Class
class DDPG:
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = Critic(state_dim, action_dim).to(device)
self.critic_target.load_state_dict(self.critic.state_dict())
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
self.max_action = max_action
self.gamma = 0.99
self.tau = 0.005
self.replay_buffer = ReplayBuffer()
def select_action(self, state):
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
return self.actor(state).cpu().data.numpy().flatten()
def train(self, iterations, batch_size=64):
for _ in range(iterations):
# Sample a batch of transitions from the buffer.
state, action, reward, next_state, done = \
self.replay_buffer.sample(batch_size)
state = torch.FloatTensor(state).to(device)
action = torch.FloatTensor(action).to(device)
reward = torch.FloatTensor(reward).reshape(-1, 1).to(device)
next_state = torch.FloatTensor(next_state).to(device)
done = torch.FloatTensor(done).reshape(-1, 1).to(device)
# Compute the target Q value
target_Q = self.critic_target(next_state, self.actor_target(next_state))
target_Q = reward + ((1 - done) * self.gamma * target_Q).detach()
# Optimize the Critic
current_Q = self.critic(state, action)
critic_loss = nn.MSELoss()(current_Q, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# Compute actor loss
actor_loss = -self.critic(state, self.actor(state)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Soft update the target networks
for target_param, param in zip(self.actor_target.parameters(), self.actor.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make("Pendulum-v0")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = env.action_space.high[0]
ddpg = DDPG(state_dim, action_dim, max_action)
total_episodes = 100
for episode in range(total_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action = ddpg.select_action(np.array(state))
next_state, reward, done, _ = env.step(action)
ddpg.replay_buffer.add(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
if ddpg.replay_buffer.size() > 1000:
ddpg.train(50)
print(f"Episode: {episode}, Reward: {episode_reward}")
注意事项:
- 该示例使用gym
库的Pendulum-v0
环境,需要根据具体问题选择合适的环境。
- 在实际应用中,为了获得更稳定的结果,可能需要对网络参数初始化、超参数选择、探索策略等进行调整。
- 由于由于代码需用到GPU,因此需要检查是否可用。
希望这段代码和解释能帮助你理解DDPG算法及其实现。