diff --git a/examples/training_navigation.py b/examples/training_navigation.py index a7920c7da1c330494f2f37e298dd9f691378f115..4a21ad11417317ab50337116cccb772ac220d7cc 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -24,7 +24,7 @@ transition_probability = [15, # empty cell - Case 0 1] # Case 2b (10) - simple switch mirrored # Example generate a random rail -""" + env = RailEnv(width=10, height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), @@ -35,6 +35,7 @@ env = RailEnv(width=15, rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), number_of_agents=3) """ +""" env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( @@ -116,7 +117,8 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs, _ = env.reset() + obs, dev_obs = env.reset() + env.dev_obs_dict = dev_obs final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): @@ -148,8 +150,8 @@ for trials in range(1, n_trials + 1): action_dict.update({a: action}) # Environment step - (next_obs,_), all_rewards, done, _ = env.step(action_dict) - + (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict) + env.dev_obs_dict = dev_obs for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, current_depth=0) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 749e5e01aa94e8ddb59d33d5787e3bd55d05f9eb..640eb5c537c303d11e3e31aac9b3e098af1149ed 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -98,7 +98,7 @@ class RailEnv(Environment): self.obs_dict = {} self.rewards_dict = {} - + self.dev_obs_dict = {} # self.agents_handles = list(range(self.number_of_agents)) # self.agents_position = [] diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 4921def4218beef21e119a3979bbb88c02af984d..1bd67c22ad4ba05bd6badfc991e8839a85fe9e30 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -471,6 +471,23 @@ class RenderTool(object): xyMid, xyMid + [-dx + dy, -dx - dy]]) self.gl.plot(*xyArrow.T, color=sColor) + def renderObs(self, agent_handles, observation_list): + """ + + :param agent_handles: + :param observation_list: + :return: + """ + rt = self.__class__ + + cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1)) + + for agent in agent_handles: + color = cmap(agent) + for visited_cell in observation_list[agent]: + cell_coord = array(visited_cell[:2]) + cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf + self._draw_square(cell_coord_trans,1 / 3, color) def renderEnv( self, show=False, curves=True, spacing=False, @@ -612,6 +629,7 @@ class RenderTool(object): if agents: self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) + self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) # Draw some textual information like fps yText = [-0.3, -0.6, -0.9] if frames: diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index f6defb2de400f1a007da237b2f91ec38df5db07b..245f2f327524653b3cf03bf921f6db6b0d4b51fb 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -6,10 +6,10 @@ Tests for `flatland` package. from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np -<<<<<<< HEAD -======= +#<<<<<<< HEAD +#======= # import os ->>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147 +#>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147 import sys import matplotlib.pyplot as plt