diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index eb7ac9ab820199f71db2d0a5af25a2dd636c614e..ddb98e9979c059f3a08840a7d229448c83b839d2 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -2,7 +2,7 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator, complex_rail_generator +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 9216402421ad0ff0d674438bd4a7a45148008ac7..7932aacafe5c67e188afa9f52a90ec47f73ff8da 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -1,5 +1,6 @@ import time from collections import deque +from enum import IntEnum import numpy as np from numpy import array @@ -10,6 +11,13 @@ from flatland.utils.graphics_pil import PILGL, PILSVG # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! +class AgentRenderVariant(IntEnum): + BOX_ONLY = 0 + ONE_STEP_BEHIND = 1 + AGENT_SHOWS_OPTIONS = 2 + ONE_STEP_BEHIND_AND_BOX = 3 + AGENT_SHOWS_OPTIONS_AND_BOX = 4 + class RenderTool(object): """ Class to render the RailEnv and agents. @@ -30,12 +38,14 @@ class RenderTool(object): gTheta = np.linspace(0, np.pi / 2, 5) gArc = array([np.cos(gTheta), np.sin(gTheta)]).T # from [1,0] to [0,1] - def __init__(self, env, gl="PILSVG", jupyter=False): + def __init__(self, env, gl="PILSVG", jupyter=False, agentRenderVariant=AgentRenderVariant.AGENT_SHOWS_OPTIONS): self.env = env self.iFrame = 0 self.time1 = time.time() self.lTimes = deque() + self.agentRenderVariant = agentRenderVariant + if gl == "PIL": self.gl = PILGL(env.width, env.height, jupyter) elif gl == "PILSVG": @@ -664,18 +674,39 @@ class RenderTool(object): if agent is None: continue - if agent.old_position is not None: - position = agent.old_position - direction = agent.direction - old_direction = agent.old_direction + if self.agentRenderVariant == AgentRenderVariant.BOX_ONLY: + self.gl.setCellOccupied(iAgent, *(agent.position)) + elif self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND or \ + self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: + if agent.old_position is not None: + position = agent.old_position + direction = agent.direction + old_direction = agent.old_direction + else: + position = agent.position + direction = agent.direction + old_direction = agent.direction + + # setAgentAt uses the agent index for the color + if self.agentRenderVariant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: + self.gl.setCellOccupied(iAgent, *(agent.position)) + self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent) else: position = agent.position direction = agent.direction - old_direction = agent.direction - - # setAgentAt uses the agent index for the color - self.gl.setCellOccupied(iAgent, *(agent.position)) - self.gl.setAgentAt(iAgent, *position, old_direction, direction, iSelectedAgent == iAgent) + for possible_directions in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + isValid = env.rail.get_transition((*agent.position, agent.direction), possible_directions) + if isValid: + direction = possible_directions + + # setAgentAt uses the agent index for the color + self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent) + + # setAgentAt uses the agent index for the color + if self.agentRenderVariant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX: + self.gl.setCellOccupied(iAgent, *(agent.position)) + self.gl.setAgentAt(iAgent, *position, agent.direction, direction, iSelectedAgent == iAgent) if show_observations: self.renderObs(range(env.get_num_agents()), env.dev_obs_dict)