diff --git a/docs/gettingstarted.rst b/docs/gettingstarted.rst new file mode 100644 index 0000000000000000000000000000000000000000..2d545b24f1703c153fc3f66739fb9c9d7538d1b6 --- /dev/null +++ b/docs/gettingstarted.rst @@ -0,0 +1,249 @@ +===== +Getting Started +===== + +Overview +-------------- + +Following are three short tutorials to help new users get acquainted with how +to create RailEnvs, how to train simple DQN agents on them, and how to customize +them. + +To use flatland in a project: + +.. code-block:: python + + import flatland + + +Part 1 : Basic Usage +-------------- + +The basic usage of RailEnv environments consists in creating a RailEnv object +endowed with a rail generator, that generates new rail networks on each reset, +and an observation generator object, that is supplied with environment-specific +information at each time step and provides a suitable observation vector to the +agents. + +The simplest rail generators are envs.generators.rail_from_manual_specifications_generator +and envs.generators.random_rail_generator. + +The first one accepts a list of lists whose each element is a 2-tuple, whose +entries represent the 'cell_type' (see core.transitions.RailEnvTransitions) and +the desired clockwise rotation of the cell contents (0, 90, 180 or 270 degrees). +For example, + +.. code-block:: python + + specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], + [(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], + [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)], + [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] + + env = RailEnv(width=6, + height=4, + rail_generator=rail_from_manual_specifications_generator(specs), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +Alternatively, a random environment can be generated (optionally specifying +weights for each cell type to increase or decrease their proportion in the +generated rail networks). + +.. code-block:: python + + # Relative weights of each cell type to be used by the random rail generators. + transition_probability = [1.0, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0, # Case 7 - dead end + 0.2, # Case 8 - turn left + 0.2, # Case 9 - turn right + 1.0] # Case 10 - mirrored switch + + # Example generate a random rail + env = RailEnv(width=10, + height=10, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +Environments can be rendered using the utils.rendertools utilities, for example: + +.. code-block:: python + + env_renderer = RenderTool(env, gl="QT") + env_renderer.renderEnv(show=True) + + +Finally, the environment can be run by supplying the environment step function +with a dictionary of actions whose keys are agents' handles (returned by +env.get_agent_handles() ) and the corresponding values the selected actions. +For example, for a 2-agents environment: + +.. code-block:: python + + handles = env.get_agent_handles() + action_dict = {handles[0]:0, handles[1]:0} + obs, all_rewards, done, _ = env.step(action_dict) + +where 'obs', 'all_rewards', and 'done' are also dictionary indexed by the agents' +handles, whose values correspond to the relevant observations, rewards and terminal +status for each agent. Further, the 'dones' dictionary returns an extra key +'__all__' that is set to True after all agents have reached their goals. + + +In the specific case a TreeObsForRailEnv observation builder is used, it is +possible to print a representation of the returned observations with the +following code. Also, tree observation data is displayed by RenderTool by default. + +.. code-block:: python + + for i in range(env.get_num_agents()): + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) + +The complete code for this part of the Getting Started guide can be found in +examples/simple_example_1.py, examples/simple_example_2.py and +examples/simple_example_3.py + + + +Part 2 : Training a Simple an Agent on Flatland +-------------- +This is a brief tutorial on how to train an agent on Flatland. +Here we use a simple random agent to illustrate the process on how to interact with the environment. +The corresponding code can be found in examples/training_example.py and in the baselines repository +you find a tutorial to train a DQN agent to solve the navigation task. + +We start by importing the necessary Flatland libraries + +.. code-block:: python + + from flatland.envs.generators import complex_rail_generator + from flatland.envs.rail_env import RailEnv + +The complex_rail_generator is used in order to guarantee feasible railway network configurations for training. +Next we configure the difficulty of our task by modifying the complex_rail_generator parameters. + +.. code-block:: python + + env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), + number_of_agents=5) + +The difficulty of a railway network depends on the dimensions (width x height) and the number of agents in the network. +By varying the number of start and goal connections (nr_start_goal) and the number of extra railway elements added (nr_extra) +the number of alternative paths of each agents can be modified. The more possible paths an agent has to reach its target the easier the task becomes. +Here we don't specify any observation builder but rather use the standard tree observation. If you would like to use a custom obervation please follow + the instructions in the next tutorial. +Feel free to vary these parameters to see how your own agent holds up on different setting. The evalutation set of railway configurations will +cover the whole spectrum from easy to complex tasks. + +Once we are set with the environment we can load our preferred agent from either RLlib or any other ressource. Here we use a random agent to illustrate the code. + +.. code-block:: python + + agent = RandomAgent(env.action_space, env.observation_space) + +We start every trial by resetting the environment + +.. code-block:: python + + obs = env.reset() + +Which provides the initial observation for all agents (obs = array of all observations). +In order for the environment to step forward in time we need a dictionar of actions for all active agents. + +.. code-block:: python + + for handle in range(env.get_num_agents()): + action = agent.act(obs[handle]) + action_dict.update({handle: action}) + +This dictionary is then passed to the environment which checks the validity of all actions and update the environment state. + +.. code-block:: python + + next_obs, all_rewards, done, _ = env.step(action_dict) + +The environment returns an array of new observations, reward dictionary for all agents as well as a flag for which agents are done. +This information can be used to update the policy of your agent and if done['__all__'] == True the episode terminates. + +Part 3 : Customizing Observations and Level Generators +-------------- + +Example code for generating custom observations given a RailEnv and to generate +random rail maps are available in examples/custom_observation_example.py and +examples/custom_railmap_example.py . + +Custom observations can be produced by deriving a new object from the +core.env_observation_builder.ObservationBuilder base class, for example as follows: + +.. code-block:: python + + class CustomObs(ObservationBuilder): + def __init__(self): + self.observation_space = [5] + + def reset(self): + return + + def get(self, handle): + observation = handle*np.ones((5,)) + return observation + +It is important that an observation_space is defined with a list of dimensions +of the returned observation tensors. get() returns the observation for each agent, +of handle 'handle'. + +A RailEnv environment can then be created as usual: + +.. code-block:: python + + env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(), + number_of_agents=3, + obs_builder_object=CustomObs()) + +As for generating custom rail maps, the RailEnv class accepts a rail_generator +argument that must be a function with arguments 'width', 'height', 'num_agents', +and 'num_resets=0', and that has to return a GridTransitionMap object (the rail map), +and three lists of tuples containing the (row,column) coordinates of each of +num_agent agents, their initial orientation (0=North, 1=East, 2=South, 3=West), +and the position of their targets. + +For example, the following custom rail map generator returns an empty map of +size (height, width), with no agents (regardless of num_agents): + +.. code-block:: python + + def custom_rail_generator(): + def generator(width, height, num_agents=0, num_resets=0): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + + agents_positions = [] + agents_direction = [] + agents_target = [] + + return grid_map, agents_positions, agents_direction, agents_target + return generator + +It is worth to note that helpful utilities to manage RailEnv environments and their +related data structures are available in 'envs.env_utils'. In particular, +envs.env_utils.get_rnd_agents_pos_tgt_dir_on_rail is fairly handy to fill in +random (but consistent) agents along with their targets and initial directions, +given a rail map (GridTransitionMap object) and the desired number of agents: + +.. code-block:: python + agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( + rail_map, + num_agents) diff --git a/docs/index.rst b/docs/index.rst index 96b9ee3bd7ecb2b6cb0d5642e00d099db9b549ef..f440aad83aa459e8f03540c56beb1ab8eef4a7d3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,7 +7,7 @@ Welcome to flatland's documentation! readme installation - usage + gettingstarted modules FAQ contributing diff --git a/docs/usage.rst b/docs/usage.rst deleted file mode 100644 index 56518bc3ec24856cd752092e1c3543471aea919e..0000000000000000000000000000000000000000 --- a/docs/usage.rst +++ /dev/null @@ -1,7 +0,0 @@ -===== -Usage -===== - -To use flatland in a project:: - - import flatland diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4fc8194aec9385619dc917bef9b3dd22492d47 --- /dev/null +++ b/examples/custom_observation_example.py @@ -0,0 +1,34 @@ +import random + +from flatland.envs.generators import random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.core.env_observation_builder import ObservationBuilder + +import numpy as np + +random.seed(100) +np.random.seed(100) + + +class CustomObs(ObservationBuilder): + def __init__(self): + self.observation_space = [5] + + def reset(self): + return + + def get(self, handle): + observation = handle*np.ones((5,)) + return observation + + +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(), + number_of_agents=3, + obs_builder_object=CustomObs()) + +# Print the observation vector for each agents +obs, all_rewards, done, _ = env.step({0: 0}) +for i in range(env.get_num_agents()): + print("Agent ", i, "'s observation: ", obs[i]) diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py new file mode 100644 index 0000000000000000000000000000000000000000..9d483c0c1acd8a802027e528f6418d36daf24759 --- /dev/null +++ b/examples/custom_railmap_example.py @@ -0,0 +1,38 @@ +import random + +from flatland.envs.rail_env import RailEnv +from flatland.core.transitions import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.utils.rendertools import RenderTool +import numpy as np + +random.seed(100) +np.random.seed(100) + + +def custom_rail_generator(): + def generator(width, height, num_agents=0, num_resets=0): + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + + agents_positions = [] + agents_direction = [] + agents_target = [] + + return grid_map, agents_positions, agents_direction, agents_target + return generator + + +env = RailEnv(width=6, + height=4, + rail_generator=custom_rail_generator(), + number_of_agents=1) + +env.reset() + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +input("Press Enter to continue...") diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py new file mode 100644 index 0000000000000000000000000000000000000000..7132b53339bbc09e91e72c8d0f19043942d87f67 --- /dev/null +++ b/examples/simple_example_1.py @@ -0,0 +1,30 @@ +from flatland.envs.generators import rail_from_manual_specifications_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.observations import TreeObsForRailEnv +from flatland.utils.rendertools import RenderTool + +# Example generate a rail given a manual specification, +# a map of tuples (cell_type, rotation) +specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], + [(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], + [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)], + [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] + +# CURVED RAIL + DEAD-ENDS TEST +# specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], +# [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)], +# [(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)], +# [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] + +env = RailEnv(width=6, + height=4, + rail_generator=rail_from_manual_specifications_generator(specs), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +env.reset() + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +input("Press Enter to continue...") diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py new file mode 100644 index 0000000000000000000000000000000000000000..535e9c9e784305a8bf049daa67fa40979b569c79 --- /dev/null +++ b/examples/simple_example_2.py @@ -0,0 +1,43 @@ +import random + +from flatland.envs.generators import random_rail_generator, rail_from_list_of_saved_GridTransitionMap_generator +from flatland.envs.rail_env import RailEnv +from flatland.envs.observations import TreeObsForRailEnv +from flatland.utils.rendertools import RenderTool +import numpy as np + +random.seed(100) +np.random.seed(100) + +# Relative weights of each cell type to be used by the random rail generators. +transition_probability = [1.0, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0, # Case 7 - dead end + 0.2, # Case 8 - turn left + 0.2, # Case 9 - turn right + 1.0] # Case 10 - mirrored switch + +# Example generate a random rail +env = RailEnv(width=10, + height=10, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +# env = RailEnv(width=10, +# height=10, +# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), +# number_of_agents=3, +# obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +env.reset() + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +input("Press Enter to continue...") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py new file mode 100644 index 0000000000000000000000000000000000000000..b5283df627b89294b899a0dc3f4652d5d2375152 --- /dev/null +++ b/examples/simple_example_3.py @@ -0,0 +1,55 @@ +import random + +from flatland.envs.generators import random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool +from flatland.core.env_observation_builder import TreeObsForRailEnv +import numpy as np + +random.seed(100) +np.random.seed(100) + +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(), + number_of_agents=2, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) + +# Print the distance map of each cell to the target of the first agent +# for i in range(4): +# print(env.obs_builder.distance_map[0, :, :, i]) + +# Print the observation vector for agent 0 +obs, all_rewards, done, _ = env.step({0: 0}) +for i in range(env.get_num_agents()): + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) + +env_renderer = RenderTool(env, gl="QT") +env_renderer.renderEnv(show=True) + +print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ + (turnleft+move, move to front, turnright+move)") +for step in range(100): + cmd = input(">> ") + cmds = cmd.split(" ") + + action_dict = {} + + i = 0 + while i < len(cmds): + if cmds[i] == 'q': + import sys + + sys.exit() + elif cmds[i] == 's': + obs, all_rewards, done, _ = env.step(action_dict) + action_dict = {} + print("Rewards: ", all_rewards, " [done=", done, "]") + else: + agent_id = int(cmds[i]) + action = int(cmds[i + 1]) + action_dict[agent_id] = action + i = i + 1 + i += 1 + + env_renderer.renderEnv(show=True) diff --git a/examples/temporary_example.py b/examples/temporary_example.py deleted file mode 100644 index 862369411056d87d411c3e173bd479e9a7e93e01..0000000000000000000000000000000000000000 --- a/examples/temporary_example.py +++ /dev/null @@ -1,127 +0,0 @@ -import random - -from flatland.envs.generators import random_rail_generator -from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool -import numpy as np - -random.seed(0) -np.random.seed(0) - -transition_probability = [1.0, # empty cell - Case 0 - 1.0, # Case 1 - straight - 1.0, # Case 2 - simple switch - 0.3, # Case 3 - diamond drossing - 0.5, # Case 4 - single slip - 0.5, # Case 5 - double slip - 0.2, # Case 6 - symmetrical - 0.0, # Case 7 - dead end - 0.2, # Case 8 - turn left - 0.2, # Case 9 - turn right - 1.0] # Case 10 - mirrored switch - -""" -# Example generate a random rail -env = RailEnv(width=20, - height=20, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=10) - -# env = RailEnv(width=20, -# height=20, -# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), -# number_of_agents=10) - -env.reset() - -env_renderer = RenderTool(env) -env_renderer.renderEnv(show=True) -""" -""" -# Example generate a rail given a manual specification, -# a map of tuples (cell_type, rotation) -specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], - [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] - -env = RailEnv(width=6, - height=2, - rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) - -handle = env.get_agent_handles() -env.agents_position[0] = [1, 4] -env.agents_target[0] = [1, 1] -env.agents_direction[0] = 1 -# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! -env.obs_builder.reset() -""" -""" -# INFINITE-LOOP TEST -specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], - [(7, 270), (1, 90), (1, 90), (2, 270), (2, 0), (0, 0)], - [(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)], - [(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]] - -# CURVED RAIL + DEAD-ENDS TEST -specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], - [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)], - [(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)]] - -env = RailEnv(width=6, - height=4, - rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) - -handle = env.get_agent_handles() -env.agents_position[0] = [1, 3] -env.agents_target[0] = [1, 1] -env.agents_direction[0] = 1 -# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! -env.obs_builder.reset() -""" -env = RailEnv(width=7, - height=7, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - # rail_generator=complex_rail_generator(nr_start_goal=2), - number_of_agents=2) - -# Print the distance map of each cell to the target of the first agent -# for i in range(4): -# print(env.obs_builder.distance_map[0, :, :, i]) - -# Print the observation vector for agent 0 -obs, all_rewards, done, _ = env.step({0: 0}) -for i in range(env.get_num_agents()): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) - -env_renderer = RenderTool(env, gl="QT") -env_renderer.renderEnv(show=True) - -print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ - (turnleft+move, move to front, turnright+move)") -for step in range(100): - cmd = input(">> ") - cmds = cmd.split(" ") - - action_dict = {} - - i = 0 - while i < len(cmds): - if cmds[i] == 'q': - import sys - - sys.exit() - elif cmds[i] == 's': - obs, all_rewards, done, _ = env.step(action_dict) - action_dict = {} - print("Rewards: ", all_rewards, " [done=", done, "]") - else: - agent_id = int(cmds[i]) - action = int(cmds[i + 1]) - action_dict[agent_id] = action - i = i + 1 - i += 1 - - env_renderer.renderEnv(show=True) diff --git a/examples/training_example.py b/examples/training_example.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f2c0268a9d7aece260b8e0f97ae1ff68d28bb6 --- /dev/null +++ b/examples/training_example.py @@ -0,0 +1,81 @@ +from flatland.envs.generators import complex_rail_generator +from flatland.envs.rail_env import RailEnv +import numpy as np + +np.random.seed(1) + +# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks +# Training on simple small tasks is the best way to get familiar with the environment +# +env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), + number_of_agents=5) + + +# Import your own Agent or use RLlib to train agents on Flatland +# As an example we use a random agent here + + +class RandomAgent: + + def __init__(self, state_size, action_size): + self.state_size = state_size + self.action_size = action_size + + def act(self, state): + """ + :param state: input is the observation of the agent + :return: returns an action + """ + return np.random.choice(np.arange(self.action_size)) + + def step(self, memories): + """ + Step function to improve agent by adjusting policy given the observations + + :param memories: SARS Tuple to be + :return: + """ + return + + def save(self): + # Store the current policy + return + + +# Initialize the agent with the parameters corresponding to the environment and observation_builder +agent = RandomAgent(218, 4) +n_trials = 1000 + +# Empty dictionary for all agent action +action_dict = dict() +print("Starting Training...") +for trials in range(1, n_trials + 1): + + # Reset environment and get initial observations for all agents + obs = env.reset() + # Here you can also further enhance the provided observation by means of normalization + # See training navigation example in the baseline repository + + score = 0 + # Run episode + for step in range(100): + # Chose an action for each agent in the environment + for a in range(env.get_num_agents()): + action = agent.act(obs[a]) + action_dict.update({a: action}) + + # Environment step which returns the observations for all agents, their corresponding + # reward and whether their are done + next_obs, all_rewards, done, _ = env.step(action_dict) + + # Update replay buffer and train agent + agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) + score += all_rewards[a] + + obs = next_obs.copy() + if done['__all__']: + break + print('Episode Nr. {}'.format(trials)) + diff --git a/flatland/core/env.py b/flatland/core/env.py index 284afdffb6ce46ac481018af469c7d2e024fc792..5334b22f5762840f400743c80512bc6a118062ee 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -9,6 +9,10 @@ class Environment: """ Base interface for multi-agent environments in Flatland. + Derived environments should implement the following attributes: + action_space: tuple with the dimensions of the actions to be passed to the step method + observation_space: tuple with the dimensions of the observations returned by reset and step + Agents are identified by agent ids (handles). Examples: >>> obs = env.reset() @@ -39,6 +43,8 @@ class Environment: """ def __init__(self): + self.action_space = () + self.observation_space = () pass def reset(self): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 09a624e872200e29ede834d272f7de506d6de076..3cef545c1658e6bfe2a292ee26c3e665ce6a5abc 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -12,9 +12,13 @@ case of multi-agent environments. class ObservationBuilder: """ ObservationBuilder base class. + + Derived objects must implement and `observation_space' attribute as a tuple with the dimensuions of the returned + observations. """ def __init__(self): + self.observation_space = () pass def _set_env(self, env): diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3e02050eb4f01889f457a72929578a8eb5c36f29..4d5fb44d98698072993cb6334e298033c9914b31 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -19,6 +19,14 @@ class TreeObsForRailEnv(ObservationBuilder): def __init__(self, max_depth): self.max_depth = max_depth + # Compute the size of the returned observation vector + size = 0 + pow4 = 1 + for i in range(self.max_depth+1): + size += pow4 + pow4 *= 4 + self.observation_space = [size * 5] + def reset(self): agents = self.env.agents nAgents = len(agents) @@ -158,10 +166,6 @@ class TreeObsForRailEnv(ObservationBuilder): the transitions. The order is: [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - - - Each branch data is organized as: [root node information] + [recursive branch data from 'left'] + @@ -491,8 +495,14 @@ class GlobalObsForRailEnv(ObservationBuilder): """ def __init__(self): + self.observation_space = () super(GlobalObsForRailEnv, self).__init__() + def _set_env(self, env): + super()._set_env(env) + + self.observation_space = [4, self.env.height, self.env.width] + def reset(self): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) for i in range(self.rail_obs.shape[0]): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 74e7526caa93ca8a1821eb5b2a47576231eb95c3..d4facaf5c681f4d5187be03aca69618b1f8cd466 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -90,6 +90,9 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + self.actions = [0] * number_of_agents self.rewards = [0] * number_of_agents self.done = False @@ -160,6 +163,7 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() + self.observation_space = self.obs_builder.observation_space # <-- change on reset? # Return the new observation vectors for each agent return self._get_observations()