From 1b4de9444f3caa4415b77ea0a93da0a31287b511 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 5 Nov 2019 09:28:35 +0100 Subject: [PATCH] include improvements suggested by Ch. Eichenberger --- changelog.md | 3 + flatland/envs/agent_utils.py | 32 +++++++++-- flatland/envs/rail_env.py | 10 ++-- flatland/envs/schedule_generators.py | 2 +- flatland/utils/editor.py | 57 +++++++++---------- ...t_flatland_envs_rail_env_shortest_paths.py | 3 + 6 files changed, 67 insertions(+), 40 deletions(-) diff --git a/changelog.md b/changelog.md index 5eba35db..543fc87a 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,9 @@ Changelog Changes since Flatland 2.0.0 -------------------------- +### Changes in `EnvAgent` +- class `EnvAgentStatic` was removed, so there is only class `EnvAgent` left which should simplify the handling of agents. The member `self.agents_static` of `RailEnv` was therefore also removed. Old Scence saved as pickle files cannot be loaded anymore. + ### Changes in malfunction behavior - agent attribute `next_malfunction`is not used anymore, it will be removed fully in future versions. - `break_agent()` function is introduced which induces malfunctions in agent according to poisson process diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 9f86f41b..bcc0fcc6 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,6 +1,6 @@ from enum import IntEnum from itertools import starmap -from typing import Tuple, Optional +from typing import Tuple, Optional, NamedTuple from attr import attrs, attrib, Factory @@ -15,6 +15,20 @@ class RailAgentStatus(IntEnum): DONE_REMOVED = 3 # removed from grid (position is None) -> prediction is None +Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), + ('initial_direction', Grid4TransitionsEnum), + ('direction', Grid4TransitionsEnum), + ('target', Tuple[int, int]), + ('moving', bool), + ('speed_data', dict), + ('malfunction_data', dict), + ('handle', int), + ('status', RailAgentStatus), + ('position', Tuple[int, int]), + ('old_direction', Grid4TransitionsEnum), + ('old_position', Tuple[int, int])]) + + @attrs class EnvAgent: @@ -55,10 +69,18 @@ class EnvAgent: self.old_direction = None self.moving = False - def to_list(self): - return [self.initial_position, self.initial_direction, int(self.direction), self.target, int(self.moving), - self.speed_data, self.malfunction_data, self.handle, self.status, self.position, self.old_direction, - self.old_position] + def move(self, new_pos: Tuple[int, int], new_dir: Tuple[int, int] = None): + self.old_position = self.position + self.position = new_pos + if new_dir is not None: + self.old_direction = self.direction + self.direction = new_dir + + def to_agent(self) -> Agent: + return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, + direction=self.direction, target=self.target, moving=self.moving, speed_data=self.speed_data, + malfunction_data=self.malfunction_data, handle=self.handle, status=self.status, + position=self.position, old_direction=self.old_direction, old_position=self.old_position) @classmethod def from_schedule(cls, schedule: Schedule): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 7eacb528..f6537274 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -813,7 +813,7 @@ class RailEnv(Environment): Returns state of environment in msgpack object """ grid_data = self.rail.grid.tolist() - agent_data = [agent.to_list() for agent in self.agents] + agent_data = [agent.to_agent() for agent in self.agents] msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) msg_data = { @@ -825,7 +825,7 @@ class RailEnv(Environment): """ Returns agents information in msgpack object """ - agent_data = [agent.to_list() for agent in self.agents] + agent_data = [agent.to_agent() for agent in self.agents] msg_data = { "agents": agent_data} return msgpack.packb(msg_data, use_bin_type=True) @@ -841,7 +841,7 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]] + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -859,7 +859,7 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]] + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] if "distance_map" in data.keys(): self.distance_map.set(data["distance_map"]) # setup with loaded data @@ -873,7 +873,7 @@ class RailEnv(Environment): Returns environment information with distance map information as msgpack object """ grid_data = self.rail.grid.tolist() - agent_data = [agent.to_list() for agent in self.agents] + agent_data = [agent.to_agent() for agent in self.agents] msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) distance_map_data = self.distance_map.get() diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index be19fda5..656bf70c 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -291,7 +291,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]] + agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data agents_position = [a.initial_position for a in agents] diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index ffc0e30f..eadea813 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -178,12 +178,12 @@ class View(object): self.writableData[(y - 2):(y + 2), (x - 2):(x + 2), :3] = 0 def xy_to_rc(self, x, y): - rcCell = ((array([y, x]) - self.yxBase)) + rc_cell = ((array([y, x]) - self.yxBase)) nX = np.floor((self.yxSize[0] - self.yxBase[0]) / self.model.env.height) nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width) - rcCell[0] = max(0, min(np.floor(rcCell[0] / nY), self.model.env.height - 1)) - rcCell[1] = max(0, min(np.floor(rcCell[1] / nX), self.model.env.width - 1)) - return rcCell + rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1)) + rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1)) + return rc_cell def log(self, *args, **kwargs): if self.output_generator: @@ -215,23 +215,23 @@ class Controller(object): y = event['canvasY'] self.debug("debug:", x, y) - rcCell = self.view.xy_to_rc(x, y) + rc_cell = self.view.xy_to_rc(x, y) bShift = event["shiftKey"] bCtrl = event["ctrlKey"] bAlt = event["altKey"] if bCtrl and not bShift and not bAlt: - self.model.click_agent(rcCell) + self.model.click_agent(rc_cell) self.lrcStroke = [] elif bShift and bCtrl: - self.model.add_target(rcCell) + self.model.add_target(rc_cell) self.lrcStroke = [] elif bAlt and not bShift and not bCtrl: - self.model.clear_cell(rcCell) + self.model.clear_cell(rc_cell) self.lrcStroke = [] - self.debug("click in cell", rcCell) - self.model.debug_cell(rcCell) + self.debug("click in cell", rc_cell) + self.model.debug_cell(rc_cell) if self.model.selected_agent is not None: self.lrcStroke = [] @@ -304,8 +304,8 @@ class Controller(object): self.view.drag_path_element(x, y) # Translate and scale from x,y to integer row,col (note order change) - rcCell = self.view.xy_to_rc(x, y) - self.editor.drag_path_element(rcCell) + rc_cell = self.view.xy_to_rc(x, y) + self.editor.drag_path_element(rc_cell) self.view.redisplay_image() @@ -413,12 +413,12 @@ class EditorModel(object): def set_draw_mode(self, draw_mode): self.draw_mode = draw_mode - def interpolate_path(self, rcLast, rcCell): - if np.array_equal(rcLast, rcCell): + def interpolate_path(self, rcLast, rc_cell): + if np.array_equal(rcLast, rc_cell): return [] rcLast = array(rcLast) - rcCell = array(rcCell) - rcDelta = rcCell - rcLast + rc_cell = array(rc_cell) + rcDelta = rc_cell - rcLast lrcInterp = [] # extra row,col points @@ -450,7 +450,7 @@ class EditorModel(object): lrcInterp = list(map(tuple, g2Interp)) return lrcInterp - def drag_path_element(self, rcCell): + def drag_path_element(self, rc_cell): """Mouse motion event handler for drawing. """ lrcStroke = self.lrcStroke @@ -458,15 +458,15 @@ class EditorModel(object): # Store the row,col location of the click, if we have entered a new cell if len(lrcStroke) > 0: rcLast = lrcStroke[-1] - if not np.array_equal(rcLast, rcCell): # only save at transition - lrcInterp = self.interpolate_path(rcLast, rcCell) + if not np.array_equal(rcLast, rc_cell): # only save at transition + lrcInterp = self.interpolate_path(rcLast, rc_cell) lrcStroke.extend(lrcInterp) - self.debug("lrcStroke ", len(lrcStroke), rcCell, "interp:", lrcInterp) + self.debug("lrcStroke ", len(lrcStroke), rc_cell, "interp:", lrcInterp) else: # This is the first cell in a mouse stroke - lrcStroke.append(rcCell) - self.debug("lrcStroke ", len(lrcStroke), rcCell) + lrcStroke.append(rc_cell) + self.debug("lrcStroke ", len(lrcStroke), rc_cell) def mod_path(self, bAddRemove): # disabled functionality (no longer required) @@ -701,8 +701,7 @@ class EditorModel(object): else: # Move the selected agent to this cell agent = self.env.agents[self.selected_agent] - agent.position = cell_row_col - agent.old_position = cell_row_col + agent.move(cell_row_col) else: # Yes # Have they clicked on the agent already selected? @@ -715,9 +714,9 @@ class EditorModel(object): self.redraw() - def add_target(self, rcCell): + def add_target(self, rc_cell): if self.selected_agent is not None: - self.env.agents[self.selected_agent].target = rcCell + self.env.agents[self.selected_agent].target = rc_cell self.view.oRT.update_background() self.redraw() @@ -735,11 +734,11 @@ class EditorModel(object): if self.debug_bool: self.log(*args, **kwargs) - def debug_cell(self, rcCell): - binTrans = self.env.rail.get_full_transitions(*rcCell) + def debug_cell(self, rc_cell): + binTrans = self.env.rail.get_full_transitions(*rc_cell) sbinTrans = format(binTrans, "#018b")[2:] self.debug("cell ", - rcCell, + rc_cell, "Transitions: ", binTrans, sbinTrans, diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 5a0c35df..8b066028 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -43,6 +43,7 @@ def test_get_shortest_paths_unreachable(): # todo file test_002.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 @pytest.mark.skip def test_get_shortest_paths(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') @@ -174,6 +175,7 @@ def test_get_shortest_paths(): # todo file test_002.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 @pytest.mark.skip def test_get_shortest_paths_max_depth(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') @@ -205,6 +207,7 @@ def test_get_shortest_paths_max_depth(): # todo file Level_distance_map_shortest_path.pkl has to be generated automatically +# see https://gitlab.aicrowd.com/flatland/flatland/issues/279 @pytest.mark.skip def test_get_shortest_paths_agent_handle(): env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests') -- GitLab