Commit 6df9e4d3 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fix serialization of agents

parent f4dc1668
......@@ -30,12 +30,31 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('old_position', Tuple[int, int]),
('speed_counter', SpeedCounter),
('action_saver', ActionSaver),
('state', TrainState),
('state_machine', TrainStateMachine),
('malfunction_handler', MalfunctionHandler),
])
def load_env_agent(agent_tuple: Agent):
return EnvAgent(
initial_position = agent_tuple.initial_position,
initial_direction = agent_tuple.initial_direction,
direction = agent_tuple.direction,
target = agent_tuple.target,
moving = agent_tuple.moving,
earliest_departure = agent_tuple.earliest_departure,
latest_arrival = agent_tuple.latest_arrival,
handle = agent_tuple.handle,
position = agent_tuple.position,
arrival_time = agent_tuple.arrival_time,
old_direction = agent_tuple.old_direction,
old_position = agent_tuple.old_position,
speed_counter = agent_tuple.speed_counter,
action_saver = agent_tuple.action_saver,
state_machine = agent_tuple.state_machine,
malfunction_handler = agent_tuple.malfunction_handler,
)
@attrs
class EnvAgent:
# INIT FROM HERE IN _from_line()
......@@ -105,13 +124,13 @@ class EnvAgent:
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
malfunction_data=self.malfunction_data,
handle=self.handle,
state=self.state,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
old_position=self.old_position,
speed_counter=self.speed_counter,
action_saver=self.action_saver,
arrival_time=self.arrival_time,
state_machine=self.state_machine,
malfunction_handler=self.malfunction_handler)
......@@ -176,13 +195,13 @@ class EnvAgent:
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
raise NotImplementedError("Not implemented for Flatland 3")
agents = []
for i, static_agent in enumerate(static_agents_data):
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
speed_counter=SpeedCounter(static_agent[4]['speed']), malfunction_data=static_agent[5],
handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
......@@ -205,7 +224,7 @@ class EnvAgent:
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.position} target: {self.target} \n \
position: {self.position} direction: {self.direction} target: {self.target} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_data: {self.malfunction_data} \n \
......
......@@ -2,28 +2,21 @@
import pickle
import msgpack
import msgpack_numpy
import numpy as np
import msgpack_numpy
msgpack_numpy.patch()
from flatland.envs import rail_env
#from flatland.core.env import Environment
from flatland.core.env_observation_builder import DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import Agent, EnvAgent
from flatland.envs.distance_map import DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.agent_utils import EnvAgent, load_env_agent
# cannot import objects / classes directly because of circular import
from flatland.envs import malfunction_generators as mal_gen
from flatland.envs import rail_generators as rail_gen
from flatland.envs import line_generators as line_gen
msgpack_numpy.patch()
class RailEnvPersister(object):
......@@ -163,7 +156,8 @@ class RailEnvPersister(object):
# remove the legacy key
del env_dict["agents_static"]
elif "agents" in env_dict:
env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
# env_dict["agents"] = [EnvAgent(*d[0:len(d)]) for d in env_dict["agents"]]
env_dict["agents"] = [load_env_agent(d) for d in env_dict["agents"]]
return env_dict
......
......@@ -10,6 +10,7 @@ from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.utils.ordered_set import OrderedSet
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils import transition_utils
class DummyPredictorForRailEnv(PredictionBuilder):
......@@ -64,8 +65,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
continue
for action in action_priorities:
cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
self.env._check_action_on_agent(action, agent)
new_cell_isValid, new_direction, new_position, transition_isValid = \
transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
if all([new_cell_isValid, transition_isValid]):
# move and change direction to face the new_direction that was
# performed
......
......@@ -473,6 +473,12 @@ class RailEnv(Environment):
self.dones["__all__"] = True
def handle_done_state(self, agent):
if agent.state == TrainState.DONE:
agent.arrival_time = self._elapsed_steps
if self.remove_agents_at_target:
agent.position = None
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
Updates rewards for the agents at a step.
......@@ -547,7 +553,7 @@ class RailEnv(Environment):
if movement_allowed:
agent.position = agent_transition_data.position
agent.direction = agent_transition_data.direction
preprocessed_action = agent_transition_data.preprocessed_action
## Update states
......@@ -555,9 +561,8 @@ class RailEnv(Environment):
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
# Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE:
agent.position = None
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
have_all_agents_ended &= (agent.state == TrainState.DONE)
......
......@@ -14,12 +14,19 @@ class ActionSaver:
def save_action_if_allowed(self, action, state):
if not self.is_action_saved and \
action.is_moving_action() and \
not state.is_malfunction_state():
if action.is_moving_action() and \
not self.is_action_saved and \
not state.is_malfunction_state() and \
not state == TrainState.DONE:
self.saved_action = action
def clear_saved_action(self):
self.saved_action = None
def to_dict(self):
return {"saved_action": self.saved_action}
def from_dict(self, load_dict):
self.saved_action = load_dict['saved_action']
......@@ -40,6 +40,12 @@ class MalfunctionHandler:
if self._malfunction_down_counter > 0:
self._malfunction_down_counter -= 1
def to_dict(self):
return {"malfunction_down_counter": self._malfunction_down_counter}
def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter']
......
......@@ -3,8 +3,7 @@ from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed) - 1
self._speed = speed
def update_counter(self, state, old_position):
# When coming onto the map, do no update speed counter
......@@ -30,3 +29,17 @@ class SpeedCounter:
def is_cell_exit(self):
return self.counter == self.max_count
@property
def speed(self):
return self._speed
@property
def max_count(self):
return int(1/self._speed) - 1
def to_dict(self):
return {"speed": self._speed}
def from_dict(self, load_dict):
self._speed = load_dict['speed']
......@@ -121,7 +121,7 @@ class TrainStateMachine:
def reset(self):
self._state = self._initial_state
self.st_signals = {}
self.st_signals = StateTransitionSignals()
self.clear_next_state()
@property
......@@ -135,5 +135,17 @@ class TrainStateMachine:
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals
def __repr__(self):
return f"\n \
state: {str(self.state)} \n \
st_signals: {self.st_signals}"
def to_dict(self):
return {"state": self._state}
def from_dict(self, load_dict):
self.set_state(load_dict['state'])
......@@ -50,7 +50,6 @@ def _step_along_shortest_path(env, obs_builder, rail):
actions = {}
expected_next_position = {}
for agent in env.agents:
agent: EnvAgent
shortest_distance = np.inf
for exit_direction in range(4):
......@@ -297,7 +296,6 @@ def test_reward_function_waiting(rendering=False):
print(env.dones["__all__"])
for agent in env.agents:
agent: EnvAgent
print("[{}] agent {} at {}, target {} ".format(iteration + 1, agent.handle, agent.position, agent.target))
print(np.all([np.array_equal(agent2.position, agent2.target) for agent2 in env.agents]))
for agent in env.agents:
......
......@@ -17,6 +17,7 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
"""Test predictions for `flatland` package."""
......
......@@ -22,7 +22,7 @@ import time
"""Tests for `flatland` package."""
@pytest.mark.skip("Msgpack serializing not supported")
def test_load_env():
#env = RailEnv(10, 10)
#env.reset()
......@@ -47,7 +47,7 @@ def test_save_load():
agent_2_pos = env.agents[1].position
agent_2_dir = env.agents[1].direction
agent_2_tar = env.agents[1].target
os.makedirs("tmp", exist_ok=True)
RailEnvPersister.save(env, "tmp/test_save.pkl")
......@@ -65,7 +65,7 @@ def test_save_load():
assert (agent_2_dir == env.agents[1].direction)
assert (agent_2_tar == env.agents[1].target)
@pytest.mark.skip("Msgpack serializing not supported")
def test_save_load_mpk():
env = RailEnv(width=30, height=30,
rail_generator=sparse_rail_generator(seed=1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment