diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py new file mode 100644 index 0000000000000000000000000000000000000000..03bbe4b71330a0d2ac2fb19e0a2afb4ae2363301 --- /dev/null +++ b/examples/custom_observation_example.py @@ -0,0 +1,32 @@ +import random + +from flatland.envs.generators import random_rail_generator, random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool +from flatland.core.env_observation_builder import ObservationBuilder +import numpy as np + +random.seed(100) +np.random.seed(100) + +class CustomObs(ObservationBuilder): + def __init__(self): + self.observation_space = [5] + + def reset(self): + return + + def get(self, handle): + observation = handle*np.ones((5,)) + return observation + +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(), + number_of_agents=3, + obs_builder_object=CustomObs()) + +# Print the observation vector for each agents +obs, all_rewards, done, _ = env.step({0: 0}) +for i in range(env.get_num_agents()): + print("Agent ", i,"'s observation: ", obs[i]) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py new file mode 100644 index 0000000000000000000000000000000000000000..71f849de24ac62656897e065341df068b5c0f6f4 --- /dev/null +++ b/examples/custom_railmap_example.py @@ -0,0 +1,37 @@ +import random + +from flatland.envs.generators import random_rail_generator, random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.core.transitions import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.utils.rendertools import RenderTool +import numpy as np + +random.seed(100) +np.random.seed(100) + +def custom_rail_generator(): + def generator(width, height, num_agents=0, num_resets=0): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + + agents_positions = [] + agents_direction = [] + agents_target = [] + + return grid_map, agents_positions, agents_direction, agents_target + return generator + +env = RailEnv(width=6, + height=4, + rail_generator=custom_rail_generator(), + number_of_agents=1) + +env.reset() + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +input("Press Enter to continue...") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index e0830ff751d15d425d22fcdc6c38b5ffc68197d5..8aac0ccc97fdb619fd2feaac76fed6607dc74867 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -3,6 +3,7 @@ import random from flatland.envs.generators import random_rail_generator, random_rail_generator from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool +from flatland.core.env_observation_builder import ObservationBuilder import numpy as np random.seed(100) @@ -11,7 +12,8 @@ np.random.seed(100) env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), - number_of_agents=2) + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) # Print the distance map of each cell to the target of the first agent # for i in range(4): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index af789f5d064f0e2b46b2328ebbfa55057bffdc14..7de6327121cf7a650bd36fad6faf54755a86b2d5 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -491,8 +491,14 @@ class GlobalObsForRailEnv(ObservationBuilder): """ def __init__(self): + self.observation_space = () super(GlobalObsForRailEnv, self).__init__() + def _set_env(self, env): + super()._set_env(env) + + self.observation_space = [4, self.env.height, self.env.width] + def reset(self): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4ac00f182ba75303e4a103a9bd16ee40d26ce9d0..be3f4d6bd3fb485508b8cf16823879dfafff1ae9 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -90,6 +90,9 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + self.actions = [0] * number_of_agents self.rewards = [0] * number_of_agents self.done = False @@ -112,10 +115,6 @@ class RailEnv(Environment): self.valid_positions = None - self.action_space = [1] - self.observation_space = self.obs_builder.observation_space # updated on resets? - - # no more agent_handles def get_agent_handles(self): return range(self.get_num_agents())