From d017b8314c32f4865cc6c111e6fb2303fd2bc22b Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 5 Oct 2019 11:23:52 -0400 Subject: [PATCH] updated code to new info on reset functionality --- docs/specifications/core.md | 2 +- docs/tutorials/01_gettingstarted.rst | 2 +- docs/tutorials/02_observationbuilder.rst | 2 +- examples/complex_rail_benchmark.py | 2 +- .../custom_observation_example_02_SingleAgentNavigationObs.py | 2 +- examples/custom_observation_example_03_ObservePredictions.py | 2 +- examples/debugging_example_DELETE.py | 2 +- examples/flatland_2_0_example.py | 2 +- examples/training_example.py | 2 +- flatland/cli.py | 2 +- flatland/core/env.py | 2 +- tests/test_flatland_envs_observations.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/specifications/core.md b/docs/specifications/core.md index 0c3a100e..b80ceedf 100644 --- a/docs/specifications/core.md +++ b/docs/specifications/core.md @@ -13,7 +13,7 @@ class Environment: Agents are identified by agent ids (handles). Examples: - >>> obs = env.reset() + >>> obs, info = env.reset() >>> print(obs) { "train_0": [2.4, 1.6], diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst index a7a8e551..2be0fee2 100644 --- a/docs/tutorials/01_gettingstarted.rst +++ b/docs/tutorials/01_gettingstarted.rst @@ -166,7 +166,7 @@ We start every trial by resetting the environment .. code-block:: python - obs = env.reset() + obs, info = env.reset() Which provides the initial observation for all agents (obs = array of all observations). In order for the environment to step forward in time we need a dictionar of actions for all active agents. diff --git a/docs/tutorials/02_observationbuilder.rst b/docs/tutorials/02_observationbuilder.rst index d1c287fe..f6e718ab 100644 --- a/docs/tutorials/02_observationbuilder.rst +++ b/docs/tutorials/02_observationbuilder.rst @@ -271,7 +271,7 @@ We can then use this new observation builder and the renderer to visualize the o number_of_agents=3, obs_builder_object=CustomObsBuilder) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") # We render the initial step and show the obsered cells as colored boxes diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 49e550b1..aaa5d286 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -39,7 +39,7 @@ def run_benchmark(): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset() + obs, info = env.reset() for a in range(env.get_num_agents()): norm = max(1, max_lt(obs[a], np.inf)) diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 7ddfcd89..3c10e415 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -80,7 +80,7 @@ def main(args): number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index f75cb745..97fac09c 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -130,7 +130,7 @@ def main(args): number_of_agents=3, obs_builder_object=custom_obs_builder) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") # We render the initial step and show the obsered cells as colored boxes diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index a52eeed4..1f094c09 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -62,7 +62,7 @@ env = RailEnv(width=14, number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) -obs = env.reset() +obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=False) for step in range(100): diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a9c08ff2..f6a4c704 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -98,7 +98,7 @@ action_dict = dict() print("Start episode...") # Reset environment and get initial observations for all agents start_reset = time.time() -obs = env.reset() +obs, info = env.reset() end_reset = time.time() print(end_reset - start_reset) print(env.get_num_agents(), ) diff --git a/examples/training_example.py b/examples/training_example.py index 78c0299d..8b42586b 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -70,7 +70,7 @@ print("Starting Training...") for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents - obs = env.reset() + obs, info = env.reset() for idx in range(env.get_num_agents()): tmp_agent = env.agents[idx] tmp_agent.speed_data["speed"] = 1 / (idx + 1) diff --git a/flatland/cli.py b/flatland/cli.py index 47c450db..f544aabc 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -33,7 +33,7 @@ def demo(args=None): env_renderer = RenderTool(env) while True: - obs = env.reset() + obs, info = env.reset() _done = False # Run a single episode here step = 0 diff --git a/flatland/core/env.py b/flatland/core/env.py index 32b688ca..a49cfb40 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -15,7 +15,7 @@ class Environment: Agents are identified by agent ids (handles). Examples: - >>> obs = env.reset() + >>> obs, info = env.reset() >>> print(obs) { "train_0": [2.4, 1.6], diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 52b04724..8f91088a 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -27,7 +27,7 @@ def test_global_obs(): number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) - global_obs = env.reset() + global_obs, info = env.reset() # we have to take step for the agent to enter the grid. global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD}) -- GitLab