diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 20dc0325cd9d1e786cd179b56df9b5527ba68b66..7350544271548b480585cb38d0e4579ab14a09f9 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -22,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ('moving', bool), ('earliest_departure', int), ('latest_arrival', int), - ('malfunction_data', dict), ('handle', int), ('position', Tuple[int, int]), ('arrival_time', int), @@ -68,13 +67,6 @@ class EnvAgent: earliest_departure = attrib(default=None, type=int) # default None during _from_line() latest_arrival = attrib(default=None, type=int) # default None during _from_line() - # if broken>0, the agent's actions are ignored for 'broken' steps - # number of time the agent had to stop, since the last time it broke down - malfunction_data = attrib( - default=Factory( - lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, - 'moving_before_malfunction': False}))) - handle = attrib(default=None) # INIT TILL HERE IN _from_line() @@ -106,10 +98,7 @@ class EnvAgent: self.old_direction = None self.moving = False - # Reset agent malfunction values - self.malfunction_data['malfunction'] = 0 - self.malfunction_data['nr_malfunctions'] = 0 - self.malfunction_data['moving_before_malfunction'] = False + self.malfunction_handler.reset() self.action_saver.clear_saved_action() self.speed_counter.reset_counter() @@ -123,7 +112,6 @@ class EnvAgent: moving=self.moving, earliest_departure=self.earliest_departure, latest_arrival=self.latest_arrival, - malfunction_data=self.malfunction_data, handle=self.handle, position=self.position, old_direction=self.old_direction, @@ -169,16 +157,6 @@ class EnvAgent: for i_agent in range(num_agents): speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0 - if line.agent_malfunction_rates is not None: - malfunction_rate = line.agent_malfunction_rates[i_agent] - else: - malfunction_rate = 0. - - malfunction_data = {'malfunction': 0, - 'malfunction_rate': malfunction_rate, - 'next_malfunction': 0, - 'nr_malfunctions': 0 - } agent = EnvAgent(initial_position = line.agent_positions[i_agent], initial_direction = line.agent_directions[i_agent], direction = line.agent_directions[i_agent], @@ -186,7 +164,6 @@ class EnvAgent: moving = False, earliest_departure = None, latest_arrival = None, - malfunction_data = malfunction_data, handle = i_agent, speed_counter = SpeedCounter(speed=speed)) agent_list.append(agent) @@ -200,17 +177,11 @@ class EnvAgent: if len(static_agent) >= 6: agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], direction=static_agent[1], target=static_agent[2], moving=static_agent[3], - speed_counter=SpeedCounter(static_agent[4]['speed']), malfunction_data=static_agent[5], - handle=i) + speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i) else: agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], direction=static_agent[1], target=static_agent[2], moving=False, - malfunction_data={ - 'malfunction': 0, - 'nr_malfunctions': 0, - 'moving_before_malfunction': False - }, speed_counter=SpeedCounter(1.0), handle=i) agents.append(agent) @@ -219,10 +190,15 @@ class EnvAgent: def __str__(self): return f"\n \ handle(agent index): {self.handle} \n \ - initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \ - position: {self.position} direction: {self.direction} target: {self.target} \n \ - old_position: {self.old_position} old_direction {self.old_direction} \n \ - earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \ + initial_position: {self.initial_position} \n \ + initial_direction: {self.initial_direction} \n \ + position: {self.position} \n \ + direction: {self.direction} \n \ + target: {self.target} \n \ + old_position: {self.old_position} \n \ + old_direction {self.old_direction} \n \ + earliest_departure: {self.earliest_departure} \n \ + latest_arrival: {self.latest_arrival} \n \ state: {str(self.state)} \n \ malfunction_handler: {self.malfunction_handler} \n \ action_saver: {self.action_saver} \n \ @@ -240,6 +216,14 @@ class EnvAgent: warnings.warn("Not recommended to set the state with this function unless completely required") self.state_machine.set_state(state) + @property + def malfunction_data(self): + raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead") + + @property + def speed_data(self): + raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead") + diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 79241f2489b4a7b3ab3008f269d2c03fbafd27c8..154a65bb5c26da7cd2e294b7aeb5a0e59ffb072d 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -148,7 +148,7 @@ class SparseLineGen(BaseLineGen): timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions))) return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=speeds) def line_from_file(filename, load_from_package=None) -> LineGenerator: @@ -186,11 +186,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator: agents_target = [a.target for a in agents] agents_speed = [a.speed_counter.speed for a in agents] - # Malfunctions from here are not used. They have their own generator. - #agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] - return Line(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, - agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=agents_speed) return generator diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 086fd9cef348ff8de6b6c358876926804dc673e5..5a00364bd9fe91e619974f826fd4b5f8d1a4b900 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -219,7 +219,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun if reset: return Malfunction(0) - if agent.malfunction_data['malfunction'] < 1: + if agent.malfunction_handler.malfunction_down_counter < 1: if np_random.rand() < _malfunction_prob(mean_malfunction_rate): num_broken_steps = np_random.randint(min_number_of_steps_broken, max_number_of_steps_broken + 1) + 1 diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 0b5f2a845d525f36456ce3c770fe4453d2c8a0e5..ab73908633a8ac170467e952d35c89f31e088758 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -99,8 +99,8 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent[tuple(_agent.position)] = 1 self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed - self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ - 'malfunction'] + self.location_has_agent_malfunction[tuple(_agent.position)] = \ + _agent.malfunction_handler.malfunction_down_counter # [NIMISH] WHAT IS THIS if _agent.state.is_off_map_state() and \ @@ -220,7 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder): (handle, *agent_virtual_position, agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, - num_agents_malfunctioning=agent.malfunction_data['malfunction'], + num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter, speed_min_fractional=agent.speed_counter.speed, num_agents_ready_to_depart=0, childs={}) @@ -603,7 +603,7 @@ class GlobalObsForRailEnv(ObservationBuilder): # second channel only for other agents if i != handle: obs_agents_state[other_agent.position][1] = other_agent.direction - obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed # fifth channel: all ready to depart on this position if other_agent.state.is_off_map_state(): diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index 29ad4760001b4d94394ffc848a7b778d36d4c7a3..a05a8cb40bd515db3a92e59e9bf42a343a0ed338 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -253,7 +253,7 @@ class RailEnvPersister(object): #msgpack.packb(agent_data, use_bin_type=True) distance_map_data = self.distance_map.get() - malfunction_data: MalfunctionProcessData = self.malfunction_process_data + malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data #msgpack.packb(distance_map_data, use_bin_type=True) # does nothing msg_data = { "grid": grid_data, diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py index bf1f188fe850272968af0a8e11c87fdf92fd5d88..02976e9d98c5665e39bcb6887a0397e37ce0c3d0 100644 --- a/flatland/envs/step_utils/malfunction_handler.py +++ b/flatland/envs/step_utils/malfunction_handler.py @@ -11,6 +11,10 @@ class MalfunctionHandler: def __init__(self): self._malfunction_down_counter = 0 self.num_malfunctions = 0 + + def reset(self): + self._malfunction_down_counter = 0 + self.num_malfunctions = 0 @property def in_malfunction(self): diff --git a/flatland/envs/timetable_utils.py b/flatland/envs/timetable_utils.py index 548624f2c08879ce0e507224e61b6fe43ffb955b..a8ccc706daa2b400ad343800baeacf58d4a1cd14 100644 --- a/flatland/envs/timetable_utils.py +++ b/flatland/envs/timetable_utils.py @@ -6,8 +6,7 @@ from flatland.core.grid.grid_utils import IntVector2DArray Line = NamedTuple('Line', [('agent_positions', IntVector2DArray), ('agent_directions', List[Grid4TransitionsEnum]), ('agent_targets', IntVector2DArray), - ('agent_speeds', List[float]), - ('agent_malfunction_rates', List[int])]) + ('agent_speeds', List[float])]) Timetable = NamedTuple('Timetable', [('earliest_departures', List[int]), ('latest_arrivals', List[int]), diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index cd765cd19ba0c9510d301ac77a0782bccd6bd6b4..4b28529f8d1a04000061f1a82dba0aa1fe4b07c1 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -688,7 +688,7 @@ class RenderLocal(RenderBase): malfunction=False) continue - is_malfunction = agent.malfunction_data["malfunction"] > 0 + is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0 if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: self.gl.set_cell_occupied(agent_idx, *(agent.position)) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 7ebf73f0c8acc98f9690c219032550a4afead3e3..16eba37049aeac83562c630de67c7e5f7c61441a 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -254,7 +254,7 @@ def test_initial_malfunction(): # reset to initialize agents_static env.reset(False, False, random_seed=10) env._max_episode_steps = 1000 - print(env.agents[0].malfunction_data) + print(env.agents[0].malfunction_handler) env.agents[0].target = (0, 5) set_penalties_for_replay(env) replay_config = ReplayConfig( @@ -568,8 +568,8 @@ def test_last_malfunction_step(): env.agents[a_idx].position = env.agents[a_idx].initial_position env.agents[a_idx].state = TrainState.MOVING # Force malfunction to be off at beginning and next malfunction to happen in 2 steps - env.agents[0].malfunction_data['next_malfunction'] = 2 - env.agents[0].malfunction_data['malfunction'] = 0 + # env.agents[0].malfunction_data['next_malfunction'] = 2 + env.agents[0].malfunction_handler.malfunction_down_counter = 0 env_data = [] # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART @@ -582,14 +582,14 @@ def test_last_malfunction_step(): # Go forward all the time action_dict[agent.handle] = RailEnvActions(2) - if env.agents[0].malfunction_data['malfunction'] < 1: + if env.agents[0].malfunction_handler.malfunction_down_counter < 1: agent_can_move = True # Store the position before and after the step pre_position = env.agents[0].speed_counter.counter _, reward, _, _ = env.step(action_dict) # Check if the agent is still allowed to move in this step - if env.agents[0].malfunction_data['malfunction'] > 0: + if env.agents[0].malfunction_handler.malfunction_down_counter > 0: agent_can_move = False post_position = env.agents[0].speed_counter.counter # Assert that the agent moved while it was still allowed diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 1ea959a251e9dd672db4a71a11e3bd76bfced433..c5bf8317cbf964756d9eb99797ab0c73cbd1e447 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -105,9 +105,9 @@ def test_get_global_observation(): for other_i, other_agent in enumerate(env.agents): if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED, TrainState.DONE] and other_agent.position == (r, c): - assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \ + assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_handler.malfunction_down_counter), \ "agent {} in state {} at {} should see agent malfunction {}, found = {}" \ - .format(i, agent.state, (r, c), other_agent.malfunction_data['malfunction'], + .format(i, agent.state, (r, c), other_agent.malfunction_handler.malfunction_down_counter, obs_agents_state[(r, c)][2]) assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed) has_agent = True diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 08acd85bc5ca9e962ef877310b7bc384b7be77bd..b972b089caf421b799890795c0553f21397da86e 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -109,5 +109,5 @@ def test_single_malfunction_generator(): break for agent in env.agents: # Go forward all the time - tot_malfunctions += agent.malfunction_data['nr_malfunctions'] + tot_malfunctions += agent.malfunction_handler.num_malfunctions assert tot_malfunctions == 1 diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 7e1c4b17324101afdfde90925193971c9bd490a0..f42978dcfb9a015a28906075a443cfc0053a2658 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -133,8 +133,8 @@ def test_seeding_and_malfunction(): action = np.random.randint(4) action_dict[a] = action # print("----------------------") - # print(env.agents[a].malfunction_data, env.agents[a].status) - # print(env2.agents[a].malfunction_data, env2.agents[a].status) + # print(env.agents[a].malfunction_handler, env.agents[a].status) + # print(env2.agents[a].malfunction_handler, env2.agents[a].status) _, reward1, done1, _ = env.step(action_dict) _, reward2, done2, _ = env2.step(action_dict)