Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
marl-flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manavsinghal157
marl-flatland
Commits
de66e9ba
Commit
de66e9ba
authored
4 years ago
by
Adrian Egli
Browse files
Options
Downloads
Patches
Plain Diff
.
parent
98d00d0b
No related branches found
Branches containing commit
Tags
submission-v0.6
Tags containing commit
No related merge requests found
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/ppo/agent.py
+106
-0
106 additions, 0 deletions
src/ppo/agent.py
src/ppo/model.py
+20
-0
20 additions, 0 deletions
src/ppo/model.py
src/ppo/replay_memory.py
+53
-0
53 additions, 0 deletions
src/ppo/replay_memory.py
with
179 additions
and
0 deletions
src/ppo/agent.py
0 → 100644
+
106
−
0
View file @
de66e9ba
import
pickle
import
torch
# from model import PolicyNetwork
# from replay_memory import Episode, ReplayBuffer
from
torch.distributions.categorical
import
Categorical
from
src.ppo.model
import
PolicyNetwork
from
src.ppo.replay_memory
import
Episode
,
ReplayBuffer
BUFFER_SIZE
=
32_000
BATCH_SIZE
=
4096
GAMMA
=
0.98
LR
=
0.5e-4
CLIP_FACTOR
=
.
005
UPDATE_EVERY
=
30
device
=
torch
.
device
(
"
cuda:0
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
print
(
"
device:
"
,
device
)
class
Agent
:
def
__init__
(
self
,
state_size
,
action_size
,
num_agents
):
self
.
policy
=
PolicyNetwork
(
state_size
,
action_size
).
to
(
device
)
self
.
old_policy
=
PolicyNetwork
(
state_size
,
action_size
).
to
(
device
)
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
policy
.
parameters
(),
lr
=
LR
)
self
.
episodes
=
[
Episode
()
for
_
in
range
(
num_agents
)]
self
.
memory
=
ReplayBuffer
(
BUFFER_SIZE
)
self
.
t_step
=
0
def
reset
(
self
):
self
.
finished
=
[
False
]
*
len
(
self
.
episodes
)
# Decide on an action to take in the environment
def
act
(
self
,
state
,
eps
=
None
):
self
.
policy
.
eval
()
with
torch
.
no_grad
():
output
=
self
.
policy
(
torch
.
from_numpy
(
state
).
float
().
unsqueeze
(
0
).
to
(
device
))
return
Categorical
(
output
).
sample
().
item
()
# Record the results of the agent's action and update the model
def
step
(
self
,
handle
,
state
,
action
,
next_state
,
agent_done
,
episode_done
,
collision
):
if
not
self
.
finished
[
handle
]:
if
agent_done
:
reward
=
1
elif
collision
:
reward
=
-
.
5
else
:
reward
=
0
# Push experience into Episode memory
self
.
episodes
[
handle
].
push
(
state
,
action
,
reward
,
next_state
,
agent_done
or
episode_done
)
# When we finish the episode, discount rewards and push the experience into replay memory
if
agent_done
or
episode_done
:
self
.
episodes
[
handle
].
discount_rewards
(
GAMMA
)
self
.
memory
.
push_episode
(
self
.
episodes
[
handle
])
self
.
episodes
[
handle
].
reset
()
self
.
finished
[
handle
]
=
True
# Perform a gradient update every UPDATE_EVERY time steps
self
.
t_step
=
(
self
.
t_step
+
1
)
%
UPDATE_EVERY
if
self
.
t_step
==
0
and
len
(
self
.
memory
)
>
BATCH_SIZE
*
4
:
self
.
learn
(
*
self
.
memory
.
sample
(
BATCH_SIZE
,
device
))
def
learn
(
self
,
states
,
actions
,
rewards
,
next_state
,
done
):
self
.
policy
.
train
()
responsible_outputs
=
torch
.
gather
(
self
.
policy
(
states
),
1
,
actions
)
old_responsible_outputs
=
torch
.
gather
(
self
.
old_policy
(
states
),
1
,
actions
).
detach
()
# rewards = rewards - rewards.mean()
ratio
=
responsible_outputs
/
(
old_responsible_outputs
+
1e-5
)
clamped_ratio
=
torch
.
clamp
(
ratio
,
1.
-
CLIP_FACTOR
,
1.
+
CLIP_FACTOR
)
loss
=
-
torch
.
min
(
ratio
*
rewards
,
clamped_ratio
*
rewards
).
mean
()
# Compute loss and perform a gradient step
self
.
old_policy
.
load_state_dict
(
self
.
policy
.
state_dict
())
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
# Checkpointing methods
def
save
(
self
,
path
,
*
data
):
torch
.
save
(
self
.
policy
.
state_dict
(),
path
/
'
ppo/model_checkpoint.policy
'
)
torch
.
save
(
self
.
optimizer
.
state_dict
(),
path
/
'
ppo/model_checkpoint.optimizer
'
)
with
open
(
path
/
'
ppo/model_checkpoint.meta
'
,
'
wb
'
)
as
file
:
pickle
.
dump
(
data
,
file
)
def
load
(
self
,
path
,
*
defaults
):
try
:
print
(
"
Loading model from checkpoint...
"
)
print
(
path
+
'
ppo/model_checkpoint.policy
'
)
self
.
policy
.
load_state_dict
(
torch
.
load
(
path
+
'
ppo/model_checkpoint.policy
'
,
map_location
=
torch
.
device
(
'
cpu
'
)))
self
.
optimizer
.
load_state_dict
(
torch
.
load
(
path
+
'
ppo/model_checkpoint.optimizer
'
,
map_location
=
torch
.
device
(
'
cpu
'
)))
with
open
(
path
+
'
ppo/model_checkpoint.meta
'
,
'
rb
'
)
as
file
:
return
pickle
.
load
(
file
)
except
:
print
(
"
No checkpoint file was found
"
)
return
defaults
This diff is collapsed.
Click to expand it.
src/ppo/model.py
0 → 100644
+
20
−
0
View file @
de66e9ba
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
PolicyNetwork
(
nn
.
Module
):
def
__init__
(
self
,
state_size
,
action_size
,
hidsize1
=
128
,
hidsize2
=
128
,
hidsize3
=
32
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
state_size
,
hidsize1
)
self
.
fc2
=
nn
.
Linear
(
hidsize1
,
hidsize2
)
# self.fc3 = nn.Linear(hidsize2, hidsize3)
self
.
output
=
nn
.
Linear
(
hidsize2
,
action_size
)
self
.
softmax
=
nn
.
Softmax
(
dim
=
1
)
self
.
bn0
=
nn
.
BatchNorm1d
(
state_size
,
affine
=
False
)
def
forward
(
self
,
inputs
):
x
=
self
.
bn0
(
inputs
.
float
())
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
# x = F.relu(self.fc3(x))
return
self
.
softmax
(
self
.
output
(
x
))
This diff is collapsed.
Click to expand it.
src/ppo/replay_memory.py
0 → 100644
+
53
−
0
View file @
de66e9ba
import
torch
import
random
import
numpy
as
np
from
collections
import
namedtuple
,
deque
,
Iterable
Transition
=
namedtuple
(
"
Experience
"
,
(
"
state
"
,
"
action
"
,
"
reward
"
,
"
next_state
"
,
"
done
"
))
class
Episode
:
memory
=
[]
def
reset
(
self
):
self
.
memory
=
[]
def
push
(
self
,
*
args
):
self
.
memory
.
append
(
tuple
(
args
))
def
discount_rewards
(
self
,
gamma
):
running_add
=
0.
for
i
,
(
state
,
action
,
reward
,
*
rest
)
in
list
(
enumerate
(
self
.
memory
))[::
-
1
]:
running_add
=
running_add
*
gamma
+
reward
self
.
memory
[
i
]
=
(
state
,
action
,
running_add
,
*
rest
)
class
ReplayBuffer
:
def
__init__
(
self
,
buffer_size
):
self
.
memory
=
deque
(
maxlen
=
buffer_size
)
def
push
(
self
,
state
,
action
,
reward
,
next_state
,
done
):
self
.
memory
.
append
(
Transition
(
np
.
expand_dims
(
state
,
0
),
action
,
reward
,
np
.
expand_dims
(
next_state
,
0
),
done
))
def
push_episode
(
self
,
episode
):
for
step
in
episode
.
memory
:
self
.
push
(
*
step
)
def
sample
(
self
,
batch_size
,
device
):
experiences
=
random
.
sample
(
self
.
memory
,
k
=
batch_size
)
states
=
torch
.
from_numpy
(
self
.
stack
([
e
.
state
for
e
in
experiences
])).
float
().
to
(
device
)
actions
=
torch
.
from_numpy
(
self
.
stack
([
e
.
action
for
e
in
experiences
])).
long
().
to
(
device
)
rewards
=
torch
.
from_numpy
(
self
.
stack
([
e
.
reward
for
e
in
experiences
])).
float
().
to
(
device
)
next_states
=
torch
.
from_numpy
(
self
.
stack
([
e
.
next_state
for
e
in
experiences
])).
float
().
to
(
device
)
dones
=
torch
.
from_numpy
(
self
.
stack
([
e
.
done
for
e
in
experiences
]).
astype
(
np
.
uint8
)).
float
().
to
(
device
)
return
states
,
actions
,
rewards
,
next_states
,
dones
def
stack
(
self
,
states
):
sub_dims
=
states
[
0
].
shape
[
1
:]
if
isinstance
(
states
[
0
],
Iterable
)
else
[
1
]
return
np
.
reshape
(
np
.
array
(
states
),
(
len
(
states
),
*
sub_dims
))
def
__len__
(
self
):
return
len
(
self
.
memory
)
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment