Skip to content
Snippets Groups Projects
Commit 5064de8e authored by spiglerg's avatar spiglerg
Browse files

change in API: rail_generator

parent c794eb4e
No related branches found
No related tags found
No related merge requests found
...@@ -9,32 +9,36 @@ from flatland.utils.rendertools import * ...@@ -9,32 +9,36 @@ from flatland.utils.rendertools import *
random.seed(1) random.seed(1)
np.random.seed(1) np.random.seed(1)
# Example generate a random rail # Example generate a random rail
rail = generate_random_rail(20, 20) env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10)
env = RailEnv(rail, number_of_agents=10)
env.reset() env.reset()
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]]
rail = generate_rail_from_manual_specifications(specs) env = RailEnv(width=6,
env = RailEnv(rail, number_of_agents=1) height=2,
rail_generator=generate_rail_from_manual_specifications(specs),
number_of_agents=1)
handle = env.get_agent_handles() handle = env.get_agent_handles()
env.reset() obs = env.reset()
env.agents_position = [[1, 4]] env.agents_position = [[1, 4]]
env.agents_target = [[1, 1]] env.agents_target = [[1, 1]]
env.agents_direction = [1] env.agents_direction = [1]
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
# TODO: delete next line
#print(env.obs_builder.distance_map[0,:,:])
#print(env.obs_builder.max_dist)
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
......
...@@ -6,6 +6,7 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv ...@@ -6,6 +6,7 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv
import random import random
from .env_observation_builder import TreeObsForRailEnv from .env_observation_builder import TreeObsForRailEnv
from flatland.utils.rail_env_generator import generate_random_rail
class Environment: class Environment:
...@@ -121,35 +122,56 @@ class RailEnv: ...@@ -121,35 +122,56 @@ class RailEnv:
""" """
def __init__(self, def __init__(self,
rail, width,
height,
rail_generator=generate_random_rail,
number_of_agents=1, number_of_agents=1,
custom_observation_builder=TreeObsForRailEnv): obs_builder_object=TreeObsForRailEnv(max_depth=2)):
""" """
Environment init. Environment init.
Parameters Parameters
------- -------
rail : numpy.ndarray of type numpy.uint16 rail_generator : function
The transition matrix that defines the environment. The rail_generator function is a function that takes the width and
height of a rail map along with the number of times the env has
been reset, and returns a GridTransitionMap object.
Implemented functions are:
generate_random_rail : generate a random rail of given size
TODO: generate_rail_from_saved_list ---
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
height : int
The height of the rail map. Potentially in the future,
a range of heights to sample from.
number_of_agents : int number_of_agents : int
Number of agents to spawn on the map. Number of agents to spawn on the map. Potentially in the future,
custom_observation_builder: ObservationBuilder object a range of number of agents to sample from.
ObservationBuilder-derived object that takes this env object obs_builder_object: ObservationBuilder object
as input as provides observation vectors for each agent. ObservationBuilder-derived object that takes builds observation
vectors for each agent.
""" """
self.rail = rail self.rail_generator = rail_generator
self.width = rail.width self.num_resets = 0
self.height = rail.height self.rail = None
self.width = width
self.height = height
self.number_of_agents = number_of_agents self.number_of_agents = number_of_agents
self.obs_builder = custom_observation_builder(env=self) self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self.actions = [0]*self.number_of_agents self.actions = [0]*self.number_of_agents
self.rewards = [0]*self.number_of_agents self.rewards = [0]*self.number_of_agents
self.done = False self.done = False
self.agents_position = []
self.agents_target = []
self.agents_direction = []
self.dones = {"__all__": False} self.dones = {"__all__": False}
self.obs_dict = {} self.obs_dict = {}
self.rewards_dict = {} self.rewards_dict = {}
...@@ -160,6 +182,9 @@ class RailEnv: ...@@ -160,6 +182,9 @@ class RailEnv:
return self.agents_handles return self.agents_handles
def reset(self): def reset(self):
self.rail = self.rail_generator(self.width, self.height, self.num_resets)
self.num_resets += 1
self.dones = {"__all__": False} self.dones = {"__all__": False}
for handle in self.agents_handles: for handle in self.agents_handles:
self.dones[handle] = False self.dones[handle] = False
......
...@@ -24,28 +24,62 @@ def generate_rail_from_manual_specifications(rail_spec): ...@@ -24,28 +24,62 @@ def generate_rail_from_manual_specifications(rail_spec):
Returns Returns
------- -------
numpy.ndarray of type numpy.uint16 function
The matrix with the correct 16-bit bitmaps for each cell. Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each cell.
""" """
t_utils = RailEnvTransitions() def generator(width, height, num_resets=0):
t_utils = RailEnvTransitions()
height = len(rail_spec) height = len(rail_spec)
width = len(rail_spec[0]) width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=t_utils) rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
for r in range(height): for r in range(height):
for c in range(width): for c in range(width):
cell = rail_spec[r][c] cell = rail_spec[r][c]
if cell[0] < 0 or cell[0] >= len(t_utils.transitions): if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
print("ERROR - invalid cell type=", cell[0]) print("ERROR - invalid cell type=", cell[0])
return [] return []
rail.set_transitions((r, c), t_utils.rotate_transition( rail.set_transitions((r, c), t_utils.rotate_transition(
t_utils.transitions[cell[0]], cell[1])) t_utils.transitions[cell[0]], cell[1]))
return rail return rail
return generator
def generate_rail_from_GridTransitionMap(rail_map):
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
-------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_resets=0):
return rail_map
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def generate_random_rail(width, height): def generate_random_rail(width, height, num_resets=0):
""" """
Dummy random level generator: Dummy random level generator:
- fill in cells at random in [width-2, height-2] - fill in cells at random in [width-2, height-2]
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from flatland.core.env import RailEnv from flatland.core.env import RailEnv
from flatland.core.transitions import Grid4Transitions from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap
import numpy as np import numpy as np
"""Tests for `flatland` package.""" """Tests for `flatland` package."""
...@@ -46,7 +47,7 @@ def test_rail_environment_single_agent(): ...@@ -46,7 +47,7 @@ def test_rail_environment_single_agent():
rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail = GridTransitionMap(width=3, height=3, transitions=transitions)
rail.grid = rail_map rail.grid = rail_map
rail_env = RailEnv(rail, number_of_agents=1) rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1)
for _ in range(200): for _ in range(200):
_ = rail_env.reset() _ = rail_env.reset()
...@@ -118,7 +119,10 @@ def test_dead_end(): ...@@ -118,7 +119,10 @@ def test_dead_end():
transitions=transitions) transitions=transitions)
rail.grid = rail_map rail.grid = rail_map
rail_env = RailEnv(rail, number_of_agents=1) rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=generate_rail_from_GridTransitionMap(rail),
number_of_agents=1)
def check_consistency(rail_env): def check_consistency(rail_env):
# We run step to check that trains do not move anymore # We run step to check that trains do not move anymore
...@@ -164,7 +168,10 @@ def test_dead_end(): ...@@ -164,7 +168,10 @@ def test_dead_end():
transitions=transitions) transitions=transitions)
rail.grid = rail_map rail.grid = rail_map
rail_env = RailEnv(rail, number_of_agents=1) rail_env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=generate_rail_from_GridTransitionMap(rail),
number_of_agents=1)
rail_env.reset() rail_env.reset()
rail_env.agents_target[0] = [0, 0] rail_env.agents_target[0] = [0, 0]
...@@ -177,11 +184,3 @@ def test_dead_end(): ...@@ -177,11 +184,3 @@ def test_dead_end():
rail_env.agents_position[0] = [2, 0] rail_env.agents_position[0] = [2, 0]
rail_env.agents_direction[0] = 0 rail_env.agents_direction[0] = 0
check_consistency(rail_env) check_consistency(rail_env)
test_dead_end()
...@@ -37,8 +37,7 @@ def checkFrozenImage(sFileImage): ...@@ -37,8 +37,7 @@ def checkFrozenImage(sFileImage):
def test_render_env(): def test_render_env():
random.seed(100) random.seed(100)
oRail = rail_env_generator.generate_random_rail(10, 10) oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2)
oEnv = RailEnv(oRail, number_of_agents=2)
oEnv.reset() oEnv.reset()
oRT = rt.RenderTool(oEnv) oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10, 10)) plt.figure(figsize=(10, 10))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment