diff --git a/examples/training_example.py b/examples/training_example.py index 17484ad7422e327ad5b50200f6e1726f19a43594..6910461327c778ff52824165032641ece019cf7a 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -75,13 +75,11 @@ for trials in range(1, n_trials + 1): score = 0 # Run episode - mean_malfunction_interval = [] for step in range(100): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): action = agent.act(obs[a]) action_dict.update({a: action}) - # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) @@ -95,5 +93,4 @@ for trials in range(1, n_trials + 1): obs = next_obs.copy() if done['__all__']: break - print(np.mean(mean_malfunction_interval)) print('Episode Nr. {}\t Score = {}'.format(trials, score)) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 27a7a380a23d0a1bd6352e06b00e8e63deafb71a..4c4070088c59499f885c16db68976c163ec91001 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -26,7 +26,8 @@ class EnvAgentStatic(object): # if broken>0, the agent's actions are ignored for 'broken' steps # number of time the agent had to stop, since the last time it broke down malfunction_data = attrib( - default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0}))) + default=Factory( + lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) @classmethod def from_lists(cls, positions, directions, targets, speeds=None): @@ -40,18 +41,19 @@ class EnvAgentStatic(object): # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set # some as broken? - broken_datas = [] + malfunction_datas = [] for i in range(len(positions)): - broken_datas.append({'malfunction': 0, + malfunction_datas.append({'malfunction': 0, 'malfunction_rate': 0, - 'next_malfunction': 0}) + 'next_malfunction': 0, + 'nr_malfunctions': 0}) return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas, - broken_datas))) + malfunction_datas))) def to_list(self): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cfd8dad4e80f1d4ce3783a594de3c890191ce0f6..d62c689aa9f6c788b54474b884d4101d98fb4ff0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -94,7 +94,8 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - max_episode_steps=None + max_episode_steps=None, + stochastic_data=None ): """ Environment init. @@ -158,12 +159,26 @@ class RailEnv(Environment): self.observation_space = self.obs_builder.observation_space # updated on resets? # Stochastic train malfunctioning parameters - self.proportion_malfunctioning_trains = 0.1 # percentage of malfunctioning trains - self.mean_malfunction_rate = 5 # Average malfunction in number of stops + if stochastic_data is not None: + prop_malfunction = stochastic_data['prop_malfunction'] + mean_malfunction_rate = stochastic_data['malfunction_rate'] + malfunction_min_duration = stochastic_data['min_duration'] + malfunction_max_duration = stochastic_data['max_duration'] + else: + prop_malfunction = 0. + mean_malfunction_rate = 0. + malfunction_min_duration = 0. + malfunction_max_duration = 0. + + # percentage of malfunctioning trains + self.proportion_malfunctioning_trains = prop_malfunction + + # Mean malfunction in number of stops + self.mean_malfunction_rate = mean_malfunction_rate # Uniform distribution parameters for malfunction duration - self.min_number_of_steps_broken = 4 - self.max_number_of_steps_broken = 10 + self.min_number_of_steps_broken = malfunction_min_duration + self.max_number_of_steps_broken = malfunction_max_duration # Rest environment self.reset() @@ -217,8 +232,9 @@ class RailEnv(Environment): agent = self.agents[i_agent] # A proportion of agent in the environment will receive a positive malfunction rate - if np.random.random() >= self.proportion_malfunctioning_trains: + if np.random.random() < self.proportion_malfunctioning_trains: agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate + agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 @@ -236,21 +252,23 @@ class RailEnv(Environment): return self._get_observations() def _agent_stopped(self, i_agent): - # Make sure agent is stopped - self.agents[i_agent].moving = False + # Decrease counter for next event + self.agents[i_agent].malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions are considered - if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0: - - # Decrease counter for next event - self.agents[i_agent].malfunction_data['next_malfunction'] -= 1 - - # If counter has come to zero, set next malfunction time and duration of current malfunction + if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[ + 'malfunction']: + # If counter has come to zero --> Agent has malfunction + # set next malfunction time and duration of current malfunction if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0: + # Increase number of malfunctions + self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1 + # Next malfunction in number of stops - self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential( - scale=self.agents[i_agent].malfunction_data['malfunction_rate'])) + next_breakdown = int( + np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate'])) + self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction num_broken_steps = np.random.randint(self.min_number_of_steps_broken, @@ -286,9 +304,6 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position - if agent.malfunction_data['malfunction'] > 0: - agent.malfunction_data['malfunction'] -= 1 - if self.dones[i_agent]: # this agent has already completed... continue @@ -298,8 +313,16 @@ class RailEnv(Environment): # The train is broken if agent.malfunction_data['malfunction'] > 0: + agent.malfunction_data['malfunction'] -= 1 + + # Broken agents are stopped + self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.agents[i_agent].moving = False action_dict[i_agent] = RailEnvActions.DO_NOTHING + # Nothing left to do with broken agent + continue + if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): print('ERROR: illegal action=', action_dict[i_agent], 'for agent with index=', i_agent, diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py new file mode 100644 index 0000000000000000000000000000000000000000..91c551db60f9d71d7aa0774ea8b6aaf42af3e35b --- /dev/null +++ b/tests/test_flatland_malfunction.py @@ -0,0 +1,88 @@ +import numpy as np + +from flatland.envs.generators import complex_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.rail_env import RailEnv + + +class SingleAgentNavigationObs(TreeObsForRailEnv): + """ + We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute + the minimum distances from each grid node to each agent's target. + + We then build a representation vector with 3 binary components, indicating which of the 3 available directions + for each agent (Left, Forward, Right) lead to the shortest path to its target. + E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector + will be [1, 0, 0]. + """ + + def __init__(self): + super().__init__(max_depth=0) + self.observation_space = [3] + + def reset(self): + # Recompute the distance map, if the environment has changed. + super().reset() + + def get(self, handle): + agent = self.env.agents[handle] + + possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right], relative to the current orientation + # If only one transition is possible, the forward branch is aligned with it. + if num_transitions == 1: + observation = [0, 1, 0] + else: + min_distances = [] + for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[direction]: + new_position = self._new_position(agent.position, direction) + min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction]) + else: + min_distances.append(np.inf) + + observation = [0, 0, 0] + observation[np.argmin(min_distances)] = 1 + + return observation + + +def test_malfunction_process(): + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 5, + 'min_duration': 3, + 'max_duration': 10} + np.random.seed(5) + + env = RailEnv(width=14, + height=14, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, + seed=0), + number_of_agents=2, + obs_builder_object=SingleAgentNavigationObs(), + stochastic_data=stochastic_data) + + obs = env.reset() + agent_halts = 0 + for step in range(100): + actions = {} + for i in range(len(obs)): + actions[i] = np.argmax(obs[i]) + 1 + + if step % 5 == 0: + actions[0] = 4 + agent_halts += 1 + + obs, all_rewards, done, _ = env.step(actions) + + if done["__all__"]: + break + + # Check that the agents breaks twice + assert env.agents[0].malfunction_data['nr_malfunctions'] == 2 + + # Check that 7 stops where performed + assert agent_halts == 7