Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
marl-flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manavsinghal157
marl-flatland
Commits
4c1427fe
Commit
4c1427fe
authored
4 years ago
by
Egli Adrian (IT-SCI-API-PFI)
Browse files
Options
Downloads
Patches
Plain Diff
FastTreeObs working
parent
d6103087
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
checkpoints/201014015722-1500.pth
+0
-0
0 additions, 0 deletions
checkpoints/201014015722-1500.pth
reinforcement_learning/dddqn_policy.py
+19
-11
19 additions, 11 deletions
reinforcement_learning/dddqn_policy.py
run.py
+22
-31
22 additions, 31 deletions
run.py
with
41 additions
and
42 deletions
checkpoints/201014015722-1500.pth
deleted
100644 → 0
+
0
−
0
View file @
d6103087
File deleted
This diff is collapsed.
Click to expand it.
reinforcement_learning/dddqn_policy.py
+
19
−
11
View file @
4c1427fe
...
...
@@ -22,7 +22,7 @@ class DDDQNPolicy(Policy):
self
.
state_size
=
state_size
self
.
action_size
=
action_size
self
.
double_dqn
=
True
self
.
hidsize
=
1
self
.
hidsize
=
1
28
if
not
evaluation_mode
:
self
.
hidsize
=
parameters
.
hidden_size
...
...
@@ -34,7 +34,7 @@ class DDDQNPolicy(Policy):
self
.
gamma
=
parameters
.
gamma
self
.
buffer_min_size
=
parameters
.
buffer_min_size
# Device
# Device
if
parameters
.
use_gpu
and
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"
cuda:0
"
)
# print("🐇 Using GPU")
...
...
@@ -43,7 +43,8 @@ class DDDQNPolicy(Policy):
# print("🐢 Using CPU")
# Q-Network
self
.
qnetwork_local
=
DuelingQNetwork
(
state_size
,
action_size
,
hidsize1
=
self
.
hidsize
,
hidsize2
=
self
.
hidsize
).
to
(
self
.
device
)
self
.
qnetwork_local
=
DuelingQNetwork
(
state_size
,
action_size
,
hidsize1
=
self
.
hidsize
,
hidsize2
=
self
.
hidsize
).
to
(
self
.
device
)
if
not
evaluation_mode
:
self
.
qnetwork_target
=
copy
.
deepcopy
(
self
.
qnetwork_local
)
...
...
@@ -119,15 +120,22 @@ class DDDQNPolicy(Policy):
torch
.
save
(
self
.
qnetwork_target
.
state_dict
(),
filename
+
"
.target
"
)
def
load
(
self
,
filename
):
if
os
.
path
.
exists
(
filename
+
"
.local
"
)
and
os
.
path
.
exists
(
filename
+
"
.target
"
):
self
.
qnetwork_local
.
load_state_dict
(
torch
.
load
(
filename
+
"
.local
"
))
self
.
qnetwork_target
.
load_state_dict
(
torch
.
load
(
filename
+
"
.target
"
))
else
:
if
os
.
path
.
exists
(
filename
):
self
.
qnetwork_local
.
load_state_dict
(
torch
.
load
(
filename
))
self
.
qnetwork_target
.
load_state_dict
(
torch
.
load
(
filename
))
try
:
if
os
.
path
.
exists
(
filename
+
"
.local
"
)
and
os
.
path
.
exists
(
filename
+
"
.target
"
):
self
.
qnetwork_local
.
load_state_dict
(
torch
.
load
(
filename
+
"
.local
"
))
print
(
"
qnetwork_local loaded (
'
{}
'
)
"
.
format
(
filename
+
"
.local
"
))
if
self
.
evaluation_mode
:
self
.
qnetwork_target
=
copy
.
deepcopy
(
self
.
qnetwork_local
)
else
:
self
.
qnetwork_target
.
load_state_dict
(
torch
.
load
(
filename
+
"
.target
"
))
print
(
"
qnetwork_target loaded (
'
{}
'
)
"
.
format
(
filename
+
"
.target
"
))
else
:
raise
FileNotFoundError
(
"
Couldn
'
t load policy from:
'
{}
'
,
'
{}
'"
.
format
(
filename
+
"
.local
"
,
filename
+
"
.target
"
))
print
(
"
>> Checkpoint not found, using untrained policy! (
'
{}
'
,
'
{}
'
)
"
.
format
(
filename
+
"
.local
"
,
filename
+
"
.target
"
))
except
Exception
as
exc
:
print
(
exc
)
print
(
"
Couldn
'
t load policy from, using untrained policy! (
'
{}
'
,
'
{}
'
)
"
.
format
(
filename
+
"
.local
"
,
filename
+
"
.target
"
))
def
save_replay_buffer
(
self
,
filename
):
memory
=
self
.
memory
.
memory
...
...
This diff is collapsed.
Click to expand it.
run.py
+
22
−
31
View file @
4c1427fe
import
os
import
sys
import
time
from
argparse
import
Namespace
from
pathlib
import
Path
import
numpy
as
np
import
time
import
torch
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
from
flatland.envs.observations
import
TreeObsForRailEnv
from
flatland.evaluators.client
import
FlatlandRemoteClient
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.evaluators.client
import
FlatlandRemoteClient
from
flatland.evaluators.client
import
TimeoutException
from
utils.deadlock_check
import
check_if_all_blocked
from
utils.fast_tree_obs
import
FastTreeObs
base_dir
=
Path
(
__file__
).
resolve
().
parent
.
parent
sys
.
path
.
append
(
str
(
base_dir
))
from
reinforcement_learning.dddqn_policy
import
DDDQNPolicy
from
utils.observation_utils
import
normalize_observation
####################################################
# EVALUATION PARAMETERS
...
...
@@ -28,7 +24,7 @@ from utils.observation_utils import normalize_observation
VERBOSE
=
True
# Checkpoint to use (remember to push it!)
checkpoint
=
"
checkpoints/2011031
50429-25
00.pth
"
checkpoint
=
"
./
checkpoints/2011031
60541-18
00.pth
"
# Use last action cache
USE_ACTION_CACHE
=
True
...
...
@@ -44,20 +40,15 @@ remote_client = FlatlandRemoteClient()
# Observation builder
predictor
=
ShortestPathPredictorForRailEnv
(
observation_max_path_depth
)
tree_observation
=
TreeObs
ForRailEnv
(
max_depth
=
observation_tree_depth
,
predictor
=
predictor
)
tree_observation
=
Fast
TreeObs
(
max_depth
=
observation_tree_depth
)
# Calculates state and action sizes
n_nodes
=
sum
([
np
.
power
(
4
,
i
)
for
i
in
range
(
observation_tree_depth
+
1
)])
state_size
=
tree_observation
.
observation_dim
*
n_nodes
state_size
=
tree_observation
.
observation_dim
action_size
=
5
# Creates the policy. No GPU on evaluation server.
policy
=
DDDQNPolicy
(
state_size
,
action_size
,
Namespace
(
**
{
'
use_gpu
'
:
False
}),
evaluation_mode
=
True
)
if
os
.
path
.
isfile
(
checkpoint
):
policy
.
load
(
checkpoint
)
else
:
print
(
"
Checkpoint not found, using untrained policy! (path: {})
"
.
format
(
checkpoint
))
policy
.
load
(
checkpoint
)
#####################################################################
# Main evaluation loop
...
...
@@ -124,15 +115,13 @@ while True:
time_start
=
time
.
time
()
action_dict
=
{}
for
agent
in
range
(
nb_agents
):
if
observation
[
agent
]
and
info
[
'
action_required
'
][
agent
]:
if
info
[
'
action_required
'
][
agent
]:
if
agent
in
agent_last_obs
and
np
.
all
(
agent_last_obs
[
agent
]
==
observation
[
agent
]):
# cache hit
action
=
agent_last_action
[
agent
]
nb_hit
+=
1
else
:
# otherwise, run normalization and inference
norm_obs
=
normalize_observation
(
observation
[
agent
],
tree_depth
=
observation_tree_depth
,
observation_radius
=
observation_radius
)
action
=
policy
.
act
(
norm_obs
,
eps
=
0.0
)
action
=
policy
.
act
(
observation
[
agent
],
eps
=
0.0
)
action_dict
[
agent
]
=
action
...
...
@@ -163,16 +152,17 @@ while True:
nb_agents_done
=
sum
(
done
[
idx
]
for
idx
in
local_env
.
get_agent_handles
())
if
VERBOSE
or
done
[
'
__all__
'
]:
print
(
"
Step {}/{}
\t
Agents done: {}
\t
Obs time {:.3f}s
\t
Inference time {:.5f}s
\t
Step time {:.3f}s
\t
Cache hits {}
\t
No-ops? {}
"
.
format
(
str
(
steps
).
zfill
(
4
),
max_nb_steps
,
nb_agents_done
,
obs_time
,
agent_time
,
step_time
,
nb_hit
,
no_ops_mode
),
end
=
"
\r
"
)
print
(
"
Step {}/{}
\t
Agents done: {}
\t
Obs time {:.3f}s
\t
Inference time {:.5f}s
\t
Step time {:.3f}s
\t
Cache hits {}
\t
No-ops? {}
"
.
format
(
str
(
steps
).
zfill
(
4
),
max_nb_steps
,
nb_agents_done
,
obs_time
,
agent_time
,
step_time
,
nb_hit
,
no_ops_mode
),
end
=
"
\r
"
)
if
done
[
'
__all__
'
]:
# When done['__all__'] == True, then the evaluation of this
...
...
@@ -190,7 +180,8 @@ while True:
np_time_taken_by_controller
=
np
.
array
(
time_taken_by_controller
)
np_time_taken_per_step
=
np
.
array
(
time_taken_per_step
)
print
(
"
Mean/Std of Time taken by Controller :
"
,
np_time_taken_by_controller
.
mean
(),
np_time_taken_by_controller
.
std
())
print
(
"
Mean/Std of Time taken by Controller :
"
,
np_time_taken_by_controller
.
mean
(),
np_time_taken_by_controller
.
std
())
print
(
"
Mean/Std of Time per Step :
"
,
np_time_taken_per_step
.
mean
(),
np_time_taken_per_step
.
std
())
print
(
"
=
"
*
100
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment