From 2c91e4868ca9437e5e75450e4df784feb9310866 Mon Sep 17 00:00:00 2001 From: "S.P. Mohanty" <spmohanty91@gmail.com> Date: Fri, 18 Oct 2019 19:04:51 +0200 Subject: [PATCH] Add a custom observation builder --- my_observation_builder.py | 106 ++++++++++++++++++++++++++++++++++++++ run.py | 29 +++++++---- 2 files changed, 125 insertions(+), 10 deletions(-) create mode 100644 my_observation_builder.py diff --git a/my_observation_builder.py b/my_observation_builder.py new file mode 100644 index 0000000..da2a209 --- /dev/null +++ b/my_observation_builder.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python + +import collections +from typing import Optional, List, Dict, Tuple + +import numpy as np + +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent +from flatland.utils.ordered_set import OrderedSet + + +class CustomObservationBuilder(ObservationBuilder): + """ + Template for building a custom observation builder for the RailEnv class + + The observation in this case composed of the following elements: + + - transition map array with dimensions (env.height, env.width),\ + where the value at X,Y will represent the 16 bits encoding of transition-map at that point. + + - the individual agent object (with position, direction, target information available) + + """ + def __init__(self): + super(CustomObservationBuilder, self).__init__() + + def set_env(self, env: Environment): + super().set_env(env) + # Note : + # The instantiations which depend on parameters of the Env object should be + # done here, as it is only here that the updated self.env instance is available + self.rail_obs = np.zeros((self.env.height, self.env.width)) + print("Env Width : ", self.env.width, "Env Height : ", self.env.height) + + def reset(self): + """ + Called internally on every env.reset() call, + to reset any observation specific variables that are being used + """ + self.rail_obs[:] = 0 + for _x in range(self.env.width): + for _y in range(self.env.height): + # Get the transition map value at location _x, _y + transition_value = self.env.rail.get_full_transitions(_y, _x) + self.rail_obs[_y, _x] = transition_value + print("Responding to obs_builder.reset()") + + def get(self, handle: int = 0): + """ + Returns the built observation for a single agent with handle : handle + + In this particular case, we return + - the global transition_map of the RailEnv, + - a tuple containing, the current agent's: + - state + - position + - direction + - initial_position + - target + """ + + agent = self.env.agents[handle] + """ + Available information for each agent object : + + - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + - agent.position : Current position of the agent + - agent.direction : Current direction of the agent + - agent.initial_position : Initial Position of the agent + - agent.target : Target position of the agent + """ + + status = agent.status + position = agent.position + direction = agent.direction + initial_position = agent.initial_position + target = agent.target + + + """ + You can also optionally access the states of the rest of the agents by + using something similar to + + for i in range(len(self.env.agents)): + other_agent: EnvAgent = self.env.agents[i] + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + ## Gather other agent specific params + other_agent_status = other_agent.status + other_agent_position = other_agent.position + other_agent_direction = other_agent.direction + other_agent_initial_position = other_agent.initial_position + other_agent_target = other_agent.target + + ## Do something nice here if you wish + """ + return self.rail_obs, (status, position, direction, initial_position, target) + diff --git a/run.py b/run.py index 15a97b8..5c5bb9a 100644 --- a/run.py +++ b/run.py @@ -1,6 +1,6 @@ from flatland.evaluators.client import FlatlandRemoteClient -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.core.env_observation_builder import DummyObservationBuilder +from my_observation_builder import CustomObservationBuilder import numpy as np import time @@ -31,10 +31,14 @@ def my_controller(obs, number_of_agents): # the example here : # https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14 ##################################################################### -my_observation_builder = TreeObsForRailEnv( - max_depth=3, - predictor=ShortestPathPredictorForRailEnv() - ) +my_observation_builder = CustomObservationBuilder() + +# Or if you want to use your own approach to build the observation from the env_step, +# please feel free to pass a DummyObservationBuilder() object as mentioned below, +# and that will just return a placeholder True for all observation, and you +# can build your own Observation for all the agents as your please. +# my_observation_builder = DummyObservationBuilder() + ##################################################################### # Main evaluation loop @@ -55,9 +59,11 @@ while True: # You can also pass your custom observation_builder object # to allow you to have as much control as you wish # over the observation of your choice. + time_start = time.time() observation, info = remote_client.env_create( obs_builder_object=my_observation_builder ) + env_creation_time = time.time() - time_start if not observation: # # If the remote_client returns False on a `env_create` call, @@ -66,7 +72,7 @@ while True: # and hence its safe to break out of the main evaluation loop break - #print("Evaluation Number : {}".format(evaluation_number)) + print("Evaluation Number : {}".format(evaluation_number)) ##################################################################### # Access to a local copy of the environment @@ -95,12 +101,12 @@ while True: # or when the number of time steps has exceed max_time_steps, which # is defined by : # - # max_time_steps = int(1.5 * (env.width + env.height)) + # max_time_steps = int(4 * 2 * (env.width + env.height + 20)) # time_taken_by_controller = [] time_taken_per_step = [] - - for k in range(10): + steps = 0 + while True: ##################################################################### # Evaluation of a single episode # @@ -119,6 +125,7 @@ while True: # are returned by the remote copy of the env time_start = time.time() observation, all_rewards, done, info = remote_client.env_step(action) + steps += 1 time_taken = time.time() - time_start time_taken_per_step.append(time_taken) @@ -136,6 +143,8 @@ while True: print("="*100) print("Evaluation Number : ", evaluation_number) print("Current Env Path : ", remote_client.current_env_path) + print("Env Creation Time : ", env_creation_time) + print("Number of Steps : ", steps) print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std()) print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) print("="*100) -- GitLab