Skip to content
Snippets Groups Projects
Commit da1b4bad authored by Erik Nygren's avatar Erik Nygren
Browse files

added initial tree rendering

updated tests after failing
parent 4ab8dbe4
No related branches found
No related tags found
No related merge requests found
...@@ -24,7 +24,7 @@ transition_probability = [15, # empty cell - Case 0 ...@@ -24,7 +24,7 @@ transition_probability = [15, # empty cell - Case 0
1] # Case 2b (10) - simple switch mirrored 1] # Case 2b (10) - simple switch mirrored
# Example generate a random rail # Example generate a random rail
"""
env = RailEnv(width=10, env = RailEnv(width=10,
height=10, height=10,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
...@@ -35,6 +35,7 @@ env = RailEnv(width=15, ...@@ -35,6 +35,7 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=10, min_dist=5, max_dist=99999, seed=0),
number_of_agents=3) number_of_agents=3)
""" """
"""
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
...@@ -116,7 +117,8 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): ...@@ -116,7 +117,8 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
for trials in range(1, n_trials + 1): for trials in range(1, n_trials + 1):
# Reset environment # Reset environment
obs, _ = env.reset() obs, dev_obs = env.reset()
env.dev_obs_dict = dev_obs
final_obs = obs.copy() final_obs = obs.copy()
final_obs_next = obs.copy() final_obs_next = obs.copy()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
...@@ -148,8 +150,8 @@ for trials in range(1, n_trials + 1): ...@@ -148,8 +150,8 @@ for trials in range(1, n_trials + 1):
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
(next_obs,_), all_rewards, done, _ = env.step(action_dict) (next_obs, dev_obs), all_rewards, done, _ = env.step(action_dict)
env.dev_obs_dict = dev_obs
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
current_depth=0) current_depth=0)
......
...@@ -98,7 +98,7 @@ class RailEnv(Environment): ...@@ -98,7 +98,7 @@ class RailEnv(Environment):
self.obs_dict = {} self.obs_dict = {}
self.rewards_dict = {} self.rewards_dict = {}
self.dev_obs_dict = {}
# self.agents_handles = list(range(self.number_of_agents)) # self.agents_handles = list(range(self.number_of_agents))
# self.agents_position = [] # self.agents_position = []
......
...@@ -471,6 +471,23 @@ class RenderTool(object): ...@@ -471,6 +471,23 @@ class RenderTool(object):
xyMid, xyMid,
xyMid + [-dx + dy, -dx - dy]]) xyMid + [-dx + dy, -dx - dy]])
self.gl.plot(*xyArrow.T, color=sColor) self.gl.plot(*xyArrow.T, color=sColor)
def renderObs(self, agent_handles, observation_list):
"""
:param agent_handles:
:param observation_list:
:return:
"""
rt = self.__class__
cmap = self.gl.get_cmap('hsv',lut=max(len(self.env.agents),len(self.env.agents_static)+1))
for agent in agent_handles:
color = cmap(agent)
for visited_cell in observation_list[agent]:
cell_coord = array(visited_cell[:2])
cell_coord_trans = np.matmul(cell_coord,rt.grc2xy)+rt.xyHalf
self._draw_square(cell_coord_trans,1 / 3, color)
def renderEnv( def renderEnv(
self, show=False, curves=True, spacing=False, self, show=False, curves=True, spacing=False,
...@@ -612,6 +629,7 @@ class RenderTool(object): ...@@ -612,6 +629,7 @@ class RenderTool(object):
if agents: if agents:
self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent) self.plotAgents(targets=True, iSelectedAgent=iSelectedAgent)
self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)
# Draw some textual information like fps # Draw some textual information like fps
yText = [-0.3, -0.6, -0.9] yText = [-0.3, -0.6, -0.9]
if frames: if frames:
......
...@@ -6,10 +6,10 @@ Tests for `flatland` package. ...@@ -6,10 +6,10 @@ Tests for `flatland` package.
from flatland.envs.rail_env import RailEnv, random_rail_generator from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np import numpy as np
<<<<<<< HEAD #<<<<<<< HEAD
======= #=======
# import os # import os
>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147 #>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147
import sys import sys
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment