From 4fae3ccb8ed4aa54e1a3bf4cf1ef697672263734 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 5 Sep 2019 00:47:29 +0200 Subject: [PATCH] #167 bugfix action_on_cellexit --- flatland/envs/rail_env.py | 77 ++++++++++++++++---------------- flatland/envs/rail_generators.py | 4 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8b4f43fe..cc85604e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,6 +4,7 @@ Definition of the RailEnv environment. # TODO: _ this is a global method --> utils or remove later import warnings from enum import IntEnum +from typing import List import msgpack import msgpack_numpy as m @@ -165,8 +166,8 @@ class RailEnv(Environment): self.dev_obs_dict = {} self.dev_pred_dict = {} - self.agents = [None] * number_of_agents # live agents - self.agents_static = [None] * number_of_agents # static agent information + self.agents: List[EnvAgent] = [None] * number_of_agents # live agents + self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information self.num_resets = 0 self.action_space = [1] @@ -239,17 +240,17 @@ class RailEnv(Environment): self.height, self.width = self.rail.grid.shape for r in range(self.height): for c in range(self.width): - rcPos = (r, c) - check = self.rail.cell_neighbours_valid(rcPos, True) + rc_pos = (r, c) + check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: - warnings.warn("Invalid grid at {} -> {}".format(rcPos, check)) + warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) if replace_agents: agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( - *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) + *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints)) self.restart_agents() for i_agent in range(self.get_num_agents()): @@ -284,25 +285,24 @@ class RailEnv(Environment): agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered - if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: - - # If counter has come to zero --> Agent has malfunction - # set next malfunction time and duration of current malfunction - if agent.malfunction_data['next_malfunction'] <= 0: - # Increase number of malfunctions - agent.malfunction_data['nr_malfunctions'] += 1 - - # Next malfunction in number of stops - next_breakdown = int( - np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) - agent.malfunction_data['next_malfunction'] = next_breakdown - - # Duration of current malfunction - num_broken_steps = np.random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 - agent.malfunction_data['malfunction'] = num_broken_steps - - return True + # If counter has come to zero --> Agent has malfunction + # set next malfunction time and duration of current malfunction + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \ + agent.malfunction_data['next_malfunction'] <= 0: + # Increase number of malfunctions + agent.malfunction_data['nr_malfunctions'] += 1 + + # Next malfunction in number of stops + next_breakdown = int( + np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown + + # Duration of current malfunction + num_broken_steps = np.random.randint(self.min_number_of_steps_broken, + self.max_number_of_steps_broken + 1) + 1 + agent.malfunction_data['malfunction'] = num_broken_steps + + return True return False def step(self, action_dict_): @@ -353,6 +353,20 @@ class RailEnv(Environment): # TODO refactor!!! # If the agent can make an action if agent.speed_data['position_fraction'] == 0.0: + if action == RailEnvActions.DO_NOTHING and agent.moving: + # Keep moving + action = RailEnvActions.MOVE_FORWARD + + if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0: + # Only allow halting an agent on entering new cells. + agent.moving = False + self.rewards_dict[i_agent] += self.stop_penalty + + if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): + # Allow agent to start with any forward or direction action + agent.moving = True + self.rewards_dict[i_agent] += self.start_penalty + if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) @@ -408,19 +422,6 @@ class RailEnv(Environment): # Nothing left to do with broken agent continue - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD - - if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.0: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += self.stop_penalty - - if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += self.start_penalty # Now perform a movement. # If the agent is in an initial position within a new cell (agent.speed_data['position_fraction']<eps) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8573c25c..9d55198b 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -1,6 +1,6 @@ """Rail generators (infrastructure manager, "Infrastrukturbetreiber").""" import warnings -from typing import Callable, Tuple, Any, Optional +from typing import Callable, Tuple, Optional, Dict import msgpack import numpy as np @@ -11,7 +11,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes -RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]] +RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] -- GitLab