diff --git a/examples/training_navigation.py b/examples/training_navigation.py index ec19ff20541e94b004b6c788d8e5df707019b2c0..cabb655e3eb2bc2d12d908559b0102b072163052 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -5,7 +5,7 @@ from flatland.utils.rendertools import * from flatland.baselines.dueling_double_dqn import Agent from collections import deque import torch, random - +import time random.seed(1) np.random.seed(1) @@ -25,15 +25,16 @@ transition_probability = [15, # empty cell - Case 0 # Example generate a random rail """ -env = RailEnv(width=10, - height=10, +env = RailEnv(width=20, + height=20, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=5) + number_of_agents=1) """ env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=3, min_dist=5, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=30, min_dist=5, max_dist=99999, seed=0), number_of_agents=3) + """ env = RailEnv(width=20, height=20, @@ -139,7 +140,8 @@ for trials in range(1, n_trials + 1): # Run episode for step in range(100): if demo: - env_renderer.renderEnv(show=True) + env_renderer.renderEnv(show=True, obsrender=True) + time.sleep(2) # print(step) # Action for a in range(env.get_num_agents()): diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 04e9a8fefac4ada79527f00200dc6fbfa3d7b924..c1578a816e2e30127fb77dda6e72ab51b2f41cb2 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail -def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=99999, seed=0): """ Parameters ------- diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 09a9f1044a4913288af2e260592b73c90245a3a8..34f3e9fa6857e86f4d99d211784d983a2e2a1e75 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -472,12 +472,13 @@ class RenderTool(object): xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color=sColor) - def renderObs(self, agent_handles, observation_list): + def renderObs(self, agent_handles, observation_dict): """ + Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be + highlighted. + :param agent_handles: List of agent indices to adapt color and get correct observation + :param observation_dict: dictionary containing sets of cells of the agent observation - :param agent_handles: - :param observation_list: - :return: """ rt = self.__class__ @@ -485,7 +486,7 @@ class RenderTool(object): for agent in agent_handles: color = cmap(agent) - for visited_cell in observation_list[agent]: + for visited_cell in observation_dict[agent]: cell_coord = array(visited_cell[:2]) cell_coord_trans = np.matmul(cell_coord, rt.grc2xy) + rt.xyHalf self._draw_square(cell_coord_trans, 1 / 3, color)