From 25bb98a0c5e908751d9bee21e48fd3f5380f8db7 Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Mon, 15 Apr 2019 14:06:15 +0200 Subject: [PATCH] Added TransitionMap objects + replaced relevant use of grid cell --- examples/temporary_example.py | 6 +- flatland/core/env.py | 40 ++--- flatland/core/transitionmap.py | 251 +++++++++++++++++++++++++++ flatland/utils/rail_env_generator.py | 12 +- flatland/utils/rendertools.py | 19 +- tests/test_rendertools.py | 41 ++--- 6 files changed, 298 insertions(+), 71 deletions(-) create mode 100644 flatland/core/transitionmap.py diff --git a/examples/temporary_example.py b/examples/temporary_example.py index cd6d42de..2ea68cfd 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -17,7 +17,7 @@ env = RailEnv(rail, number_of_agents=10) env.reset() env_renderer = RenderTool(env) -env_renderer.renderEnv() +env_renderer.renderEnv(show=True) # Example generate a rail given a manual specification, @@ -37,7 +37,7 @@ env.agents_target = [[1, 1]] env.agents_direction = [1] env_renderer = RenderTool(env) -env_renderer.renderEnv() +env_renderer.renderEnv(show=True) print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ @@ -64,4 +64,4 @@ for step in range(100): i = i+1 i += 1 - env_renderer.renderEnv() + env_renderer.renderEnv(show=True) diff --git a/flatland/core/env.py b/flatland/core/env.py index 7cc3b768..d6493507 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -5,8 +5,6 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv """ import random -from .transitions import RailEnvTransitions - class Environment: """ @@ -133,8 +131,8 @@ class RailEnv: """ self.rail = rail - self.width = len(self.rail[0]) - self.height = len(self.rail) + self.width = rail.width + self.height = rail.height self.number_of_agents = number_of_agents @@ -144,8 +142,6 @@ class RailEnv: self.agents_handles = list(range(self.number_of_agents)) - self.trans = RailEnvTransitions() - def get_agent_handles(self): return self.agents_handles @@ -159,7 +155,7 @@ class RailEnv: valid_positions = [] for r in range(self.height): for c in range(self.width): - if self.rail[r][c] > 0: + if self.rail.get_transitions((r, c)) > 0: valid_positions.append((r, c)) self.agents_position = random.sample(valid_positions, @@ -175,8 +171,8 @@ class RailEnv: valid_movements = [] for direction in range(4): position = self.agents_position[i] - moves = self.trans.get_transitions( - self.rail[position[0]][position[1]], direction) + moves = self.rail.get_transitions( + (position[0], position[1], direction)) for move_index in range(4): if moves[move_index]: valid_movements.append((direction, move_index)) @@ -251,8 +247,9 @@ class RailEnv: if action == 2: # compute number of possible transitions in the current # cell + is_deadend = False nbits = 0 - tmp = self.rail[pos[0]][pos[1]] + tmp = self.rail.get_transitions((pos[0], pos[1])) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 @@ -270,14 +267,13 @@ class RailEnv: elif direction == 3: reverse_direction = 1 - valid_transition = self.trans.get_transition( - self.rail[pos[0]][pos[1]], - reverse_direction, + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), reverse_direction) - if valid_transition: direction = reverse_direction - movement = direction + movement = reverse_direction + is_deadend = True new_position = self._new_position(pos, movement) @@ -289,15 +285,14 @@ class RailEnv: new_position[0] < 0 or new_position[1] < 0: new_cell_isValid = False - elif self.rail[new_position[0]][new_position[1]] > 0: + elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: new_cell_isValid = True else: new_cell_isValid = False - transition_isValid = self.trans.get_transition( - self.rail[pos[0]][pos[1]], - direction, - movement) + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend cell_isFree = True for j in range(self.number_of_agents): @@ -363,8 +358,7 @@ class RailEnv: return 1 if node not in visited: visited.add(node) - moves = self.trans.get_transitions( - self.rail[node[0][0]][node[0][1]], node[1]) + moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) for move_index in range(4): if moves[move_index]: stack.append((self._new_position(node[0], move_index), @@ -373,7 +367,7 @@ class RailEnv: # If cell is a dead-end, append previous node with reversed # orientation! nbits = 0 - tmp = self.rail[node[0][0]][node[0][1]] + tmp = self.rail.get_transitions((node[0][0], node[0][1])) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 diff --git a/flatland/core/transitionmap.py b/flatland/core/transitionmap.py new file mode 100644 index 00000000..d3fcf5c8 --- /dev/null +++ b/flatland/core/transitionmap.py @@ -0,0 +1,251 @@ +""" +TransitionMap and derived classes. +""" + +import numpy as np + +from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions + + +class TransitionMap: + """ + Base TransitionMap class. + + Generic class that implements a collection of transitions over a set of + cells. + """ + + def get_transitions(self, cell_id): + """ + Return a tuple of transitions available in a cell specified by + `cell_id' (e.g., a tuple of size of the maximum number of transitions, + with values 0 or 1, or potentially in between, + for stochastic transitions). + + Parameters + ---------- + cell_id : [cell identifier] + The cell_id object depends on the specific implementation. + It generally is an int (e.g., an index) or a tuple of indices. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + raise NotImplementedError() + + def set_transitions(self, cell_id, new_transitions): + """ + Replaces the available transitions in cell `cell_id' with the tuple + `new_transitions'. `new_transitions' must have + one element for each possible transition. + + Parameters + ---------- + cell_id : [cell identifier] + The cell_id object depends on the specific implementation. + It generally is an int (e.g., an index) or a tuple of indices. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + """ + raise NotImplementedError() + + def get_transition(self, cell_id, transition_index): + """ + Return the status of whether an agent in cell `cell_id' can perform a + movement along transition `transition_index (e.g., the NESW direction + of movement, for agents on a grid). + + Parameters + ---------- + cell_id : [cell identifier] + The cell_id object depends on the specific implementation. + It generally is an int (e.g., an index) or a tuple of indices. + transition_index : int + Index of the transition to probe, as index in the tuple returned by + get_transitions(). e.g., the NESW direction of movement, for agents + on a grid. + + Returns + ------- + int or float (depending on derived class) + Validity of the requested transition (e.g., + 0/1 allowed/not allowed, a probability in [0,1], etc...) + + """ + raise NotImplementedError() + + def set_transition(self, cell_id, transition_index, new_transition): + """ + Replaces the validity of transition to `transition_index' in cell + `cell_id' with the new `new_transition'. + + + Parameters + ---------- + cell_id : [cell identifier] + The cell_id object depends on the specific implementation. + It generally is an int (e.g., an index) or a tuple of indices. + transition_index : int + Index of the transition to probe, as index in the tuple returned by + get_transitions(). e.g., the NESW direction of movement, for agents + on a grid. + new_transition : int or float (depending on derived class) + Validity of the requested transition (e.g., + 0/1 allowed/not allowed, a probability in [0,1], etc...) + + """ + raise NotImplementedError() + + +class GridTransitionMap(TransitionMap): + """ + Implements a TransitionMap over a 2D grid. + + GridTransitionMap implements utility functions. + """ + + def __init__(self, width, height, transitions=Grid4Transitions([])): + """ + Builder for GridTransitionMap object. + + Parameters + ---------- + width : int + Width of the grid. + height : int + Height of the grid. + transitions_class : Transitions object + The Transitions object to use to encode/decode transitions over the + grid. + + """ + + self.width = width + self.height = height + self.transitions = transitions + + if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions): + self.grid = np.ndarray((height, width), dtype=np.uint16) + elif isinstance(self.transitions, Grid8Transitions): + self.grid = np.ndarray((height, width), dtype=np.uint64) + + def get_transitions(self, cell_id): + """ + Return a tuple of transitions available in a cell specified by + `cell_id' (e.g., a tuple of size of the maximum number of transitions, + with values 0 or 1, or potentially in between, + for stochastic transitions). + + Parameters + ---------- + cell_id : tuple + The cell_id indices a cell as (column, row, orientation), + where orientation is the direction an agent is facing within a cell. + Alternatively, it can be accessed as (column, row) to return the + full cell content. + + Returns + ------- + tuple + List of the validity of transitions in the cell. + + """ + if len(cell_id) == 3: + return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2]) + elif len(cell_id) == 2: + return self.grid[cell_id[0]][cell_id[1]] + else: + print('GridTransitionMap.get_transitions() ERROR: \ + wrong cell_id tuple.') + return () + + def set_transitions(self, cell_id, new_transitions): + """ + Replaces the available transitions in cell `cell_id' with the tuple + `new_transitions'. `new_transitions' must have + one element for each possible transition. + + Parameters + ---------- + cell_id : tuple + The cell_id indices a cell as (column, row, orientation), + where orientation is the direction an agent is facing within a cell. + Alternatively, it can be accessed as (column, row) to replace the + full cell content. + new_transitions : tuple + Tuple of new transitions validitiy for the cell. + + """ + if len(cell_id) == 3: + self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2], new_transitions) + elif len(cell_id) == 2: + self.grid[cell_id[0]][cell_id[1]] = new_transitions + else: + print('GridTransitionMap.get_transitions() ERROR: \ + wrong cell_id tuple.') + + def get_transition(self, cell_id, transition_index): + """ + Return the status of whether an agent in cell `cell_id' can perform a + movement along transition `transition_index (e.g., the NESW direction + of movement, for agents on a grid). + + Parameters + ---------- + cell_id : tuple + The cell_id indices a cell as (column, row, orientation), + where orientation is the direction an agent is facing within a cell. + transition_index : int + Index of the transition to probe, as index in the tuple returned by + get_transitions(). e.g., the NESW direction of movement, for agents + on a grid. + + Returns + ------- + int or float (depending on derived class) + Validity of the requested transition (e.g., + 0/1 allowed/not allowed, a probability in [0,1], etc...) + + """ + if len(cell_id) != 3: + print('GridTransitionMap.get_transition() ERROR: \ + wrong cell_id tuple.') + return () + return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index) + + def set_transition(self, cell_id, transition_index, new_transition): + """ + Replaces the validity of transition to `transition_index' in cell + `cell_id' with the new `new_transition'. + + + Parameters + ---------- + cell_id : tuple + The cell_id indices a cell as (column, row, orientation), + where orientation is the direction an agent is facing within a cell. + transition_index : int + Index of the transition to probe, as index in the tuple returned by + get_transitions(). e.g., the NESW direction of movement, for agents + on a grid. + new_transition : int or float (depending on derived class) + Validity of the requested transition (e.g., + 0/1 allowed/not allowed, a probability in [0,1], etc...) + + """ + if len(cell_id) != 3: + print('GridTransitionMap.set_transition() ERROR: \ + wrong cell_id tuple.') + return + self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition) + + +# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids +# (most general implementation) or to make Grid-class specific methods for +# slicing over the 3 dimensions? I'd say both perhaps. + +# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?) diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py index 4d73d520..5b292f03 100644 --- a/flatland/utils/rail_env_generator.py +++ b/flatland/utils/rail_env_generator.py @@ -6,6 +6,7 @@ import random import numpy as np from flatland.core.transitions import RailEnvTransitions +from flatland.core.transitionmap import GridTransitionMap def generate_rail_from_manual_specifications(rail_spec): @@ -30,7 +31,7 @@ def generate_rail_from_manual_specifications(rail_spec): height = len(rail_spec) width = len(rail_spec[0]) - rail = np.zeros((height, width), dtype=np.uint16) + rail = GridTransitionMap(width=width, height=height, transitions=t_utils) for r in range(height): for c in range(width): @@ -38,8 +39,8 @@ def generate_rail_from_manual_specifications(rail_spec): if cell[0] < 0 or cell[0] >= len(t_utils.transitions): print("ERROR - invalid cell type=", cell[0]) return [] - rail[r, c] = t_utils.rotate_transition( - t_utils.transitions[cell[0]], cell[1]) + rail.set_transitions((r, c), t_utils.rotate_transition( + t_utils.transitions[cell[0]], cell[1])) return rail @@ -300,4 +301,7 @@ def generate_random_rail(width, height): if rail[r][c] is None: rail[r][c] = int('0000000000000000', 2) - return np.asarray(rail, dtype=np.uint16) + tmp_rail = np.asarray(rail, dtype=np.uint16) + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + return_rail.grid = tmp_rail + return return_rail diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 0068928e..fc426f80 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -4,7 +4,6 @@ import numpy as np from numpy import array import xarray as xr import matplotlib.pyplot as plt -from flatland.core.transitions import RailEnvTransitions class RenderTool(object): @@ -25,7 +24,6 @@ class RenderTool(object): gCentres = xr.DataArray(gGrid, dims=["xy", "p1", "p2"], coords={"xy": ["x", "y"]}) + xyPixHalf - RETrans = RailEnvTransitions() def __init__(self, env): self.env = env @@ -56,16 +54,14 @@ class RenderTool(object): # TODO: this was `rcDir' but it was undefined rcNext = rcPos + iDir # transition for next cell - oTrans = self.env.rail[rcNext[0]][rcNext[1]] - tbTrans = RailEnvTransitions. \ - get_transitions(oTrans, iDir) + tbTrans = self.env.rail. \ + get_transitions((rcNext[0], rcNext[1], iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = self.__class__.gTransRC[giTrans] for visit in lVisits: # transition for next cell - oTrans = self.env.rail[visit.rc] - tbTrans = rt.RETrans.get_transitions(oTrans, visit.iDir) + tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions gTransRCAg = rt.gTransRC[giTrans] self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color) @@ -102,11 +98,9 @@ class RenderTool(object): [0, 1] # available transition indices, ie N, E ) """ - rt = self.__class__ # TODO: suggest we provide an accessor in RailEnv - oTrans = self.env.rail[rcPos] # transition for current cell - tbTrans = rt.RETrans.get_transitions(oTrans, iDir) + tbTrans = self.env.get_transitions((rcPos[0], rcPos[1], iDir)) giTrans = np.where(tbTrans)[0] # RC list of transitions # HACK: workaround dead-end transitions @@ -406,7 +400,6 @@ class RenderTool(object): ]) plt.plot(*xyArrow.T, color=sColor) - RETrans = RailEnvTransitions() env = self.env # Draw cells grid @@ -442,7 +435,7 @@ class RenderTool(object): xyCentre = array([x0, y1]) + cell_size / 2 # cell transition values - oCell = env.rail[r, c] + oCell = env.rail.get_transitions((r, c)) # Special Case 7, with a single bit; terminate at center nbits = 0 @@ -463,7 +456,7 @@ class RenderTool(object): # renderer.push() # renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS) - tMoves = RETrans.get_transitions(oCell, orientation) + tMoves = env.rail.get_transitions((r, c, orientation)) # to_ori = (orientation + 2) % 4 for to_ori in range(4): diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index c68f84db..2ae69015 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -1,32 +1,20 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +""" +Tests for `flatland` package. +""" from flatland.core.env import RailEnv -#from flatland.core.transitions import GridTransitions import numpy as np import random import os -from recordtype import recordtype - -import numpy as np -from numpy import array -import xarray as xr import matplotlib.pyplot as plt -from flatland.core.transitions import RailEnvTransitions -#import flatland.core.env from flatland.utils import rail_env_generator -from flatland.core.env import RailEnv import flatland.utils.rendertools as rt - - -"""Tests for `flatland` package.""" - - - def checkFrozenImage(sFileImage): sDirRoot = "." sTmpFileImage = sDirRoot + "/images/test/" + sFileImage @@ -37,7 +25,7 @@ def checkFrozenImage(sFileImage): plt.savefig(sTmpFileImage) bytesFrozenImage = None - for sDir in [ "/images/", "/images/test/" ]: + for sDir in ["/images/", "/images/test/"]: sfPath = sDirRoot + sDir + sFileImage bytesImage = plt.imread(sfPath) if bytesFrozenImage is None: @@ -49,37 +37,34 @@ def checkFrozenImage(sFileImage): def test_render_env(): random.seed(100) - oRail = rail_env_generator.generate_random_rail(10,10) + oRail = rail_env_generator.generate_random_rail(10, 10) type(oRail), len(oRail) oEnv = RailEnv(oRail, number_of_agents=2) oEnv.reset() oRT = rt.RenderTool(oEnv) - plt.figure(figsize=(10,10)) + plt.figure(figsize=(10, 10)) oRT.renderEnv() checkFrozenImage("basic-env.png") - plt.figure(figsize=(10,10)) + plt.figure(figsize=(10, 10)) oRT.renderEnv() - + lVisits = oRT.getTreeFromRail( - oEnv.agents_position[0], - oEnv.agents_direction[0], + oEnv.agents_position[0], + oEnv.agents_direction[0], nDepth=17, bPlot=True) checkFrozenImage("env-tree-spatial.png") - - plt.figure(figsize=(8,8)) + + plt.figure(figsize=(8, 8)) xyTarg = oRT.env.agents_target[0] visitDest = oRT.plotTree(lVisits, xyTarg) checkFrozenImage("env-tree-graph.png") - - oFig = plt.figure(figsize=(10,10)) + plt.figure(figsize=(10, 10)) oRT.renderEnv() oRT.plotPath(visitDest) checkFrozenImage("env-path.png") - - -- GitLab