Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
6df9e4d3
Commit
6df9e4d3
authored
Sep 10, 2021
by
Dipam Chakraborty
Browse files
fix serialization of agents
parent
f4dc1668
Changes
11
Hide whitespace changes
Inline
Side-by-side
flatland/envs/agent_utils.py
View file @
6df9e4d3
...
@@ -30,12 +30,31 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
...
@@ -30,12 +30,31 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
(
'old_position'
,
Tuple
[
int
,
int
]),
(
'old_position'
,
Tuple
[
int
,
int
]),
(
'speed_counter'
,
SpeedCounter
),
(
'speed_counter'
,
SpeedCounter
),
(
'action_saver'
,
ActionSaver
),
(
'action_saver'
,
ActionSaver
),
(
'state'
,
TrainState
),
(
'state_machine'
,
TrainStateMachine
),
(
'state_machine'
,
TrainStateMachine
),
(
'malfunction_handler'
,
MalfunctionHandler
),
(
'malfunction_handler'
,
MalfunctionHandler
),
])
])
def
load_env_agent
(
agent_tuple
:
Agent
):
return
EnvAgent
(
initial_position
=
agent_tuple
.
initial_position
,
initial_direction
=
agent_tuple
.
initial_direction
,
direction
=
agent_tuple
.
direction
,
target
=
agent_tuple
.
target
,
moving
=
agent_tuple
.
moving
,
earliest_departure
=
agent_tuple
.
earliest_departure
,
latest_arrival
=
agent_tuple
.
latest_arrival
,
handle
=
agent_tuple
.
handle
,
position
=
agent_tuple
.
position
,
arrival_time
=
agent_tuple
.
arrival_time
,
old_direction
=
agent_tuple
.
old_direction
,
old_position
=
agent_tuple
.
old_position
,
speed_counter
=
agent_tuple
.
speed_counter
,
action_saver
=
agent_tuple
.
action_saver
,
state_machine
=
agent_tuple
.
state_machine
,
malfunction_handler
=
agent_tuple
.
malfunction_handler
,
)
@
attrs
@
attrs
class
EnvAgent
:
class
EnvAgent
:
# INIT FROM HERE IN _from_line()
# INIT FROM HERE IN _from_line()
...
@@ -105,13 +124,13 @@ class EnvAgent:
...
@@ -105,13 +124,13 @@ class EnvAgent:
earliest_departure
=
self
.
earliest_departure
,
earliest_departure
=
self
.
earliest_departure
,
latest_arrival
=
self
.
latest_arrival
,
latest_arrival
=
self
.
latest_arrival
,
malfunction_data
=
self
.
malfunction_data
,
malfunction_data
=
self
.
malfunction_data
,
handle
=
self
.
handle
,
handle
=
self
.
handle
,
state
=
self
.
state
,
position
=
self
.
position
,
position
=
self
.
position
,
old_direction
=
self
.
old_direction
,
old_direction
=
self
.
old_direction
,
old_position
=
self
.
old_position
,
old_position
=
self
.
old_position
,
speed_counter
=
self
.
speed_counter
,
speed_counter
=
self
.
speed_counter
,
action_saver
=
self
.
action_saver
,
action_saver
=
self
.
action_saver
,
arrival_time
=
self
.
arrival_time
,
state_machine
=
self
.
state_machine
,
state_machine
=
self
.
state_machine
,
malfunction_handler
=
self
.
malfunction_handler
)
malfunction_handler
=
self
.
malfunction_handler
)
...
@@ -176,13 +195,13 @@ class EnvAgent:
...
@@ -176,13 +195,13 @@ class EnvAgent:
@
classmethod
@
classmethod
def
load_legacy_static_agent
(
cls
,
static_agents_data
:
Tuple
):
def
load_legacy_static_agent
(
cls
,
static_agents_data
:
Tuple
):
raise
NotImplementedError
(
"Not implemented for Flatland 3"
)
agents
=
[]
agents
=
[]
for
i
,
static_agent
in
enumerate
(
static_agents_data
):
for
i
,
static_agent
in
enumerate
(
static_agents_data
):
if
len
(
static_agent
)
>=
6
:
if
len
(
static_agent
)
>=
6
:
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
moving
=
static_agent
[
3
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
moving
=
static_agent
[
3
],
speed_data
=
static_agent
[
4
],
malfunction_data
=
static_agent
[
5
],
handle
=
i
)
speed_counter
=
SpeedCounter
(
static_agent
[
4
][
'speed'
]),
malfunction_data
=
static_agent
[
5
],
handle
=
i
)
else
:
else
:
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
...
@@ -205,7 +224,7 @@ class EnvAgent:
...
@@ -205,7 +224,7 @@ class EnvAgent:
return
f
"
\n
\
return
f
"
\n
\
handle(agent index):
{
self
.
handle
}
\n
\
handle(agent index):
{
self
.
handle
}
\n
\
initial_position:
{
self
.
initial_position
}
initial_direction:
{
self
.
initial_direction
}
\n
\
initial_position:
{
self
.
initial_position
}
initial_direction:
{
self
.
initial_direction
}
\n
\
position:
{
self
.
position
}
direction:
{
self
.
posi
tion
}
target:
{
self
.
target
}
\n
\
position:
{
self
.
position
}
direction:
{
self
.
direc
tion
}
target:
{
self
.
target
}
\n
\
earliest_departure:
{
self
.
earliest_departure
}
latest_arrival:
{
self
.
latest_arrival
}
\n
\
earliest_departure:
{
self
.
earliest_departure
}
latest_arrival:
{
self
.
latest_arrival
}
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
malfunction_data:
{
self
.
malfunction_data
}
\n
\
malfunction_data:
{
self
.
malfunction_data
}
\n
\
...
...
flatland/envs/persistence.py
View file @
6df9e4d3
...
@@ -2,28 +2,21 @@
...
@@ -2,28 +2,21 @@
import
pickle
import
pickle
import
msgpack
import
msgpack
import
msgpack_numpy
import
numpy
as
np
import
numpy
as
np
import
msgpack_numpy
msgpack_numpy
.
patch
()
from
flatland.envs
import
rail_env
from
flatland.envs
import
rail_env
#from flatland.core.env import Environment
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.agent_utils
import
Agent
,
EnvAgent
from
flatland.envs.agent_utils
import
EnvAgent
,
load_env_agent
from
flatland.envs.distance_map
import
DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
# cannot import objects / classes directly because of circular import
# cannot import objects / classes directly because of circular import
from
flatland.envs
import
malfunction_generators
as
mal_gen
from
flatland.envs
import
malfunction_generators
as
mal_gen
from
flatland.envs
import
rail_generators
as
rail_gen
from
flatland.envs
import
rail_generators
as
rail_gen
from
flatland.envs
import
line_generators
as
line_gen
from
flatland.envs
import
line_generators
as
line_gen
msgpack_numpy
.
patch
()
class
RailEnvPersister
(
object
):
class
RailEnvPersister
(
object
):
...
@@ -163,7 +156,8 @@ class RailEnvPersister(object):
...
@@ -163,7 +156,8 @@ class RailEnvPersister(object):
# remove the legacy key
# remove the legacy key
del
env_dict
[
"agents_static"
]
del
env_dict
[
"agents_static"
]
elif
"agents"
in
env_dict
:
elif
"agents"
in
env_dict
:
env_dict
[
"agents"
]
=
[
EnvAgent
(
*
d
[
0
:
len
(
d
)])
for
d
in
env_dict
[
"agents"
]]
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict
[
"agents"
]
=
[
load_env_agent
(
d
)
for
d
in
env_dict
[
"agents"
]]
return
env_dict
return
env_dict
...
...
flatland/envs/predictions.py
View file @
6df9e4d3
...
@@ -10,6 +10,7 @@ from flatland.envs.rail_env_action import RailEnvActions
...
@@ -10,6 +10,7 @@ from flatland.envs.rail_env_action import RailEnvActions
from
flatland.envs.rail_env_shortest_paths
import
get_shortest_paths
from
flatland.envs.rail_env_shortest_paths
import
get_shortest_paths
from
flatland.utils.ordered_set
import
OrderedSet
from
flatland.utils.ordered_set
import
OrderedSet
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.step_utils
import
transition_utils
class
DummyPredictorForRailEnv
(
PredictionBuilder
):
class
DummyPredictorForRailEnv
(
PredictionBuilder
):
...
@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
...
@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
continue
for
action
in
action_priorities
:
for
action
in
action_priorities
:
cell_is_free
,
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
new_cell_isValid
,
new_direction
,
new_position
,
transition_isValid
=
\
self
.
env
.
_
check_action_on_agent
(
action
,
agent
)
transition_utils
.
check_action_on_agent
(
action
,
self
.
env
.
rail
,
agent
.
position
,
agent
.
direction
)
if
all
([
new_cell_isValid
,
transition_isValid
]):
if
all
([
new_cell_isValid
,
transition_isValid
]):
# move and change direction to face the new_direction that was
# move and change direction to face the new_direction that was
# performed
# performed
...
...
flatland/envs/rail_env.py
View file @
6df9e4d3
...
@@ -473,6 +473,12 @@ class RailEnv(Environment):
...
@@ -473,6 +473,12 @@ class RailEnv(Environment):
self
.
dones
[
"__all__"
]
=
True
self
.
dones
[
"__all__"
]
=
True
def
handle_done_state
(
self
,
agent
):
if
agent
.
state
==
TrainState
.
DONE
:
agent
.
arrival_time
=
self
.
_elapsed_steps
if
self
.
remove_agents_at_target
:
agent
.
position
=
None
def
step
(
self
,
action_dict_
:
Dict
[
int
,
RailEnvActions
]):
def
step
(
self
,
action_dict_
:
Dict
[
int
,
RailEnvActions
]):
"""
"""
Updates rewards for the agents at a step.
Updates rewards for the agents at a step.
...
@@ -547,7 +553,7 @@ class RailEnv(Environment):
...
@@ -547,7 +553,7 @@ class RailEnv(Environment):
if
movement_allowed
:
if
movement_allowed
:
agent
.
position
=
agent_transition_data
.
position
agent
.
position
=
agent_transition_data
.
position
agent
.
direction
=
agent_transition_data
.
direction
agent
.
direction
=
agent_transition_data
.
direction
preprocessed_action
=
agent_transition_data
.
preprocessed_action
preprocessed_action
=
agent_transition_data
.
preprocessed_action
## Update states
## Update states
...
@@ -555,9 +561,8 @@ class RailEnv(Environment):
...
@@ -555,9 +561,8 @@ class RailEnv(Environment):
agent
.
state_machine
.
set_transition_signals
(
state_transition_signals
)
agent
.
state_machine
.
set_transition_signals
(
state_transition_signals
)
agent
.
state_machine
.
step
()
agent
.
state_machine
.
step
()
# Remove agent is required
# Handle done state actions, optionally remove agents
if
self
.
remove_agents_at_target
and
agent
.
state
==
TrainState
.
DONE
:
self
.
handle_done_state
(
agent
)
agent
.
position
=
None
have_all_agents_ended
&=
(
agent
.
state
==
TrainState
.
DONE
)
have_all_agents_ended
&=
(
agent
.
state
==
TrainState
.
DONE
)
...
...
flatland/envs/step_utils/action_saver.py
View file @
6df9e4d3
...
@@ -14,12 +14,19 @@ class ActionSaver:
...
@@ -14,12 +14,19 @@ class ActionSaver:
def
save_action_if_allowed
(
self
,
action
,
state
):
def
save_action_if_allowed
(
self
,
action
,
state
):
if
not
self
.
is_action_saved
and
\
if
action
.
is_moving_action
()
and
\
action
.
is_moving_action
()
and
\
not
self
.
is_action_saved
and
\
not
state
.
is_malfunction_state
():
not
state
.
is_malfunction_state
()
and
\
not
state
==
TrainState
.
DONE
:
self
.
saved_action
=
action
self
.
saved_action
=
action
def
clear_saved_action
(
self
):
def
clear_saved_action
(
self
):
self
.
saved_action
=
None
self
.
saved_action
=
None
def
to_dict
(
self
):
return
{
"saved_action"
:
self
.
saved_action
}
def
from_dict
(
self
,
load_dict
):
self
.
saved_action
=
load_dict
[
'saved_action'
]
flatland/envs/step_utils/malfunction_handler.py
View file @
6df9e4d3
...
@@ -40,6 +40,12 @@ class MalfunctionHandler:
...
@@ -40,6 +40,12 @@ class MalfunctionHandler:
if
self
.
_malfunction_down_counter
>
0
:
if
self
.
_malfunction_down_counter
>
0
:
self
.
_malfunction_down_counter
-=
1
self
.
_malfunction_down_counter
-=
1
def
to_dict
(
self
):
return
{
"malfunction_down_counter"
:
self
.
_malfunction_down_counter
}
def
from_dict
(
self
,
load_dict
):
self
.
_malfunction_down_counter
=
load_dict
[
'malfunction_down_counter'
]
...
...
flatland/envs/step_utils/speed_counter.py
View file @
6df9e4d3
...
@@ -3,8 +3,7 @@ from flatland.envs.step_utils.states import TrainState
...
@@ -3,8 +3,7 @@ from flatland.envs.step_utils.states import TrainState
class
SpeedCounter
:
class
SpeedCounter
:
def
__init__
(
self
,
speed
):
def
__init__
(
self
,
speed
):
self
.
speed
=
speed
self
.
_speed
=
speed
self
.
max_count
=
int
(
1
/
speed
)
-
1
def
update_counter
(
self
,
state
,
old_position
):
def
update_counter
(
self
,
state
,
old_position
):
# When coming onto the map, do no update speed counter
# When coming onto the map, do no update speed counter
...
@@ -30,3 +29,17 @@ class SpeedCounter:
...
@@ -30,3 +29,17 @@ class SpeedCounter:
def
is_cell_exit
(
self
):
def
is_cell_exit
(
self
):
return
self
.
counter
==
self
.
max_count
return
self
.
counter
==
self
.
max_count
@
property
def
speed
(
self
):
return
self
.
_speed
@
property
def
max_count
(
self
):
return
int
(
1
/
self
.
_speed
)
-
1
def
to_dict
(
self
):
return
{
"speed"
:
self
.
_speed
}
def
from_dict
(
self
,
load_dict
):
self
.
_speed
=
load_dict
[
'speed'
]
flatland/envs/step_utils/state_machine.py
View file @
6df9e4d3
...
@@ -121,7 +121,7 @@ class TrainStateMachine:
...
@@ -121,7 +121,7 @@ class TrainStateMachine:
def
reset
(
self
):
def
reset
(
self
):
self
.
_state
=
self
.
_initial_state
self
.
_state
=
self
.
_initial_state
self
.
st_signals
=
{}
self
.
st_signals
=
StateTransitionSignals
()
self
.
clear_next_state
()
self
.
clear_next_state
()
@
property
@
property
...
@@ -135,5 +135,17 @@ class TrainStateMachine:
...
@@ -135,5 +135,17 @@ class TrainStateMachine:
def
set_transition_signals
(
self
,
state_transition_signals
):
def
set_transition_signals
(
self
,
state_transition_signals
):
self
.
st_signals
=
state_transition_signals
self
.
st_signals
=
state_transition_signals
def
__repr__
(
self
):
return
f
"
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
st_signals:
{
self
.
st_signals
}
"
def
to_dict
(
self
):
return
{
"state"
:
self
.
_state
}
def
from_dict
(
self
,
load_dict
):
self
.
set_state
(
load_dict
[
'state'
])
tests/test_flatland_envs_observations.py
View file @
6df9e4d3
...
@@ -50,7 +50,6 @@ def _step_along_shortest_path(env, obs_builder, rail):
...
@@ -50,7 +50,6 @@ def _step_along_shortest_path(env, obs_builder, rail):
actions
=
{}
actions
=
{}
expected_next_position
=
{}
expected_next_position
=
{}
for
agent
in
env
.
agents
:
for
agent
in
env
.
agents
:
agent
:
EnvAgent
shortest_distance
=
np
.
inf
shortest_distance
=
np
.
inf
for
exit_direction
in
range
(
4
):
for
exit_direction
in
range
(
4
):
...
@@ -297,7 +296,6 @@ def test_reward_function_waiting(rendering=False):
...
@@ -297,7 +296,6 @@ def test_reward_function_waiting(rendering=False):
print
(
env
.
dones
[
"__all__"
])
print
(
env
.
dones
[
"__all__"
])
for
agent
in
env
.
agents
:
for
agent
in
env
.
agents
:
agent
:
EnvAgent
print
(
"[{}] agent {} at {}, target {} "
.
format
(
iteration
+
1
,
agent
.
handle
,
agent
.
position
,
agent
.
target
))
print
(
"[{}] agent {} at {}, target {} "
.
format
(
iteration
+
1
,
agent
.
handle
,
agent
.
position
,
agent
.
target
))
print
(
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
env
.
agents
]))
print
(
np
.
all
([
np
.
array_equal
(
agent2
.
position
,
agent2
.
target
)
for
agent2
in
env
.
agents
]))
for
agent
in
env
.
agents
:
for
agent
in
env
.
agents
:
...
...
tests/test_flatland_envs_predictions.py
View file @
6df9e4d3
...
@@ -17,6 +17,7 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
...
@@ -17,6 +17,7 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.envs.rail_env_action
import
RailEnvActions
from
flatland.envs.step_utils.states
import
TrainState
from
flatland.envs.step_utils.states
import
TrainState
"""Test predictions for `flatland` package."""
"""Test predictions for `flatland` package."""
...
...
tests/test_flatland_envs_rail_env.py
View file @
6df9e4d3
...
@@ -22,7 +22,7 @@ import time
...
@@ -22,7 +22,7 @@ import time
"""Tests for `flatland` package."""
"""Tests for `flatland` package."""
@
pytest
.
mark
.
skip
(
"Msgpack serializing not supported"
)
def
test_load_env
():
def
test_load_env
():
#env = RailEnv(10, 10)
#env = RailEnv(10, 10)
#env.reset()
#env.reset()
...
@@ -47,7 +47,7 @@ def test_save_load():
...
@@ -47,7 +47,7 @@ def test_save_load():
agent_2_pos
=
env
.
agents
[
1
].
position
agent_2_pos
=
env
.
agents
[
1
].
position
agent_2_dir
=
env
.
agents
[
1
].
direction
agent_2_dir
=
env
.
agents
[
1
].
direction
agent_2_tar
=
env
.
agents
[
1
].
target
agent_2_tar
=
env
.
agents
[
1
].
target
os
.
makedirs
(
"tmp"
,
exist_ok
=
True
)
os
.
makedirs
(
"tmp"
,
exist_ok
=
True
)
RailEnvPersister
.
save
(
env
,
"tmp/test_save.pkl"
)
RailEnvPersister
.
save
(
env
,
"tmp/test_save.pkl"
)
...
@@ -65,7 +65,7 @@ def test_save_load():
...
@@ -65,7 +65,7 @@ def test_save_load():
assert
(
agent_2_dir
==
env
.
agents
[
1
].
direction
)
assert
(
agent_2_dir
==
env
.
agents
[
1
].
direction
)
assert
(
agent_2_tar
==
env
.
agents
[
1
].
target
)
assert
(
agent_2_tar
==
env
.
agents
[
1
].
target
)
@
pytest
.
mark
.
skip
(
"Msgpack serializing not supported"
)
def
test_save_load_mpk
():
def
test_save_load_mpk
():
env
=
RailEnv
(
width
=
30
,
height
=
30
,
env
=
RailEnv
(
width
=
30
,
height
=
30
,
rail_generator
=
sparse_rail_generator
(
seed
=
1
),
rail_generator
=
sparse_rail_generator
(
seed
=
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment