diff --git a/flatland/core/env.py b/flatland/core/env.py index 3c25beea236945b1728959e02ea07f6c0ba7a6ac..32691f507f4cb5586f10b5645cc22ece718edc21 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,6 +84,21 @@ class Environment: """ raise NotImplementedError() + def predict(self): + """ + Predictions step. + + Returns predictions for the agents. + The returns are dicts mapping from agent_id strings to values. + + Returns + ------- + predictions : dict + New predictions for each ready agent. + + """ + raise NotImplementedError() + def render(self): """ Perform rendering of the environment. diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 3cef545c1658e6bfe2a292ee26c3e665ce6a5abc..f85afee4b625e59374c6cce266bf55b21e7fdb84 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -19,7 +19,6 @@ class ObservationBuilder: def __init__(self): self.observation_space = () - pass def _set_env(self, env): self.env = env diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5e4dc5033ba3a789313b47d94b033251cb8276 --- /dev/null +++ b/flatland/core/env_prediction_builder.py @@ -0,0 +1,44 @@ +""" +PredictionBuilder objects are objects that can be passed to environments designed for customizability. +The PredictionBuilder-derived custom classes implement 2 functions, reset() and get([handle]). +If predictions are not required in every step or not for all agents, then + ++ Reset() is called after each environment reset, to allow for pre-computing relevant data. + ++ Get() is called whenever an step has to be computed, potentially for each agent independently in +case of multi-agent environments. +""" + + +class PredictionBuilder: + """ + PredictionBuilder base class. + """ + + def __init__(self, max_depth: int = 20): + self.max_depth = max_depth + + def _set_env(self, env): + self.env = env + + def reset(self): + """ + Called after each environment reset. + """ + pass + + def get(self, handle=0): + """ + Called whenever step_prediction is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + A prediction structure, specific to the corresponding environment. + """ + raise NotImplementedError() diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..95c1a984c4151a9a873deeeb29438b290bb4f77e --- /dev/null +++ b/flatland/envs/predictions.py @@ -0,0 +1,72 @@ +""" +Collection of environment-specific PredictionBuilder. +""" + +import numpy as np + +from flatland.core.env_prediction_builder import PredictionBuilder + + +class DummyPredictorForRailEnv(PredictionBuilder): + """ + DummyPredictorForRailEnv object. + + This object returns predictions for agents in the RailEnv environment. + The prediction acts as if no other agent is in the environment and always takes the forward action. + """ + + def get(self, handle=None): + """ + Called whenever step_prediction is called on the environment. + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + Returns a dictionary index by the agent handle and for each agent a vector of 5 elements: + - time_offset + - position axis 0 + - position axis 1 + - direction + - action taken to come here + """ + agents = self.env.agents + if handle: + agents = [self.env.agents[handle]] + + prediction_dict = {} + + for agent in agents: + + # 0: do nothing + # 1: turn left and move to the next cell + # 2: move to the next cell in front of the agent + # 3: turn right and move to the next cell + action_priorities = [2, 1, 3] + _agent_initial_position = agent.position + _agent_initial_direction = agent.direction + prediction = np.zeros(shape=(self.max_depth, 5)) + prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] + for index in range(1, self.max_depth): + action_done = False + for action in action_priorities: + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action, + agent) + if all([new_cell_isValid, transition_isValid]): + # move and change direction to face the new_direction that was + # performed + agent.position = new_position + agent.direction = new_direction + prediction[index] = [index, new_position[0], new_position[1], new_direction, action] + action_done = True + break + if not action_done: + print("Cannot move further.") + prediction_dict[agent.handle] = prediction + agent.position = _agent_initial_position + agent.direction = _agent_initial_direction + return prediction_dict diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a23b6ac71651e8e424bb90a23cbcfd472ce89a12..82d694cccb695da512960f3a69c918917ba73d22 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -51,7 +51,9 @@ class RailEnv(Environment): height, rail_generator=random_rail_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)): + obs_builder_object=TreeObsForRailEnv(max_depth=2), + prediction_builder_object=None + ): """ Environment init. @@ -94,6 +96,11 @@ class RailEnv(Environment): self.obs_builder = obs_builder_object self.obs_builder._set_env(self) + self.prediction_builder = prediction_builder_object + if self.prediction_builder: + self.prediction_builder._set_env(self) + + self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -212,52 +219,8 @@ class RailEnv(Environment): return if action > 0: - # pos = agent.position # self.agents_position[i] - # direction = agent.direction # self.agents_direction[i] - - # compute number of possible transitions in the current - # cell used to check for invalid actions - - new_direction, transition_isValid = self.check_action(agent, action) - - new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell - - # if ( - # new_position[1] >= self.width or - # new_position[0] >= self.height or - # new_position[0] < 0 or new_position[1] < 0): - # new_cell_isValid = False - - # if self.rail.get_transitions(new_position) == 0: - # new_cell_isValid = False - - new_cell_isValid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) - - # If transition validity hasn't been checked yet. - if transition_isValid is None: - transition_isValid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - - # cell_isFree = True - # for j in range(self.number_of_agents): - # if self.agents_position[j] == new_position: - # cell_isFree = False - # break - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_isFree = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) - + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action, + agent) if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the new_direction that was # performed @@ -296,6 +259,52 @@ class RailEnv(Environment): self.actions = [0] * self.get_num_agents() return self._get_observations(), self.rewards_dict, self.dones, {} + def _check_action_on_agent(self, action, agent): + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_isValid = self.check_action(agent, action) + new_position = get_new_position(agent.position, new_direction) + # Is it a legal move? + # 1) transition allows the new_direction in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) + # If transition validity hasn't been checked yet. + if transition_isValid is None: + transition_isValid = self.rail.get_transition( + (*agent.position, agent.direction), + new_direction) + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid + + def predict(self): + if not self.prediction_builder: + return {} + return self.prediction_builder.get() + + def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -332,6 +341,11 @@ class RailEnv(Environment): self.obs_dict[iAgent] = self.obs_builder.get(iAgent) return self.obs_dict + def _get_predictions(self): + if not self.prediction_builder: + return {} + return {} + def render(self): # TODO: pass diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 8951095fdbfff3d62986ddc0d83a471f064e8bc4..b738e5e1a920f30a65389522b971900b0fbcf17e 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -394,9 +394,9 @@ class PILSVG(PILGL): print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) if isSelected: - svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg") - self.clear_layer(3, 0) - self.drawImageRC(svgBG, (row, col), layer=3) + svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg") + self.clear_layer(3,0) + self.drawImageRC(svgBG,(row,col),layer=3) def recolorImage(self, pil, a3BaseColor, ltColors): rgbaImg = array(pil) diff --git a/notebooks/Editor2.ipynb b/notebooks/Editor2.ipynb index 0645909b911be07a5cf53d75225350ad01723def..ddfea41efa25d4c789522e051fb5bb8a1e8e3ef2 100644 --- a/notebooks/Editor2.ipynb +++ b/notebooks/Editor2.ipynb @@ -9,9 +9,18 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -54,31 +63,23 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "cairo installed: OK\n" - ] - } - ], + "outputs": [], "source": [ "from flatland.utils.editor import EditorMVC, EditorModel, View, Controller" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "<flatland.utils.graphics_pil.PILSVG object at 0x0000022C5FB44198> <class 'flatland.utils.graphics_pil.PILSVG'>\n", + "<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FB9E198> <class 'flatland.utils.graphics_pil.PILSVG'>\n", "<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>\n", "Clear rails\n" ] @@ -115,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -123,7 +124,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "094a0a3e4351403d8d119b0696abaee4", + "model_id": "df04f776b29f456eabb20b9587ea1f16", "version_major": 2, "version_minor": 0 }, @@ -138,7 +139,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "<flatland.utils.graphics_pil.PILSVG object at 0x0000022C6066EC50> <class 'flatland.utils.graphics_pil.PILSVG'>\n", + "<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FBB7FD0> <class 'flatland.utils.graphics_pil.PILSVG'>\n", + "<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>\n", + "<flatland.utils.graphics_pil.PILSVG object at 0x000001FC6FA8C5C0> <class 'flatland.utils.graphics_pil.PILSVG'>\n", + "<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>\n", + "<flatland.utils.graphics_pil.PILSVG object at 0x000001FC73AF2908> <class 'flatland.utils.graphics_pil.PILSVG'>\n", "<super: <class 'PILSVG'>, <PILSVG object>> <class 'super'>\n" ] } @@ -149,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -157,7 +162,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "98b1504729884d8a9362dbf246d81f78", + "model_id": "6c0846dadce244ed877d53410dcfe0a7", "version_major": 2, "version_minor": 0 }, @@ -176,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -185,7 +190,7 @@ "(0, 0)" ] }, - "execution_count": 8, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } diff --git a/svg/Selected_Agent.svg b/svg/Selected_Agent.svg index 15761df861dd79c22a7099e04e4d76aefefb58f7..ca6071534193c689f99988d1168c0011fc74e27d 100644 --- a/svg/Selected_Agent.svg +++ b/svg/Selected_Agent.svg @@ -2,45 +2,45 @@ <!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) --> <svg - xmlns:dc="http://purl.org/dc/elements/1.1/" - xmlns:cc="http://creativecommons.org/ns#" - xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" - xmlns:svg="http://www.w3.org/2000/svg" - xmlns="http://www.w3.org/2000/svg" - xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" - xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" - version="1.1" - id="Ebene_1" - x="0px" - y="0px" - viewBox="0 0 240 240" - style="enable-background:new 0 0 240 240;" - xml:space="preserve" - sodipodi:docname="Selected_Agent.svg" - inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata + xmlns:dc="http://purl.org/dc/elements/1.1/" + xmlns:cc="http://creativecommons.org/ns#" + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns="http://www.w3.org/2000/svg" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + version="1.1" + id="Ebene_1" + x="0px" + y="0px" + viewBox="0 0 240 240" + style="enable-background:new 0 0 240 240;" + xml:space="preserve" + sodipodi:docname="Selected_Agent.svg" + inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata id="metadata11"><rdf:RDF><cc:Work rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type - rdf:resource="http://purl.org/dc/dcmitype/StillImage" /><dc:title></dc:title></cc:Work></rdf:RDF></metadata><defs + rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata> + <defs id="defs9" /><sodipodi:namedview - pagecolor="#ffffff" - bordercolor="#666666" - borderopacity="1" - objecttolerance="10" - gridtolerance="10" - guidetolerance="10" - inkscape:pageopacity="0" - inkscape:pageshadow="2" - inkscape:window-width="1920" - inkscape:window-height="1137" - id="namedview7" - showgrid="false" - inkscape:zoom="2.7812867" - inkscape:cx="205.50339" - inkscape:cy="161.549" - inkscape:window-x="-8" - inkscape:window-y="-8" - inkscape:window-maximized="1" - inkscape:current-layer="Ebene_1" /> + pagecolor="#ffffff" + bordercolor="#666666" + borderopacity="1" + objecttolerance="10" + gridtolerance="10" + guidetolerance="10" + inkscape:pageopacity="0" + inkscape:pageshadow="2" + inkscape:window-width="1920" + inkscape:window-height="1137" + id="namedview7" + showgrid="false" + inkscape:zoom="2.7812867" + inkscape:cx="126.94263" + inkscape:cy="161.549" + inkscape:window-x="-8" + inkscape:window-y="-8" + inkscape:window-maximized="1" + inkscape:current-layer="Ebene_1" /> <style type="text/css" id="style2"> @@ -48,45 +48,60 @@ </style> <rect - id="rect13" - width="23.389822" - height="23.38983" - x="1.697217e-07" - y="-0.23616901" /><rect - id="rect13-0" - width="23.389822" - height="23.38983" - x="216.82077" - y="0.26172119" /><rect - id="rect13-0-0" - width="23.389822" - height="23.38983" - x="216.75911" - y="216.39955" /><rect - id="rect13-0-0-4" - width="23.389822" - height="23.38983" - x="0.50847793" - y="216.6974" /><rect - id="rect60" - width="2.5168207" - height="198.4693" - x="10.067283" - y="22.474777" /><rect - id="rect60-8" - width="2.5168207" - height="198.4693" - x="228.49136" - y="22.115229" /><rect - id="rect60-8-5" - width="2.5168207" - height="198.4693" - x="-11.868174" - y="19.775019" - transform="rotate(-90)" /><rect - id="rect60-8-5-1" - width="2.5168207" - height="198.4693" - x="-230.11249" - y="20.853657" - transform="rotate(-90)" /></svg> \ No newline at end of file + id="rect13" + width="23.389822" + height="23.38983" + x="1.697217e-07" + y="-0.23616901" + style="fill:#ff0000"/> + <rect + id="rect13-0" + width="23.389822" + height="23.38983" + x="216.82077" + y="0.26172119" + style="fill:#ff0000"/> + <rect + id="rect13-0-0" + width="23.389822" + height="23.38983" + x="216.75911" + y="216.39955" + style="fill:#ff0000"/> + <rect + id="rect13-0-0-4" + width="23.389822" + height="23.38983" + x="0.50847793" + y="216.6974" + style="fill:#ff0000"/> + <rect + id="rect60" + width="2.5741608" + height="240.53616" + x="1.697217e-07" + y="-0.23616901" + style="stroke-width:1.11335897;fill:#ff0000"/> + <rect + id="rect60-8" + width="2.5741608" + height="239.45752" + x="237.63643" + y="0.26172119" + style="stroke-width:1.11085987;fill:#ff0000"/> + <rect + id="rect60-8-5" + width="2.5168207" + height="237.86693" + x="-2.778542" + y="2.3436601" + transform="rotate(-90)" + style="stroke-width:1.09476364;fill:#ff0000"/> + <rect + id="rect60-8-5-1" + width="2.5168207" + height="240.0242" + x="-240.29999" + y="1.6972172e-07" + transform="rotate(-90)" + style="stroke-width:1.09971678;fill:#ff0000"/></svg> diff --git a/svg/Selected_Target.svg b/svg/Selected_Target.svg new file mode 100644 index 0000000000000000000000000000000000000000..d5834efcf6756139aa422f867f681e0e2d21b1df --- /dev/null +++ b/svg/Selected_Target.svg @@ -0,0 +1,79 @@ +<?xml version="1.0" encoding="UTF-8" standalone="no"?> +<!-- Generator: Adobe Illustrator 23.0.3, SVG Export Plug-In . SVG Version: 6.00 Build 0) --> + +<svg + xmlns:dc="http://purl.org/dc/elements/1.1/" + xmlns:cc="http://creativecommons.org/ns#" + xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" + xmlns="http://www.w3.org/2000/svg" + xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" + xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" + version="1.1" + id="Ebene_1" + x="0px" + y="0px" + viewBox="0 0 240 240" + style="enable-background:new 0 0 240 240;" + xml:space="preserve" + sodipodi:docname="Selected_Target.svg" + inkscape:version="0.92.4 (5da689c313, 2019-01-14)"><metadata + id="metadata11"><rdf:RDF><cc:Work + rdf:about=""><dc:format>image/svg+xml</dc:format><dc:type + rdf:resource="http://purl.org/dc/dcmitype/StillImage"/><dc:title/></cc:Work></rdf:RDF></metadata> + <defs + id="defs9" /><sodipodi:namedview + pagecolor="#ffffff" + bordercolor="#666666" + borderopacity="1" + objecttolerance="10" + gridtolerance="10" + guidetolerance="10" + inkscape:pageopacity="0" + inkscape:pageshadow="2" + inkscape:window-width="1920" + inkscape:window-height="1137" + id="namedview7" + showgrid="false" + inkscape:zoom="2.7812867" + inkscape:cx="48.381869" + inkscape:cy="161.549" + inkscape:window-x="-8" + inkscape:window-y="-8" + inkscape:window-maximized="1" + inkscape:current-layer="Ebene_1" /> +<style + type="text/css" + id="style2"> + .st0{fill:#9CCB89;} +</style> + +<rect + id="rect60" + width="13.303195" + height="220.04205" + x="10.067283" + y="9.8906736" + style="fill:#ff0000;stroke-width:2.42079496"/> + <rect + id="rect60-8" + width="14.741379" + height="220.40161" + x="215.72748" + y="9.531127" + style="fill:#ff0000;stroke-width:2.5503726"/> + <rect + id="rect60-8-5" + width="12.584104" + height="220.04205" + x="-22.115231" + y="10.4268" + transform="rotate(-90)" + style="fill:#ff0000;stroke-width:2.35445929"/> + <rect + id="rect60-8-5-1" + width="13.662741" + height="220.40158" + x="216.26999" + y="10.067283" + transform="matrix(0,1,1,0,0,0)" + style="fill:#ff0000;stroke-width:2.45529389"/></svg> diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..35a6a27b970ce54e1cabd3cf8c80d30a34800a25 --- /dev/null +++ b/tests/test_env_prediction_builder.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np + +from flatland.core.transition_map import GridTransitionMap, Grid4Transitions +from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.rail_env import RailEnv + +"""Test predictions for `flatland` package.""" + + +def test_predictions(): + # We instantiate a very simple rail network on a 7x10 grid: + # | + # | + # | + # _ _ _ /_\ _ _ _ _ _ _ + # \ / + # | + # | + # | + + cells = [int('0000000000000000', 2), # empty cell - Case 0 + int('1000000000100000', 2), # Case 1 - straight + int('1001001000100000', 2), # Case 2 - simple switch + int('1000010000100001', 2), # Case 3 - diamond drossing + int('1001011000100001', 2), # Case 4 - single slip switch + int('1100110000110011', 2), # Case 5 - double slip switch + int('0101001000000010', 2), # Case 6 - symmetrical switch + int('0010000000000000', 2)] # Case 7 - dead end + + transitions = Grid4Transitions([]) + empty = cells[0] + + dead_end_from_south = cells[7] + dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90) + dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180) + dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270) + + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + + double_switch_south_horizontal_straight = horizontal_straight + cells[6] + double_switch_north_horizontal_straight = transitions.rotate_transition( + double_switch_south_horizontal_straight, 180) + + rail_map = np.array( + [[empty] * 3 + [dead_end_from_south] + [empty] * 6] + + [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 + + [[dead_end_from_east] + [horizontal_straight] * 2 + + [double_switch_north_horizontal_straight] + + [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] + + [horizontal_straight] * 2 + [dead_end_from_west]] + + [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 + + [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + ) + + env.reset() + + # set initial position and direction for testing... + env.agents[0].position = (5, 6) + env.agents[0].direction = 0 + + predictions = env.predict() + positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) + directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) + time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) + actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0]))) + + # compare against expected values + expected_positions = np.array([[5., 6.], + [4., 6.], + [3., 6.], + [3., 5.], + [3., 4.], + [3., 3.], + [3., 2.], + [3., 1.], + [3., 0.], + [3., 1.], + [3., 2.], + [3., 3.], + [3., 4.], + [3., 5.], + [3., 6.], + [3., 7.], + [3., 8.], + [3., 9.], + [3., 8.], + [3., 7.]]) + expected_directions = np.array([[0.], + [0.], + [0.], + [3.], + [3.], + [3.], + [3.], + [3.], + [3.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [1.], + [3.], + [3.]]) + expected_time_offsets = np.array([[0.], + [1.], + [2.], + [3.], + [4.], + [5.], + [6.], + [7.], + [8.], + [9.], + [10.], + [11.], + [12.], + [13.], + [14.], + [15.], + [16.], + [17.], + [18.], + [19.]]) + expected_actions = np.array([[0.], + [2.], + [2.], + [1.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.], + [2.]]) + assert np.array_equal(positions, expected_positions) + assert np.array_equal(directions, expected_directions) + assert np.array_equal(time_offsets, expected_time_offsets) + assert np.array_equal(actions, expected_actions) + + +def main(): + test_predictions() + + +if __name__ == "__main__": + main()