diff --git a/flatland/envs/agent_chains.py b/flatland/envs/agent_chains.py index 7706f6d41cbd59ce333682f2420a306608ba4fe0..d13b734b1c44805983ba2ff5b9616321867788ec 100644 --- a/flatland/envs/agent_chains.py +++ b/flatland/envs/agent_chains.py @@ -2,7 +2,6 @@ import networkx as nx import numpy as np -import matplotlib.pyplot as plt from typing import List, Tuple import graphviz as gv @@ -372,18 +371,11 @@ def test_agent_following(): for v in lvCells ] dPos = dict(zip(lvCells, lvCells)) - #plt.ion() nx.draw(omc.G, with_labels=True, arrowsize=20, pos=dPos, node_color = lColours) - - #plt.pause(20) - #plt.show() - - - def main(): test_agent_following() diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 11c471ff6fb50466eaa459310d57d73a3b5aac8a..d2f8f4d245a316b4e666f00156cdbee250a62ef9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -4,13 +4,10 @@ Definition of the RailEnv environment. import random # TODO: _ this is a global method --> utils or remove later from enum import IntEnum -from typing import List, NamedTuple, Optional, Dict +from typing import List, NamedTuple, Optional, Dict, Tuple -import msgpack -import msgpack_numpy as m import numpy as np -from gym.utils import seeding -from msgpack import Packer + from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder @@ -28,21 +25,50 @@ from flatland.envs import schedule_generators as sched_gen from flatland.envs import persistence from flatland.envs import agent_chains as ac +from flatland.envs.observations import GlobalObsForRailEnv +from gym.utils import seeding + # Direct import of objects / classes does not work with circular imports. # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData # from flatland.envs.observations import GlobalObsForRailEnv # from flatland.envs.rail_generators import random_rail_generator, RailGenerator # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator -from flatland.envs.observations import GlobalObsForRailEnv - -# import debugpy -import pickle m.patch() +# Adrian Egli performance fix (the fast methods brings more than 50%) +def fast_isclose(a, b, rtol): + return (a < (b + rtol)) or (a < (b - rtol)) + + +def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool: + return ( + max(min_value[0], min(position[0], max_value[0])), + max(min_value[1], min(position[1], max_value[1])) + ) + + +def fast_argmax(possible_transitions: (int, int, int, int)) -> bool: + if possible_transitions[0] == 1: + return 0 + if possible_transitions[1] == 1: + return 1 + if possible_transitions[2] == 1: + return 2 + return 3 + + +def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] + + +def fast_count_nonzero(possible_transitions: (int, int, int, int)): + return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3] + + class RailEnvActions(IntEnum): DO_NOTHING = 0 # implies change of direction in a dead-end! MOVE_LEFT = 1 @@ -298,11 +324,11 @@ class RailEnv(Environment): False: Agent cannot provide an action """ return (agent.status == RailAgentStatus.READY_TO_DEPART or ( - agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0, - rtol=1e-03))) + agent.status == RailAgentStatus.ACTIVE and fast_isclose(agent.speed_data['position_fraction'], 0.0, + rtol=1e-03))) def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, - random_seed: bool = None) -> (Dict, Dict): + random_seed: bool = None) -> Tuple[Dict, Dict]: """ reset(regenerate_rail, regenerate_schedule, activate_agents, random_seed) @@ -604,7 +630,7 @@ class RailEnv(Environment): RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, - RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): + RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): agent.status = RailAgentStatus.ACTIVE self._set_agent_to_initial_position(agent, agent.initial_position) self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] @@ -626,7 +652,7 @@ class RailEnv(Environment): # Is the agent at the beginning of the cell? Then, it can take an action. # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! - if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): + if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): # No action has been supplied for this agent -> set DO_NOTHING as default if action is None: action = RailEnvActions.DO_NOTHING @@ -686,8 +712,8 @@ class RailEnv(Environment): # transition_action_on_cellexit if the cell is free. if agent.moving: agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 1.0 or np.isclose(agent.speed_data['position_fraction'], 1.0, - rtol=1e-03): + if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0, + rtol=1e-03): # Perform stored action to transition to the next cell as soon as cell is free # Notice that we've already checked new_cell_valid and transition valid when we stored the action, # so we only have to check cell_free now! @@ -695,7 +721,7 @@ class RailEnv(Environment): # Traditional check that next cell is free # cell and transition validity was checked when we stored transition_action_on_cellexit! cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( - agent.speed_data['transition_action_on_cellexit'], agent) + agent.speed_data['transition_action_on_cellexit'], agent) # N.B. validity of new_cell and transition should have been verified before the action was stored! assert new_cell_valid @@ -845,7 +871,6 @@ class RailEnv(Environment): trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4] if (trans_block == "0000"): print (i_agent, agent.position, agent.direction, sbTrans, trans_block) - # debugpy.breakpoint() # if agent cannot enter env, then we should have move=False @@ -862,20 +887,16 @@ class RailEnv(Environment): if not all([transition_valid, new_cell_valid]): print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}") - # debugpy.breakpoint() if new_position != rc_next: print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next} " + - f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" + - f"stored action: {agent.speed_data['transition_action_on_cellexit']}") - # debugpy.breakpoint() - + f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" + + f"stored action: {agent.speed_data['transition_action_on_cellexit']}") sbTrans = format(self.rail.grid[agent.position], "016b") trans_block = sbTrans[agent.direction*4 : agent.direction * 4 + 4] if (trans_block == "0000"): print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block) - # debugpy.breakpoint() agent.position = rc_next agent.direction = new_direction @@ -937,6 +958,7 @@ class RailEnv(Environment): self.agent_positions[agent.position] = -1 if self.remove_agents_at_target: agent.position = None + # setting old_position to None here stops the DONE agents from appearing in the rendered image agent.old_position = None agent.status = RailAgentStatus.DONE_REMOVED @@ -964,9 +986,9 @@ class RailEnv(Environment): new_position = get_new_position(agent.position, new_direction) new_cell_valid = ( - np.array_equal( # Check the new position is still in the grid + fast_position_equal( # Check the new position is still in the grid new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1])) and # check the new position has some transitions (ie is not an empty cell) self.rail.get_full_transitions(*new_position) > 0) @@ -1038,7 +1060,7 @@ class RailEnv(Environment): """ transition_valid = None possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) + num_transitions = fast_count_nonzero(possible_transitions) new_direction = agent.direction if action == RailEnvActions.MOVE_LEFT: @@ -1057,7 +1079,7 @@ class RailEnv(Environment): # - dead-end, straight line or curved line; # new_direction will be the only valid transition # - take only available transition - new_direction = np.argmax(possible_transitions) + new_direction = fast_argmax(possible_transitions) transition_valid = True return new_direction, transition_valid diff --git a/flatland/utils/env_edit_utils.py b/flatland/utils/env_edit_utils.py index ac748469e89a063b53e3212303b974c864f7afda..b1a401740f01c57d4fe4ed86a303b3b45e7945ab 100644 --- a/flatland/utils/env_edit_utils.py +++ b/flatland/utils/env_edit_utils.py @@ -122,5 +122,10 @@ def makeTestEnv(sName="single_alternative", nAg=2, bUCF=True): dSpec = ddEnvSpecs[sName] - return makeEnv2(nAg=nAg, bUCF=bUCF, **dSpec) + +def getAgentState(env): + dAgState={} + for iAg, ag in enumerate(env.agents): + dAgState[iAg] = (*ag.position, ag.direction) + return dAgState \ No newline at end of file