From 5064de8e8ecfe004f4e3a0bcb472bc62d7ad5aff Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Thu, 18 Apr 2019 21:30:44 +0200 Subject: [PATCH] change in API: rail_generator --- examples/temporary_example.py | 20 +++++---- flatland/core/env.py | 49 ++++++++++++++++----- flatland/utils/rail_env_generator.py | 66 +++++++++++++++++++++------- tests/test_environments.py | 21 +++++---- tests/test_rendertools.py | 3 +- 5 files changed, 110 insertions(+), 49 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 2ea68cfd..8b53b1dd 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -9,32 +9,36 @@ from flatland.utils.rendertools import * random.seed(1) np.random.seed(1) - # Example generate a random rail -rail = generate_random_rail(20, 20) - -env = RailEnv(rail, number_of_agents=10) +env = RailEnv(width=20, height=20, rail_generator=generate_random_rail, number_of_agents=10) env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] -rail = generate_rail_from_manual_specifications(specs) -env = RailEnv(rail, number_of_agents=1) +env = RailEnv(width=6, + height=2, + rail_generator=generate_rail_from_manual_specifications(specs), + number_of_agents=1) handle = env.get_agent_handles() -env.reset() +obs = env.reset() env.agents_position = [[1, 4]] env.agents_target = [[1, 1]] env.agents_direction = [1] +# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! +env.obs_builder.reset() + +# TODO: delete next line +#print(env.obs_builder.distance_map[0,:,:]) +#print(env.obs_builder.max_dist) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) diff --git a/flatland/core/env.py b/flatland/core/env.py index a7e63fd4..02d912a3 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -6,6 +6,7 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv import random from .env_observation_builder import TreeObsForRailEnv +from flatland.utils.rail_env_generator import generate_random_rail class Environment: @@ -121,35 +122,56 @@ class RailEnv: """ def __init__(self, - rail, + width, + height, + rail_generator=generate_random_rail, number_of_agents=1, - custom_observation_builder=TreeObsForRailEnv): + obs_builder_object=TreeObsForRailEnv(max_depth=2)): """ Environment init. Parameters ------- - rail : numpy.ndarray of type numpy.uint16 - The transition matrix that defines the environment. + rail_generator : function + The rail_generator function is a function that takes the width and + height of a rail map along with the number of times the env has + been reset, and returns a GridTransitionMap object. + Implemented functions are: + generate_random_rail : generate a random rail of given size + TODO: generate_rail_from_saved_list --- + width : int + The width of the rail map. Potentially in the future, + a range of widths to sample from. + height : int + The height of the rail map. Potentially in the future, + a range of heights to sample from. number_of_agents : int - Number of agents to spawn on the map. - custom_observation_builder: ObservationBuilder object - ObservationBuilder-derived object that takes this env object - as input as provides observation vectors for each agent. + Number of agents to spawn on the map. Potentially in the future, + a range of number of agents to sample from. + obs_builder_object: ObservationBuilder object + ObservationBuilder-derived object that takes builds observation + vectors for each agent. """ - self.rail = rail - self.width = rail.width - self.height = rail.height + self.rail_generator = rail_generator + self.num_resets = 0 + self.rail = None + self.width = width + self.height = height self.number_of_agents = number_of_agents - self.obs_builder = custom_observation_builder(env=self) + self.obs_builder = obs_builder_object + self.obs_builder.set_env(self) self.actions = [0]*self.number_of_agents self.rewards = [0]*self.number_of_agents self.done = False + self.agents_position = [] + self.agents_target = [] + self.agents_direction = [] + self.dones = {"__all__": False} self.obs_dict = {} self.rewards_dict = {} @@ -160,6 +182,9 @@ class RailEnv: return self.agents_handles def reset(self): + self.rail = self.rail_generator(self.width, self.height, self.num_resets) + self.num_resets += 1 + self.dones = {"__all__": False} for handle in self.agents_handles: self.dones[handle] = False diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py index 69e5b831..ab11ac5e 100644 --- a/flatland/utils/rail_env_generator.py +++ b/flatland/utils/rail_env_generator.py @@ -24,28 +24,62 @@ def generate_rail_from_manual_specifications(rail_spec): Returns ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each cell. """ - t_utils = RailEnvTransitions() + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() - height = len(rail_spec) - width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + height = len(rail_spec) + width = len(rail_spec[0]) + rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - for r in range(height): - for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) - return [] - rail.set_transitions((r, c), t_utils.rotate_transition( - t_utils.transitions[cell[0]], cell[1])) + for r in range(height): + for c in range(width): + cell = rail_spec[r][c] + if cell[0] < 0 or cell[0] >= len(t_utils.transitions): + print("ERROR - invalid cell type=", cell[0]) + return [] + rail.set_transitions((r, c), t_utils.rotate_transition( + t_utils.transitions[cell[0]], cell[1])) - return rail + return rail + + return generator + + +def generate_rail_from_GridTransitionMap(rail_map): + """ + Utility to convert a rail given by a GridTransitionMap map with the correct + 16-bit transitions specifications. + + Parameters + ------- + rail_map : GridTransitionMap object + GridTransitionMap object to return when the generator is called. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + return rail_map + + return generator + + +""" +def generate_rail_from_list_of_manual_specifications(list_of_specifications) + def generator(width, height, num_resets=0): + return generate_rail_from_manual_specifications(list_of_specifications) + + return generator +""" -def generate_random_rail(width, height): +def generate_random_rail(width, height, num_resets=0): """ Dummy random level generator: - fill in cells at random in [width-2, height-2] diff --git a/tests/test_environments.py b/tests/test_environments.py index 03544b08..ce9fbd4f 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -4,6 +4,7 @@ from flatland.core.env import RailEnv from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap +from flatland.utils.rail_env_generator import generate_rail_from_GridTransitionMap import numpy as np """Tests for `flatland` package.""" @@ -46,7 +47,7 @@ def test_rail_environment_single_agent(): rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=3, height=3, rail_generator=generate_rail_from_GridTransitionMap(rail), number_of_agents=1) for _ in range(200): _ = rail_env.reset() @@ -118,7 +119,10 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=generate_rail_from_GridTransitionMap(rail), + number_of_agents=1) def check_consistency(rail_env): # We run step to check that trains do not move anymore @@ -164,7 +168,10 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=generate_rail_from_GridTransitionMap(rail), + number_of_agents=1) rail_env.reset() rail_env.agents_target[0] = [0, 0] @@ -177,11 +184,3 @@ def test_dead_end(): rail_env.agents_position[0] = [2, 0] rail_env.agents_direction[0] = 0 check_consistency(rail_env) - - - - - - -test_dead_end() - diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index ae9a9e18..5fecd085 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -37,8 +37,7 @@ def checkFrozenImage(sFileImage): def test_render_env(): random.seed(100) - oRail = rail_env_generator.generate_random_rail(10, 10) - oEnv = RailEnv(oRail, number_of_agents=2) + oEnv = RailEnv(width=10, height=10, rail_generator=rail_env_generator.generate_random_rail, number_of_agents=2) oEnv.reset() oRT = rt.RenderTool(oEnv) plt.figure(figsize=(10, 10)) -- GitLab