Skip to content
Snippets Groups Projects
Commit 25bb98a0 authored by spiglerg's avatar spiglerg
Browse files

Added TransitionMap objects + replaced relevant use of grid cell

parent 2374d1b9
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ env = RailEnv(rail, number_of_agents=10)
env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv()
env_renderer.renderEnv(show=True)
# Example generate a rail given a manual specification,
......@@ -37,7 +37,7 @@ env.agents_target = [[1, 1]]
env.agents_direction = [1]
env_renderer = RenderTool(env)
env_renderer.renderEnv()
env_renderer.renderEnv(show=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
......@@ -64,4 +64,4 @@ for step in range(100):
i = i+1
i += 1
env_renderer.renderEnv()
env_renderer.renderEnv(show=True)
......@@ -5,8 +5,6 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv
"""
import random
from .transitions import RailEnvTransitions
class Environment:
"""
......@@ -133,8 +131,8 @@ class RailEnv:
"""
self.rail = rail
self.width = len(self.rail[0])
self.height = len(self.rail)
self.width = rail.width
self.height = rail.height
self.number_of_agents = number_of_agents
......@@ -144,8 +142,6 @@ class RailEnv:
self.agents_handles = list(range(self.number_of_agents))
self.trans = RailEnvTransitions()
def get_agent_handles(self):
return self.agents_handles
......@@ -159,7 +155,7 @@ class RailEnv:
valid_positions = []
for r in range(self.height):
for c in range(self.width):
if self.rail[r][c] > 0:
if self.rail.get_transitions((r, c)) > 0:
valid_positions.append((r, c))
self.agents_position = random.sample(valid_positions,
......@@ -175,8 +171,8 @@ class RailEnv:
valid_movements = []
for direction in range(4):
position = self.agents_position[i]
moves = self.trans.get_transitions(
self.rail[position[0]][position[1]], direction)
moves = self.rail.get_transitions(
(position[0], position[1], direction))
for move_index in range(4):
if moves[move_index]:
valid_movements.append((direction, move_index))
......@@ -251,8 +247,9 @@ class RailEnv:
if action == 2:
# compute number of possible transitions in the current
# cell
is_deadend = False
nbits = 0
tmp = self.rail[pos[0]][pos[1]]
tmp = self.rail.get_transitions((pos[0], pos[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......@@ -270,14 +267,13 @@ class RailEnv:
elif direction == 3:
reverse_direction = 1
valid_transition = self.trans.get_transition(
self.rail[pos[0]][pos[1]],
reverse_direction,
valid_transition = self.rail.get_transition(
(pos[0], pos[1], direction),
reverse_direction)
if valid_transition:
direction = reverse_direction
movement = direction
movement = reverse_direction
is_deadend = True
new_position = self._new_position(pos, movement)
......@@ -289,15 +285,14 @@ class RailEnv:
new_position[0] < 0 or new_position[1] < 0:
new_cell_isValid = False
elif self.rail[new_position[0]][new_position[1]] > 0:
elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
new_cell_isValid = True
else:
new_cell_isValid = False
transition_isValid = self.trans.get_transition(
self.rail[pos[0]][pos[1]],
direction,
movement)
transition_isValid = self.rail.get_transition(
(pos[0], pos[1], direction),
movement) or is_deadend
cell_isFree = True
for j in range(self.number_of_agents):
......@@ -363,8 +358,7 @@ class RailEnv:
return 1
if node not in visited:
visited.add(node)
moves = self.trans.get_transitions(
self.rail[node[0][0]][node[0][1]], node[1])
moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
for move_index in range(4):
if moves[move_index]:
stack.append((self._new_position(node[0], move_index),
......@@ -373,7 +367,7 @@ class RailEnv:
# If cell is a dead-end, append previous node with reversed
# orientation!
nbits = 0
tmp = self.rail[node[0][0]][node[0][1]]
tmp = self.rail.get_transitions((node[0][0], node[0][1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
......
"""
TransitionMap and derived classes.
"""
import numpy as np
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
class TransitionMap:
"""
Base TransitionMap class.
Generic class that implements a collection of transitions over a set of
cells.
"""
def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
raise NotImplementedError()
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
one element for each possible transition.
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
"""
raise NotImplementedError()
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
Returns
-------
int or float (depending on derived class)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Parameters
----------
cell_id : [cell identifier]
The cell_id object depends on the specific implementation.
It generally is an int (e.g., an index) or a tuple of indices.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on derived class)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
raise NotImplementedError()
class GridTransitionMap(TransitionMap):
"""
Implements a TransitionMap over a 2D grid.
GridTransitionMap implements utility functions.
"""
def __init__(self, width, height, transitions=Grid4Transitions([])):
"""
Builder for GridTransitionMap object.
Parameters
----------
width : int
Width of the grid.
height : int
Height of the grid.
transitions_class : Transitions object
The Transitions object to use to encode/decode transitions over the
grid.
"""
self.width = width
self.height = height
self.transitions = transitions
if isinstance(self.transitions, Grid4Transitions) or isinstance(self.transitions, RailEnvTransitions):
self.grid = np.ndarray((height, width), dtype=np.uint16)
elif isinstance(self.transitions, Grid8Transitions):
self.grid = np.ndarray((height, width), dtype=np.uint64)
def get_transitions(self, cell_id):
"""
Return a tuple of transitions available in a cell specified by
`cell_id' (e.g., a tuple of size of the maximum number of transitions,
with values 0 or 1, or potentially in between,
for stochastic transitions).
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to return the
full cell content.
Returns
-------
tuple
List of the validity of transitions in the cell.
"""
if len(cell_id) == 3:
return self.transitions.get_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2])
elif len(cell_id) == 2:
return self.grid[cell_id[0]][cell_id[1]]
else:
print('GridTransitionMap.get_transitions() ERROR: \
wrong cell_id tuple.')
return ()
def set_transitions(self, cell_id, new_transitions):
"""
Replaces the available transitions in cell `cell_id' with the tuple
`new_transitions'. `new_transitions' must have
one element for each possible transition.
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
Alternatively, it can be accessed as (column, row) to replace the
full cell content.
new_transitions : tuple
Tuple of new transitions validitiy for the cell.
"""
if len(cell_id) == 3:
self.transitions.set_transitions(self.grid[cell_id[0]][cell_id[1]], cell_id[2], new_transitions)
elif len(cell_id) == 2:
self.grid[cell_id[0]][cell_id[1]] = new_transitions
else:
print('GridTransitionMap.get_transitions() ERROR: \
wrong cell_id tuple.')
def get_transition(self, cell_id, transition_index):
"""
Return the status of whether an agent in cell `cell_id' can perform a
movement along transition `transition_index (e.g., the NESW direction
of movement, for agents on a grid).
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
Returns
-------
int or float (depending on derived class)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
if len(cell_id) != 3:
print('GridTransitionMap.get_transition() ERROR: \
wrong cell_id tuple.')
return ()
return self.transitions.get_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index)
def set_transition(self, cell_id, transition_index, new_transition):
"""
Replaces the validity of transition to `transition_index' in cell
`cell_id' with the new `new_transition'.
Parameters
----------
cell_id : tuple
The cell_id indices a cell as (column, row, orientation),
where orientation is the direction an agent is facing within a cell.
transition_index : int
Index of the transition to probe, as index in the tuple returned by
get_transitions(). e.g., the NESW direction of movement, for agents
on a grid.
new_transition : int or float (depending on derived class)
Validity of the requested transition (e.g.,
0/1 allowed/not allowed, a probability in [0,1], etc...)
"""
if len(cell_id) != 3:
print('GridTransitionMap.set_transition() ERROR: \
wrong cell_id tuple.')
return
self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition)
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
# TODO: override __getitem__ and __setitem__ (cell contents, not transitions?)
......@@ -6,6 +6,7 @@ import random
import numpy as np
from flatland.core.transitions import RailEnvTransitions
from flatland.core.transitionmap import GridTransitionMap
def generate_rail_from_manual_specifications(rail_spec):
......@@ -30,7 +31,7 @@ def generate_rail_from_manual_specifications(rail_spec):
height = len(rail_spec)
width = len(rail_spec[0])
rail = np.zeros((height, width), dtype=np.uint16)
rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
for r in range(height):
for c in range(width):
......@@ -38,8 +39,8 @@ def generate_rail_from_manual_specifications(rail_spec):
if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
print("ERROR - invalid cell type=", cell[0])
return []
rail[r, c] = t_utils.rotate_transition(
t_utils.transitions[cell[0]], cell[1])
rail.set_transitions((r, c), t_utils.rotate_transition(
t_utils.transitions[cell[0]], cell[1]))
return rail
......@@ -300,4 +301,7 @@ def generate_random_rail(width, height):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
return np.asarray(rail, dtype=np.uint16)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail
......@@ -4,7 +4,6 @@ import numpy as np
from numpy import array
import xarray as xr
import matplotlib.pyplot as plt
from flatland.core.transitions import RailEnvTransitions
class RenderTool(object):
......@@ -25,7 +24,6 @@ class RenderTool(object):
gCentres = xr.DataArray(gGrid,
dims=["xy", "p1", "p2"],
coords={"xy": ["x", "y"]}) + xyPixHalf
RETrans = RailEnvTransitions()
def __init__(self, env):
self.env = env
......@@ -56,16 +54,14 @@ class RenderTool(object):
# TODO: this was `rcDir' but it was undefined
rcNext = rcPos + iDir
# transition for next cell
oTrans = self.env.rail[rcNext[0]][rcNext[1]]
tbTrans = RailEnvTransitions. \
get_transitions(oTrans, iDir)
tbTrans = self.env.rail. \
get_transitions((rcNext[0], rcNext[1], iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = self.__class__.gTransRC[giTrans]
for visit in lVisits:
# transition for next cell
oTrans = self.env.rail[visit.rc]
tbTrans = rt.RETrans.get_transitions(oTrans, visit.iDir)
tbTrans = self.env.rail.get_transitions((visit.rc[0], visit.rc[1], visit.iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = rt.gTransRC[giTrans]
self.plotTrans(visit.rc, gTransRCAg, depth=str(visit.iDepth), color=color)
......@@ -102,11 +98,9 @@ class RenderTool(object):
[0, 1] # available transition indices, ie N, E
)
"""
rt = self.__class__
# TODO: suggest we provide an accessor in RailEnv
oTrans = self.env.rail[rcPos] # transition for current cell
tbTrans = rt.RETrans.get_transitions(oTrans, iDir)
tbTrans = self.env.get_transitions((rcPos[0], rcPos[1], iDir))
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
......@@ -406,7 +400,6 @@ class RenderTool(object):
])
plt.plot(*xyArrow.T, color=sColor)
RETrans = RailEnvTransitions()
env = self.env
# Draw cells grid
......@@ -442,7 +435,7 @@ class RenderTool(object):
xyCentre = array([x0, y1]) + cell_size / 2
# cell transition values
oCell = env.rail[r, c]
oCell = env.rail.get_transitions((r, c))
# Special Case 7, with a single bit; terminate at center
nbits = 0
......@@ -463,7 +456,7 @@ class RenderTool(object):
# renderer.push()
# renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
tMoves = RETrans.get_transitions(oCell, orientation)
tMoves = env.rail.get_transitions((r, c, orientation))
# to_ori = (orientation + 2) % 4
for to_ori in range(4):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tests for `flatland` package.
"""
from flatland.core.env import RailEnv
#from flatland.core.transitions import GridTransitions
import numpy as np
import random
import os
from recordtype import recordtype
import numpy as np
from numpy import array
import xarray as xr
import matplotlib.pyplot as plt
from flatland.core.transitions import RailEnvTransitions
#import flatland.core.env
from flatland.utils import rail_env_generator
from flatland.core.env import RailEnv
import flatland.utils.rendertools as rt
"""Tests for `flatland` package."""
def checkFrozenImage(sFileImage):
sDirRoot = "."
sTmpFileImage = sDirRoot + "/images/test/" + sFileImage
......@@ -37,7 +25,7 @@ def checkFrozenImage(sFileImage):
plt.savefig(sTmpFileImage)
bytesFrozenImage = None
for sDir in [ "/images/", "/images/test/" ]:
for sDir in ["/images/", "/images/test/"]:
sfPath = sDirRoot + sDir + sFileImage
bytesImage = plt.imread(sfPath)
if bytesFrozenImage is None:
......@@ -49,37 +37,34 @@ def checkFrozenImage(sFileImage):
def test_render_env():
random.seed(100)
oRail = rail_env_generator.generate_random_rail(10,10)
oRail = rail_env_generator.generate_random_rail(10, 10)
type(oRail), len(oRail)
oEnv = RailEnv(oRail, number_of_agents=2)
oEnv.reset()
oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10,10))
plt.figure(figsize=(10, 10))
oRT.renderEnv()
checkFrozenImage("basic-env.png")
plt.figure(figsize=(10,10))
plt.figure(figsize=(10, 10))
oRT.renderEnv()
lVisits = oRT.getTreeFromRail(
oEnv.agents_position[0],
oEnv.agents_direction[0],
oEnv.agents_position[0],
oEnv.agents_direction[0],
nDepth=17, bPlot=True)
checkFrozenImage("env-tree-spatial.png")
plt.figure(figsize=(8,8))
plt.figure(figsize=(8, 8))
xyTarg = oRT.env.agents_target[0]
visitDest = oRT.plotTree(lVisits, xyTarg)
checkFrozenImage("env-tree-graph.png")
oFig = plt.figure(figsize=(10,10))
plt.figure(figsize=(10, 10))
oRT.renderEnv()
oRT.plotPath(visitDest)
checkFrozenImage("env-path.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment