Skip to content
Snippets Groups Projects
Commit ef60e4ce authored by maljx's avatar maljx
Browse files

refactor GridTransition into Grid4, Grid8Transition.

parent 89aa1a38
No related branches found
No related tags found
No related merge requests found
......@@ -144,9 +144,7 @@ class RailEnv:
self.agents_handles = list(range(self.number_of_agents))
self.t_utils = RailEnvTransitions()
# TODO : bad hack for pylint 80 characters per line; shortened function
self.gtfotd = self.t_utils.get_transition_from_orientation_to_direction
self.trans = RailEnvTransitions()
def get_agent_handles(self):
return self.agents_handles
......@@ -177,7 +175,7 @@ class RailEnv:
valid_movements = []
for direction in range(4):
position = self.agents_position[i]
moves = self.t_utils.get_transitions_from_orientation(
moves = self.trans.get_transitions(
self.rail[position[0]][position[1]], direction)
for move_index in range(4):
if moves[move_index]:
......@@ -272,7 +270,7 @@ class RailEnv:
elif direction == 3:
reverse_direction = 1
valid_transition = self.gtfotd(
valid_transition = self.trans.get_transition(
self.rail[pos[0]][pos[1]],
reverse_direction,
reverse_direction)
......@@ -295,7 +293,7 @@ class RailEnv:
else:
new_cell_isValid = False
transition_isValid = self.gtfotd(
transition_isValid = self.trans.get_transition(
self.rail[pos[0]][pos[1]],
direction,
movement)
......@@ -364,7 +362,7 @@ class RailEnv:
return 1
if node not in visited:
visited.add(node)
moves = self.t_utils.get_transitions_from_orientation(
moves = self.trans.get_transitions(
self.rail[node[0][0]][node[0][1]], node[1])
for move_index in range(4):
if moves[move_index]:
......
This diff is collapsed.
......@@ -2,8 +2,8 @@
The rail_env_generator module defines provides utilities to generate env
bitmaps for the RailEnv environment.
"""
import numpy as np
import random
import numpy as np
from flatland.core.transitions import RailEnvTransitions
......@@ -82,8 +82,7 @@ def generate_random_rail(width, height):
for i in range(len(t_utils.transitions)-1): # don't include dead-ends
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions_from_orientation(
t_utils.transitions[i], dir_)
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
......@@ -148,8 +147,7 @@ def generate_random_rail(width, height):
max_bit = 0
for k in range(4):
max_bit |= \
t_utils.get_transition_from_orientation_to_direction(
neigh_trans, k, el[1])
t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
......
......@@ -58,7 +58,7 @@ class RenderTool(object):
# transition for next cell
oTrans = self.env.rail[rcNext[0]][rcNext[1]]
tbTrans = RailEnvTransitions. \
get_transitions_from_orientation(oTrans, iDir)
get_transitions(oTrans, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
gTransRCAg = self.__class__.gTransRC[giTrans]
......@@ -106,7 +106,7 @@ class RenderTool(object):
# TODO: suggest we provide an accessor in RailEnv
oTrans = self.env.rail[rcPos] # transition for current cell
tbTrans = rt.RETrans.get_transitions_from_orientation(oTrans, iDir)
tbTrans = rt.RETrans.get_transitions(oTrans, iDir)
giTrans = np.where(tbTrans)[0] # RC list of transitions
# HACK: workaround dead-end transitions
......@@ -363,8 +363,7 @@ class RenderTool(object):
# renderer.translate(c * CELL_PIXELS, r * CELL_PIXELS)
if True:
tMoves = RETrans.get_transitions_from_orientation(
oCell, orientation)
tMoves = RETrans.get_transitions(oCell, orientation)
# to_ori = (orientation + 2) % 4
for to_ori in range(4):
......
......@@ -2,24 +2,22 @@
# -*- coding: utf-8 -*-
from flatland.core.env import RailEnv
from flatland.core.transitions import GridTransitions
from flatland.core.transitions import Grid4Transitions
import numpy as np
import random
"""Tests for `flatland` package."""
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
int('1000010000100001', 2), # Case 3 - diamond drossing
int('1001011000100001', 2), # Case 4 - single slip switch
int('1100110000110011', 2), # Case 5 - double slip switch
int('0101001000000010', 2), # Case 6 - symmetrical switch
int('0010000000000000', 2)] # Case 7 - dead end
# We instantiate the following map on a 3x3 grid
# _ _
......@@ -27,7 +25,7 @@ def test_rail_environment_single_agent():
# | | |
# \_/\_/
transitions = GridTransitions([], False)
transitions = Grid4Transitions([])
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
......@@ -51,7 +49,7 @@ def test_rail_environment_single_agent():
# Check that trains are always initialized at a consistent position / direction.
# They should always be able to go somewhere.
assert(transitions.get_transitions_from_orientation(
assert(transitions.get_transitions(
rail_map[rail_env.agents_position[0]],
rail_env.agents_direction[0]) != (0, 0, 0, 0))
......
......@@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
"""Tests for `flatland` package."""
from flatland.core.transitions import RailEnvTransitions, GridTransitions
from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
def test_valid_railenv_transitions():
......@@ -14,36 +14,36 @@ def test_valid_railenv_transitions():
# 'W': 3}
for i in range(2):
assert(rail_env_trans.get_transitions_from_orientation(
assert(rail_env_trans.get_transitions(
int('1100110000110011', 2), i) == (1, 1, 0, 0))
assert(rail_env_trans.get_transitions_from_orientation(
assert(rail_env_trans.get_transitions(
int('1100110000110011', 2), 2+i) == (0, 0, 1, 1))
no_transition_cell = int('0000000000000000', 2)
for i in range(4):
assert(rail_env_trans.get_transitions_from_orientation(
assert(rail_env_trans.get_transitions(
no_transition_cell, i) == (0, 0, 0, 0))
# Facing south, going south
north_south_transition = rail_env_trans.set_transitions_from_orientation(
north_south_transition = rail_env_trans.set_transitions(
no_transition_cell, 2, (0, 0, 1, 0))
assert(rail_env_trans.set_transition_from_orientation_to_direction(
assert(rail_env_trans.set_transition(
north_south_transition, 2, 2, 0) == no_transition_cell)
assert(rail_env_trans.get_transition_from_orientation_to_direction(
assert(rail_env_trans.get_transition(
north_south_transition, 2, 2))
# Facing north, going east
south_east_transition = \
rail_env_trans.set_transition_from_orientation_to_direction(
rail_env_trans.set_transition(
no_transition_cell, 0, 1, 1)
assert(rail_env_trans.get_transition_from_orientation_to_direction(
assert(rail_env_trans.get_transition(
south_east_transition, 0, 1))
# The opposite transitions are not feasible
assert(not rail_env_trans.get_transition_from_orientation_to_direction(
assert(not rail_env_trans.get_transition(
north_south_transition, 2, 0))
assert(not rail_env_trans.get_transition_from_orientation_to_direction(
assert(not rail_env_trans.get_transition(
south_east_transition, 2, 1))
east_west_transition = rail_env_trans.rotate_transition(
......@@ -52,10 +52,10 @@ def test_valid_railenv_transitions():
south_east_transition, 180)
# Facing west, going west
assert(rail_env_trans.get_transition_from_orientation_to_direction(
assert(rail_env_trans.get_transition(
east_west_transition, 3, 3))
# Facing south, going west
assert(rail_env_trans.get_transition_from_orientation_to_direction(
assert(rail_env_trans.get_transition(
north_west_transition, 2, 3))
assert(south_east_transition == rail_env_trans.rotate_transition(
......@@ -63,16 +63,16 @@ def test_valid_railenv_transitions():
def test_diagonal_transitions():
diagonal_trans_env = GridTransitions([], True)
diagonal_trans_env = Grid8Transitions([])
# Facing north, going north-east
south_northeast_transition = int('01000000' + '0'*8*7, 2)
assert(diagonal_trans_env.get_transitions_from_orientation(
assert(diagonal_trans_env.get_transitions(
south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
# Allowing transition from north to southwest: Facing south, going SW
north_southwest_transition = \
diagonal_trans_env.set_transitions_from_orientation(
diagonal_trans_env.set_transitions(
int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
assert(diagonal_trans_env.rotate_transition(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment