Skip to content
Snippets Groups Projects
Commit 351a7e06 authored by u214892's avatar u214892
Browse files

#175 no sleep when running examples as benchmark or for profiling

parent e9f4408a
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ for entry in [entry for entry in importlib_resources.contents('examples') if
and entry.endswith(".py")
and '__init__' not in entry
and 'demo.py' not in entry
and 'DELETE' not in entry
]:
with path('examples', entry) as file_in:
print("")
......@@ -22,4 +23,6 @@ for entry in [entry for entry in importlib_resources.contents('examples') if
print("Running {}".format(entry))
print("*****************************************************************")
with swap_attr(sys, "stdin", StringIO("q")):
runpy.run_path(file_in, run_name="__main__")
runpy.run_path(file_in, run_name="__main__", init_globals={
'argv': ['--no-sleep']
})
import random
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator
random.seed(100)
np.random.seed(100)
class SimpleObs(ObservationBuilder):
"""
Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent.
"""
def __init__(self):
self.observation_space = [5]
def reset(self):
return
def get(self, handle):
observation = handle * np.ones((5,))
return observation
def main():
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(),
number_of_agents=3,
obs_builder_object=SimpleObs())
# Print the observation vector for each agents
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
print("Agent ", i, "'s observation: ", obs[i])
if __name__ == '__main__':
main()
import getopt
import random
import sys
import time
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def main(args):
try:
opts, args = getopt.getopt(args, "", ["no-sleep", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
no_sleep = False
for o, a in opts:
if o in ("--no-sleep"):
no_sleep = True
else:
assert False, "unhandled option"
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=True)
for step in range(100):
action = np.argmax(obs[0]) + 1
obs, all_rewards, done, _ = env.step({0: action})
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True)
if not no_sleep:
time.sleep(0.1)
if done["__all__"]:
break
env_renderer.close_window()
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import getopt
import random
import sys
import time
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
......@@ -16,101 +17,6 @@ random.seed(100)
np.random.seed(100)
class SimpleObs(ObservationBuilder):
"""
Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent.
"""
def __init__(self):
self.observation_space = [5]
def reset(self):
return
def get(self, handle):
observation = handle * np.ones((5,))
return observation
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(),
number_of_agents=3,
obs_builder_object=SimpleObs())
# Print the observation vector for each agents
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
print("Agent ", i, "'s observation: ", obs[i])
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=True)
for step in range(100):
action = np.argmax(obs[0]) + 1
obs, all_rewards, done, _ = env.step({0: action})
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True)
time.sleep(0.1)
if done["__all__"]:
break
env_renderer.close_window()
class ObservePredictions(TreeObsForRailEnv):
"""
We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
......@@ -194,32 +100,55 @@ class ObservePredictions(TreeObsForRailEnv):
return observation
# Initiate the Predictor
CustomPredictor = ShortestPathPredictorForRailEnv(10)
def main(argv):
try:
opts, args = getopt.getopt(argv, "", ["no-sleep", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
no_sleep = False
for o, a in opts:
if o in ("--no-sleep"):
no_sleep = True
else:
assert False, "unhandled option"
# Pass the Predictor to the observation builder
CustomObsBuilder = ObservePredictions(CustomPredictor)
# Initiate the Predictor
custom_predictor = ShortestPathPredictorForRailEnv(10)
# Initiate Environment
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
# Pass the Predictor to the observation builder
custom_obs_builder = ObservePredictions(custom_predictor)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
# Initiate Environment
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999,
seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=custom_obs_builder)
# We render the initial step and show the obsered cells as colored boxes
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
action_dict = {}
for step in range(100):
for a in range(env.get_num_agents()):
action = np.random.randint(0, 5)
action_dict[a] = action
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
# We render the initial step and show the obsered cells as colored boxes
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
time.sleep(0.5)
action_dict = {}
for step in range(100):
for a in range(env.get_num_agents()):
action = np.random.randint(0, 5)
action_dict[a] = action
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
if not no_sleep:
time.sleep(0.5)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment