From ce1386648b4733a7fc6a5649b1bb53f45551d865 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 10 Aug 2019 14:54:21 -0400 Subject: [PATCH] updated poisson process for malfunction of agents --- examples/debugging_example_DELETE.py | 8 ++- examples/training_example.py | 8 +-- flatland/envs/agent_utils.py | 13 ++--- flatland/envs/rail_env.py | 74 ++++++++++++++-------------- 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index 8df84833..56148f20 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -3,11 +3,8 @@ import time import numpy as np -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid_utils import coordinate_to_position -from flatland.envs.generators import random_rail_generator, complex_rail_generator +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool @@ -77,7 +74,8 @@ for step in range(100): actions[0] = 4 # Halt obs, all_rewards, done, _ = env.step(actions) - print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken']) + if env.agents[0].broken_data['broken'] > 0: + print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken']) env_renderer.render_env(show=True, frames=True, show_observations=False) time.sleep(0.5) diff --git a/examples/training_example.py b/examples/training_example.py index cfed6c92..17484ad7 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,9 +16,9 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=50, height=50, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, min_dist=8, max_dist=99999, seed=0), obs_builder_object=TreeObservation, - number_of_agents=5) + number_of_agents=20) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -75,6 +75,7 @@ for trials in range(1, n_trials + 1): score = 0 # Run episode + mean_malfunction_interval = [] for step in range(100): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): @@ -84,7 +85,7 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=True, show_predictions=True) + env_renderer.render_env(show=True, show_observations=False, show_predictions=True) # Update replay buffer and train agent for a in range(env.get_num_agents()): @@ -94,4 +95,5 @@ for trials in range(1, n_trials + 1): obs = next_obs.copy() if done['__all__']: break + print(np.mean(mean_malfunction_interval)) print('Episode Nr. {}\t Score = {}'.format(trials, score)) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 2017d706..27a7a380 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -25,8 +25,8 @@ class EnvAgentStatic(object): # if broken>0, the agent's actions are ignored for 'broken' steps # number of time the agent had to stop, since the last time it broke down - broken_data = attrib( - default=Factory(lambda: dict({'broken': 0, 'number_of_halts': 0}))) + malfunction_data = attrib( + default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0}))) @classmethod def from_lists(cls, positions, directions, targets, speeds=None): @@ -42,8 +42,9 @@ class EnvAgentStatic(object): # some as broken? broken_datas = [] for i in range(len(positions)): - broken_datas.append({'broken': 0, - 'number_of_halts': 0}) + broken_datas.append({'malfunction': 0, + 'malfunction_rate': 0, + 'next_malfunction': 0}) return list(starmap(EnvAgentStatic, zip(positions, directions, @@ -64,7 +65,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.broken_data] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data] @attrs @@ -82,7 +83,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving, self.speed_data, self.broken_data] + self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0f03870b..cfd8dad4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -75,16 +75,13 @@ class RailEnv(Environment): - stop_penalty = 0 # penalty for stopping a moving agent - start_penalty = 0 # penalty for starting a stopped agent - Stochastic breaking of trains: - Trains in RailEnv can break down if they are halted too often (either by their own choice or because an invalid + Stochastic malfunctioning of trains: + Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid action or cell is selected. - Every time an agent stops, an agent has a certain probability of breaking. The probability is the product of 2 - distributions: the first distribution selects the average number of trains that will break during an episode - (e.g., max(1, 10% of the trains) ). The second distribution is a Poisson distribution with mean set to the average - number of stops at which a train breaks. - If a random number in [0,1] is lower than the product of the 2 distributions, the train breaks. - A broken train samples a random number of steps it will stay broken for, during which all its actions are ignored. + Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a + poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep + complexity managable. TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init(). For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility. @@ -160,20 +157,20 @@ class RailEnv(Environment): self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? + # Stochastic train malfunctioning parameters + self.proportion_malfunctioning_trains = 0.1 # percentage of malfunctioning trains + self.mean_malfunction_rate = 5 # Average malfunction in number of stops + + # Uniform distribution parameters for malfunction duration + self.min_number_of_steps_broken = 4 + self.max_number_of_steps_broken = 10 + + # Rest environment self.reset() self.num_resets = 0 # yes, set it to zero again! self.valid_positions = None - # Stochastic train breaking parameters - self.min_average_broken_trains = 1 - self.average_proportion_of_broken_trains = 0.1 # ~10% of the trains can be expected to break down in an episode - self.mean_number_halts_to_break = 3 - - # Uniform distribution - self.min_number_of_steps_broken = 4 - self.max_number_of_steps_broken = 8 - # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) @@ -218,9 +215,12 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] + + # A proportion of agent in the environment will receive a positive malfunction rate + if np.random.random() >= self.proportion_malfunctioning_trains: + agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate agent.speed_data['position_fraction'] = 0.0 - agent.broken_data['broken'] = 0 - agent.broken_data['number_of_halts'] = 0 + agent.malfunction_data['malfunction'] = 0 self.num_resets += 1 self._elapsed_steps = 0 @@ -236,24 +236,26 @@ class RailEnv(Environment): return self._get_observations() def _agent_stopped(self, i_agent): - self.agents[i_agent].broken_data['number_of_halts'] += 1 + # Make sure agent is stopped + self.agents[i_agent].moving = False + + # Only agents that have a positive rate for malfunctions are considered + if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0: - def poisson_pdf(x, mean): - return np.power(mean, x) * np.exp(-mean) / np.prod(range(2, x)) + # Decrease counter for next event + self.agents[i_agent].malfunction_data['next_malfunction'] -= 1 - p1_prob_train_i_breaks = max(self.min_average_broken_trains / len(self.agents), - self.average_proportion_of_broken_trains) - p2_prob_train_breaks_at_halt_j = poisson_pdf(self.agents[i_agent].broken_data['number_of_halts'], - self.mean_number_halts_to_break) + # If counter has come to zero, set next malfunction time and duration of current malfunction - s1 = np.random.random() - s2 = np.random.random() + if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0: + # Next malfunction in number of stops + self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential( + scale=self.agents[i_agent].malfunction_data['malfunction_rate'])) - if s1 * s2 <= p1_prob_train_i_breaks * p2_prob_train_breaks_at_halt_j: - # +1 because the counter is decreased at the beginning of step() - num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken+1) + 1 - self.agents[i_agent].broken_data['broken'] = num_broken_steps - self.agents[i_agent].broken_data['number_of_halts'] = 0 + # Duration of current malfunction + num_broken_steps = np.random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + 1 + self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps def step(self, action_dict_): self._elapsed_steps += 1 @@ -284,8 +286,8 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position - if agent.broken_data['broken'] > 0: - agent.broken_data['broken'] -= 1 + if agent.malfunction_data['malfunction'] > 0: + agent.malfunction_data['malfunction'] -= 1 if self.dones[i_agent]: # this agent has already completed... continue @@ -295,7 +297,7 @@ class RailEnv(Environment): action_dict[i_agent] = RailEnvActions.DO_NOTHING # The train is broken - if agent.broken_data['broken'] > 0: + if agent.malfunction_data['malfunction'] > 0: action_dict[i_agent] = RailEnvActions.DO_NOTHING if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): -- GitLab