From 97a275598dc03158ecc5ebf6b850a164feb3dc87 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Thu, 23 May 2019 20:14:30 +0200 Subject: [PATCH] part3 of getting started, new custom_obs and custom_rail examples, fixes to GlobalObs --- examples/custom_observation_example.py | 32 ++++++++++++++++++++++ examples/custom_railmap_example.py | 37 ++++++++++++++++++++++++++ examples/simple_example_3.py | 4 ++- flatland/envs/observations.py | 6 +++++ flatland/envs/rail_env.py | 7 +++-- 5 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 examples/custom_observation_example.py create mode 100644 examples/custom_railmap_example.py diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py new file mode 100644 index 00000000..03bbe4b7 --- /dev/null +++ b/examples/custom_observation_example.py @@ -0,0 +1,32 @@ +import random + +from flatland.envs.generators import random_rail_generator, random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool +from flatland.core.env_observation_builder import ObservationBuilder +import numpy as np + +random.seed(100) +np.random.seed(100) + +class CustomObs(ObservationBuilder): + def __init__(self): + self.observation_space = [5] + + def reset(self): + return + + def get(self, handle): + observation = handle*np.ones((5,)) + return observation + +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(), + number_of_agents=3, + obs_builder_object=CustomObs()) + +# Print the observation vector for each agents +obs, all_rewards, done, _ = env.step({0: 0}) +for i in range(env.get_num_agents()): + print("Agent ", i,"'s observation: ", obs[i]) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py new file mode 100644 index 00000000..71f849de --- /dev/null +++ b/examples/custom_railmap_example.py @@ -0,0 +1,37 @@ +import random + +from flatland.envs.generators import random_rail_generator, random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.core.transitions import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.utils.rendertools import RenderTool +import numpy as np + +random.seed(100) +np.random.seed(100) + +def custom_rail_generator(): + def generator(width, height, num_agents=0, num_resets=0): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + + agents_positions = [] + agents_direction = [] + agents_target = [] + + return grid_map, agents_positions, agents_direction, agents_target + return generator + +env = RailEnv(width=6, + height=4, + rail_generator=custom_rail_generator(), + number_of_agents=1) + +env.reset() + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +input("Press Enter to continue...") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index e0830ff7..8aac0ccc 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -3,6 +3,7 @@ import random from flatland.envs.generators import random_rail_generator, random_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from flatland.core.env_observation_builder import ObservationBuilder import numpy as np random.seed(100) @@ -11,7 +12,8 @@ np.random.seed(100) env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), - number_of_agents=2) + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) # Print the distance map of each cell to the target of the first agent # for i in range(4): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index af789f5d..7de63271 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder): """ def __init__(self): + self.observation_space = () super(GlobalObsForRailEnv, self).__init__() + def _set_env(self, env): + super()._set_env(env) + + self.observation_space = [4, self.env.height, self.env.width] + def reset(self): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4ac00f18..be3f4d6b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -90,6 +90,9 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + self.actions = [0] * number_of_agents self.rewards = [0] * number_of_agents self.done = False @@ -112,10 +115,6 @@ class RailEnv(Environment): self.valid_positions = None - self.action_space = [1] - self.observation_space = self.obs_builder.observation_space # updated on resets? - - # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents()) -- GitLab