diff --git a/changelog.md b/changelog.md index dfdb89fcf51a079dadf65582f241a8c56e233ee7..5b582d99a369e78c7ebf70fce47bb4a5152e1444 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,9 @@ Changelog Changes since Flatland 2.0.0 -------------------------- +### Changes in `EnvAgent` +- class `EnvAgentStatic` was removed, so there is only class `EnvAgent` left which should simplify the handling of agents. The member `self.agents_static` of `RailEnv` was therefore also removed. Old Scence saved as pickle files cannot be loaded anymore. + ### Changes in malfunction behavior - agent attribute `next_malfunction`is not used anymore, it will be removed fully in future versions. - `break_agent()` function is introduced which induces malfunctions in agent according to poisson process diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst index c818742144b6ccdc547e96855794a4ab40066394..c5d73178384109758cf0e8bb5c2f25802d8700ca 100644 --- a/docs/tutorials/01_gettingstarted.rst +++ b/docs/tutorials/01_gettingstarted.rst @@ -109,15 +109,12 @@ following code. Also, tree observation data is displayed by RenderTool by defaul for i in range(env.get_num_agents()): env.obs_builder.util_print_obs_subtree( tree=obs[i], - num_features_per_node=5 ) The complete code for this part of the Getting Started guide can be found in * `examples/simple_example_1.py <https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/simple_example_1.py>`_ * `examples/simple_example_2.py <https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/simple_example_2.py>`_ -* `examples/simple_example_3.py <https://gitlab.aicrowd.com/flatland/flatland/blob/master/examples/simple_example_3.py>`_ - Part 2 : Training a Simple an Agent on Flatland diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 2bb9677aab020560aeb28aad97edfa23efebe9bf..dd639997f2b759e86b0879a84e4ab91f7ffc824b 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,8 +1,7 @@ from enum import IntEnum from itertools import starmap -from typing import Tuple, Optional +from typing import Tuple, Optional, NamedTuple -import numpy as np from attr import attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum @@ -16,14 +15,25 @@ class RailAgentStatus(IntEnum): DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None +Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), + ('initial_direction', Grid4TransitionsEnum), + ('direction', Grid4TransitionsEnum), + ('target', Tuple[int, int]), + ('moving', bool), + ('speed_data', dict), + ('malfunction_data', dict), + ('handle', int), + ('status', RailAgentStatus), + ('position', Tuple[int, int]), + ('old_direction', Grid4TransitionsEnum), + ('old_position', Tuple[int, int])]) + + @attrs -class EnvAgentStatic(object): - """ EnvAgentStatic - Stores initial position, direction and target. - This is like static data for the environment - it's where an agent starts, - rather than where it is at the moment. - The target should also be stored here. - """ +class EnvAgent: + initial_position = attrib(type=Tuple[int, int]) + initial_direction = attrib(type=Grid4TransitionsEnum) direction = attrib(type=Grid4TransitionsEnum) target = attrib(type=Tuple[int, int]) moving = attrib(default=False, type=bool) @@ -42,12 +52,33 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, 'moving_before_malfunction': False}))) + handle = attrib(default=None) + status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) position = attrib(default=None, type=Optional[Tuple[int, int]]) + # used in rendering + old_direction = attrib(default=None) + old_position = attrib(default=None) + + def reset(self): + self.position = None + # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280 + self.direction = self.initial_direction + self.status = RailAgentStatus.READY_TO_DEPART + self.old_position = None + self.old_direction = None + self.moving = False + + def to_agent(self) -> Agent: + return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, + direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data, + malfunction_data=self.malfunction_data, handle=self.handle, status=self.status, + position=self.position, old_direction=self.old_direction, old_position=self.old_position) + @classmethod - def from_lists(cls, schedule: Schedule): - """ Create a list of EnvAgentStatics from lists of positions, directions and targets + def from_schedule(cls, schedule: Schedule): + """ Create a list of EnvAgent from lists of positions, directions and targets """ speed_datas = [] @@ -56,9 +87,6 @@ class EnvAgentStatic(object): 'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0, 'transition_action_on_cellexit': 0}) - # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set - # some as broken? - malfunction_datas = [] for i in range(len(schedule.agent_positions)): malfunction_datas.append({'malfunction': 0, @@ -67,59 +95,11 @@ class EnvAgentStatic(object): 'next_malfunction': 0, 'nr_malfunctions': 0}) - return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, - schedule.agent_directions, - schedule.agent_targets, - [False] * len(schedule.agent_positions), - speed_datas, - malfunction_datas))) - - def to_list(self): - - # I can't find an expression which works on both tuples, lists and ndarrays - # which converts them all to a list of native python ints. - lPos = self.initial_position - if type(lPos) is np.ndarray: - lPos = lPos.tolist() - - lTarget = self.target - if type(lTarget) is np.ndarray: - lTarget = lTarget.tolist() - - return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data] - - -@attrs -class EnvAgent(EnvAgentStatic): - """ EnvAgent - replace separate agent_* lists with a single list - of agent objects. The EnvAgent represent's the environment's view - of the dynamic agent state. - We are duplicating target in the EnvAgent, which seems simpler than - forcing the env to refer to it in the EnvAgentStatic - """ - handle = attrib(default=None) - old_direction = attrib(default=None) - old_position = attrib(default=None) - - def to_list(self): - return [ - self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data] - - @classmethod - def from_static(cls, oStatic): - """ Create an EnvAgent from the EnvAgentStatic, - copying all the fields, and adding handle with the default 0. - """ - return EnvAgent(*oStatic.__dict__, handle=0) - - @classmethod - def list_from_static(cls, lEnvAgentStatic, handles=None): - """ Create an EnvAgent from the EnvAgentStatic, - copying all the fields, and adding handle with the default 0. - """ - if handles is None: - handles = range(len(lEnvAgentStatic)) - - return [EnvAgent(**oEAS.__dict__, handle=handle) - for handle, oEAS in zip(handles, lEnvAgentStatic)] + return list(starmap(EnvAgent, zip(schedule.agent_positions, + schedule.agent_directions, + schedule.agent_directions, + schedule.agent_targets, + [False] * len(schedule.agent_positions), + speed_datas, + malfunction_datas, + range(len(schedule.agent_positions))))) diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 6a4899995ac84257b1845265a5402db9048bd654..c2d342d6b43a445c1deb93dc59476478875bc786 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -157,8 +157,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder): new_position = agent_virtual_position visited = OrderedSet() for index in range(1, self.max_depth + 1): - # if we're at the target or not moving, stop moving until max_depth is reached - if new_position == agent.target or not agent.moving or not shortest_path: + # if we're at the target, stop moving until max_depth is reached + if new_position == agent.target or not shortest_path: prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING] visited.add((*new_position, agent.direction)) continue diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 81c3a3569642c680b5b0c7246fb6d0b94685d9ed..092feb10ca28a715b9e9cfdaad0f7de99355a952 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -17,7 +17,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus +from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData from flatland.envs.observations import GlobalObsForRailEnv @@ -183,8 +183,8 @@ class RailEnv(Environment): self.dev_obs_dict = {} self.dev_pred_dict = {} - self.agents: List[EnvAgent] = [None] * number_of_agents # live agents - self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information + self.agents: List[EnvAgent] = [] + self.number_of_agents = number_of_agents self.num_resets = 0 self.distance_map = DistanceMap(self.agents, self.height, self.width) @@ -210,18 +210,15 @@ class RailEnv(Environment): def get_agent_handles(self): return range(self.get_num_agents()) - def get_num_agents(self, static=True): - if static: - return len(self.agents_static) - else: - return len(self.agents) + def get_num_agents(self) -> int: + return len(self.agents) - def add_agent_static(self, agent_static): + def add_agent(self, agent): """ Add static info for a single agent. Returns the index of the new agent. """ - self.agents_static.append(agent_static) - return len(self.agents_static) - 1 + self.agents.append(agent) + return len(self.agents) - 1 def set_agent_active(self, agent: EnvAgent): if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position): @@ -229,9 +226,10 @@ class RailEnv(Environment): self._set_agent_to_initial_position(agent, agent.initial_position) def restart_agents(self): - """ Reset the agents to their starting positions defined in agents_static + """ Reset the agents to their starting positions """ - self.agents = EnvAgent.list_from_static(self.agents_static) + for agent in self.agents: + agent.reset() self.active_agents = [i for i in range(len(self.agents))] @staticmethod @@ -309,7 +307,7 @@ class RailEnv(Environment): optionals = {} if regenerate_rail or self.rail is None: - rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets) self.rail = rail self.height, self.width = self.rail.grid.shape @@ -322,17 +320,13 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) - # todo change self.agents_static[0] with the refactoring for agents_static -> issue nr. 185 - # https://gitlab.aicrowd.com/flatland/flatland/issues/185 - if regenerate_schedule or regenerate_rail or self.agents_static[0] is None: + if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: agents_hints = optionals['agents_hints'] - # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185 - # why do we need static agents? could we it more elegantly? - schedule = self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets) - self.agents_static = EnvAgentStatic.from_lists(schedule) + schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets) + self.agents = EnvAgent.from_schedule(schedule) if agents_hints and 'city_orientations' in agents_hints: ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations']) @@ -372,9 +366,9 @@ class RailEnv(Environment): info_dict: Dict = { 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { - i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents()) + i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents) }, - 'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}, + 'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)}, 'status': {i: agent.status for i, agent in enumerate(self.agents)} } # Return the new observation vectors for each agent @@ -800,35 +794,38 @@ class RailEnv(Environment): Returns state of environment in msgpack object """ grid_data = self.rail.grid.tolist() - agent_static_data = [agent.to_list() for agent in self.agents_static] - agent_data = [agent.to_list() for agent in self.agents] + agent_data = [agent.to_agent() for agent in self.agents] malfunction_data: MalfunctionProcessData = self.malfunction_process_data msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) - msgpack.packb(agent_static_data, use_bin_type=True) msg_data = { "grid": grid_data, - "agents_static": agent_static_data, "agents": agent_data, "malfunction": malfunction_data} return msgpack.packb(msg_data, use_bin_type=True) + def get_agent_state_msg(self): + """ + Returns agents information in msgpack object + """ + agent_data = [agent.to_agent() for agent in self.agents] + msg_data = { + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + def get_full_state_dist_msg(self): """ Returns environment information with distance map information as msgpack object """ grid_data = self.rail.grid.tolist() - agent_static_data = [agent.to_list() for agent in self.agents_static] - agent_data = [agent.to_list() for agent in self.agents] + agent_data = [agent.to_agent() for agent in self.agents] msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) - msgpack.packb(agent_static_data, use_bin_type=True) distance_map_data = self.distance_map.get() malfunction_data: MalfunctionProcessData = self.malfunction_process_data msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, - "agents_static": agent_static_data, "agents": agent_data, "distance_map": distance_map_data, "malfunction": malfunction_data} @@ -845,8 +842,7 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -864,8 +860,7 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] if "distance_map" in data.keys(): self.distance_map.set(data["distance_map"]) # setup with loaded data @@ -874,6 +869,26 @@ class RailEnv(Environment): self.rail.width = self.width self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + def get_full_state_dist_msg(self): + """ + Returns environment information with distance map information as msgpack object + """ + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] + agent_data = [agent.to_list() for agent in self.agents] + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) + distance_map_data = self.distance_map.get() + msgpack.packb(distance_map_data, use_bin_type=True) + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data, + "distance_map": distance_map_data} + + return msgpack.packb(msg_data, use_bin_type=True) + def save(self, filename, save_distance_maps=False): """ Saves environment and distance map information in a file diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index f48264d52e754ca464a0cdc83cde9972492eaf6d..a19501df988bb62741d70df00b2711ea21ec85b2 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -7,7 +7,7 @@ import numpy as np from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import EnvAgentStatic +from flatland.envs.agent_utils import EnvAgent from flatland.envs.schedule_utils import Schedule AgentPosition = Tuple[int, int] @@ -291,21 +291,15 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - - # agents are always reset as not moving - if len(data['agents_static'][0]) > 5: - agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] - else: - agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] + agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data - agents_position = [a.initial_position for a in agents_static] - agents_direction = [a.direction for a in agents_static] - agents_target = [a.target for a in agents_static] - if len(data['agents_static'][0]) > 5: - agents_speed = [a.speed_data['speed'] for a in agents_static] - else: - agents_speed = None + agents_position = [a.initial_position for a in agents] + agents_direction = [a.direction for a in agents] + agents_target = [a.target for a in agents] + agents_speed = [a.speed_data['speed'] for a in agents] + agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents] + return Schedule(agent_positions=agents_position, agent_directions=agents_direction, agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index c309f9eb3b56c82b872b1842f30eace25a70026a..2024704e1ee5f5d077947a965c6a6c5e1dc675a7 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -10,7 +10,7 @@ from numpy import array import flatland.utils.rendertools as rt from flatland.core.grid.grid4_utils import mirror -from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic +from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator @@ -144,7 +144,7 @@ class View(object): def redraw(self): with self.output_generator: self.oRT.set_new_rail() - self.model.env.agents = self.model.env.agents_static + self.model.env.restart_agents() for a in self.model.env.agents: if hasattr(a, 'old_position') is False: a.old_position = a.position @@ -175,12 +175,12 @@ class View(object): self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0 def xy_to_rc(self, x, y): - rcCell = ((array([y, x]) - self.yxBase)) + rc_cell = ((array([y, x]) - self.yxBase)) nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height) nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width) - rcCell[0] = max(0, min(np.floor(rcCell[0] / nY), self.model.env.height - 1)) - rcCell[1] = max(0, min(np.floor(rcCell[1] / nX), self.model.env.width - 1)) - return rcCell + rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1)) + rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1)) + return rc_cell def log(self, *args, **kwargs): if self.output_generator: @@ -212,23 +212,23 @@ class Controller(object): y = event['canvasY'] self.debug("debug:", x, y) - rcCell = self.view.xy_to_rc(x, y) + rc_cell = self.view.xy_to_rc(x, y) bShift = event["shiftKey"] bCtrl = event["ctrlKey"] bAlt = event["altKey"] if bCtrl and not bShift and not bAlt: - self.model.click_agent(rcCell) + self.model.click_agent(rc_cell) self.lrcStroke = [] elif bShift and bCtrl: - self.model.add_target(rcCell) + self.model.add_target(rc_cell) self.lrcStroke = [] elif bAlt and not bShift and not bCtrl: - self.model.clear_cell(rcCell) + self.model.clear_cell(rc_cell) self.lrcStroke = [] - self.debug("click in cell", rcCell) - self.model.debug_cell(rcCell) + self.debug("click in cell", rc_cell) + self.model.debug_cell(rc_cell) if self.model.selected_agent is not None: self.lrcStroke = [] @@ -301,8 +301,8 @@ class Controller(object): self.view.drag_path_element(x, y) # Translate and scale from x,y to integer row,col (note order change) - rcCell = self.view.xy_to_rc(x, y) - self.editor.drag_path_element(rcCell) + rc_cell = self.view.xy_to_rc(x, y) + self.editor.drag_path_element(rc_cell) self.view.redisplay_image() @@ -326,7 +326,7 @@ class Controller(object): def rotate_agent(self, event): self.log("Rotate Agent:", self.model.selected_agent) if self.model.selected_agent is not None: - for agent_idx, agent in enumerate(self.model.env.agents_static): + for agent_idx, agent in enumerate(self.model.env.agents): if agent is None: continue if agent_idx == self.model.selected_agent: @@ -336,13 +336,7 @@ class Controller(object): def restart_agents(self, event): self.log("Restart Agents - nAgents:", self.view.regen_n_agents.value) - if self.model.init_agents_static is not None: - self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in - self.model.init_agents_static] - self.model.env.agents = None - self.model.init_agents_static = None - self.model.env.restart_agents() - self.model.env.reset(False, False) + self.model.env.reset(False, False) self.refresh(event) def regenerate(self, event): @@ -396,7 +390,6 @@ class EditorModel(object): self.env_filename = "temp.pkl" self.set_env(env) self.selected_agent = None - self.init_agents_static = None self.thread = None self.save_image_count = 0 @@ -417,12 +410,12 @@ class EditorModel(object): def set_draw_mode(self, draw_mode): self.draw_mode = draw_mode - def interpolate_path(self, rcLast, rcCell): - if np.array_equal(rcLast, rcCell): + def interpolate_path(self, rcLast, rc_cell): + if np.array_equal(rcLast, rc_cell): return [] rcLast = array(rcLast) - rcCell = array(rcCell) - rcDelta = rcCell - rcLast + rc_cell = array(rc_cell) + rcDelta = rc_cell - rcLast lrcInterp = [] # extra row,col points @@ -454,7 +447,7 @@ class EditorModel(object): lrcInterp = list(map(tuple, g2Interp)) return lrcInterp - def drag_path_element(self, rcCell): + def drag_path_element(self, rc_cell): """Mouse motion event handler for drawing. """ lrcStroke = self.lrcStroke @@ -462,15 +455,15 @@ class EditorModel(object): # Store the row,col location of the click, if we have entered a new cell if len(lrcStroke) > 0: rcLast = lrcStroke[-1] - if not np.array_equal(rcLast, rcCell): # only save at transition - lrcInterp = self.interpolate_path(rcLast, rcCell) + if not np.array_equal(rcLast, rc_cell): # only save at transition + lrcInterp = self.interpolate_path(rcLast, rc_cell) lrcStroke.extend(lrcInterp) - self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp) + self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp) else: # This is the first cell in a mouse stroke - lrcStroke.append(rcCell) - self.debug("lrcStroke ", len(lrcStroke), rcCell) + lrcStroke.append(rc_cell) + self.debug("lrcStroke ", len(lrcStroke), rc_cell) def mod_path(self, bAddRemove): # disabled functionality (no longer required) @@ -599,7 +592,6 @@ class EditorModel(object): def clear(self): self.env.rail.grid[:, :] = 0 self.env.agents = [] - self.env.agents_static = [] self.redraw() @@ -613,7 +605,7 @@ class EditorModel(object): self.redraw() def restart_agents(self): - self.env.agents = EnvAgent.list_from_static(self.env.agents_static) + self.env.restart_agents() self.redraw() def set_filename(self, filename): @@ -631,7 +623,6 @@ class EditorModel(object): self.env.restart_agents() self.env.reset(False, False) - self.init_agents_static = None self.view.oRT.update_background() self.fix_env() self.set_env(self.env) @@ -641,12 +632,7 @@ class EditorModel(object): def save(self): self.log("save to ", self.env_filename, " working dir: ", os.getcwd()) - temp_store = self.env.agents - # clear agents before save , because we want the "init" position of the agent to expert - self.env.agents = [] self.env.save(self.env_filename) - # reset agents current (current position) - self.env.agents = temp_store def save_image(self): self.view.oRT.gl.save_image('frame_{:04d}.bmp'.format(self.save_image_count)) @@ -683,7 +669,7 @@ class EditorModel(object): self.regen_size_height = size def find_agent_at(self, cell_row_col): - for agent_idx, agent in enumerate(self.env.agents_static): + for agent_idx, agent in enumerate(self.env.agents): if tuple(agent.position) == tuple(cell_row_col): return agent_idx return None @@ -703,15 +689,14 @@ class EditorModel(object): # No if self.selected_agent is None: # Create a new agent and select it. - agent_static = EnvAgentStatic(position=cell_row_col, direction=0, target=cell_row_col, moving=False) - self.selected_agent = self.env.add_agent_static(agent_static) + agent = EnvAgent(position=cell_row_col, direction=0, target=cell_row_col, moving=False) + self.selected_agent = self.env.add_agent(agent) self.view.oRT.update_background() else: # Move the selected agent to this cell - agent_static = self.env.agents_static[self.selected_agent] - agent_static.position = cell_row_col - agent_static.old_position = cell_row_col - self.env.agents = [] + agent = self.env.agents[self.selected_agent] + agent.position = cell_row_col + agent.old_position = cell_row_col else: # Yes # Have they clicked on the agent already selected? @@ -722,13 +707,11 @@ class EditorModel(object): # No - select the agent self.selected_agent = agent_idx - self.init_agents_static = None self.redraw() - def add_target(self, rcCell): + def add_target(self, rc_cell): if self.selected_agent is not None: - self.env.agents_static[self.selected_agent].target = rcCell - self.init_agents_static = None + self.env.agents[self.selected_agent].target = rc_cell self.view.oRT.update_background() self.redraw() @@ -746,11 +729,11 @@ class EditorModel(object): if self.debug_bool: self.log(*args, **kwargs) - def debug_cell(self, rcCell): - binTrans = self.env.rail.get_full_transitions(*rcCell) + def debug_cell(self, rc_cell): + binTrans = self.env.rail.get_full_transitions(*rc_cell) sbinTrans = format(binTrans, "#018b")[2:] self.debug("cell ", - rcCell, + rc_cell, "Transitions: ", binTrans, sbinTrans, diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index fc96b22d737917e98ec8c5617151f4f30ae22d10..cc496cb94cd2ba0d927749bf813cd449bd70e236 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -77,7 +77,7 @@ class RenderTool(object): def update_background(self): # create background map targets = {} - for agent_idx, agent in enumerate(self.env.agents_static): + for agent_idx, agent in enumerate(self.env.agents): if agent is None: continue targets[tuple(agent.target)] = agent_idx @@ -93,10 +93,9 @@ class RenderTool(object): self.new_rail = True def plot_agents(self, targets=True, selected_agent=None): - color_map = self.gl.get_cmap('hsv', - lut=max(len(self.env.agents), len(self.env.agents_static) + 1)) + color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1)) - for agent_idx, agent in enumerate(self.env.agents_static): + for agent_idx, agent in enumerate(self.env.agents): if agent is None: continue color = color_map(agent_idx) @@ -515,7 +514,7 @@ class RenderTool(object): # store the targets targets = {} selected = {} - for agent_idx, agent in enumerate(self.env.agents_static): + for agent_idx, agent in enumerate(self.env.agents): if agent is None: continue targets[tuple(agent.target)] = agent_idx diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 22cea8280d377b7b7b8a118a4ba1fe3d35e972a7..c4ab4e4cb561da64c3f5375b9f603b03795ba063 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -25,17 +25,21 @@ def test_walker(): rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, - predictor=ShortestPathPredictorForRailEnv(max_depth=10))) + predictor=ShortestPathPredictorForRailEnv(max_depth=10)), + ) # reset to initialize agents_static env.reset() # set initial position and direction for testing... - env.agents_static[0].position = (0, 1) - env.agents_static[0].direction = 1 - env.agents_static[0].target = (0, 0) + env.agents[0].position = (0, 1) + env.agents[0].direction = 1 + env.agents[0].target = (0, 0) # reset to set agents from agents_static env.reset(False, False) diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 8bc7235edbbed51d65818b1c4de5197b7455ddbe..a569aa35534385698369980566c426cf72b7bb4b 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -53,13 +53,11 @@ def test_grid8_set_transitions(): def check_path(env, rail, position, direction, target, expected, rendering=False): - agent = env.agents_static[0] + agent = env.agents[0] agent.position = position # south dead-end agent.direction = direction # north agent.target = target # east dead-end agent.moving = True - # reset to set agents from agents_static - # env.reset(False, False) if rendering: renderer = RenderTool(env, gl="PILSVG") renderer.render_env(show=True, show_observations=False) @@ -69,11 +67,13 @@ def check_path(env, rail, position, direction, target, expected, rendering=False def test_path_exists(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - - # reset to initialize agents_static + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) env.reset() check_path( @@ -131,11 +131,13 @@ def test_path_exists(rendering=False): def test_path_not_exists(rendering=False): rail, rail_map = make_simple_rail_unconnected() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - - # reset to initialize agents_static + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) env.reset() check_path( diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 5543f3912aa4aec750ad63024c7371b02de30309..6e5a374d6606800b311205bd86d99600db447b91 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -96,26 +96,37 @@ def test_reward_function_conflict(rendering=False): schedule_generator=random_schedule_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) obs_builder: TreeObsForRailEnv = env.obs_builder - # initialize agents_static env.reset() # set the initial position - agent = env.agents_static[0] + agent = env.agents[0] agent.position = (5, 6) # south dead-end + agent.initial_position = (5, 6) # south dead-end agent.direction = 0 # north + agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - agent = env.agents_static[1] + agent = env.agents[1] agent.position = (3, 8) # east dead-end + agent.initial_position = (3, 8) # east dead-end agent.direction = 3 # west + agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - # reset to set agents from agents_static env.reset(False, False) + env.agents[0].moving = True + env.agents[1].moving = True + env.agents[0].status = RailAgentStatus.ACTIVE + env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[0].position = (5, 6) + env.agents[1].position = (3, 8) + print("\n") + print(env.agents[0]) + print(env.agents[1]) if rendering: renderer = RenderTool(env, gl="PILSVG") @@ -174,28 +185,34 @@ def test_reward_function_waiting(rendering=False): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) obs_builder: TreeObsForRailEnv = env.obs_builder - # initialize agents_static env.reset() # set the initial position - agent = env.agents_static[0] + agent = env.agents[0] agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west + agent.initial_direction = 3 # west agent.target = (3, 1) # west dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - agent = env.agents_static[1] + agent = env.agents[1] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north + agent.initial_direction = 0 # north agent.target = (3, 8) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - # reset to set agents from agents_static env.reset(False, False) + env.agents[0].moving = True + env.agents[1].moving = True + env.agents[0].status = RailAgentStatus.ACTIVE + env.agents[1].status = RailAgentStatus.ACTIVE + env.agents[0].position = (3, 8) + env.agents[1].position = (5, 6) if rendering: renderer = RenderTool(env, gl="PILSVG") diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index cfb3e9826d3d2d262a3e28a637f5012eabeed2b8..d865bea715a525b2538f6430120c32a8bd139666 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -21,18 +21,21 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make def test_dummy_predictor(rendering=False): rail, rail_map = make_simple_rail2() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10))) - # reset to initialize agents_static + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), + ) env.reset() # set initial position and direction for testing... - env.agents_static[0].initial_position = (5, 6) - env.agents_static[0].direction = 0 - env.agents_static[0].target = (3, 0) + env.agents[0].initial_position = (5, 6) + env.agents[0].initial_direction = 0 + env.agents[0].direction = 0 + env.agents[0].target = (3, 0) - # reset to set agents from agents_static env.reset(False, False) env.set_agent_active(env.agents[0]) @@ -109,23 +112,25 @@ def test_dummy_predictor(rendering=False): def test_shortest_path_predictor(rendering=False): rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - - # reset to initialize agents_static + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) env.reset() # set the initial position - agent = env.agents_static[0] + agent = env.agents[0] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north + agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - # reset to set agents from agents_static env.reset(False, False) if rendering: @@ -243,30 +248,34 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): rail, rail_map = make_invalid_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - # initialize agents_static + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) env.reset() # set the initial position - agent = env.agents_static[0] + agent = env.agents[0] agent.initial_position = (5, 6) # south dead-end agent.position = (5, 6) # south dead-end agent.direction = 0 # north + agent.initial_direction = 0 # north agent.target = (3, 9) # east dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - agent = env.agents_static[1] + agent = env.agents[1] agent.initial_position = (3, 8) # east dead-end agent.position = (3, 8) # east dead-end agent.direction = 3 # west + agent.initial_direction = 3 # west agent.target = (6, 6) # south dead-end agent.moving = True agent.status = RailAgentStatus.ACTIVE - # reset to set agents from agents_static observations, info = env.reset(False, False, True) if rendering: diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index a7fd93d0e1fc84e274eae4e5628d3e0b43fadaac..5c518a27160f28a8a3518a4fe013f97324e4d41e 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -5,7 +5,6 @@ import numpy as np from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent -from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -22,8 +21,8 @@ def test_load_env(): env.reset() env.load_resource('env_data.tests', 'test-10x10.mpk') - agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) - env.add_agent_static(agent_static) + agent_static = EnvAgent((0, 0), 2, (5, 5), False) + env.add_agent(agent_static) assert env.get_num_agents() == 1 @@ -32,23 +31,23 @@ def test_save_load(): rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=2) env.reset() - agent_1_pos = env.agents_static[0].position - agent_1_dir = env.agents_static[0].direction - agent_1_tar = env.agents_static[0].target - agent_2_pos = env.agents_static[1].position - agent_2_dir = env.agents_static[1].direction - agent_2_tar = env.agents_static[1].target + agent_1_pos = env.agents[0].position + agent_1_dir = env.agents[0].direction + agent_1_tar = env.agents[0].target + agent_2_pos = env.agents[1].position + agent_2_dir = env.agents[1].direction + agent_2_tar = env.agents[1].target env.save("test_save.dat") env.load("test_save.dat") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) - assert (agent_1_pos == env.agents_static[0].position) - assert (agent_1_dir == env.agents_static[0].direction) - assert (agent_1_tar == env.agents_static[0].target) - assert (agent_2_pos == env.agents_static[1].position) - assert (agent_2_dir == env.agents_static[1].direction) - assert (agent_2_tar == env.agents_static[1].target) + assert (agent_1_pos == env.agents[0].position) + assert (agent_1_dir == env.agents[0].direction) + assert (agent_1_tar == env.agents[0].target) + assert (agent_2_pos == env.agents[1].position) + assert (agent_2_dir == env.agents[1].direction) + assert (agent_2_tar == env.agents[1].target) def test_rail_environment_single_agent(): @@ -158,10 +157,10 @@ def test_dead_end(): # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=1, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=1, direction=1, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(initial_position=(0, 2), direction=3, target=(0, 4), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(0, 2), initial_direction=3, direction=3, target=(0, 4), moving=False)] # In the vertical configuration: rail_map = np.array( @@ -180,10 +179,10 @@ def test_dead_end(): obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=2, target=(0, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=2, direction=2, target=(0, 0), moving=False)] rail_env.reset() - rail_env.agents = [EnvAgent(initial_position=(2, 0), direction=0, target=(4, 0), moving=False)] + rail_env.agents = [EnvAgent(initial_position=(2, 0), initial_direction=0, direction=0, target=(4, 0), moving=False)] # TODO make assertions @@ -231,7 +230,6 @@ def test_rail_env_reset(): env.reset() env.save(file_name) dist_map_shape = np.shape(env.distance_map.get()) - # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 344739b798dd2d36136ff0c35698ce0025fc781d..bb3e5e2a8d4fd9bde5255627ed4f1ac99bc261dd 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -1,6 +1,7 @@ import sys import numpy as np +import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv @@ -22,14 +23,13 @@ def test_get_shortest_paths_unreachable(): env.reset() # set the initial position - agent = env.agents_static[0] + agent = env.agents[0] agent.position = (3, 1) # west dead-end agent.initial_position = (3, 1) # west dead-end agent.direction = Grid4TransitionsEnum.WEST agent.target = (3, 9) # east dead-end agent.moving = True - # reset to set agents from agents_static env.reset(False, False) actual = get_shortest_paths(env.distance_map) @@ -38,6 +38,9 @@ def test_get_shortest_paths_unreachable(): assert actual == expected, "actual={},expected={}".format(actual, expected) +# todo file test_002.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 +@pytest.mark.skip def test_get_shortest_paths(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') env.reset() @@ -167,6 +170,9 @@ def test_get_shortest_paths(): "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) +# todo file test_002.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 +@pytest.mark.skip def test_get_shortest_paths_max_depth(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') env.reset() @@ -196,6 +202,9 @@ def test_get_shortest_paths_max_depth(): "[{}] actual={},expected={}".format(agent_handle, actual[agent_handle], expected[agent_handle]) +# todo file Level_distance_map_shortest_path.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 +@pytest.mark.skip def test_get_shortest_paths_agent_handle(): env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests') env.reset() diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index f1090d9181f01a2dc31ac3a901f76ff48b75920c..df398a21583c7c24bc4cb5e1a08e7a517ea3483c 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -81,7 +81,6 @@ def test_malfunction_process(): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) - # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) agent_halts = 0 @@ -136,7 +135,6 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) - # reset to initialize agents_static env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) @@ -182,7 +180,6 @@ def test_malfunction_before_entry(): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=SingleAgentNavigationObs() ) - # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) @@ -227,7 +224,6 @@ def test_malfunction_values_and_behavior(): obs_builder_object=SingleAgentNavigationObs() ) - # reset to initialize agents_static env.reset(False, False, activate_agents=True, random_seed=10) # Assertions @@ -248,9 +244,14 @@ def test_initial_malfunction(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs()) + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator + obs_builder_object=SingleAgentNavigationObs() + ) # reset to initialize agents_static env.reset(False, False, True, random_seed=10) print(env.agents[0].malfunction_data) @@ -401,9 +402,13 @@ def test_initial_malfunction_do_nothing(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(), number_of_agents=1) - # reset to initialize agents_static + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator + ) env.reset() set_penalties_for_replay(env) replay_config = ReplayConfig( @@ -482,7 +487,6 @@ def tests_random_interference_from_outside(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) env_data = [] @@ -507,7 +511,6 @@ def tests_random_interference_from_outside(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 env.reset(False, False, False, random_seed=10) @@ -543,9 +546,8 @@ def test_last_malfunction_step(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() - # reset to initialize agents_static env.agents[0].speed_data['speed'] = 1. / 3. - env.agents_static[0].target = (0, 0) + env.agents[0].target = (0, 0) env.reset(False, False, True) # Force malfunction to be off at beginning and next malfunction to happen in 2 steps diff --git a/tests/test_generators.py b/tests/test_generators.py index 83ef0d76360d1c662c9e181c6f88f41adb62ebbf..d8f2ed118d15192887c3161ce0edb2cbe6d1653e 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -112,7 +112,6 @@ def tests_rail_from_file(): env.reset() env.save(file_name) dist_map_shape = np.shape(env.distance_map.get()) - # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -140,7 +139,6 @@ def tests_rail_from_file(): env2.reset() env2.save(file_name_2) - # initialize agents_static rails_initial_2 = env2.rail.grid agents_initial_2 = env2.agents @@ -170,10 +168,13 @@ def tests_rail_from_file(): # Test to save without distance map and load with generating distance map - # initialize agents_static - env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2), - schedule_generator=schedule_from_file(file_name_2), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) + env4 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name_2), + schedule_generator=schedule_from_file(file_name_2), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2), + ) env4.reset() rails_loaded_4 = env4.rail.grid agents_loaded_4 = env4.agents diff --git a/tests/test_utils.py b/tests/test_utils.py index 6dfc6239ed191d06c16feeca5e8d68dbd6654952..e4fba2aebd795462971e8d1e8f16992c2affbac8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -77,9 +77,10 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: for step in range(len(test_configs[0].replay)): if step == 0: for a, test_config in enumerate(test_configs): - agent: EnvAgent = env.agents_static[a] + agent: EnvAgent = env.agents[a] # set the initial position agent.initial_position = test_config.initial_position + agent.initial_direction = test_config.initial_direction agent.direction = test_config.initial_direction agent.target = test_config.target agent.speed_data['speed'] = test_config.speed