-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathppo_agent.py
316 lines (261 loc) · 12.8 KB
/
ppo_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import copy
import os
from collections import namedtuple
from typing import Union
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
# Hyperparameters
from policy.learning_policy.learning_policy import LearningPolicy
from policy.learning_policy.replay_buffer import ReplayBuffer
# Proximal Policy Optimization (PPO)
# https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html
class EpisodeBuffers:
def __init__(self):
self.reset()
def __len__(self):
"""Return the current size of internal memory."""
return len(self.memory)
def reset(self):
self.memory = {}
def get_transitions(self, handle):
return self.memory.get(handle, [])
def push_transition(self, handle, transition):
transitions = self.get_transitions(handle)
transitions.append(transition)
self.memory.update({handle: transitions})
class ActorCriticModel(nn.Module):
def __init__(self, state_size, action_size, device, hidsize1=512, hidsize2=256):
super(ActorCriticModel, self).__init__()
self.device = device
self.actor = nn.Sequential(
nn.Linear(state_size, hidsize1),
nn.Tanh(),
nn.Linear(hidsize1, hidsize2),
nn.Tanh(),
nn.Linear(hidsize2, action_size),
nn.Softmax(dim=-1)
).to(self.device)
self.critic = nn.Sequential(
nn.Linear(state_size, hidsize1),
nn.Tanh(),
nn.Linear(hidsize1, hidsize2),
nn.Tanh(),
nn.Linear(hidsize2, 1)
).to(self.device)
def forward(self, x):
raise NotImplementedError
def get_actor_dist(self, state):
action_probs = self.actor(state)
dist = Categorical(action_probs)
return dist
def evaluate(self, states, actions):
action_probs = self.actor(states)
dist = Categorical(action_probs)
action_logprobs = dist.log_prob(actions)
dist_entropy = dist.entropy()
state_value = self.critic(states)
return action_logprobs, torch.squeeze(state_value), dist_entropy
def save(self, filename):
# print("Saving model from checkpoint:", filename)
torch.save(self.actor.state_dict(), filename + ".actor")
torch.save(self.critic.state_dict(), filename + ".value")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=self.device))
except:
print(" >> failed!")
return obj
def load(self, filename):
print("load model from file", filename)
self.actor = self._load(self.actor, filename + ".actor")
self.critic = self._load(self.critic, filename + ".value")
PPO_Param = namedtuple('PPO_Param',
['hidden_size', 'buffer_size', 'batch_size', 'learning_rate',
'discount', 'buffer_min_size', 'use_replay_buffer', 'use_gpu'])
class PPOPolicy(LearningPolicy):
def __init__(self,
state_size: int,
action_size: int,
in_parameters: Union[PPO_Param, None] = None):
super(PPOPolicy, self).__init__()
self.state_size = state_size
self.action_size = action_size
# parameters
self.ppo_parameters = in_parameters
if self.ppo_parameters is not None:
self.hidden_size = self.ppo_parameters.hidden_size
self.buffer_size = self.ppo_parameters.buffer_size
self.batch_size = self.ppo_parameters.batch_size
self.learning_rate = self.ppo_parameters.learning_rate
self.use_replay_buffer = self.ppo_parameters.use_replay_buffer
self.discount = self.ppo_parameters.discount
else:
self.hidden_size = 128
self.learning_rate = 1.0e-3
self.discount = 0.95
self.buffer_size = 32_000
self.batch_size = 1024
self.use_replay_buffer = True
self.device = torch.device("cpu")
# Device
if self.ppo_parameters.use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
print("🐇 Using GPU")
else:
self.device = torch.device("cpu")
print("🐢 Using CPU")
self.surrogate_eps_clip = 0.1
self.K_epoch = 10
self.weight_loss = 0.5
self.weight_entropy = 0.01
self.buffer_min_size = 0
self.current_episode_memory = EpisodeBuffers()
self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
self.loss = 0
self.actor_critic_model = ActorCriticModel(state_size, action_size, self.device,
hidsize1=self.hidden_size,
hidsize2=self.hidden_size)
self.optimizer = optim.AdamW(self.actor_critic_model.parameters(), lr=self.learning_rate)
self.loss_function = nn.MSELoss() # nn.SmoothL1Loss()
def get_name(self):
return self.__class__.__name__
def act(self, handle, state, eps=None):
# sample a action to take
torch_state = torch.tensor(state, dtype=torch.float).to(self.device)
dist = self.actor_critic_model.get_actor_dist(torch_state)
action = dist.sample()
return action.item()
def step(self, handle, state, action, reward, next_state, done):
# record transitions ([state] -> [action] -> [reward, next_state, done])
torch_action = torch.tensor(action, dtype=torch.float).to(self.device)
torch_state = torch.tensor(state, dtype=torch.float).to(self.device)
# evaluate actor
dist = self.actor_critic_model.get_actor_dist(torch_state)
action_logprobs = dist.log_prob(torch_action)
transition = (state, action, reward, next_state, action_logprobs.item(), done)
self.current_episode_memory.push_transition(handle, transition)
def _push_transitions_to_replay_buffer(self,
state_list,
action_list,
reward_list,
state_next_list,
done_list,
prob_a_list):
for idx in range(len(reward_list)):
state_i = state_list[idx]
action_i = action_list[idx]
reward_i = reward_list[idx]
state_next_i = state_next_list[idx]
done_i = done_list[idx]
prob_action_i = prob_a_list[idx]
self.memory.add(state_i, action_i, reward_i, state_next_i, done_i, prob_action_i)
def _convert_transitions_to_torch_tensors(self, transitions_array):
# build empty lists(arrays)
state_list, action_list, reward_list, state_next_list, prob_a_list, done_list = [], [], [], [], [], []
# set discounted_reward to zero
discounted_reward = 0
for transition in transitions_array[::-1]:
state_i, action_i, reward_i, state_next_i, prob_action_i, done_i = transition
state_list.insert(0, state_i)
action_list.insert(0, action_i)
if done_i:
discounted_reward = 0
done_list.insert(0, 1)
else:
done_list.insert(0, 0)
discounted_reward = reward_i + self.discount * discounted_reward
reward_list.insert(0, discounted_reward)
state_next_list.insert(0, state_next_i)
prob_a_list.insert(0, prob_action_i)
if self.use_replay_buffer:
self._push_transitions_to_replay_buffer(state_list, action_list,
reward_list, state_next_list,
done_list, prob_a_list)
# convert data to torch tensors
states, actions, rewards, states_next, dones, prob_actions = \
torch.tensor(state_list, dtype=torch.float).to(self.device), \
torch.tensor(action_list).to(self.device), \
torch.tensor(reward_list, dtype=torch.float).to(self.device), \
torch.tensor(state_next_list, dtype=torch.float).to(self.device), \
torch.tensor(done_list, dtype=torch.float).to(self.device), \
torch.tensor(prob_a_list).to(self.device)
return states, actions, rewards, states_next, dones, prob_actions
def _get_transitions_from_replay_buffer(self, states, actions, rewards, states_next, dones, probs_action):
if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
states, actions, rewards, states_next, dones, probs_action = self.memory.sample()
actions = torch.squeeze(actions)
rewards = torch.squeeze(rewards)
states_next = torch.squeeze(states_next)
dones = torch.squeeze(dones)
probs_action = torch.squeeze(probs_action)
return states, actions, rewards, states_next, dones, probs_action
def train_net(self):
# All agents have to propagate their experiences made during past episode
for handle in range(len(self.current_episode_memory)):
# Extract agent's episode history (list of all transitions)
agent_episode_history = self.current_episode_memory.get_transitions(handle)
if len(agent_episode_history) > 0:
# Convert the replay buffer to torch tensors (arrays)
states, actions, rewards, states_next, dones, probs_action = \
self._convert_transitions_to_torch_tensors(agent_episode_history)
# Optimize policy for K epochs:
for k_loop in range(int(self.K_epoch)):
if self.use_replay_buffer:
states, actions, rewards, states_next, dones, probs_action = \
self._get_transitions_from_replay_buffer(
states, actions, rewards, states_next, dones, probs_action
)
# Evaluating actions (actor) and values (critic)
logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
# Finding the ratios (pi_thetas / pi_thetas_replayed):
ratios = torch.exp(logprobs - probs_action.detach())
# Finding Surrogate Loos
advantages = rewards - state_values.detach()
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1. - self.surrogate_eps_clip, 1. + self.surrogate_eps_clip) * advantages
# The loss function is used to estimate the gardient and use the entropy function based
# heuristic to penalize the gradient function when the policy becomes deterministic this would let
# the gradient becomes very flat and so the gradient is no longer useful.
loss = \
-torch.min(surr1, surr2) \
+ self.weight_loss * self.loss_function(state_values, rewards) \
- self.weight_entropy * dist_entropy
# Make a gradient step
self.optimizer.zero_grad()
loss.mean().backward()
self.optimizer.step()
# Transfer the current loss to the agents loss (information) for debug purpose only
self.loss = loss.mean().detach().cpu().numpy()
# Reset all collect transition data
self.current_episode_memory.reset()
def end_episode(self, train):
if train:
self.train_net()
# Checkpointing methods
def save(self, filename):
# print("Saving model from checkpoint:", filename)
self.actor_critic_model.save(filename)
torch.save(self.optimizer.state_dict(), filename + ".optimizer")
def _load(self, obj, filename):
if os.path.exists(filename):
print(' >> ', filename)
try:
obj.load_state_dict(torch.load(filename, map_location=self.device))
except:
print(" >> failed!")
else:
print(" >> file not found!")
return obj
def load(self, filename):
self.actor_critic_model.load(filename)
self.optimizer = self._load(self.optimizer, filename + ".optimizer")
print('{} -> load {} ok'.format(self.get_name(), filename))
def clone(self):
policy = PPOPolicy(self.state_size, self.action_size, self.ppo_parameters)
policy.actor_critic_model = copy.deepcopy(self.actor_critic_model)
policy.optimizer = copy.deepcopy(self.optimizer)
return policy