Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
B
baselines
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Flatland
baselines
Commits
a54b734a
Commit
a54b734a
authored
5 years ago
by
gmollard
Browse files
Options
Downloads
Patches
Plain Diff
observation benchmark script
parent
bc400346
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
CustomPreprocessor.py
+8
-0
8 additions, 0 deletions
CustomPreprocessor.py
experiment_configs/observation_benchmark/config.gin
+16
-0
16 additions, 0 deletions
experiment_configs/observation_benchmark/config.gin
train_experiment.py
+24
-6
24 additions, 6 deletions
train_experiment.py
with
48 additions
and
6 deletions
CustomPreprocessor.py
+
8
−
0
View file @
a54b734a
...
@@ -54,3 +54,11 @@ class CustomPreprocessor(Preprocessor):
...
@@ -54,3 +54,11 @@ class CustomPreprocessor(Preprocessor):
def
transform
(
self
,
observation
):
def
transform
(
self
,
observation
):
return
norm_obs_clip
(
observation
)
# return the preprocessed observation
return
norm_obs_clip
(
observation
)
# return the preprocessed observation
# class NoPreprocessor:
# def _init_shape(self, obs_space, options):
# num_features = 0
# for space in obs_space:
This diff is collapsed.
Click to expand it.
experiment_configs/observation_benchmark/config.gin
0 → 100644
+
16
−
0
View file @
a54b734a
run_experiment.name = "n_agents_results"
run_experiment.num_iterations = 1002
run_experiment.save_every = 200
run_experiment.hidden_sizes = [32, 32]
run_experiment.map_width = 20
run_experiment.map_height = 20
run_experiment.n_agents = {"grid_search": [2, 5]}
run_experiment.policy_folder_name = "ppo_policy_{config[n_agents]}_agents"
run_experiment.horizon = 50
run_experiment.seed = 123
run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv, @GlobalObsForRailEnv]}
TreeObsForRailEnv.max_depth = 2
This diff is collapsed.
Click to expand it.
train_experiment.py
+
24
−
6
View file @
a54b734a
...
@@ -24,9 +24,12 @@ import tempfile
...
@@ -24,9 +24,12 @@ import tempfile
import
gin
import
gin
from
ray
import
tune
from
ray
import
tune
from
ray.rllib.utils.seed
import
seed
as
set_seed
from
ray.rllib.utils.seed
import
seed
as
set_seed
from
flatland.envs.observations
import
TreeObsForRailEnv
,
GlobalObsForRailEnv
from
ray.rllib.models.preprocessors
import
TupleFlatteningPreprocessor
ModelCatalog
.
register_custom_preprocessor
(
"
my
_prep
"
,
CustomPreprocessor
)
ModelCatalog
.
register_custom_preprocessor
(
"
tree_obs
_prep
"
,
CustomPreprocessor
)
ray
.
init
()
ray
.
init
()
...
@@ -55,7 +58,22 @@ def train(config, reporter):
...
@@ -55,7 +58,22 @@ def train(config, reporter):
"
seed
"
:
config
[
'
seed
'
]}
"
seed
"
:
config
[
'
seed
'
]}
# Observation space and action space definitions
# Observation space and action space definitions
obs_space
=
gym
.
spaces
.
Box
(
low
=-
float
(
'
inf
'
),
high
=
float
(
'
inf
'
),
shape
=
(
105
,))
if
type
(
config
[
"
obs_builder
"
])
==
TreeObsForRailEnv
:
obs_space
=
gym
.
spaces
.
Box
(
low
=-
float
(
'
inf
'
),
high
=
float
(
'
inf
'
),
shape
=
(
105
,))
preprocessor
=
"
tree_obs_prep
"
elif
type
(
config
[
"
obs_builder
"
])
==
GlobalObsForRailEnv
:
obs_space
=
gym
.
spaces
.
Tuple
((
gym
.
spaces
.
Box
(
low
=
0
,
high
=
1
,
shape
=
(
config
[
'
map_height
'
],
config
[
'
map_width
'
],
16
)),
gym
.
spaces
.
Box
(
low
=
0
,
high
=
1
,
shape
=
(
4
,
config
[
'
map_height
'
],
config
[
'
map_width
'
])),
gym
.
spaces
.
Space
(
4
)))
preprocessor
=
TupleFlatteningPreprocessor
else
:
raise
ValueError
(
"
Undefined observation space
"
)
act_space
=
gym
.
spaces
.
Discrete
(
4
)
act_space
=
gym
.
spaces
.
Discrete
(
4
)
# Dict with the different policies to train
# Dict with the different policies to train
...
@@ -69,7 +87,7 @@ def train(config, reporter):
...
@@ -69,7 +87,7 @@ def train(config, reporter):
# Trainer configuration
# Trainer configuration
trainer_config
=
DEFAULT_CONFIG
.
copy
()
trainer_config
=
DEFAULT_CONFIG
.
copy
()
trainer_config
[
'
model
'
]
=
{
"
fcnet_hiddens
"
:
config
[
'
hidden_sizes
'
],
"
custom_preprocessor
"
:
"
my_prep
"
}
trainer_config
[
'
model
'
]
=
{
"
fcnet_hiddens
"
:
config
[
'
hidden_sizes
'
],
"
custom_preprocessor
"
:
preprocessor
}
trainer_config
[
'
multiagent
'
]
=
{
"
policy_graphs
"
:
policy_graphs
,
trainer_config
[
'
multiagent
'
]
=
{
"
policy_graphs
"
:
policy_graphs
,
"
policy_mapping_fn
"
:
policy_mapping_fn
,
"
policy_mapping_fn
"
:
policy_mapping_fn
,
"
policies_to_train
"
:
list
(
policy_graphs
.
keys
())}
"
policies_to_train
"
:
list
(
policy_graphs
.
keys
())}
...
@@ -129,8 +147,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
...
@@ -129,8 +147,8 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
"
seed
"
:
seed
"
seed
"
:
seed
},
},
resources_per_trial
=
{
resources_per_trial
=
{
"
cpu
"
:
1
,
"
cpu
"
:
1
2
,
"
gpu
"
:
0.
0
"
gpu
"
:
0.
5
},
},
local_dir
=
local_dir
local_dir
=
local_dir
)
)
...
@@ -138,6 +156,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
...
@@ -138,6 +156,6 @@ def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
gin
.
external_configurable
(
tune
.
grid_search
)
gin
.
external_configurable
(
tune
.
grid_search
)
dir
=
'
/
home/guillaume/Desktop/distMAgent
/baselines/experiment_configs/
n_agents_experiment
'
# To Modify
dir
=
'
/
mount/SDC/flatland
/baselines/experiment_configs/
observation_benchmark
'
# To Modify
gin
.
parse_config_file
(
dir
+
'
/config.gin
'
)
gin
.
parse_config_file
(
dir
+
'
/config.gin
'
)
run_experiment
(
local_dir
=
dir
)
run_experiment
(
local_dir
=
dir
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment