Commit 2c91e486 authored by spmohanty's avatar spmohanty

Add a custom observation builder

parent 5f96df8f
#!/usr/bin/env python
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
from flatland.utils.ordered_set import OrderedSet
class CustomObservationBuilder(ObservationBuilder):
"""
Template for building a custom observation builder for the RailEnv class
The observation in this case composed of the following elements:
- transition map array with dimensions (env.height, env.width),\
where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
- the individual agent object (with position, direction, target information available)
"""
def __init__(self):
super(CustomObservationBuilder, self).__init__()
def set_env(self, env: Environment):
super().set_env(env)
# Note :
# The instantiations which depend on parameters of the Env object should be
# done here, as it is only here that the updated self.env instance is available
self.rail_obs = np.zeros((self.env.height, self.env.width))
print("Env Width : ", self.env.width, "Env Height : ", self.env.height)
def reset(self):
"""
Called internally on every env.reset() call,
to reset any observation specific variables that are being used
"""
self.rail_obs[:] = 0
for _x in range(self.env.width):
for _y in range(self.env.height):
# Get the transition map value at location _x, _y
transition_value = self.env.rail.get_full_transitions(_y, _x)
self.rail_obs[_y, _x] = transition_value
print("Responding to obs_builder.reset()")
def get(self, handle: int = 0):
"""
Returns the built observation for a single agent with handle : handle
In this particular case, we return
- the global transition_map of the RailEnv,
- a tuple containing, the current agent's:
- state
- position
- direction
- initial_position
- target
"""
agent = self.env.agents[handle]
"""
Available information for each agent object :
- agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
- agent.position : Current position of the agent
- agent.direction : Current direction of the agent
- agent.initial_position : Initial Position of the agent
- agent.target : Target position of the agent
"""
status = agent.status
position = agent.position
direction = agent.direction
initial_position = agent.initial_position
target = agent.target
"""
You can also optionally access the states of the rest of the agents by
using something similar to
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
continue
## Gather other agent specific params
other_agent_status = other_agent.status
other_agent_position = other_agent.position
other_agent_direction = other_agent.direction
other_agent_initial_position = other_agent.initial_position
other_agent_target = other_agent.target
## Do something nice here if you wish
"""
return self.rail_obs, (status, position, direction, initial_position, target)
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.core.env_observation_builder import DummyObservationBuilder
from my_observation_builder import CustomObservationBuilder
import numpy as np
import time
......@@ -31,10 +31,14 @@ def my_controller(obs, number_of_agents):
# the example here :
# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
#####################################################################
my_observation_builder = TreeObsForRailEnv(
max_depth=3,
predictor=ShortestPathPredictorForRailEnv()
)
my_observation_builder = CustomObservationBuilder()
# Or if you want to use your own approach to build the observation from the env_step,
# please feel free to pass a DummyObservationBuilder() object as mentioned below,
# and that will just return a placeholder True for all observation, and you
# can build your own Observation for all the agents as your please.
# my_observation_builder = DummyObservationBuilder()
#####################################################################
# Main evaluation loop
......@@ -55,9 +59,11 @@ while True:
# You can also pass your custom observation_builder object
# to allow you to have as much control as you wish
# over the observation of your choice.
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=my_observation_builder
)
env_creation_time = time.time() - time_start
if not observation:
#
# If the remote_client returns False on a `env_create` call,
......@@ -66,7 +72,7 @@ while True:
# and hence its safe to break out of the main evaluation loop
break
#print("Evaluation Number : {}".format(evaluation_number))
print("Evaluation Number : {}".format(evaluation_number))
#####################################################################
# Access to a local copy of the environment
......@@ -95,12 +101,12 @@ while True:
# or when the number of time steps has exceed max_time_steps, which
# is defined by :
#
# max_time_steps = int(1.5 * (env.width + env.height))
# max_time_steps = int(4 * 2 * (env.width + env.height + 20))
#
time_taken_by_controller = []
time_taken_per_step = []
for k in range(10):
steps = 0
while True:
#####################################################################
# Evaluation of a single episode
#
......@@ -119,6 +125,7 @@ while True:
# are returned by the remote copy of the env
time_start = time.time()
observation, all_rewards, done, info = remote_client.env_step(action)
steps += 1
time_taken = time.time() - time_start
time_taken_per_step.append(time_taken)
......@@ -136,6 +143,8 @@ while True:
print("="*100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("="*100)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment