Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
manavsinghal157
flatland-rl
Commits
2c91e486
Commit
2c91e486
authored
Oct 18, 2019
by
spmohanty
Browse files
Add a custom observation builder
parent
5f96df8f
Changes
2
Show whitespace changes
Inline
Side-by-side
my_observation_builder.py
0 → 100644
View file @
2c91e486
#!/usr/bin/env python
import
collections
from
typing
import
Optional
,
List
,
Dict
,
Tuple
import
numpy
as
np
from
flatland.core.env
import
Environment
from
flatland.core.env_observation_builder
import
ObservationBuilder
from
flatland.core.env_prediction_builder
import
PredictionBuilder
from
flatland.core.grid.grid4_utils
import
get_new_position
from
flatland.core.grid.grid_utils
import
coordinate_to_position
from
flatland.envs.agent_utils
import
RailAgentStatus
,
EnvAgent
from
flatland.utils.ordered_set
import
OrderedSet
class
CustomObservationBuilder
(
ObservationBuilder
):
"""
Template for building a custom observation builder for the RailEnv class
The observation in this case composed of the following elements:
- transition map array with dimensions (env.height, env.width),
\
where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
- the individual agent object (with position, direction, target information available)
"""
def
__init__
(
self
):
super
(
CustomObservationBuilder
,
self
).
__init__
()
def
set_env
(
self
,
env
:
Environment
):
super
().
set_env
(
env
)
# Note :
# The instantiations which depend on parameters of the Env object should be
# done here, as it is only here that the updated self.env instance is available
self
.
rail_obs
=
np
.
zeros
((
self
.
env
.
height
,
self
.
env
.
width
))
print
(
"Env Width : "
,
self
.
env
.
width
,
"Env Height : "
,
self
.
env
.
height
)
def
reset
(
self
):
"""
Called internally on every env.reset() call,
to reset any observation specific variables that are being used
"""
self
.
rail_obs
[:]
=
0
for
_x
in
range
(
self
.
env
.
width
):
for
_y
in
range
(
self
.
env
.
height
):
# Get the transition map value at location _x, _y
transition_value
=
self
.
env
.
rail
.
get_full_transitions
(
_y
,
_x
)
self
.
rail_obs
[
_y
,
_x
]
=
transition_value
print
(
"Responding to obs_builder.reset()"
)
def
get
(
self
,
handle
:
int
=
0
):
"""
Returns the built observation for a single agent with handle : handle
In this particular case, we return
- the global transition_map of the RailEnv,
- a tuple containing, the current agent's:
- state
- position
- direction
- initial_position
- target
"""
agent
=
self
.
env
.
agents
[
handle
]
"""
Available information for each agent object :
- agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
- agent.position : Current position of the agent
- agent.direction : Current direction of the agent
- agent.initial_position : Initial Position of the agent
- agent.target : Target position of the agent
"""
status
=
agent
.
status
position
=
agent
.
position
direction
=
agent
.
direction
initial_position
=
agent
.
initial_position
target
=
agent
.
target
"""
You can also optionally access the states of the rest of the agents by
using something similar to
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
continue
## Gather other agent specific params
other_agent_status = other_agent.status
other_agent_position = other_agent.position
other_agent_direction = other_agent.direction
other_agent_initial_position = other_agent.initial_position
other_agent_target = other_agent.target
## Do something nice here if you wish
"""
return
self
.
rail_obs
,
(
status
,
position
,
direction
,
initial_position
,
target
)
run.py
View file @
2c91e486
from
flatland.evaluators.client
import
FlatlandRemoteClient
from
flatland.env
s.
observation
s
import
TreeObsForRailEnv
from
flatland.envs.predictions
import
ShortestPathPredictorForRailEnv
from
flatland.
core.
env
_
observation
_builder
import
DummyObservationBuilder
from
my_observation_builder
import
CustomObservationBuilder
import
numpy
as
np
import
time
...
...
@@ -31,10 +31,14 @@ def my_controller(obs, number_of_agents):
# the example here :
# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
#####################################################################
my_observation_builder
=
TreeObsForRailEnv
(
max_depth
=
3
,
predictor
=
ShortestPathPredictorForRailEnv
()
)
my_observation_builder
=
CustomObservationBuilder
()
# Or if you want to use your own approach to build the observation from the env_step,
# please feel free to pass a DummyObservationBuilder() object as mentioned below,
# and that will just return a placeholder True for all observation, and you
# can build your own Observation for all the agents as your please.
# my_observation_builder = DummyObservationBuilder()
#####################################################################
# Main evaluation loop
...
...
@@ -55,9 +59,11 @@ while True:
# You can also pass your custom observation_builder object
# to allow you to have as much control as you wish
# over the observation of your choice.
time_start
=
time
.
time
()
observation
,
info
=
remote_client
.
env_create
(
obs_builder_object
=
my_observation_builder
)
env_creation_time
=
time
.
time
()
-
time_start
if
not
observation
:
#
# If the remote_client returns False on a `env_create` call,
...
...
@@ -66,7 +72,7 @@ while True:
# and hence its safe to break out of the main evaluation loop
break
#
print("Evaluation Number : {}".format(evaluation_number))
print
(
"Evaluation Number : {}"
.
format
(
evaluation_number
))
#####################################################################
# Access to a local copy of the environment
...
...
@@ -95,12 +101,12 @@ while True:
# or when the number of time steps has exceed max_time_steps, which
# is defined by :
#
# max_time_steps = int(
1.5
* (env.width + env.height))
# max_time_steps = int(
4 * 2
* (env.width + env.height
+ 20
))
#
time_taken_by_controller
=
[]
time_taken_per_step
=
[]
for
k
in
range
(
10
)
:
steps
=
0
while
True
:
#####################################################################
# Evaluation of a single episode
#
...
...
@@ -119,6 +125,7 @@ while True:
# are returned by the remote copy of the env
time_start
=
time
.
time
()
observation
,
all_rewards
,
done
,
info
=
remote_client
.
env_step
(
action
)
steps
+=
1
time_taken
=
time
.
time
()
-
time_start
time_taken_per_step
.
append
(
time_taken
)
...
...
@@ -136,6 +143,8 @@ while True:
print
(
"="
*
100
)
print
(
"Evaluation Number : "
,
evaluation_number
)
print
(
"Current Env Path : "
,
remote_client
.
current_env_path
)
print
(
"Env Creation Time : "
,
env_creation_time
)
print
(
"Number of Steps : "
,
steps
)
print
(
"Mean/Std of Time taken by Controller : "
,
np_time_taken_by_controller
.
mean
(),
np_time_taken_by_controller
.
std
())
print
(
"Mean/Std of Time per Step : "
,
np_time_taken_per_step
.
mean
(),
np_time_taken_per_step
.
std
())
print
(
"="
*
100
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment