From 0744a856ceb882e85a39213c8080b7fcb15b4fea Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Mon, 1 Jul 2019 22:47:06 +0200 Subject: [PATCH] added variants of rendering: see AgentRenderVariant --- examples/simple_example_3.py | 2 +- flatland/utils/rendertools.py | 51 ++++++++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index eb7ac9a..ddb98e9 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -2,7 +2,7 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator, complex_rail_generator +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 9216402..7932aac 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -1,5 +1,6 @@ import time from collections import deque +from enum import IntEnum import numpy as np from numpy import array @@ -10,6 +11,13 @@ from flatland.utils.graphics_pil import PILGL, PILSVG # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! +class AgentRenderVariant(IntEnum): + BOX_ONLY = 0 + ONE_STEP_BEHIND = 1 + AGENT_SHOWS_OPTIONS = 2 + ONE_STEP_BEHIND_AND_BOX = 3 + AGENT_SHOWS_OPTIONS_AND_BOX = 4 + class RenderTool(object): """ Class to render the RailEnv and agents. @@ -30,12 +38,14 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="PILSVG", jupyter=False): + def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS): self.env = env self.iFrame = 0 self.time1 = time.time() self.lTimes = deque() + self.agentRenderVariant = agentRenderVariant + if gl == "PIL": self.gl = PILGL(env.width, env.height, jupyter) elif gl == "PILSVG": @@ -664,18 +674,39 @@ class RenderTool(object): if agent is None: continue - if agent.old_position is not None: - position = agent.old_position - direction = agent.direction - old_direction = agent.old_direction + if self.agentRenderVariant == AgentRenderVariant.BOX_ONLY: + self.gl.setCellOccupied(iAgent, *(agent.position)) + elif self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND or \ + self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: + if agent.old_position is not None: + position = agent.old_position + direction = agent.direction + old_direction = agent.old_direction + else: + position = agent.position + direction = agent.direction + old_direction = agent.direction + + # setAgentAt uses the agent index for the color + if self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: + self.gl.setCellOccupied(iAgent, *(agent.position)) + self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent) else: position = agent.position direction = agent.direction - old_direction = agent.direction - - # setAgentAt uses the agent index for the color - self.gl.setCellOccupied(iAgent, *(agent.position)) - self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent) + for possible_directions in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions) + if isValid: + direction = possible_directions + + # setAgentAt uses the agent index for the color + self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent) + + # setAgentAt uses the agent index for the color + if self.agentRenderVariant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: + self.gl.setCellOccupied(iAgent, *(agent.position)) + self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent) if show_observations: self.renderObs(range(env.get_num_agents()), env.dev_obs_dict) -- GitLab