diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ea34f28d1950ab3a21dc44635251320dd031c1ae..5e2bb6612e8b1c7896317b4098ece6d54e3d8237 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -8,7 +8,6 @@ from typing import List, NamedTuple, Optional, Dict, Tuple import msgpack_numpy as m import numpy as np -from gym.utils import seeding from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder @@ -28,14 +27,43 @@ from flatland.envs import agent_chains as ac from flatland.envs.observations import GlobalObsForRailEnv +from gym.utils import seeding +# Direct import of objects / classes does not work with circular imports. +# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +# from flatland.envs.observations import GlobalObsForRailEnv +# from flatland.envs.rail_generators import random_rail_generator, RailGenerator +# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator -import pickle - m.patch() +# Adrian Egli performance fix (the fast methods brings more than 50%) +def fast_isclose(a, b, rtol): + return (a < (b + rtol)) or (a < (b - rtol)) + +def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool: + return ( + max(min_value[0], min(position[0], max_value[0])), + max(min_value[1], min(position[1], max_value[1])) + ) + +def fast_argmax(possible_transitions: (int, int, int, int)) -> bool: + if possible_transitions[0] == 1: + return 0 + if possible_transitions[1] == 1: + return 1 + if possible_transitions[2] == 1: + return 2 + return 3 + +def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] + +def fast_count_nonzero(possible_transitions: (int, int, int, int)): + return possible_transitions[0]+possible_transitions[1]+possible_transitions[2]+possible_transitions[3] + class RailEnvActions(IntEnum): DO_NOTHING = 0 # implies change of direction in a dead-end! MOVE_LEFT = 1 @@ -289,8 +317,8 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and self.my_isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None) -> Tuple[Dict, Dict]: @@ -441,7 +469,6 @@ class RailEnv(Environment): """ - #malfunction: Malfunction = self.malfunction_generator(agent, self.np_random) if "generate" in dir(self.malfunction_generator): malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random) else: @@ -567,8 +594,6 @@ class RailEnv(Environment): return self._get_observations(), self.rewards_dict, self.dones, info_dict - - def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): """ Performs a step and step, start and stop penalty on a single agent in the following sub steps: @@ -615,7 +640,7 @@ class RailEnv(Environment): # Is the agent at the beginning of the cell? Then, it can take an action. # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! - if self.my_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): + if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: action = RailEnvActions.DO_NOTHING @@ -675,8 +700,8 @@ class RailEnv(Environment): # transition_action_on_cellexit if the cell is free. if agent.moving: agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 1.0 or self.my_isclose(agent.speed_data['position_fraction'], 1.0, - rtol=1e-03): + if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0, + rtol=1e-03): # Perform stored action to transition to the next cell as soon as cell is free # Notice that we've already checked new_cell_valid and transition valid when we stored the action, # so we only have to check cell_free now! @@ -694,7 +719,6 @@ class RailEnv(Environment): agent.direction = new_direction agent.speed_data['position_fraction'] = 0.0 - # has the agent reached its target? if np.equal(agent.position, agent.target).all(): agent.status = RailAgentStatus.DONE @@ -926,7 +950,6 @@ class RailEnv(Environment): self.agent_positions[agent.position] = -1 if self.remove_agents_at_target: agent.position = None - agent.old_position = None agent.status = RailAgentStatus.DONE_REMOVED def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): @@ -953,11 +976,11 @@ class RailEnv(Environment): new_position = get_new_position(agent.position, new_direction) new_cell_valid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_full_transitions(*new_position) > 0) + fast_position_equal( # Check the new position is still in the grid + new_position, + fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_full_transitions(*new_position) > 0) # If transition validity hasn't been checked yet. if transition_valid is None: @@ -1027,7 +1050,7 @@ class RailEnv(Environment): """ transition_valid = None possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) + num_transitions = fast_count_nonzero(possible_transitions) new_direction = agent.direction if action == RailEnvActions.MOVE_LEFT: @@ -1046,7 +1069,7 @@ class RailEnv(Environment): # - dead-end, straight line or curved line; # new_direction will be the only valid transition # - take only available transition - new_direction = np.argmax(possible_transitions) + new_direction = fast_argmax(possible_transitions) transition_valid = True return new_direction, transition_valid @@ -1105,6 +1128,3 @@ class RailEnv(Environment): def save(self, filename): print("deprecated call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) - - def my_isclose(self, x, y, rtol=1.e-5, atol=1.e-8): - return abs(x - y) <= atol + rtol * abs(y)