Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
d162772f
Commit
d162772f
authored
3 years ago
by
nilabha
Browse files
Options
Downloads
Patches
Plain Diff
update pettingzoo changes for flatland3
parent
2438febf
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
examples/env_generators.py
+12
-12
12 additions, 12 deletions
examples/env_generators.py
requirements_dev.txt
+5
-0
5 additions, 0 deletions
requirements_dev.txt
tests/test_pettingzoo_interface.py
+3
-3
3 additions, 3 deletions
tests/test_pettingzoo_interface.py
with
20 additions
and
15 deletions
examples/env_generators.py
+
12
−
12
View file @
d162772f
...
...
@@ -6,7 +6,7 @@ from typing import NamedTuple
from
flatland.envs.malfunction_generators
import
malfunction_from_params
,
MalfunctionParameters
,
ParamMalfunctionGen
,
no_malfunction_generator
from
flatland.envs.rail_env
import
RailEnv
from
flatland.envs.rail_generators
import
sparse_rail_generator
from
flatland.envs.
schedul
e_generators
import
sparse_
schedul
e_generator
from
flatland.envs.
lin
e_generators
import
sparse_
lin
e_generator
from
flatland.envs.agent_utils
import
RailAgentStatus
from
flatland.core.grid.grid4_utils
import
get_new_position
...
...
@@ -59,8 +59,8 @@ def get_shortest_path_action(env,handle):
def
small_v0
(
random_seed
,
observation_builder
,
max_width
=
35
,
max_height
=
35
):
random
.
seed
(
random_seed
)
width
=
25
height
=
25
width
=
30
height
=
30
nr_trains
=
5
max_num_cities
=
4
grid_mode
=
False
...
...
@@ -73,21 +73,21 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
rail_generator
=
sparse_rail_generator
(
max_num_cities
=
max_num_cities
,
seed
=
random_seed
,
grid_mode
=
False
,
max_rails_between_cities
=
max_rails_between_cities
,
max_rails_in_city
=
max_rails_in_city
)
max_rail
_pair
s_in_city
=
max_rails_in_city
)
stochastic_data
=
MalfunctionParameters
(
malfunction_rate
=
malfunction_rate
,
# Rate of malfunction occurence
min_duration
=
malfunction_min_duration
,
# Minimal duration of malfunction
max_duration
=
malfunction_max_duration
# Max duration of malfunction
)
speed_ratio_map
=
None
schedul
e_generator
=
sparse_
schedul
e_generator
(
speed_ratio_map
)
lin
e_generator
=
sparse_
lin
e_generator
(
speed_ratio_map
)
malfunction_generator
=
no_malfunction_generator
()
while
width
<=
max_width
and
height
<=
max_height
:
try
:
env
=
RailEnv
(
width
=
width
,
height
=
height
,
rail_generator
=
rail_generator
,
schedul
e_generator
=
schedul
e_generator
,
number_of_agents
=
nr_trains
,
lin
e_generator
=
lin
e_generator
,
number_of_agents
=
nr_trains
,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator_and_process_data
=
malfunction_generator
,
obs_builder_object
=
observation_builder
,
remove_agents_at_target
=
False
)
...
...
@@ -122,19 +122,19 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma
rail_generator
=
sparse_rail_generator
(
max_num_cities
=
nr_cities
,
seed
=
random_seed
,
grid_mode
=
False
,
max_rails_between_cities
=
max_rails_between_cities
,
max_rails_in_city
=
max_rails_in_cities
)
max_rail
_pair
s_in_city
=
max_rails_in_cities
)
stochastic_data
=
MalfunctionParameters
(
malfunction_rate
=
malfunction_rate
,
# Rate of malfunction occurence
min_duration
=
malfunction_min_duration
,
# Minimal duration of malfunction
max_duration
=
malfunction_max_duration
# Max duration of malfunction
)
schedul
e_generator
=
sparse_
schedul
e_generator
({
1.
:
0.25
,
1.
/
2.
:
0.25
,
1.
/
3.
:
0.25
,
1.
/
4.
:
0.25
})
lin
e_generator
=
sparse_
lin
e_generator
({
1.
:
0.25
,
1.
/
2.
:
0.25
,
1.
/
3.
:
0.25
,
1.
/
4.
:
0.25
})
while
width
<=
max_width
and
height
<=
max_height
:
try
:
env
=
RailEnv
(
width
=
width
,
height
=
height
,
rail_generator
=
rail_generator
,
schedul
e_generator
=
schedul
e_generator
,
number_of_agents
=
nr_trains
,
lin
e_generator
=
lin
e_generator
,
number_of_agents
=
nr_trains
,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
malfunction_generator
=
ParamMalfunctionGen
(
stochastic_data
),
obs_builder_object
=
observation_builder
,
remove_agents_at_target
=
False
)
...
...
@@ -168,7 +168,7 @@ def sparse_env_small(random_seed, observation_builder):
seed
=
seed
,
grid_mode
=
grid_distribution_of_cities
,
max_rails_between_cities
=
max_rails_between_cities
,
max_rails_in_city
=
max_rail_in_cities
,
max_rail
_pair
s_in_city
=
max_rail_in_cities
,
)
# Different agent types (trains) with different speeds.
...
...
@@ -179,7 +179,7 @@ def sparse_env_small(random_seed, observation_builder):
# We can now initiate the schedule generator with the given speed profiles
schedul
e_generator
=
sparse_
schedule
_generator
(
speed_ration_map
)
lin
e_generator
=
sparse_
rail
_generator
(
speed_ration_map
)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
...
...
@@ -192,7 +192,7 @@ def sparse_env_small(random_seed, observation_builder):
rail_env
=
RailEnv
(
width
=
width
,
height
=
height
,
rail_generator
=
rail_generator
,
schedul
e_generator
=
schedul
e_generator
,
lin
e_generator
=
lin
e_generator
,
number_of_agents
=
nr_trains
,
obs_builder_object
=
observation_builder
,
# malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
...
...
This diff is collapsed.
Click to expand it.
requirements_dev.txt
+
5
−
0
View file @
d162772f
...
...
@@ -24,3 +24,8 @@ ipycanvas
graphviz
imageio
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
This diff is collapsed.
Click to expand it.
tests/test_pettingzoo_interface.py
+
3
−
3
View file @
d162772f
...
...
@@ -23,7 +23,7 @@ def test_petting_zoo_interface_env():
# Custom observation builder with predictor
observation_builder
=
TreeObsForRailEnv
(
max_depth
=
2
,
predictor
=
ShortestPathPredictorForRailEnv
(
30
))
seed
=
11
save
=
Fals
e
save
=
Tru
e
np
.
random
.
seed
(
seed
)
experiment_name
=
"
flatland_pettingzoo
"
total_episodes
=
1
...
...
@@ -108,8 +108,8 @@ def test_petting_zoo_interface_env():
frame_list
=
[]
env
.
close
()
env
.
reset
(
random_seed
=
seed
+
ep_no
)
assert
all_actions_pettingzoo_env
.
sort
()
==
all_actions_env
.
sort
()
,
"
actions do not match for shortest path
"
min_len
=
min
(
len
(
all_actions_pettingzoo_env
),
len
(
all_actions_env
))
assert
all_actions_pettingzoo_env
[:
min_len
]
==
all_actions_env
[:
min_len
]
,
"
actions do not match for shortest path
"
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment