From d5b16a5210992b38360021ba9d5194535d9b75a3 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 29 Oct 2019 18:37:24 -0400 Subject: [PATCH] updating tests to new malfunction generation --- flatland/envs/agent_utils.py | 1 + flatland/envs/rail_env.py | 49 ++++++++++++---------------- flatland/envs/schedule_generators.py | 15 ++++----- flatland/envs/schedule_utils.py | 3 +- tests/test_flatland_malfunction.py | 19 ++--------- 5 files changed, 31 insertions(+), 56 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index ef2d4855..01ce2908 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -63,6 +63,7 @@ class EnvAgentStatic(object): for i in range(len(schedule.agent_positions)): malfunction_datas.append({'malfunction': 0, 'nr_malfunctions': 0, + 'moving_before_malfunction': False, 'fixed': True}) return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3e17a4a4..53784a67 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -1,6 +1,7 @@ """ Definition of the RailEnv environment. """ +import random # TODO: _ this is a global method --> utils or remove later from enum import IntEnum from typing import List, NamedTuple, Optional, Dict @@ -8,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict import msgpack import msgpack_numpy as m import numpy as np -import random from gym.utils import seeding from flatland.core.env import Environment @@ -209,7 +209,6 @@ class RailEnv(Environment): # Uniform distribution parameters for malfunction duration self.min_number_of_steps_broken = malfunction_min_duration self.max_number_of_steps_broken = malfunction_max_duration - # Reset environment self.valid_positions = None @@ -331,8 +330,8 @@ class RailEnv(Environment): if agents_hints and 'city_orientations' in agents_hints: ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations']) self._max_episode_steps = self.compute_max_episode_steps( - width=self.width, height=self.height, - ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities) + width=self.width, height=self.height, + ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities) else: self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height) @@ -394,9 +393,6 @@ class RailEnv(Environment): self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] return False - - - def _malfunction(self, rate) -> bool: """ Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run @@ -404,16 +400,13 @@ class RailEnv(Environment): """ if np.random.random() < self._malfunction_prob(rate): breaking_agent = random.choice(self.agents) - while breaking_agent.status == RailAgentStatus.DONE_REMOVED: - breaking_agent = random.choice(self.agents) - - num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) - breaking_agent.malfunction_data['malfunction'] = num_broken_steps - breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving - breaking_agent.malfunction_data['fixed'] = False - - + if breaking_agent.malfunction_data['malfunction'] < 1: + num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + breaking_agent.malfunction_data['malfunction'] = num_broken_steps + breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving + breaking_agent.malfunction_data['fixed'] = False + breaking_agent.malfunction_data['nr_malfunctions'] += 1 def step(self, action_dict_: Dict[int, RailEnvActions]): @@ -423,10 +416,10 @@ class RailEnv(Environment): if self.dones["__all__"]: self.rewards_dict = {} info_dict = { - "action_required" : {}, - "malfunction" : {}, - "speed" : {}, - "status" : {}, + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, } for i_agent, agent in enumerate(self.agents): self.rewards_dict[i_agent] = self.global_reward @@ -440,12 +433,12 @@ class RailEnv(Environment): # Reset the step rewards self.rewards_dict = dict() info_dict = { - "action_required" : {}, - "malfunction" : {}, - "speed" : {}, - "status" : {}, + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, } - have_all_agents_ended = True # boolean flag to check if all agents are done + have_all_agents_ended = True # boolean flag to check if all agents are done # Evoke the malfunction generator self._malfunction(self.mean_malfunction_rate) @@ -462,8 +455,8 @@ class RailEnv(Environment): # Build info dict info_dict["action_required"][i_agent] = \ (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index cb8b1537..58a7be34 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -79,7 +79,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se speeds = [1.0] * len(agents_position) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=speeds) return generator @@ -165,7 +165,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see speeds = [1.0] * len(agents_position) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=speeds) return generator @@ -199,12 +199,12 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = valid_positions.append((r, c)) if len(valid_positions) == 0: return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[]) if len(valid_positions) < num_agents: warnings.warn("schedule_generators: len(valid_positions) < num_agents") return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[]) agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] @@ -263,7 +263,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed) return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None) + agent_targets=agents_target, agent_speeds=agents_speed) return generator @@ -304,12 +304,9 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: agents_target = [a.target for a in agents_static] if len(data['agents_static'][0]) > 5: agents_speed = [a.speed_data['speed'] for a in agents_static] - agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static] else: agents_speed = None - agents_malfunction = None return Schedule(agent_positions=agents_position, agent_directions=agents_direction, - agent_targets=agents_target, agent_speeds=agents_speed, - agent_malfunction_rates=agents_malfunction) + agent_targets=agents_target, agent_speeds=agents_speed) return generator diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py index e89f170d..c61d2f6b 100644 --- a/flatland/envs/schedule_utils.py +++ b/flatland/envs/schedule_utils.py @@ -6,5 +6,4 @@ from flatland.core.grid.grid_utils import IntVector2DArray Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray), ('agent_directions', List[Grid4TransitionsEnum]), ('agent_targets', IntVector2DArray), - ('agent_speeds', List[float]), - ('agent_malfunction_rates', List[int])]) + ('agent_speeds', List[float])]) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index a46467ed..7eac117f 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -66,8 +66,7 @@ class SingleAgentNavigationObs(ObservationBuilder): def test_malfunction_process(): # Set fixed malfunction duration for this test - stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 1000, + stochastic_data = {'malfunction_rate': 1, 'min_duration': 3, 'max_duration': 3} @@ -84,11 +83,6 @@ def test_malfunction_process(): # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) - # Check that a initial duration for malfunction was assigned - assert env.agents[0].malfunction_data['next_malfunction'] > 0 - for agent in env.agents: - agent.status = RailAgentStatus.ACTIVE - agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position @@ -101,12 +95,6 @@ def test_malfunction_process(): for i in range(len(obs)): actions[i] = np.argmax(obs[i]) + 1 - if step % 5 == 0: - # Stop the agent and set it to be malfunctioning - env.agents[0].malfunction_data['malfunction'] = -1 - env.agents[0].malfunction_data['next_malfunction'] = 0 - agent_halts += 1 - obs, all_rewards, done, _ = env.step(actions) if env.agents[0].malfunction_data['malfunction'] > 0: @@ -122,12 +110,9 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 20, "Actual {}".format( + assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format( env.agents[0].malfunction_data['nr_malfunctions']) - # Check that 20 stops where performed - assert agent_halts == 20 - # Check that malfunctioning data was standing around assert total_down_time > 0 -- GitLab