diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0f814576caef84471d20c91dd92d23d4db02ac --- /dev/null +++ b/examples/debugging_example_DELETE.py @@ -0,0 +1,85 @@ +import random +import time + +import numpy as np + +from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + +random.seed(1) +np.random.seed(1) + +class SingleAgentNavigationObs(TreeObsForRailEnv): + """ + We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute + the minimum distances from each grid node to each agent's target. + + We then build a representation vector with 3 binary components, indicating which of the 3 available directions + for each agent (Left, Forward, Right) lead to the shortest path to its target. + E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector + will be [1, 0, 0]. + """ + def __init__(self): + super().__init__(max_depth=0) + self.observation_space = [3] + + def reset(self): + # Recompute the distance map, if the environment has changed. + super().reset() + + def get(self, handle): + agent = self.env.agents[handle] + + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if num_transitions == 1: + observation = [0, 1, 0] + else: + min_distances = [] + for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[direction]: + new_position = self._new_position(agent.position, direction) + min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + else: + min_distances.append(np.inf) + + observation = [0, 0, 0] + observation[np.argmin(min_distances)] = 1 + + return observation + + +env = RailEnv(width=14, + height=14, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), + number_of_agents=2, + obs_builder_object=SingleAgentNavigationObs()) + +obs = env.reset() +env_renderer = RenderTool(env, gl="PILSVG") +env_renderer.render_env(show=True, frames=True, show_observations=False) +for step in range(100): + actions = {} + for i in range(len(obs)): + actions[i] = np.argmax(obs[i])+1 + + if step%5 == 0: + print("Agent halts") + actions[0] = 4 # Halt + + obs, all_rewards, done, _ = env.step(actions) + if env.agents[0].malfunction_data['malfunction'] > 0: + print("Agent 0 broken-ness: ", env.agents[0].malfunction_data['malfunction']) + + env_renderer.render_env(show=True, frames=True, show_observations=False) + time.sleep(0.5) + if done["__all__"]: + break +env_renderer.close_window() + diff --git a/examples/training_example.py b/examples/training_example.py index c038e7b477069957efdec622b2c56e9e84cb7ac0..d125be1587a56025ba1cd3f78b28ba3976f01fbf 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -57,7 +57,7 @@ class RandomAgent: # Initialize the agent with the parameters corresponding to the environment and observation_builder -agent = RandomAgent(218, 4) +agent = RandomAgent(218, 5) n_trials = 5 # Empty dictionary for all agent action @@ -77,12 +77,11 @@ for trials in range(1, n_trials + 1): score = 0 # Run episode - for step in range(100): + for step in range(500): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): action = agent.act(obs[a]) action_dict.update({a: action}) - # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) @@ -92,7 +91,6 @@ for trials in range(1, n_trials + 1): for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) score += all_rewards[a] - obs = next_obs.copy() if done['__all__']: break diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index e353af29ddbee16c208e2059767c18fa7880cb64..4c4070088c59499f885c16db68976c163ec91001 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -15,6 +15,7 @@ class EnvAgentStatic(object): direction = attrib() target = attrib() moving = attrib(default=False) + # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # cell if speed=1, as default) @@ -22,6 +23,12 @@ class EnvAgentStatic(object): speed_data = attrib( default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) + # if broken>0, the agent's actions are ignored for 'broken' steps + # number of time the agent had to stop, since the last time it broke down + malfunction_data = attrib( + default=Factory( + lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) + @classmethod def from_lists(cls, positions, directions, targets, speeds=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets @@ -31,7 +38,22 @@ class EnvAgentStatic(object): speed_datas.append({'position_fraction': 0.0, 'speed': speeds[i] if speeds is not None else 1.0, 'transition_action_on_cellexit': 0}) - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) + + # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set + # some as broken? + malfunction_datas = [] + for i in range(len(positions)): + malfunction_datas.append({'malfunction': 0, + 'malfunction_rate': 0, + 'next_malfunction': 0, + 'nr_malfunctions': 0}) + + return list(starmap(EnvAgentStatic, zip(positions, + directions, + targets, + [False] * len(positions), + speed_datas, + malfunction_datas))) def to_list(self): @@ -45,7 +67,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data] @attrs @@ -63,7 +85,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving, self.speed_data] + self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index abe623ae173a593e265cff7d4d88eb323e16b08e..2281282977d8c9d972f13526efa9d96abaf84a52 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -75,6 +75,17 @@ class RailEnv(Environment): - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent + Stochastic malfunctioning of trains: + Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid + action or cell is selected. + + Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a + poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep + complexity managable. + + TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init(). + For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. + """ def __init__(self, @@ -83,7 +94,8 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - max_episode_steps=None + max_episode_steps=None, + stochastic_data=None ): """ Environment init. @@ -146,6 +158,29 @@ class RailEnv(Environment): self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? + # Stochastic train malfunctioning parameters + if stochastic_data is not None: + prop_malfunction = stochastic_data['prop_malfunction'] + mean_malfunction_rate = stochastic_data['malfunction_rate'] + malfunction_min_duration = stochastic_data['min_duration'] + malfunction_max_duration = stochastic_data['max_duration'] + else: + prop_malfunction = 0. + mean_malfunction_rate = 0. + malfunction_min_duration = 0. + malfunction_max_duration = 0. + + # percentage of malfunctioning trains + self.proportion_malfunctioning_trains = prop_malfunction + + # Mean malfunction in number of stops + self.mean_malfunction_rate = mean_malfunction_rate + + # Uniform distribution parameters for malfunction duration + self.min_number_of_steps_broken = malfunction_min_duration + self.max_number_of_steps_broken = malfunction_max_duration + + # Rest environment self.reset() self.num_resets = 0 # yes, set it to zero again! @@ -195,7 +230,15 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] + + # A proportion of agent in the environment will receive a positive malfunction rate + if np.random.random() < self.proportion_malfunctioning_trains: + agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate + agent.speed_data['position_fraction'] = 0.0 + agent.malfunction_data['malfunction'] = 0 + + self._agent_malfunction(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -210,6 +253,30 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() + def _agent_malfunction(self, agent): + # Decrease counter for next event + agent.malfunction_data['next_malfunction'] -= 1 + + # Only agents that have a positive rate for malfunctions and are not currently broken are considered + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[ + 'malfunction']: + + # If counter has come to zero --> Agent has malfunction + # set next malfunction time and duration of current malfunction + if agent.malfunction_data['next_malfunction'] <= 0: + # Increase number of malfunctions + agent.malfunction_data['nr_malfunctions'] += 1 + + # Next malfunction in number of stops + next_breakdown = int( + np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown + + # Duration of current malfunction + num_broken_steps = np.random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + 1 + agent.malfunction_data['malfunction'] = num_broken_steps + def step(self, action_dict_): self._elapsed_steps += 1 @@ -238,12 +305,29 @@ class RailEnv(Environment): agent = self.agents[i_agent] agent.old_direction = agent.direction agent.old_position = agent.position + + # Check if agent breaks at this step + self._agent_malfunction(agent) + if self.dones[i_agent]: # this agent has already completed... continue - if i_agent not in action_dict: # no action has been supplied for this agent + # No action has been supplied for this agent + if i_agent not in action_dict: action_dict[i_agent] = RailEnvActions.DO_NOTHING + # The train is broken + if agent.malfunction_data['malfunction'] > 0: + agent.malfunction_data['malfunction'] -= 1 + + # Broken agents are stopped + self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.agents[i_agent].moving = False + action_dict[i_agent] = RailEnvActions.DO_NOTHING + + # Nothing left to do with broken agent + continue + if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): print('ERROR: illegal action=', action_dict[i_agent], 'for agent with index=', i_agent, @@ -329,7 +413,7 @@ class RailEnv(Environment): agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 else: - # If the agent cannot move due to any reason, we set its state to not moving. + # If the agent cannot move due to any reason, we set its state to not moving agent.moving = False if np.equal(agent.position, agent.target).all(): diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py new file mode 100644 index 0000000000000000000000000000000000000000..67dcd25c0769e542fd9a03502c2a8c1b29333b2b --- /dev/null +++ b/tests/test_flatland_malfunction.py @@ -0,0 +1,110 @@ +import numpy as np + +from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv + + +class SingleAgentNavigationObs(TreeObsForRailEnv): + """ + We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute + the minimum distances from each grid node to each agent's target. + + We then build a representation vector with 3 binary components, indicating which of the 3 available directions + for each agent (Left, Forward, Right) lead to the shortest path to its target. + E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector + will be [1, 0, 0]. + """ + + def __init__(self): + super().__init__(max_depth=0) + self.observation_space = [3] + + def reset(self): + # Recompute the distance map, if the environment has changed. + super().reset() + + def get(self, handle): + agent = self.env.agents[handle] + + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if num_transitions == 1: + observation = [0, 1, 0] + else: + min_distances = [] + for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[direction]: + new_position = self._new_position(agent.position, direction) + min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + else: + min_distances.append(np.inf) + + observation = [0, 0, 0] + observation[np.argmin(min_distances)] = 1 + + return observation + + +def test_malfunction_process(): + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 1000, + 'min_duration': 3, + 'max_duration': 3} + np.random.seed(5) + + env = RailEnv(width=20, + height=20, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, + seed=0), + number_of_agents=2, + obs_builder_object=SingleAgentNavigationObs(), + stochastic_data=stochastic_data) + + obs = env.reset() + + # Check that a initial duration for malfunction was assigned + assert env.agents[0].malfunction_data['next_malfunction'] > 0 + + agent_halts = 0 + total_down_time = 0 + agent_malfunctioning = False + agent_old_position = env.agents[0].position + for step in range(100): + actions = {} + for i in range(len(obs)): + actions[i] = np.argmax(obs[i]) + 1 + + if step % 5 == 0: + # Stop the agent and set it to be malfunctioning + env.agents[0].malfunction_data['malfunction'] = -1 + env.agents[0].malfunction_data['next_malfunction'] = 0 + agent_halts += 1 + + obs, all_rewards, done, _ = env.step(actions) + + if env.agents[0].malfunction_data['malfunction'] > 0: + agent_malfunctioning = True + else: + agent_malfunctioning = False + + if agent_malfunctioning: + # Check that agent is not moving while malfunctioning + assert agent_old_position == env.agents[0].position + + agent_old_position = env.agents[0].position + total_down_time += env.agents[0].malfunction_data['malfunction'] + + # Check that the appropriate number of malfunctions is achieved + assert env.agents[0].malfunction_data['nr_malfunctions'] == 21 + + # Check that 20 stops where performed + assert agent_halts == 20 + + # Check that malfunctioning data was standing around + assert total_down_time > 0