Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
elrichgro
Flatland
Compare revisions
master to b6c4bb5824bb36ebc2ad2cf023013c02b4fbea64
Compare revisions
Changes are shown as if the
source
revision was being merged into the
target
revision.
Learn more about comparing revisions.
Source
elrichgro/flatland
Select target project
No results found
b6c4bb5824bb36ebc2ad2cf023013c02b4fbea64
Select Git revision
Swap
Target
flatland/flatland
Select target project
flatland/flatland
stefan_otte/flatland
jiaodaxiaozi/flatland
sfwatergit/flatland
utozx126/flatland
ChenKuanSun/flatland
ashivani/flatland
minhhoa/flatland
pranjal_dhole/flatland
darthgera123/flatland
rivesunder/flatland
thomaslecat/flatland
joel_joseph/flatland
kchour/flatland
alex_zharichenko/flatland
yoogottamk/flatland
troye_fang/flatland
elrichgro/flatland
jun_jin/flatland
nimishsantosh107/flatland
20 results
master
Select Git revision
Show changes
Only incoming changes from source
Include changes to target since source was created
Compare
Commits on Source (1)
added basic save episode functionality to rail_env
· b6c4bb58
hagrid67
authored
5 years ago
and hacked custom_observation_example.py to save an env with an episode
b6c4bb58
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
examples/custom_observation_example.py
+7
-2
7 additions, 2 deletions
examples/custom_observation_example.py
flatland/envs/rail_env.py
+34
-8
34 additions, 8 deletions
flatland/envs/rail_env.py
with
41 additions
and
10 deletions
examples/custom_observation_example.py
View file @
b6c4bb58
...
@@ -206,7 +206,8 @@ env = RailEnv(width=10,
...
@@ -206,7 +206,8 @@ env = RailEnv(width=10,
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
5
,
nr_extra
=
1
,
min_dist
=
8
,
max_dist
=
99999
,
seed
=
0
),
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
5
,
nr_extra
=
1
,
min_dist
=
8
,
max_dist
=
99999
,
seed
=
0
),
schedule_generator
=
complex_schedule_generator
(),
schedule_generator
=
complex_schedule_generator
(),
number_of_agents
=
3
,
number_of_agents
=
3
,
obs_builder_object
=
CustomObsBuilder
)
obs_builder_object
=
CustomObsBuilder
,
save_episodes
=
True
)
obs
=
env
.
reset
()
obs
=
env
.
reset
()
env_renderer
=
RenderTool
(
env
,
gl
=
"
PILSVG
"
)
env_renderer
=
RenderTool
(
env
,
gl
=
"
PILSVG
"
)
...
@@ -222,4 +223,8 @@ for step in range(100):
...
@@ -222,4 +223,8 @@ for step in range(100):
obs
,
all_rewards
,
done
,
_
=
env
.
step
(
action_dict
)
obs
,
all_rewards
,
done
,
_
=
env
.
step
(
action_dict
)
print
(
"
Rewards:
"
,
all_rewards
,
"
[done=
"
,
done
,
"
]
"
)
print
(
"
Rewards:
"
,
all_rewards
,
"
[done=
"
,
done
,
"
]
"
)
env_renderer
.
render_env
(
show
=
True
,
frames
=
True
,
show_observations
=
True
,
show_predictions
=
False
)
env_renderer
.
render_env
(
show
=
True
,
frames
=
True
,
show_observations
=
True
,
show_predictions
=
False
)
time
.
sleep
(
0.5
)
time
.
sleep
(
0.01
)
sFilename
=
"
saved_episode_{:}x{:}.mpk
"
.
format
(
*
env
.
rail
.
grid
.
shape
)
env
.
save
(
sFilename
)
This diff is collapsed.
Click to expand it.
flatland/envs/rail_env.py
View file @
b6c4bb58
...
@@ -108,7 +108,8 @@ class RailEnv(Environment):
...
@@ -108,7 +108,8 @@ class RailEnv(Environment):
number_of_agents
=
1
,
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
),
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
),
max_episode_steps
=
None
,
max_episode_steps
=
None
,
stochastic_data
=
None
stochastic_data
=
None
,
save_episodes
=
False
):
):
"""
"""
Environment init.
Environment init.
...
@@ -201,6 +202,11 @@ class RailEnv(Environment):
...
@@ -201,6 +202,11 @@ class RailEnv(Environment):
self
.
valid_positions
=
None
self
.
valid_positions
=
None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self
.
save_episodes
=
save_episodes
self
.
episodes
=
[]
self
.
cur_episode
=
[]
# no more agent_handles
# no more agent_handles
def
get_agent_handles
(
self
):
def
get_agent_handles
(
self
):
return
range
(
self
.
get_num_agents
())
return
range
(
self
.
get_num_agents
())
...
@@ -291,7 +297,7 @@ class RailEnv(Environment):
...
@@ -291,7 +297,7 @@ class RailEnv(Environment):
# If counter has come to zero --> Agent has malfunction
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
# set next malfunction time and duration of current malfunction
if
agent
.
malfunction_data
[
'
malfunction_rate
'
]
>
0
>=
agent
.
malfunction_data
[
'
malfunction
'
]
and
\
if
agent
.
malfunction_data
[
'
malfunction_rate
'
]
>
0
>=
agent
.
malfunction_data
[
'
malfunction
'
]
and
\
agent
.
malfunction_data
[
'
next_malfunction
'
]
<=
0
:
agent
.
malfunction_data
[
'
next_malfunction
'
]
<=
0
:
# Increase number of malfunctions
# Increase number of malfunctions
agent
.
malfunction_data
[
'
nr_malfunctions
'
]
+=
1
agent
.
malfunction_data
[
'
nr_malfunctions
'
]
+=
1
...
@@ -310,6 +316,10 @@ class RailEnv(Environment):
...
@@ -310,6 +316,10 @@ class RailEnv(Environment):
# TODO refactor to decrease length of this method!
# TODO refactor to decrease length of this method!
def
step
(
self
,
action_dict_
):
def
step
(
self
,
action_dict_
):
if
self
.
save_episodes
:
self
.
record_timestep
()
self
.
_elapsed_steps
+=
1
self
.
_elapsed_steps
+=
1
# Reset the step rewards
# Reset the step rewards
...
@@ -364,7 +374,7 @@ class RailEnv(Environment):
...
@@ -364,7 +374,7 @@ class RailEnv(Environment):
self
.
rewards_dict
[
i_agent
]
+=
self
.
stop_penalty
self
.
rewards_dict
[
i_agent
]
+=
self
.
stop_penalty
if
not
agent
.
moving
and
not
(
if
not
agent
.
moving
and
not
(
action
==
RailEnvActions
.
DO_NOTHING
or
action
==
RailEnvActions
.
STOP_MOVING
):
action
==
RailEnvActions
.
DO_NOTHING
or
action
==
RailEnvActions
.
STOP_MOVING
):
# Allow agent to start with any forward or direction action
# Allow agent to start with any forward or direction action
agent
.
moving
=
True
agent
.
moving
=
True
self
.
rewards_dict
[
i_agent
]
+=
self
.
start_penalty
self
.
rewards_dict
[
i_agent
]
+=
self
.
start_penalty
...
@@ -435,8 +445,8 @@ class RailEnv(Environment):
...
@@ -435,8 +445,8 @@ class RailEnv(Environment):
# so we only have to check cell_free now!
# so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
self
.
_check_action_on_agent
(
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
\
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
if
cell_free
:
if
cell_free
:
agent
.
position
=
new_position
agent
.
position
=
new_position
...
@@ -475,6 +485,15 @@ class RailEnv(Environment):
...
@@ -475,6 +485,15 @@ class RailEnv(Environment):
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
def
record_timestep
(
self
):
'''
Record the positions and orientations of all agents in memory.
'''
list_timestep
=
[]
for
i_agent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
i_agent
]
list_timestep
.
append
([
*
agent
.
position
,
int
(
agent
.
direction
)])
self
.
cur_episode
.
append
(
list_timestep
)
def
_check_action_on_agent
(
self
,
action
,
agent
):
def
_check_action_on_agent
(
self
,
action
,
agent
):
# compute number of possible transitions in the current
# compute number of possible transitions in the current
...
@@ -540,13 +559,16 @@ class RailEnv(Environment):
...
@@ -540,13 +559,16 @@ class RailEnv(Environment):
grid_data
=
self
.
rail
.
grid
.
tolist
()
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
episode_data
=
[
self
.
cur_episode
]
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
episode_data
,
use_bin_type
=
True
)
msg_data
=
{
msg_data
=
{
"
grid
"
:
grid_data
,
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
}
"
agents
"
:
agent_data
,
"
episodes
"
:
episode_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
def
get_agent_state_msg
(
self
):
def
get_agent_state_msg
(
self
):
...
@@ -585,9 +607,11 @@ class RailEnv(Environment):
...
@@ -585,9 +607,11 @@ class RailEnv(Environment):
grid_data
=
self
.
rail
.
grid
.
tolist
()
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
episode_data
=
[
self
.
cur_episode
]
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
episode_data
,
use_bin_type
=
True
)
if
hasattr
(
self
.
obs_builder
,
'
distance_map
'
):
if
hasattr
(
self
.
obs_builder
,
'
distance_map
'
):
distance_map_data
=
self
.
obs_builder
.
distance_map
distance_map_data
=
self
.
obs_builder
.
distance_map
msgpack
.
packb
(
distance_map_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
distance_map_data
,
use_bin_type
=
True
)
...
@@ -595,12 +619,14 @@ class RailEnv(Environment):
...
@@ -595,12 +619,14 @@ class RailEnv(Environment):
"
grid
"
:
grid_data
,
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
,
"
agents
"
:
agent_data
,
"
distance_maps
"
:
distance_map_data
}
"
distance_maps
"
:
distance_map_data
,
"
episodes
"
:
episode_data
}
else
:
else
:
msg_data
=
{
msg_data
=
{
"
grid
"
:
grid_data
,
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
}
"
agents
"
:
agent_data
,
"
episodes
"
:
episode_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
...
...
This diff is collapsed.
Click to expand it.