diff --git a/examples/play_model.py b/examples/play_model.py index 4f6ba79b3eb58317b54c28344688b20753b0cdf7..f80d9a14c54c1fa411a8736e4891ba5b95c188d8 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -106,10 +106,7 @@ def main(render=True, delay=0.0): if render: env_renderer = RenderTool(env, gl="QTSVG") - # plt.figure(figsize=(5,5)) - # fRedis = redis.Redis() - - # handle = env.get_agent_handles() + # env_renderer = RenderTool(env, gl="QT") state_size = 105 action_size = 4 @@ -167,6 +164,14 @@ def main(render=True, delay=0.0): action_prob[action] += 1 action_dict.update({a: action}) + if render: + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=action_dict) + #time.sleep(10) + if delay > 0: + time.sleep(delay) + + iFrame += 1 + # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): @@ -177,13 +182,6 @@ def main(render=True, delay=0.0): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) score += all_rewards[a] - if render: - env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) - #time.sleep(10) - if delay > 0: - time.sleep(delay) - - iFrame += 1 obs = next_obs.copy() if done['__all__']: diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 6ef8aa4148ef0ffb313d72c612b2a5471c47b975..49209460dd191f1628aabac49d3ce5b72df7e817 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -23,8 +23,7 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - - next_handle = 0 # this is not properly implemented + old_direction = attrib(default=None) @classmethod def from_lists(cls, positions, directions, targets): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 33b9f969d24424535a4630ab0edca7f9b932c383..adb40f03b2ee100d8e12cfead660f9600dd9968b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -208,34 +208,11 @@ class RailEnv(Environment): # compute number of possible transitions in the current # cell used to check for invalid actions - possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) - num_transitions = np.count_nonzero(possible_transitions) - - movement = agent.direction - # print(nbits,np.sum(possible_transitions)) - if action == 1: - movement = agent.direction - 1 - if num_transitions <= 1: - transition_isValid = False - - elif action == 3: - movement = agent.direction + 1 - if num_transitions <= 1: - transition_isValid = False - - movement %= 4 - - if action == 2: - if num_transitions == 1: - # - dead-end, straight line or curved line; - # movement will be the only valid transition - # - take only available transition - movement = np.argmax(possible_transitions) - transition_isValid = True - - new_position = get_new_position(agent.position, movement) + new_direction, transition_isValid = self.check_action(agent, action) + + new_position = get_new_position(agent.position, new_direction) # Is it a legal move? - # 1) transition allows the movement in the cell, + # 1) transition allows the new_direction in the cell, # 2) the new cell is not empty (case 0), # 3) the cell is free, i.e., no agent is currently in that cell @@ -259,7 +236,7 @@ class RailEnv(Environment): if transition_isValid is None: transition_isValid = self.rail.get_transition( (*agent.position, agent.direction), - movement) + new_direction) # cell_isFree = True # for j in range(self.number_of_agents): @@ -272,12 +249,13 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) if all([new_cell_isValid, transition_isValid, cell_isFree]): - # move and change direction to face the movement that was + # move and change direction to face the new_direction that was # performed # self.agents_position[i] = new_position - # self.agents_direction[i] = movement + # self.agents_direction[i] = new_direction agent.position = new_position - agent.direction = movement + agent.old_direction = agent.direction + agent.direction = new_direction else: # the action was not valid, add penalty self.rewards_dict[iAgent] += invalid_action_penalty @@ -307,6 +285,34 @@ class RailEnv(Environment): self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} + def check_action(self, agent, action): + transition_isValid = None + possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) + num_transitions = np.count_nonzero(possible_transitions) + + new_direction = agent.direction + # print(nbits,np.sum(possible_transitions)) + if action == 1: + new_direction = agent.direction - 1 + if num_transitions <= 1: + transition_isValid = False + + elif action == 3: + new_direction = agent.direction + 1 + if num_transitions <= 1: + transition_isValid = False + + new_direction %= 4 + + if action == 2: + if num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = np.argmax(possible_transitions) + transition_isValid = True + return new_direction, transition_isValid + def _get_observations(self): self.obs_dict = {} # for handle in self.agents_handles: diff --git a/flatland/utils/render_qt.py b/flatland/utils/render_qt.py index 56088068a217dd7050c973050f0cff0055a50b68..f70afaed2c82254f821b12deefc1a99abf445afd 100644 --- a/flatland/utils/render_qt.py +++ b/flatland/utils/render_qt.py @@ -175,38 +175,46 @@ class QTSVG(GraphicsLayer): svgWidget.renderer().load(bySVG) self.layout.addWidget(svgWidget, row, col) self.lwTrack.append(svgWidget) + else: + print("Illegal rail:", row, col, format(binTrans, "#018b")[2:]) def setAgentAt(self, iAgent, row, col, iDirIn, iDirOut): if iAgent < len(self.lwAgents): wAgent = self.lwAgents[iAgent] agentPrev = self.agents_prev[iAgent] + + # If we have an existing agent widget, we can just move it if wAgent is not None: self.layout.removeWidget(wAgent) self.layout.addWidget(wAgent, row, col) - if agentPrev.direction == iDirIn: - # print("moved ", iAgent, row, col, iDirIn) + # We can only reuse the image if noth new and old are straight and the same: + if iDirIn == iDirOut and \ + agentPrev.direction == iDirIn and \ + agentPrev.old_direction == agentPrev.direction: return else: + # need to load new image # print("new dir:", iAgent, row, col, agentPrev.direction, iDirIn) - agentPrev.direction = iDirIn + agentPrev.direction = iDirOut + agentPrev.old_direction = iDirIn sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string() bySVG = bytearray(sSVG, encoding='utf-8') wAgent.renderer().load(bySVG) return - else: - # Ensure we have adequate slots in the list lwAgents - for i in range(len(self.lwAgents), iAgent+1): - self.lwAgents.append(None) - self.agents_prev.append(None) + # Ensure we have adequate slots in the list lwAgents + for i in range(len(self.lwAgents), iAgent+1): + self.lwAgents.append(None) + self.agents_prev.append(None) + # Create a new widget for the agent sSVG = self.zug.getSvg(iAgent, iDirIn, iDirOut).to_string() bySVG = bytearray(sSVG, encoding='utf-8') svgWidget = QtSvg.QSvgWidget() svgWidget.renderer().load(bySVG) self.lwAgents[iAgent] = svgWidget - self.agents_prev[iAgent] = EnvAgent((row, col), iDirIn, (0, 0)) + self.agents_prev[iAgent] = EnvAgent((row, col), iDirOut, (0, 0), old_direction=iDirIn) self.layout.addWidget(svgWidget, row, col) # print("Created ", iAgent, row, col) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 16fcd0545bf8610328a8881c20d5ab6f3be30bda..18bc3c8d1854a0d74b9c8a9c08e473d70604cb26 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -476,7 +476,8 @@ class RenderTool(object): self, show=False, curves=True, spacing=False, arrows=False, agents=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None): + iSelectedAgent=None, + action_dict=None): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -489,7 +490,7 @@ class RenderTool(object): self.renderEnv2(show, curves, spacing, arrows, agents, sRailColor, frames, iEpisode, iStep, - iSelectedAgent) + iSelectedAgent, action_dict) return # cell_size is a bit pointless with matplotlib - it does not relate to pixels, @@ -691,7 +692,8 @@ class RenderTool(object): self, show=False, curves=True, spacing=False, arrows=False, agents=True, sRailColor="gray", frames=False, iEpisode=None, iStep=None, - iSelectedAgent=None): + iSelectedAgent=None, + action_dict=dict()): """ Draw the environment using matplotlib. Draw into the figure if provided. @@ -715,10 +717,24 @@ class RenderTool(object): for iAgent, agent in enumerate(self.env.agents): if agent is None: continue - self.gl.setAgentAt(iAgent, *agent.position, agent.direction, agent.direction) + new_direction = agent.direction + action_isValid = False + + if iAgent in action_dict: + iAction = action_dict[iAgent] + new_direction, action_isValid = self.env.check_action(agent, iAction) + + if action_isValid: + self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction) + else: + pass + # print("invalid action - agent ", iAgent, " bend ", agent.direction, new_direction) + # self.gl.setAgentAt(iAgent, *agent.position, agent.direction, new_direction) + self.gl.show() - self.gl.processEvents() + for i in range(3): + self.gl.processEvents() self.iFrame += 1 return diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index f9daf19c832e4fd779ea856c12a2c6ab623bbd00..bd6ab4e24f4b0441073a08aad0924b25c151bdbd 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -65,19 +65,21 @@ class Zug(object): def getSvg(self, iAgent, iDirIn, iDirOut): delta_dir = (iDirOut - iDirIn) % 4 + # if delta_dir != 0: + # print("Bend:", iAgent, iDirIn, iDirOut) if delta_dir in (0, 2): svg = self.svg_straight.copy() svg.set_rotate(iDirIn * 90) return svg - if delta_dir == 1: + if delta_dir == 1: # bend to right, eg N->E, E->S svg = self.svg_curve1.copy() svg.set_rotate((iDirIn - 1) * 90) return svg - elif delta_dir == 3: - svg = self.svg_curve1.copy() + elif delta_dir == 3: # bend to left, eg N->W + svg = self.svg_curve2.copy() svg.set_rotate(iDirIn * 90) return svg @@ -94,7 +96,7 @@ class Track(object): "ES NW": "Gleis_Kurve_unten_links.svg", "NE WS": "Gleis_Kurve_unten_rechts.svg", "NN SS": "Gleis_vertikal.svg", - "NN SS ES NW SE WN": "Weiche_Double_Slip.svg", + "NN SS EE WW ES NW SE WN": "Weiche_Double_Slip.svg", "EE WW EN SW": "Weiche_horizontal_oben_links.svg", "EE WW SE WN": "Weiche_horizontal_oben_rechts.svg", "EE WW ES NW": "Weiche_horizontal_unten_links.svg",