Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 1983 additions and 444 deletions
import random
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
random.seed(100)
np.random.seed(100)
class SimpleObs(ObservationBuilder):
"""
Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent.
"""
def __init__(self):
super().__init__()
def reset(self):
return
def get(self, handle: int = 0) -> np.ndarray:
observation = handle * np.ones((5,))
return observation
def create_env():
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=SimpleObs()
)
return env
def main():
env = create_env()
env.reset()
# Print the observation vector for each agents
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
print("Agent ", i, "'s observation: ", obs[i])
if __name__ == '__main__':
main()
import getopt
import random
import sys
import time
from typing import List
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
class SingleAgentNavigationObs(ObservationBuilder):
"""
We build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__()
def reset(self):
pass
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
if agent.position:
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
else:
possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(agent.position, direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def create_env():
nAgents = 1
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=30,
height=40,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=SingleAgentNavigationObs()
)
return env
def custom_observation_example_02_SingleAgentNavigationObs(sleep_for_animation, do_rendering):
env = create_env()
obs, info = env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env)
env_renderer.render_env(show=True, frames=True, show_observations=False)
for step in range(100):
action = np.argmax(obs[0]) + 1
obs, all_rewards, done, _ = env.step({0: action})
print("Rewards: ", all_rewards, " [done=", done, "]")
if env_renderer is not None:
env_renderer.render_env(show=True, frames=True, show_observations=True)
if sleep_for_animation:
time.sleep(0.1)
if done["__all__"]:
break
if env_renderer is not None:
env_renderer.close_window()
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
custom_observation_example_02_SingleAgentNavigationObs(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import getopt
import random
import sys
import time
from typing import Optional, List, Dict
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.ordered_set import OrderedSet
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
class SimpleObs(ObservationBuilder):
"""
Simplest observation builder. The object returns observation vectors with 5 identical components,
all equal to the ID of the respective agent.
"""
def __init__(self):
self.observation_space = [5]
def reset(self):
return
def get(self, handle):
observation = handle * np.ones((5,))
return observation
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(),
number_of_agents=3,
obs_builder_object=SimpleObs())
# Print the observation vector for each agents
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
print("Agent ", i, "'s observation: ", obs[i])
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=1,
obs_builder_object=SingleAgentNavigationObs())
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=True)
for step in range(100):
action = np.argmax(obs[0]) + 1
obs, all_rewards, done, _ = env.step({0: action})
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True)
time.sleep(0.1)
if done["__all__"]:
break
env_renderer.close_window()
class ObservePredictions(TreeObsForRailEnv):
class ObservePredictions(ObservationBuilder):
"""
We use the provided ShortestPathPredictor to illustrate the usage of predictors in your custom observation.
We derive our observation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
This is necessary so that we can pass the distance map to the ShortestPathPredictor
Here we also want to highlight how you can visualize your observation
"""
def __init__(self, predictor):
super().__init__(max_depth=0)
self.observation_space = [10]
super().__init__()
self.predictor = predictor
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
pass
def get_many(self, handles=None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
'''
Because we do not want to call the predictor seperately for every agent we implement the get_many function
Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
......@@ -140,23 +41,25 @@ class ObservePredictions(TreeObsForRailEnv):
:return:
'''
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
self.predictions = self.predictor.get()
self.predicted_pos = {}
if handles is None:
handles = []
for t in range(len(self.predictions[0])):
pos_list = []
for a in handles:
pos_list.append(self.predictions[a][t][1:3])
# We transform (x,y) coodrinates to a single integer number for simpler comparison
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
observations = {}
# Collect all the different observation for all the agents
for h in handles:
observations[h] = self.get(h)
observations = super().get_many(handles)
return observations
def get(self, handle):
def get(self, handle: int = 0) -> np.ndarray:
'''
Lets write a simple observation which just indicates whether or not the own predicted path
overlaps with other predicted paths at any time. This is useless for the task of navigation but might
......@@ -176,7 +79,7 @@ class ObservePredictions(TreeObsForRailEnv):
# We are going to track what cells where considered while building the obervation and make them accesible
# For rendering
visited = set()
visited = OrderedSet()
for _idx in range(10):
# Check if any of the other prediction overlap with agents own predictions
x_coord = self.predictions[handle][_idx][1]
......@@ -193,33 +96,93 @@ class ObservePredictions(TreeObsForRailEnv):
return observation
def set_env(self, env: Environment):
super().set_env(env)
if self.predictor:
self.predictor.set_env(self.env)
def create_env(custom_obs_builder):
nAgents = 3
n_cities = 2
max_rails_between_cities = 4
max_rails_in_city = 2
seed = 0
env = RailEnv(
width=30,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=custom_obs_builder
)
return env
def custom_observation_example_03_ObservePredictions(sleep_for_animation, do_rendering):
# Initiate the Predictor
custom_predictor = ShortestPathPredictorForRailEnv(10)
# Pass the Predictor to the observation builder
custom_obs_builder = ObservePredictions(custom_predictor)
# Initiate Environment
env = create_env(custom_obs_builder)
obs, info = env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env)
# We render the initial step and show the obsered cells as colored boxes
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
action_dict = {}
for step in range(100):
for a in range(env.get_num_agents()):
action = np.random.randint(0, 5)
action_dict[a] = action
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
if env_renderer is not None:
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
if sleep_for_animation:
time.sleep(0.5)
if done["__all__"]:
print("All done!")
break
if env_renderer is not None:
env_renderer.close_window()
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# Initiate the Predictor
CustomPredictor = ShortestPathPredictorForRailEnv(10)
# Pass the Predictor to the observation builder
CustomObsBuilder = ObservePredictions(CustomPredictor)
# Initiate Environment
env = RailEnv(width=10,
height=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
# We render the initial step and show the obsered cells as colored boxes
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
# execute example
custom_observation_example_03_ObservePredictions(sleep_for_animation, do_rendering)
action_dict = {}
for step in range(100):
for a in range(env.get_num_agents()):
action = np.random.randint(0, 5)
action_dict[a] = action
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
time.sleep(0.5)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import getopt
import random
from typing import Any
import sys
import time
from typing import Tuple
import numpy as np
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
def custom_rail_map() -> Tuple[GridTransitionMap, np.array]:
# We instantiate a very simple rail network on a 7x10 grid:
# 0 1 2 3 4 5 6 7 8 9 10
# 0 /-------------\
# 1 | |
# 2 | |
# 3 _ _ _ /_ _ _ |
# 4 \ ___ /
# 5 |/
# 6 |
# 7 |
transitions = RailEnvTransitions()
cells = transitions.transition_list
def custom_rail_generator() -> RailGenerator:
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
new_tran = rail_trans.set_transition(1, 1, 1, 1)
print(new_tran)
rail_array[0, 0] = new_tran
rail_array[0, 1] = new_tran
return grid_map, None
empty = cells[0]
dead_end_from_south = cells[7]
right_turn_from_south = cells[8]
right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90)
right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180)
dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
vertical_straight = cells[1]
simple_switch_north_left = cells[2]
simple_switch_north_right = cells[10]
simple_switch_left_east = transitions.rotate_transition(simple_switch_north_left, 90)
horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
double_switch_south_horizontal_straight = horizontal_straight + cells[6]
double_switch_north_horizontal_straight = transitions.rotate_transition(
double_switch_south_horizontal_straight, 180)
rail_map = np.array(
[[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west]] +
[[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight]] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [
right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
[[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
rail.grid = rail_map
city_positions = [(0, 3), (6, 6)]
train_stations = [
[((0, 3), 0)],
[((6, 6), 0)],
]
city_orientations = [0, 2]
agents_hints = {'city_positions': city_positions,
'train_stations': train_stations,
'city_orientations': city_orientations
}
optionals = {'agents_hints': agents_hints}
return rail, rail_map, optionals
return generator
def create_env():
rail, rail_map, optionals = custom_rail_map()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(),
number_of_agents=2,
obs_builder_object=DummyObservationBuilder(),
)
return env
def custom_schedule_generator() -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
agents_positions = []
agents_direction = []
agents_target = []
speeds = []
return agents_positions, agents_direction, agents_target, speeds
return generator
def custom_railmap_example(sleep_for_animation, do_rendering):
random.seed(100)
np.random.seed(100)
env = create_env()
env.reset()
env = RailEnv(width=6,
height=4,
rail_generator=custom_rail_generator(),
number_of_agents=1)
if do_rendering:
env_renderer = RenderTool(env)
env_renderer.render_env(show=True, show_observations=False)
env_renderer.close_window()
env.reset()
if sleep_for_animation:
time.sleep(1)
env_renderer = RenderTool(env)
env_renderer.render_env(show=True)
# uncomment to keep the renderer open
# input("Press Enter to continue...")
input("Press Enter to continue...")
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
custom_railmap_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import random
import time
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(1)
np.random.seed(1)
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
env = RailEnv(width=14,
height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs())
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=False)
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
print("Agent halts")
actions[0] = 4 # Halt
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
print("Agent 0 broken-ness: ", env.agents[0].malfunction_data['malfunction'])
env_renderer.render_env(show=True, frames=True, show_observations=False)
time.sleep(0.5)
if done["__all__"]:
break
env_renderer.close_window()
import numpy as np
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.5, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 10 # Max duration of malfunction
}
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are)
num_intersections=1, # Number of intersections (no start / target)
num_trainstations=15, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=2, # Number of connections to other cities/intersections
seed=15, # Random seed
realistic_mode=True,
enhance_intersection=True
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=5,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation)
env_renderer = RenderTool(env, gl="PILSVG", )
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
# Set action space to 4 to remove stop action
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
print("Start episode...")
# Reset environment and get initial observations for all agents
obs = env.reset()
# Update/Set agent's speed
for idx in range(env.get_num_agents()):
speed = 1.0 / ((idx % 5) + 1.0)
env.agents[idx].speed_data["speed"] = speed
# Reset the rendering sytem
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
frame_step = 0
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode: Steps {}\t Score = {}'.format(step, score))
import getopt
import sys
import time
import numpy as np
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import MalfunctionParameters
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return 2 # np.random.choice(np.arange(self.action_size))
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
def create_env():
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time
stochastic_data = MalfunctionParameters(malfunction_rate=30, # Rate of malfunction occurence
min_duration=3, # Minimal duration of malfunction
max_duration=20 # Max duration of malfunction
)
# Custom observation builder
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
nAgents = 3
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=20,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
)
return env
def flatland_3_0_example(sleep_for_animation, do_rendering):
np.random.seed(1)
env = create_env()
env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env, gl="PILSVG",
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
show_debug=True,
screen_height=1000,
screen_width=1000)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
# Set action space to 4 to remove stop action
agent = RandomAgent(218, 4)
# Empty dictionary for all agent action
action_dict = dict()
print("Start episode...")
# Reset environment and get initial observations for all agents
start_reset = time.time()
obs, info = env.reset()
end_reset = time.time()
print(end_reset - start_reset)
print(env.get_num_agents(), )
# Reset the rendering sytem
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
frame_step = 0
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
if env_renderer is not None:
env_renderer.close_window()
print('Episode: Steps {}\t Score = {}'.format(step, score))
RailEnvPersister.save(env, "saved_episode_2.pkl")
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
flatland_3_0_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
import cProfile
import pstats
import numpy as np
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
class RandomAgent:
def __init__(self, action_size):
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice(np.arange(self.action_size))
def get_rail_env(nAgents=70, use_dummy_obs=False, width=300, height=300):
# Rail Generator:
num_cities = 5 # Number of cities to place on the map
seed = 1 # Random seed
max_rails_between_cities = 2 # Maximum number of rails connecting 2 cities
max_rail_pairs_in_cities = 2 # Maximum number of pairs of tracks within a city
# Even tracks are used as start points, odd tracks are used as endpoints)
rail_generator = sparse_rail_generator(
max_num_cities=num_cities,
seed=seed,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_pairs_in_cities,
)
# Line Generator
# sparse_line_generator accepts a dictionary which maps speeds to probabilities.
# Different agent types (trains) with different speeds.
speed_probability_map = {
1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25 # Slow freight train
}
line_generator = sparse_line_generator(speed_probability_map)
# Malfunction Generator:
stochastic_data = MalfunctionParameters(
malfunction_rate=1 / 10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
malfunction_generator = ParamMalfunctionGen(stochastic_data)
# Observation Builder
# tree observation returns a tree of possible paths from the current position.
max_depth = 3 # Max depth of the tree
predictor = ShortestPathPredictorForRailEnv(
max_depth=50) # (Specific to Tree Observation - read code)
observation_builder = TreeObsForRailEnv(
max_depth=max_depth,
predictor=predictor
)
if use_dummy_obs:
observation_builder = DummyObservationBuilder()
number_of_agents = nAgents # Number of trains to create
seed = 1 # Random seed
env = RailEnv(
width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=number_of_agents,
random_seed=seed,
obs_builder_object=observation_builder,
malfunction_generator=malfunction_generator
)
return env
def run_simulation(env_fast: RailEnv, do_rendering):
agent = RandomAgent(action_size=5)
max_steps = 200
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env_fast,
gl="PGL",
show_debug=True,
agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)
env_renderer.set_new_rail()
env_renderer.reset()
for step in range(max_steps):
# Chose an action for each agent in the environment
for handle in range(env_fast.get_num_agents()):
action = agent.act(handle)
action_dict.update({handle: action})
next_obs, all_rewards, done, _ = env_fast.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(
show=True,
frames=False,
show_observations=True,
show_predictions=False
)
if env_renderer is not None:
env_renderer.close_window()
USE_PROFILER = True
PROFILE_CREATE = False
PROFILE_RESET = False
PROFILE_STEP = True
PROFILE_OBSERVATION = False
RUN_SIMULATION = False
DO_RENDERING = False
if __name__ == "__main__":
print("Start ...")
if USE_PROFILER:
profiler = cProfile.Profile()
print("Create env ... ")
if PROFILE_CREATE:
profiler.enable()
env_fast = get_rail_env(nAgents=200, use_dummy_obs=False, width=100, height=100)
if PROFILE_CREATE:
profiler.disable()
print("Reset env ... ")
if PROFILE_RESET:
profiler.enable()
env_fast.reset(random_seed=1)
if PROFILE_RESET:
profiler.disable()
print("Make actions ... ")
action_dict = {agent.handle: 0 for agent in env_fast.agents}
print("Step env ... ")
if PROFILE_STEP:
profiler.enable()
for i in range(1):
env_fast.step(action_dict)
if PROFILE_STEP:
profiler.disable()
if PROFILE_OBSERVATION:
profiler.enable()
print("get observation ... ")
obs = env_fast._get_observations()
if PROFILE_OBSERVATION:
profiler.disable()
if USE_PROFILER:
if False:
print("---- tottime")
stats = pstats.Stats(profiler).sort_stats('tottime') # ncalls, 'cumtime'...
stats.print_stats(20)
if True:
print("---- cumtime")
stats = pstats.Stats(profiler).sort_stats('cumtime') # ncalls, 'cumtime'...
stats.print_stats(200)
if False:
print("---- ncalls")
stats = pstats.Stats(profiler).sort_stats('ncalls') # ncalls, 'cumtime'...
stats.print_stats(200)
print("... end ")
if RUN_SIMULATION:
run_simulation(env_fast, DO_RENDERING)
import os
import numpy as np
from flatland.envs.line_generators import sparse_line_generator
# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import GlobalObsForRailEnv
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
# We also include a renderer because we want to visualize what is going on in the environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# This is an introduction example for the Flatland 2.1.* version.
# Changes and highlights of this version include
# - Stochastic events (malfunctions)
# - Different travel speeds for differet agents
# - Levels are generated using a novel generator to reflect more realistic railway networks
# - Agents start outside of the environment and enter at their own time
# - Agents leave the environment after they have reached their goal
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
# We start by importing the necessary rail and schedule generators
# The rail generator will generate the railway infrastructure
# The schedule generator will assign tasks to all the agent within the railway network
# The railway infrastructure can be build using any of the provided generators in env/rail_generators.py
# Here we use the sparse_rail_generator with the following parameters
DO_RENDERING = False
width = 16 * 7 # With of map
height = 9 * 7 # Height of map
nr_trains = 50 # Number of trains that have an assigned task in the env
cities_in_map = 20 # Number of cities where agents can start or end
seed = 14 # Random seed
grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation
rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
seed=seed,
grid_mode=grid_distribution_of_cities,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_in_cities,
)
# rail_generator = SparseRailGen(max_num_cities=cities_in_map,
# seed=seed,
# grid_mode=grid_distribution_of_cities,
# max_rails_between_cities=max_rails_between_cities,
# max_rails_in_city=max_rail_in_cities,
# )
# The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.
# The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical
# distribution of speed profiles
# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
# We can now initiate the schedule generator with the given speed profiles
line_generator = sparse_line_generator(speed_ration_map)
# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.
stochastic_data = MalfunctionParameters(malfunction_rate=1 / 10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
# Custom observation builder without predictor
observation_builder = GlobalObsForRailEnv()
# Custom observation builder with predictor, uncomment line below if you want to try this one
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=nr_trains,
obs_builder_object=observation_builder,
malfunction_generator=ParamMalfunctionGen(stochastic_data),
remove_agents_at_target=True)
env.reset()
# Initiate the renderer
env_renderer = None
if DO_RENDERING:
env_renderer = RenderTool(env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
# The first thing we notice is that some agents don't have feasible paths to their target.
# We first look at the map we have created
# nv_renderer.render_env(show=True)
# time.sleep(2)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent instead
class RandomAgent:
def __init__(self, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
def act(self, state):
"""
:param state: input is the observation of the agent
:return: returns an action
"""
return np.random.choice([RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_LEFT,
RailEnvActions.STOP_MOVING])
def step(self, memories):
"""
Step function to improve agent by adjusting policy given the observations
:param memories: SARS Tuple to be
:return:
"""
return
def save(self, filename):
# Store the current policy
return
def load(self, filename):
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
controller = RandomAgent(218, env.action_space[0])
# We start by looking at the information of each agent
# We can see the task assigned to the agent by looking at
print("\n Agents in the environment have to solve the following tasks: \n")
for agent_idx, agent in enumerate(env.agents):
print(
"The agent with index {} has the task to go from its initial position {}, facing in the direction {} to its target at {}.".format(
agent_idx, agent.initial_position, agent.direction, agent.target))
# The agent will always have a status indicating if it is currently present in the environment or done or active
# For example we see that agent with index 0 is currently not active
print("\n Their current statuses are:")
print("============================")
for agent_idx, agent in enumerate(env.agents):
print("Agent {} status is: {} with its current position being {}".format(agent_idx, str(agent.state),
str(agent.position)))
# The agent needs to take any action [1,2,3] except do_nothing or stop to enter the level
# If the starting cell is free they will enter the level
# If multiple agents want to enter the same cell at the same time the lower index agent will enter first.
# Let's check if there are any agents with the same start location
agents_with_same_start = set()
print("\n The following agents have the same initial position:")
print("=====================================================")
for agent_idx, agent in enumerate(env.agents):
for agent_2_idx, agent2 in enumerate(env.agents):
if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position:
print("Agent {} as the same initial position as agent {}".format(agent_idx, agent_2_idx))
agents_with_same_start.add(agent_idx)
# Lets try to enter with all of these agents at the same time
action_dict = dict()
for agent_id in agents_with_same_start:
action_dict[agent_id] = 1 # Try to move with the agents
# Do a step in the environment to see what agents entered:
env.step(action_dict)
# Current state and position of the agents after all agents with same start position tried to move
print("\n This happened when all tried to enter at the same time:")
print("========================================================")
for agent_id in agents_with_same_start:
print(
"Agent {} status is: {} with the current position being {}.".format(
agent_id, str(env.agents[agent_id].state),
str(env.agents[agent_id].position)))
# As you see only the agents with lower indexes moved. As soon as the cell is free again the agents can attempt
# to start again.
# You will also notice, that the agents move at different speeds once they are on the rail.
# The agents will always move at full speed when moving, never a speed inbetween.
# The fastest an agent can go is 1, meaning that it moves to the next cell at every time step
# All slower speeds indicate the fraction of a cell that is moved at each time step
# Lets look at the current speed data of the agents:
print("\n The speed information of the agents are:")
print("=========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} speed is: {:.2f} with the current fractional position being {}/{}".format(
agent_idx, agent.speed_counter.speed, agent.speed_counter.counter, agent.speed_counter.max_count))
# New the agents can also have stochastic malfunctions happening which will lead to them being unable to move
# for a certain amount of time steps. The malfunction data of the agents can easily be accessed as follows
print("\n The malfunction data of the agents are:")
print("========================================")
for agent_idx, agent in enumerate(env.agents):
print(
"Agent {} is OK = {}".format(
agent_idx, agent.malfunction_handler.in_malfunction))
# Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take
# an action at every time step as it will only change the outcome when actions are chosen at cell entry.
# Therefore the environment provides information about what agents need to provide an action in the next step.
# You can access this in the following way.
# Chose an action for each agent
for a in range(env.get_num_agents()):
action = controller.act(0)
action_dict.update({a: action})
# Do the environment step
observations, rewards, dones, information = env.step(action_dict)
print("\n The following agents can register an action:")
print("========================================")
for info in information['action_required']:
print("Agent {} needs to submit an action.".format(info))
# We recommend that you monitor the malfunction data and the action required in order to optimize your training
# and controlling code.
# Let us now look at an episode playing out with random actions performed
print("\nStart episode...")
# Reset the rendering system
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
frame_step = 0
os.makedirs("tmp/frames", exist_ok=True)
for step in range(200):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = controller.act(observations[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
env_renderer.gl.save_image('tmp/frames/flatland_frame_{:04d}.png'.format(step))
frame_step += 1
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
controller.step((observations[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
observations = next_obs.copy()
if done['__all__']:
break
print('Episode: Steps {}\t Score = {}'.format(step, score))
# close the renderer / rendering window
if env_renderer is not None:
env_renderer.close_window()
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_manual_specifications_generator
from flatland.utils.rendertools import RenderTool
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1)
env.reset()
env_renderer = RenderTool(env)
env_renderer.render_env(show=True, show_predictions=False, show_observations=False)
input("Press Enter to continue...")
import random
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator
from flatland.utils.rendertools import RenderTool
random.seed(100)
np.random.seed(100)
# Relative weights of each cell type to be used by the random rail generators.
transition_probability = [1.0, # empty cell - Case 0
1.0, # Case 1 - straight
1.0, # Case 2 - simple switch
0.3, # Case 3 - diamond drossing
0.5, # Case 4 - single slip
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0, # Case 7 - dead end
0.2, # Case 8 - turn left
0.2, # Case 9 - turn right
1.0] # Case 10 - mirrored switch
# Example generate a random rail
env = RailEnv(width=10,
height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=3)
env.reset()
env_renderer = RenderTool(env, gl="PIL")
env_renderer.render_env(show=True)
input("Press Enter to continue...")
import random
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.rendertools import RenderTool
random.seed(1)
np.random.seed(1)
env = RailEnv(width=7,
height=7,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
env.obs_builder.util_print_obs_subtree(tree=obs[i])
env_renderer = RenderTool(env)
env_renderer.render_env(show=True, frames=True)
print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \
(turnleft+move, move to front, turnright+move)")
for step in range(100):
cmd = input(">> ")
cmds = cmd.split(" ")
action_dict = {}
i = 0
while i < len(cmds):
if cmds[i] == 'q':
import sys
sys.exit()
elif cmds[i] == 's':
obs, all_rewards, done, _ = env.step(action_dict)
action_dict = {}
print("Rewards: ", all_rewards, " [done=", done, "]")
else:
agent_id = int(cmds[i])
action = int(cmds[i + 1])
action_dict[agent_id] = action
i = i + 1
i += 1
env_renderer.render_env(show=True, frames=True)
import getopt
import sys
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObservation,
number_of_agents=3)
env_renderer = RenderTool(env, gl="PILSVG", )
def create_env():
nAgents = 1
n_cities = 2
max_rails_between_cities = 2
max_rails_in_city = 4
seed = 0
env = RailEnv(
width=30,
height=40,
rail_generator=sparse_rail_generator(
max_num_cities=n_cities,
seed=seed,
grid_mode=True,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rails_in_city
),
line_generator=sparse_line_generator(),
number_of_agents=nAgents,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
)
return env
# Import your own Agent or use RLlib to train agents on Flatland
......@@ -59,42 +70,85 @@ class RandomAgent:
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 5)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs = env.reset()
for idx in range(env.get_num_agents()):
tmp_agent = env.agents[idx]
tmp_agent.speed_data["speed"] = 1 / (idx + 1)
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
def training_example(sleep_for_animation, do_rendering):
np.random.seed(1)
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
env = create_env()
env.reset()
env_renderer = None
if do_rendering:
env_renderer = RenderTool(env)
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 5)
n_trials = 5
# Empty dictionary for all agent action
action_dict = dict()
print("Starting Training...")
for trials in range(1, n_trials + 1):
# Reset environment and get initial observations for all agents
obs, info = env.reset()
if env_renderer is not None:
env_renderer.reset()
# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository
score = 0
# Run episode
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a]
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials, score))
if env_renderer is not None:
env_renderer.close_window()
def main(args):
try:
opts, args = getopt.getopt(args, "", ["sleep-for-animation=", "do_rendering=", ""])
except getopt.GetoptError as err:
print(str(err)) # will print something like "option -a not recognized"
sys.exit(2)
sleep_for_animation = True
do_rendering = True
for o, a in opts:
if o in ("--sleep-for-animation"):
sleep_for_animation = str2bool(a)
elif o in ("--do_rendering"):
do_rendering = str2bool(a)
else:
assert False, "unhandled option"
# execute example
training_example(sleep_for_animation, do_rendering)
if __name__ == '__main__':
if 'argv' in globals():
main(argv)
else:
main(sys.argv[1:])
......@@ -4,4 +4,4 @@
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
__version__ = '0.3.10'
__version__ = '3.0.15'
import pprint
from typing import Dict, List, Optional, NamedTuple
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
# ---- ActionPlan ---------------
# an action plan element represents the actions to be taken by an agent at the given time step
ActionPlanElement = NamedTuple('ActionPlanElement', [
('scheduled_at', int),
('action', RailEnvActions)
])
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]
# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
ActionPlanDict = Dict[int, ActionPlan]
class ControllerFromTrainruns():
"""Takes train runs, derives the actions from it and re-acts them."""
pp = pprint.PrettyPrinter(indent=4)
def __init__(self,
env: RailEnv,
trainrun_dict: Dict[int, Trainrun]):
self.env: RailEnv = env
self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
for agent_id, chosen_path in trainrun_dict.items()]
def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
"""
Get the way point point from which the current position can be extracted.
Parameters
----------
agent_id
step
Returns
-------
WalkingElement
"""
trainrun = self.trainrun_dict[agent_id]
entry_time_step = trainrun[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = trainrun[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
waypoint = None
for trainrun_waypoint in trainrun:
if step < trainrun_waypoint.scheduled_at:
return waypoint
if step >= trainrun_waypoint.scheduled_at:
waypoint = trainrun_waypoint.waypoint
assert waypoint is not None
return waypoint
def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
Get the current action if any is defined in the `ActionPlan`.
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
agent_id
current_step
Returns
-------
WalkingElement, optional
"""
for action_plan_element in self.action_plan[agent_id]:
scheduled_at = action_plan_element.scheduled_at
if scheduled_at > current_step:
return None
elif current_step == scheduled_at:
return action_plan_element.action
return None
def act(self, current_step: int) -> Dict[int, RailEnvActions]:
"""
Get the action dictionary to be replayed at the current step.
Returns only action where required (no action for done agents or those not at the beginning of the cell).
ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
Parameters
----------
current_step: int
Returns
-------
Dict[int, RailEnvActions]
"""
action_dict = {}
for agent_id in range(len(self.env.agents)):
action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
if action is not None:
action_dict[agent_id] = action
return action_dict
def print_action_plan(self):
"""Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns` to stdout."""
self.__class__.print_action_plan_dict(self.action_plan)
@staticmethod
def print_action_plan_dict(action_plan: ActionPlanDict):
"""Pretty-prints `ActionPlanDict` to stdout."""
for agent_id, plan in enumerate(action_plan):
print("{}: ".format(agent_id))
for step in plan:
print(" {}".format(step))
@staticmethod
def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict):
assert len(expected_action_plan) == len(actual_action_plan)
for k in range(len(expected_action_plan)):
assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
"len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format(
k,
len(expected_action_plan[k]),
ControllerFromTrainruns.pp.pformat(expected_action_plan[k]),
len(actual_action_plan[k]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k]))
for i in range(len(expected_action_plan[k])):
assert expected_action_plan[k][i] == actual_action_plan[k][i], \
"not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format(
k, i,
ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]),
ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i]))
assert expected_action_plan == actual_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = agent.speed_counter.max_count + 1
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
position = trainrun_waypoint.waypoint.position
if Vec2d.is_equal(agent.target, position):
break
next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
next_position = next_trainrun_waypoint.waypoint.position
if path_loop == 0:
self._add_action_plan_elements_for_first_path_element_of_agent(
action_plan,
trainrun_waypoint,
next_trainrun_waypoint,
minimum_cell_time
)
continue
just_before_target = Vec2d.is_equal(agent.target, next_position)
self._add_action_plan_elements_for_current_path_element(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
# add a final element
if just_before_target:
self._add_action_plan_elements_for_target_at_path_element_just_before_target(
action_plan,
minimum_cell_time,
trainrun_waypoint,
next_trainrun_waypoint)
return action_plan
def _add_action_plan_elements_for_current_path_element(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
next_entry_value = next_trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if next_entry_value > scheduled_at + minimum_cell_time:
action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
action_plan.append(action)
action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
action_plan.append(action)
else:
action = ActionPlanElement(scheduled_at, next_action)
action_plan.append(action)
def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
minimum_cell_time: int,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan.append(action)
def _add_action_plan_elements_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint,
minimum_cell_time: int):
scheduled_at = trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
# add intial do nothing if we do not enter immediately, actually not necessary
if scheduled_at > 0:
action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
action_plan.append(action)
# add action to enter the grid
action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
action_plan.append(action)
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
# if the agent is blocked in the cell, we have to call stop upon entering!
if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
action_plan.append(action)
# execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action)
from typing import Callable
from flatland.action_plan.action_plan import ControllerFromTrainruns
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_trainrun_data_structures import Waypoint
ControllerFromTrainrunsReplayerRenderCallback = Callable[[RailEnv], None]
class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv,
call_back: ControllerFromTrainrunsReplayerRenderCallback = lambda *a, **k: None):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible.
Parameters
----------
ctl
env
call_back
Called before/after each step() call. The env is passed to it.
"""
call_back(env)
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
actions = ctl.act(i)
obs, all_rewards, done, _ = env.step(actions)
call_back(env)
i += 1
......@@ -9,9 +9,9 @@ import numpy as np
import redis
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.evaluators.service import FlatlandRemoteEvaluationService, FLATLAND_RL_SERVICE_ID
from flatland.utils.rendertools import RenderTool
......@@ -19,39 +19,41 @@ from flatland.utils.rendertools import RenderTool
def demo(args=None):
"""Demo script to check installation"""
env = RailEnv(
width=15,
height=15,
rail_generator=complex_rail_generator(
nr_start_goal=10,
nr_extra=1,
min_dist=8,
max_dist=99999),
schedule_generator=complex_schedule_generator(),
width=30,
height=30,
rail_generator=sparse_rail_generator(
max_num_cities=3,
grid_mode=False,
max_rails_between_cities=4,
max_rail_pairs_in_city=2,
seed=0
),
line_generator=sparse_line_generator(),
number_of_agents=5)
env._max_episode_steps = int(15 * (env.width + env.height))
env_renderer = RenderTool(env)
while True:
obs = env.reset()
_done = False
# Run a single episode here
step = 0
while not _done:
# Compute Action
_action = {}
for _idx, _ in enumerate(env.agents):
_action[_idx] = np.random.randint(0, 5)
obs, all_rewards, done, _ = env.step(_action)
_done = done['__all__']
step += 1
env_renderer.render_env(
show=True,
frames=False,
show_observations=False,
show_predictions=False
)
time.sleep(0.3)
obs, info = env.reset()
_done = False
# Run a single episode here
step = 0
while not _done:
# Compute Action
_action = {}
for _idx, _ in enumerate(env.agents):
_action[_idx] = np.random.randint(0, 5)
obs, all_rewards, done, _ = env.step(_action)
_done = done['__all__']
step += 1
env_renderer.render_env(
show=True,
frames=False,
show_observations=False,
show_predictions=False
)
time.sleep(0.1)
return 0
......@@ -62,11 +64,28 @@ def demo(args=None):
required=True
)
@click.option('--service_id',
default="FLATLAND_RL_SERVICE_ID",
default=FLATLAND_RL_SERVICE_ID,
help="Evaluation Service ID. This has to match the service id on the client.",
required=False
)
def evaluator(tests, service_id):
@click.option('--shuffle',
type=bool,
default=False,
help="Shuffle the environments before starting evaluation.",
required=False
)
@click.option('--disable_timeouts',
default=False,
help="Disable all evaluation timeouts.",
required=False
)
@click.option('--results_path',
type=click.Path(exists=False),
default=None,
help="Path where the evaluator should write the results metadata.",
required=False
)
def evaluator(tests, service_id, shuffle, disable_timeouts, results_path):
try:
redis_connection = redis.Redis()
redis_connection.ping()
......@@ -80,7 +99,10 @@ def evaluator(tests, service_id):
test_env_folder=tests,
flatland_rl_service_id=service_id,
visualize=False,
verbose=False
result_output_path=results_path,
verbose=False,
shuffle=shuffle,
disable_timeouts=disable_timeouts
)
grader.run()
......
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""
def parallel_wrapper_fn(env_fn):
def par_fn(**kwargs):
env = env_fn(**kwargs)
env = custom_parallel_wrapper(env)
return env
return par_fn
def env(**kwargs):
env = raw_env(**kwargs)
# env = wrappers.AssertOutOfBoundsWrapper(env)
# env = wrappers.OrderEnforcingWrapper(env)
return env
parallel_env = parallel_wrapper_fn(env)
class custom_parallel_wrapper(to_parallel_wrapper):
def step(self, actions):
rewards = {a: 0 for a in self.aec_env.agents}
dones = {}
infos = {}
observations = {}
for agent in self.aec_env.agents:
try:
assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
except Exception as e:
# print(e)
print(self.aec_env.dones.values())
raise e
obs, rew, done, info = self.aec_env.last()
self.aec_env.step(actions.get(agent,0))
for agent in self.aec_env.agents:
rewards[agent] += self.aec_env.rewards[agent]
dones = dict(**self.aec_env.dones)
infos = dict(**self.aec_env.infos)
self.agents = self.aec_env.agents
observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
return observations, rewards, dones, infos
class raw_env(AECEnv, gym.Env):
metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
self._possible_agents = self.agents[:]
self._reset_next_step = True
self._agent_selector = agent_selector(self.agents)
self.num_actions = 5
self.action_spaces = {
agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
}
self.seed()
# preprocessor must be for observation builders other than global obs
# treeobs builders would use the default preprocessor if none is
# supplied
self.preprocessor = self._obtain_preprocessor(preprocessor)
self._include_agent_info = agent_info
# observation space:
# flatland defines no observation space for an agent. Here we try
# to define the observation space. All agents are identical and would
# have the same observation space.
# Infer observation space based on returned observation
obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
obs = self.preprocessor(obs)
self.observation_spaces = {
i: infer_observation_space(ob) for i, ob in obs.items()
}
@property
def environment(self) -> RailEnv:
"""Returns the wrapped environment."""
return self._environment
@property
def dones(self):
dones = self._environment.dones
# remove_all = dones.pop("__all__", None)
return {get_agent_keys(key): value for key, value in dones.items()}
@property
def obs_builder(self):
return self._environment.obs_builder
@property
def width(self):
return self._environment.width
@property
def height(self):
return self._environment.height
@property
def agents_data(self):
"""Rail Env Agents data."""
return self._environment.agents
@property
def num_agents(self) -> int:
"""Returns the number of trains/agents in the flatland environment"""
return int(self._environment.number_of_agents)
# def __getattr__(self, name):
# """Expose any other attributes of the underlying environment."""
# return getattr(self._environment, name)
@property
def agents(self):
return self._agents
@property
def possible_agents(self):
return self._possible_agents
def env_done(self):
return self._environment.dones["__all__"] or not self.agents
def observe(self,agent):
return self.obs.get(agent)
def last(self, observe=True):
'''
returns observation, reward, done, info for the current agent (specified by self.agent_selection)
'''
agent = self.agent_selection
observation = self.observe(agent) if observe else None
return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
def seed(self, seed: int = None) -> None:
self._environment._seed(seed)
def state(self):
'''
Returns an observation of the global environment
'''
return None
def _clear_rewards(self):
'''
clears all items in .rewards
'''
# pass
for agent in self.rewards:
self.rewards[agent] = 0
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
self.agent_selection = self._agent_selector.next()
self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
return observations
def step(self, action):
if self.env_done():
self._agents = []
self._reset_next_step = True
return self.last()
agent = self.agent_selection
self.action_dict[get_agent_handle(agent)] = action
if self.dones[agent]:
# Disabled.. In case we want to remove agents once done
# if self.remove_agents:
# self.agents.remove(agent)
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
else:
self._clear_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
if self._agent_selector.is_last():
observations, rewards, dones, infos = self._environment.step(self.action_dict)
self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
if observations:
observations = self._collate_obs_and_info(observations, infos)
else:
self._clear_rewards()
# self._cumulative_rewards[agent] = 0
self._accumulate_rewards()
obs, cumulative_reward, done, info = self.last()
self.agent_selection = self._agent_selector.next()
return obs, cumulative_reward, done, info
# collate agent info and observation into a tuple, making the agents obervation to
# be a tuple of the observation from the env and the agent info
def _collate_obs_and_info(self, observes, info):
observations = {}
infos = {}
observes = self.preprocessor(observes)
for agent, obs in observes.items():
all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
agent_info = np.array(
list(all_infos.values()), dtype=np.float32
)
infos[agent] = all_infos
obs = (obs, agent_info) if self._include_agent_info else obs
observations[agent] = obs
self.infos = infos
self.obs = observations
return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='rgb_array'):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
"""
if mode == "rgb_array":
env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
self.image_shape = env_rgb_array.shape
if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
df = pd.DataFrame(np.array(self.probs).T)
sns.barplot(x=df.index, y=0, data=df, ax=ax)
ax.set(xlabel='actions', ylabel='probs')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())
Image.fromarray(X)
# Image.fromarray(X)
rgb_image = np.array(Image.fromarray(X).convert('RGB'))
plt.close(fig)
q_value_rgb_array = rgb_image
return np.append(env_rgb_array, q_value_rgb_array, axis=1)
else:
return self._environment.render(mode)
def close(self):
self._environment.close()
def _obtain_preprocessor(self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
_preprocessor = preprocessor if preprocessor else lambda x: x
if isinstance(self.obs_builder, TreeObsForRailEnv):
_preprocessor = (
partial(
normalize_observation, tree_depth=self.obs_builder.max_depth
)
if not preprocessor
else preprocessor
)
assert _preprocessor is not None
else:
def _preprocessor(x):
return x
def returned_preprocessor(obs):
temp_obs = {}
for agent_id, ob in obs.items():
temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
return temp_obs
return returned_preprocessor
# Utility functions
def convert_np_type(dtype, value):
return np.dtype(dtype).type(value)
def get_agent_handle(id):
"""Obtain an agents handle given its id"""
return int(id)
def get_agent_keys(id):
"""Obtain an agents handle given its id"""
return str(id)
\ No newline at end of file
id-mava[flatland]
id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
seaborn
matplotlib
pandas
\ No newline at end of file