diff --git a/examples-not-running/simple_example_1.py b/examples/simple_example_1.py similarity index 96% rename from examples-not-running/simple_example_1.py rename to examples/simple_example_1.py index 3055e3195247c987e1a7c9f9dbc0c24aee6ca2b8..89536edde93ba5d593d710290657c555bc542c5b 100644 --- a/examples-not-running/simple_example_1.py +++ b/examples/simple_example_1.py @@ -20,5 +20,6 @@ env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.renderEnv(show=True) +env_renderer.renderEnv(show=True) input("Press Enter to continue...") diff --git a/examples-not-running/simple_example_2.py b/examples/simple_example_2.py similarity index 97% rename from examples-not-running/simple_example_2.py rename to examples/simple_example_2.py index 8c40a34978764830accafcec12c1be750dd1457d..05290f15d5f5ca672321e38560d9871be170610d 100644 --- a/examples-not-running/simple_example_2.py +++ b/examples/simple_example_2.py @@ -34,5 +34,6 @@ env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.renderEnv(show=True) +env_renderer.renderEnv(show=True) input("Press Enter to continue...") diff --git a/examples-not-running/simple_example_3.py b/examples/simple_example_3.py similarity index 93% rename from examples-not-running/simple_example_3.py rename to examples/simple_example_3.py index e20b81ea7d1f4b13275e2d22cc01b1485a2b0f56..1661ef65a9a33f3b44a098caaf83317919722398 100644 --- a/examples-not-running/simple_example_3.py +++ b/examples/simple_example_3.py @@ -7,8 +7,8 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -random.seed(10) -np.random.seed(10) +random.seed(1) +np.random.seed(1) env = RailEnv(width=7, height=7, @@ -19,10 +19,11 @@ env = RailEnv(width=7, # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()): - env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7) env_renderer = RenderTool(env, gl="PIL") env_renderer.renderEnv(show=True, frames=True) +env_renderer.renderEnv(show=True, frames=True) print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ (turnleft+move, move to front, turnright+move)") diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 5cc3f26d0374cb07a673d05771cbbf64a66f0fc0..2214544adb9620ab1eefe6067f88c6bd3be8205d 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,19 +31,32 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} + self.agents_previous_reset = None + def reset(self): agents = self.env.agents nAgents = len(agents) - self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, - self.env.height, - self.env.width, - 4)) - self.max_dist = np.zeros(nAgents) - - self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] - # Update local lookup table for all agents' target locations - self.location_has_target = {tuple(agent.target): 1 for agent in agents} + compute_distance_map = True + if self.agents_previous_reset is not None: + if nAgents == len(self.agents_previous_reset): + compute_distance_map = False + for i in range(nAgents): + if agents[i].target != self.agents_previous_reset[i].target: + compute_distance_map = True + self.agents_previous_reset = agents + + if compute_distance_map: + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, + self.env.height, + self.env.width, + 4)) + self.max_dist = np.zeros(nAgents) + + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] + + # Update local lookup table for all agents' target locations + self.location_has_target = {tuple(agent.target): 1 for agent in agents} def _distance_map_walker(self, position, target_nr): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index c3ed6e02f80d9d7a6038e03d3984bda0e9afcbeb..6cd645147a1a870c760c3bc60dd364903a2065d2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -173,7 +173,9 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def step(self, action_dict): + def step(self, action_dict_): + action_dict = action_dict_.copy() + alpha = 1.0 beta = 1.0 diff --git a/tox.ini b/tox.ini index c5a1ce1bff9934563a9ffec0573d8a66b2c91cc5..a347bce1168da56a03ce5f09a292df7983653bbe 100644 --- a/tox.ini +++ b/tox.ini @@ -79,7 +79,7 @@ commands = sh -c 'echo DISPLAY=$DISPLAY' sh -c 'echo XAUTHORITY=$XAUTHORITY' ; pipe echo into python since some examples expect input to close the window after the example is run - sh -c 'ls examples/*.py | xargs -I{} -n 1 sh -c "echo -e \"\n====== Running {} ========\n\"; echo | python {}"' + sh -c 'ls examples/*.py | xargs -I{} -n 1 sh -c "echo -e \"\n====== Running {} ========\n\"; echo "q" | python {}"' [testenv:notebooks] basepython = python