diff --git a/flatland/core/grid/grid4_utils.py b/flatland/core/grid/grid4_utils.py index 1475589ef62863fcb98a2d238b4b4c5dbe078b3c..624c14daefa06c25a576b554f349cf9e4e974805 100644 --- a/flatland/core/grid/grid4_utils.py +++ b/flatland/core/grid/grid4_utils.py @@ -25,17 +25,9 @@ def get_direction(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum: def mirror(dir): return (dir + 2) % 4 - +MOVEMENT_ARRAY = [(-1, 0), (0, 1), (1, 0), (0, -1)] def get_new_position(position, movement): - """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """ - if movement == Grid4TransitionsEnum.NORTH: - return (position[0] - 1, position[1]) - elif movement == Grid4TransitionsEnum.EAST: - return (position[0], position[1] + 1) - elif movement == Grid4TransitionsEnum.SOUTH: - return (position[0] + 1, position[1]) - elif movement == Grid4TransitionsEnum.WEST: - return (position[0], position[1] - 1) + return (position[0] + MOVEMENT_ARRAY[movement][0], position[1] + MOVEMENT_ARRAY[movement][1]) def direction_to_point(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 30ff30eb80323afa08bb73b60c781dbd2cd5d153..a7212c4c7414950a0c0560b2d73195a15b7bdad3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -35,8 +35,6 @@ from flatland.envs import persistence from flatland.envs.observations import GlobalObsForRailEnv - - import pickle m.patch() @@ -130,11 +128,11 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator = None, - schedule_generator = None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), + rail_generator=None, + schedule_generator=None, # : sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), - malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), + malfunction_generator_and_process_data=None, # mal_gen.no_malfunction_generator(), remove_agents_at_target=True, random_seed=1, record_steps=False @@ -179,11 +177,11 @@ class RailEnv(Environment): if malfunction_generator_and_process_data is None: malfunction_generator_and_process_data = mal_gen.no_malfunction_generator() self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data - #self.rail_generator: RailGenerator = rail_generator + # self.rail_generator: RailGenerator = rail_generator if rail_generator is None: rail_generator = rail_gen.random_rail_generator() self.rail_generator = rail_generator - #self.schedule_generator: ScheduleGenerator = schedule_generator + # self.schedule_generator: ScheduleGenerator = schedule_generator if schedule_generator is None: schedule_generator = sched_gen.random_schedule_generator() self.schedule_generator = schedule_generator @@ -230,8 +228,8 @@ class RailEnv(Environment): # save episode timesteps ie agent positions, orientations. (not yet actions / observations) self.record_steps = record_steps # whether to save timesteps # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] - self.cur_episode = [] - self.list_actions = [] # save actions in here + self.cur_episode = [] + self.list_actions = [] # save actions in here def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -264,8 +262,6 @@ class RailEnv(Environment): agent.reset() self.active_agents = [i for i in range(len(self.agents))] - - def action_required(self, agent): """ Check if an agent needs to provide an action @@ -281,8 +277,8 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + agent.status == RailAgentStatus.ACTIVE and self.my_isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None) -> (Dict, Dict): @@ -329,8 +325,6 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) - - if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: @@ -548,7 +542,7 @@ class RailEnv(Environment): # Is the agent at the beginning of the cell? Then, it can take an action. # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! - if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): + if self.my_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: action = RailEnvActions.DO_NOTHING @@ -608,8 +602,8 @@ class RailEnv(Environment): # transition_action_on_cellexit if the cell is free. if agent.moving: agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0, - rtol=1e-03): + if agent.speed_data['position_fraction'] > 1.0 or self.my_isclose(agent.speed_data['position_fraction'], 1.0, + rtol=1e-03): # Perform stored action to transition to the next cell as soon as cell is free # Notice that we've already checked new_cell_valid and transition valid when we stored the action, # so we only have to check cell_free now! @@ -740,7 +734,7 @@ class RailEnv(Environment): pos = (int(agent.position[0]), int(agent.position[1])) # print("pos:", pos, type(pos[0])) list_agents_state.append( - [*pos, int(agent.direction), agent.malfunction_data["malfunction"] ]) + [*pos, int(agent.direction), agent.malfunction_data["malfunction"]]) self.cur_episode.append(list_agents_state) self.list_actions.append(dActions) @@ -809,7 +803,7 @@ class RailEnv(Environment): ------ Dict object """ - #print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") + # print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict @@ -828,8 +822,6 @@ class RailEnv(Environment): """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) - - def _exp_distirbution_synced(self, rate: float) -> float: """ Generates sample from exponential distribution @@ -859,4 +851,5 @@ class RailEnv(Environment): print("deprecated call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) - + def my_isclose(self, x, y, rtol=1.e-5, atol=1.e-8): + return abs(x - y) <= atol + rtol * abs(y)