Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
elrichgro
Flatland
Commits
b6c4bb58
Commit
b6c4bb58
authored
5 years ago
by
hagrid67
Browse files
Options
Downloads
Patches
Plain Diff
added basic save episode functionality to rail_env
and hacked custom_observation_example.py to save an env with an episode
parent
77511dfc
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
examples/custom_observation_example.py
+7
-2
7 additions, 2 deletions
examples/custom_observation_example.py
flatland/envs/rail_env.py
+34
-8
34 additions, 8 deletions
flatland/envs/rail_env.py
with
41 additions
and
10 deletions
examples/custom_observation_example.py
+
7
−
2
View file @
b6c4bb58
...
...
@@ -206,7 +206,8 @@ env = RailEnv(width=10,
rail_generator
=
complex_rail_generator
(
nr_start_goal
=
5
,
nr_extra
=
1
,
min_dist
=
8
,
max_dist
=
99999
,
seed
=
0
),
schedule_generator
=
complex_schedule_generator
(),
number_of_agents
=
3
,
obs_builder_object
=
CustomObsBuilder
)
obs_builder_object
=
CustomObsBuilder
,
save_episodes
=
True
)
obs
=
env
.
reset
()
env_renderer
=
RenderTool
(
env
,
gl
=
"
PILSVG
"
)
...
...
@@ -222,4 +223,8 @@ for step in range(100):
obs
,
all_rewards
,
done
,
_
=
env
.
step
(
action_dict
)
print
(
"
Rewards:
"
,
all_rewards
,
"
[done=
"
,
done
,
"
]
"
)
env_renderer
.
render_env
(
show
=
True
,
frames
=
True
,
show_observations
=
True
,
show_predictions
=
False
)
time
.
sleep
(
0.5
)
time
.
sleep
(
0.01
)
sFilename
=
"
saved_episode_{:}x{:}.mpk
"
.
format
(
*
env
.
rail
.
grid
.
shape
)
env
.
save
(
sFilename
)
This diff is collapsed.
Click to expand it.
flatland/envs/rail_env.py
+
34
−
8
View file @
b6c4bb58
...
...
@@ -108,7 +108,8 @@ class RailEnv(Environment):
number_of_agents
=
1
,
obs_builder_object
=
TreeObsForRailEnv
(
max_depth
=
2
),
max_episode_steps
=
None
,
stochastic_data
=
None
stochastic_data
=
None
,
save_episodes
=
False
):
"""
Environment init.
...
...
@@ -201,6 +202,11 @@ class RailEnv(Environment):
self
.
valid_positions
=
None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self
.
save_episodes
=
save_episodes
self
.
episodes
=
[]
self
.
cur_episode
=
[]
# no more agent_handles
def
get_agent_handles
(
self
):
return
range
(
self
.
get_num_agents
())
...
...
@@ -291,7 +297,7 @@ class RailEnv(Environment):
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if
agent
.
malfunction_data
[
'
malfunction_rate
'
]
>
0
>=
agent
.
malfunction_data
[
'
malfunction
'
]
and
\
agent
.
malfunction_data
[
'
next_malfunction
'
]
<=
0
:
agent
.
malfunction_data
[
'
next_malfunction
'
]
<=
0
:
# Increase number of malfunctions
agent
.
malfunction_data
[
'
nr_malfunctions
'
]
+=
1
...
...
@@ -310,6 +316,10 @@ class RailEnv(Environment):
# TODO refactor to decrease length of this method!
def
step
(
self
,
action_dict_
):
if
self
.
save_episodes
:
self
.
record_timestep
()
self
.
_elapsed_steps
+=
1
# Reset the step rewards
...
...
@@ -364,7 +374,7 @@ class RailEnv(Environment):
self
.
rewards_dict
[
i_agent
]
+=
self
.
stop_penalty
if
not
agent
.
moving
and
not
(
action
==
RailEnvActions
.
DO_NOTHING
or
action
==
RailEnvActions
.
STOP_MOVING
):
action
==
RailEnvActions
.
DO_NOTHING
or
action
==
RailEnvActions
.
STOP_MOVING
):
# Allow agent to start with any forward or direction action
agent
.
moving
=
True
self
.
rewards_dict
[
i_agent
]
+=
self
.
start_penalty
...
...
@@ -435,8 +445,8 @@ class RailEnv(Environment):
# so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
cell_free
,
new_cell_valid
,
new_direction
,
new_position
,
transition_valid
=
\
self
.
_check_action_on_agent
(
agent
.
speed_data
[
'
transition_action_on_cellexit
'
],
agent
)
if
cell_free
:
agent
.
position
=
new_position
...
...
@@ -475,6 +485,15 @@ class RailEnv(Environment):
return
self
.
_get_observations
(),
self
.
rewards_dict
,
self
.
dones
,
info_dict
def
record_timestep
(
self
):
'''
Record the positions and orientations of all agents in memory.
'''
list_timestep
=
[]
for
i_agent
in
range
(
self
.
get_num_agents
()):
agent
=
self
.
agents
[
i_agent
]
list_timestep
.
append
([
*
agent
.
position
,
int
(
agent
.
direction
)])
self
.
cur_episode
.
append
(
list_timestep
)
def
_check_action_on_agent
(
self
,
action
,
agent
):
# compute number of possible transitions in the current
...
...
@@ -540,13 +559,16 @@ class RailEnv(Environment):
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
episode_data
=
[
self
.
cur_episode
]
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
episode_data
,
use_bin_type
=
True
)
msg_data
=
{
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
}
"
agents
"
:
agent_data
,
"
episodes
"
:
episode_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
def
get_agent_state_msg
(
self
):
...
...
@@ -585,9 +607,11 @@ class RailEnv(Environment):
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_static_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents_static
]
agent_data
=
[
agent
.
to_list
()
for
agent
in
self
.
agents
]
episode_data
=
[
self
.
cur_episode
]
msgpack
.
packb
(
grid_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
agent_static_data
,
use_bin_type
=
True
)
msgpack
.
packb
(
episode_data
,
use_bin_type
=
True
)
if
hasattr
(
self
.
obs_builder
,
'
distance_map
'
):
distance_map_data
=
self
.
obs_builder
.
distance_map
msgpack
.
packb
(
distance_map_data
,
use_bin_type
=
True
)
...
...
@@ -595,12 +619,14 @@ class RailEnv(Environment):
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
,
"
distance_maps
"
:
distance_map_data
}
"
distance_maps
"
:
distance_map_data
,
"
episodes
"
:
episode_data
}
else
:
msg_data
=
{
"
grid
"
:
grid_data
,
"
agents_static
"
:
agent_static_data
,
"
agents
"
:
agent_data
}
"
agents
"
:
agent_data
,
"
episodes
"
:
episode_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment