Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
6df9e4d3
Commit
6df9e4d3
authored
Sep 10, 2021
by
Dipam Chakraborty
Browse files
fix serialization of agents
parent
f4dc1668
Changes
11
Hide whitespace changes
Inline
Side-by-side
flatland/envs/agent_utils.py
View file @
6df9e4d3
...
...
@@ -30,12 +30,31 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
(
'old_position'
,
Tuple
[
int
,
int
]),
(
'speed_counter'
,
SpeedCounter
),
(
'action_saver'
,
ActionSaver
),
(
'state'
,
TrainState
),
(
'state_machine'
,
TrainStateMachine
),
(
'malfunction_handler'
,
MalfunctionHandler
),
])
def
load_env_agent
(
agent_tuple
:
Agent
):
return
EnvAgent
(
initial_position
=
agent_tuple
.
initial_position
,
initial_direction
=
agent_tuple
.
initial_direction
,
direction
=
agent_tuple
.
direction
,
target
=
agent_tuple
.
target
,
moving
=
agent_tuple
.
moving
,
earliest_departure
=
agent_tuple
.
earliest_departure
,
latest_arrival
=
agent_tuple
.
latest_arrival
,
handle
=
agent_tuple
.
handle
,
position
=
agent_tuple
.
position
,
arrival_time
=
agent_tuple
.
arrival_time
,
old_direction
=
agent_tuple
.
old_direction
,
old_position
=
agent_tuple
.
old_position
,
speed_counter
=
agent_tuple
.
speed_counter
,
action_saver
=
agent_tuple
.
action_saver
,
state_machine
=
agent_tuple
.
state_machine
,
malfunction_handler
=
agent_tuple
.
malfunction_handler
,
)
@
attrs
class
EnvAgent
:
# INIT FROM HERE IN _from_line()
...
...
@@ -105,13 +124,13 @@ class EnvAgent:
earliest_departure
=
self
.
earliest_departure
,
latest_arrival
=
self
.
latest_arrival
,
malfunction_data
=
self
.
malfunction_data
,
handle
=
self
.
handle
,
state
=
self
.
state
,
handle
=
self
.
handle
,
position
=
self
.
position
,
old_direction
=
self
.
old_direction
,
old_position
=
self
.
old_position
,
speed_counter
=
self
.
speed_counter
,
action_saver
=
self
.
action_saver
,
arrival_time
=
self
.
arrival_time
,
state_machine
=
self
.
state_machine
,
malfunction_handler
=
self
.
malfunction_handler
)
...
...
@@ -176,13 +195,13 @@ class EnvAgent:
@
classmethod
def
load_legacy_static_agent
(
cls
,
static_agents_data
:
Tuple
):
raise
NotImplementedError
(
"Not implemented for Flatland 3"
)
agents
=
[]
for
i
,
static_agent
in
enumerate
(
static_agents_data
):
if
len
(
static_agent
)
>=
6
:
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
moving
=
static_agent
[
3
],
speed_data
=
static_agent
[
4
],
malfunction_data
=
static_agent
[
5
],
handle
=
i
)
speed_counter
=
SpeedCounter
(
static_agent
[
4
][
'speed'
]),
malfunction_data
=
static_agent
[
5
],
handle
=
i
)
else
:
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
...
...
@@ -205,7 +224,7 @@ class EnvAgent:
return
f
"
\n
\
handle(agent index):
{
self
.
handle
}
\n
\
initial_position:
{
self
.
initial_position
}
initial_direction:
{
self
.
initial_direction
}
\n
\
position:
{
self
.
position
}
direction:
{
self
.
posi
tion
}
target:
{
self
.
target
}
\n
\
position:
{
self
.
position
}
direction:
{
self
.
direc
tion
}
target:
{
self
.
target
}
\n
\
earliest_departure:
{
self
.
earliest_departure
}
latest_arrival:
{
self
.
latest_arrival
}
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
malfunction_data:
{
self
.
malfunction_data
}
\n
\
...
...
flatland/envs/persistence.py
View file @
6df9e4d3
...
...
@@ -2,28 +2,21 @@
import
pickle
import
msgpack
import
msgpack_numpy
import
numpy
as
np
import
msgpack_numpy
msgpack_numpy
.
patch
()
from
flatland.envs
import
rail_env
#from flatland.core.env import Environment
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.agent_utils
import
Agent
,
EnvAgent
from
flatland.envs.distance_map
import
DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
from
flatland.envs.agent_utils
import
EnvAgent
,
load_env_agent
# cannot import objects / classes directly because of circular import
from
flatland.envs
import
malfunction_generators
as
mal_gen
from
flatland.envs
import
rail_generators
as
rail_gen
from
flatland.envs
import
line_generators
as
line_gen
msgpack_numpy
.
patch
()
class
RailEnvPersister
(
object
):
...
...
@@ -163,7 +156,8 @@ class RailEnvPersister(object):
# remove the legacy key
del
env_dict
[
"agents_static"
]
elif
"agents"
in
env_dict
:
env_dict
[
"agents"
]
=
[
EnvAgent
(
*
d
[
0
:
len
(
d
)])
for
d
in
env_dict
[
"agents"
]]
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict
[
"agents"
]
=
[
load_env_agent
(
d
)
for
d
in
env_dict
[
"agents"
]]
return
env_dict
...
...
flatland/envs/predictions.py
View file @
6df9e4d3
...
...
@@ -10,6 +10,7 @@ from flatland.envs.rail_env_action import RailEnvActions
from
flatland.envs.rail_env_shortest_paths
import
get_shortest_paths
from
flatland.utils.ordered_set
import
OrderedSet
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.step_utils
import
transition_utils
class
DummyPredictorForRailEnv
(
PredictionBuilder
):
...
...
@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for
action
in
action_priorities
:
cell_is_free
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
env
.
_
check_action_on_agent
(
action
,
agent
)
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
transition_utils
.
check_action_on_agent
(
action
,
self
.
env
.
rail
,
agent
.
position
,
agent
.
direction
)
if
all
([
new_cell_isValid
,
transition_isValid
]):
# move and change direction to face the new_direction that was
# performed
...
...
flatland/envs/rail_env.py
View file @
6df9e4d3
...
...
@@ -473,6 +473,12 @@ class RailEnv(Environment):
self
.
dones
[
"__all__"
]
=
True
def
handle_done_state
(
self
,
agent
):
if
agent
.
state
==
TrainState
.
DONE
:
agent
.
arrival_time
=
self
.
_elapsed_steps
if
self
.
remove_agents_at_target
:
agent
.
position
=
None
def
step
(
self
,
action_dict_
:
Dict
[
int
,
RailEnvActions
]):
"""
Updates rewards for the agents at a step.
...
...
@@ -547,7 +553,7 @@ class RailEnv(Environment):
if
movement_allowed
:
agent
.
position
=
agent_transition_data
.
position
agent
.
direction
=
agent_transition_data
.
direction
preprocessed_action
=
agent_transition_data
.
preprocessed_action
## Update states
...
...
@@ -555,9 +561,8 @@ class RailEnv(Environment):
agent
.
state_machine
.
set_transition_signals
(
state_transition_signals
)
agent
.
state_machine
.
step
()
# Remove agent is required
if
self
.
remove_agents_at_target
and
agent
.
state
==
TrainState
.
DONE
:
agent
.
position
=
None
# Handle done state actions, optionally remove agents
self
.
handle_done_state
(
agent
)
have_all_agents_ended
&=
(
agent
.
state
==
TrainState
.
DONE
)
...
...
flatland/envs/step_utils/action_saver.py
View file @
6df9e4d3
...
...
@@ -14,12 +14,19 @@ class ActionSaver:
def
save_action_if_allowed
(
self
,
action
,
state
):
if
not
self
.
is_action_saved
and
\
action
.
is_moving_action
()
and
\
not
state
.
is_malfunction_state
():
if
action
.
is_moving_action
()
and
\
not
self
.
is_action_saved
and
\
not
state
.
is_malfunction_state
()
and
\
not
state
==
TrainState
.
DONE
:
self
.
saved_action
=
action
def
clear_saved_action
(
self
):
self
.
saved_action
=
None
def
to_dict
(
self
):
return
{
"saved_action"
:
self
.
saved_action
}
def
from_dict
(
self
,
load_dict
):
self
.
saved_action
=
load_dict
[
'saved_action'
]
flatland/envs/step_utils/malfunction_handler.py
View file @
6df9e4d3
...
...
@@ -40,6 +40,12 @@ class MalfunctionHandler:
if
self
.
_malfunction_down_counter
>
0
:
self
.
_malfunction_down_counter
-=
1
def
to_dict
(
self
):
return
{
"malfunction_down_counter"
:
self
.
_malfunction_down_counter
}
def
from_dict
(
self
,
load_dict
):
self
.
_malfunction_down_counter
=
load_dict
[
'malfunction_down_counter'
]
...
...
flatland/envs/step_utils/speed_counter.py
View file @
6df9e4d3
...
...
@@ -3,8 +3,7 @@ from flatland.envs.step_utils.states import TrainState
class
SpeedCounter
:
def
__init__
(
self
,
speed
):
self
.
speed
=
speed
self
.
max_count
=
int
(
1
/
speed
)
-
1
self
.
_speed
=
speed
def
update_counter
(
self
,
state
,
old_position
):
# When coming onto the map, do no update speed counter
...
...
@@ -30,3 +29,17 @@ class SpeedCounter:
def
is_cell_exit
(
self
):
return
self
.
counter
==
self
.
max_count
@
property
def
speed
(
self
):
return
self
.
_speed
@
property
def
max_count
(
self
):
return
int
(
1
/
self
.
_speed
)
-
1
def
to_dict
(
self
):
return
{
"speed"
:
self
.
_speed
}
def
from_dict
(
self
,
load_dict
):
self
.
_speed
=
load_dict
[
'speed'
]
flatland/envs/step_utils/state_machine.py
View file @
6df9e4d3
...
...
@@ -121,7 +121,7 @@ class TrainStateMachine:
def
reset
(
self
):
self
.
_state
=
self
.
_initial_state
self
.
st_signals
=
{}
self
.
st_signals
=
StateTransitionSignals
()
self
.
clear_next_state
()
@
property
...
...
@@ -135,5 +135,17 @@ class TrainStateMachine:
def
set_transition_signals
(
self
,
state_transition_signals
):
self
.
st_signals
=
state_transition_signals
def
__repr__
(
self
):
return
f
"
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
st_signals:
{
self
.
st_signals
}
"
def
to_dict
(
self
):
return
{
"state"
:
self
.
_state
}
def
from_dict
(
self
,
load_dict
):
self
.
set_state
(
load_dict
[
'state'
])
tests/test_flatland_envs_observations.py
View file @
6df9e4d3
...
...
@@ -50,7 +50,6 @@ def _step_along_shortest_path(env, obs_builder, rail):
actions
=
{}
expected_next_position
=
{}
for
agent
in
env
.
agents
:
agent
:
EnvAgent
shortest_distance
=
np
.
inf
for
exit_direction
in
range
(
4
):
...
...
@@ -297,7 +296,6 @@ def test_reward_function_waiting(rendering=False):
print
(
env
.
dones
[
"__all__"
])
for
agent
in
env
.
agents
:
agent
:
EnvAgent
print
(
"[{}] agent {} at {}, target {} "
.
format
(
iteration
+
1
,
agent
.
handle
,
agent
.
position
,
agent
.
target
))
print
(
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
env
.
agents
]))
for
agent
in
env
.
agents
:
...
...
tests/test_flatland_envs_predictions.py
View file @
6df9e4d3
...
...
@@ -17,6 +17,7 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.envs.step_utils.states
import
TrainState
"""Test predictions for `flatland` package."""
...
...
tests/test_flatland_envs_rail_env.py
View file @
6df9e4d3
...
...
@@ -22,7 +22,7 @@ import time
"""Tests for `flatland` package."""
@
pytest
.
mark
.
skip
(
"Msgpack serializing not supported"
)
def
test_load_env
():
#env = RailEnv(10, 10)
#env.reset()
...
...
@@ -47,7 +47,7 @@ def test_save_load():
agent_2_pos
=
env
.
agents
[
1
].
position
agent_2_dir
=
env
.
agents
[
1
].
direction
agent_2_tar
=
env
.
agents
[
1
].
target
os
.
makedirs
(
"tmp"
,
exist_ok
=
True
)
RailEnvPersister
.
save
(
env
,
"tmp/test_save.pkl"
)
...
...
@@ -65,7 +65,7 @@ def test_save_load():
assert
(
agent_2_dir
==
env
.
agents
[
1
].
direction
)
assert
(
agent_2_tar
==
env
.
agents
[
1
].
target
)
@
pytest
.
mark
.
skip
(
"Msgpack serializing not supported"
)
def
test_save_load_mpk
():
env
=
RailEnv
(
width
=
30
,
height
=
30
,
rail_generator
=
sparse_rail_generator
(
seed
=
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment