diff --git a/docs/specifications/core.md b/docs/specifications/core.md index 0c3a100e0db4db39312676d5e879227899adaca5..b80ceedfa25b30afbbeb07ab7bf859c162324ba2 100644 --- a/docs/specifications/core.md +++ b/docs/specifications/core.md @@ -13,7 +13,7 @@ class Environment: Agents are identified by agent ids (handles). Examples: - >>> obs = env.reset() + >>> obs, info = env.reset() >>> print(obs) { "train_0": [2.4, 1.6], diff --git a/docs/tutorials/01_gettingstarted.rst b/docs/tutorials/01_gettingstarted.rst index a7a8e5514e03d2f9feb31408d433badc7f3767c2..2be0fee229c93f2aee215cf1e459dd8fc8f92ebb 100644 --- a/docs/tutorials/01_gettingstarted.rst +++ b/docs/tutorials/01_gettingstarted.rst @@ -166,7 +166,7 @@ We start every trial by resetting the environment .. code-block:: python - obs = env.reset() + obs, info = env.reset() Which provides the initial observation for all agents (obs = array of all observations). In order for the environment to step forward in time we need a dictionar of actions for all active agents. diff --git a/docs/tutorials/02_observationbuilder.rst b/docs/tutorials/02_observationbuilder.rst index d1c287fedf880fd091a1b24921292010eb01359e..f6e718ab156a3972a57f4693367d5f191c8fcdd0 100644 --- a/docs/tutorials/02_observationbuilder.rst +++ b/docs/tutorials/02_observationbuilder.rst @@ -271,7 +271,7 @@ We can then use this new observation builder and the renderer to visualize the o number_of_agents=3, obs_builder_object=CustomObsBuilder) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") # We render the initial step and show the obsered cells as colored boxes diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py index 49e550b195ffb15e6554413069369378e80e5f82..aaa5d286322107baf600a5e5f22fdc1aafdec7e4 100644 --- a/examples/complex_rail_benchmark.py +++ b/examples/complex_rail_benchmark.py @@ -39,7 +39,7 @@ def run_benchmark(): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset() + obs, info = env.reset() for a in range(env.get_num_agents()): norm = max(1, max_lt(obs[a], np.inf)) diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py index 7ddfcd899f747f038471cbe3921e6df76fff37ee..3c10e415155429aa1464c554cdeb9ee1309780f2 100644 --- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py +++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py @@ -80,7 +80,7 @@ def main(args): number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=True) for step in range(100): diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index f75cb74537f03dc6fb1aecbadff37a183432e55a..97fac09c8d222b11cc8e8cc3ab4f76c75a37a95b 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -130,7 +130,7 @@ def main(args): number_of_agents=3, obs_builder_object=custom_obs_builder) - obs = env.reset() + obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") # We render the initial step and show the obsered cells as colored boxes diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py index a52eeed47c5cb1fe75c87d430d93f30f50336fbf..1f094c09b6a929b7e862bc42afffd1da38e52e73 100644 --- a/examples/debugging_example_DELETE.py +++ b/examples/debugging_example_DELETE.py @@ -62,7 +62,7 @@ env = RailEnv(width=14, number_of_agents=2, obs_builder_object=SingleAgentNavigationObs()) -obs = env.reset() +obs, info = env.reset() env_renderer = RenderTool(env, gl="PILSVG") env_renderer.render_env(show=True, frames=True, show_observations=False) for step in range(100): diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a9c08ff24682854a0d9d6eeed1c89f367bdbbc93..f6a4c704d70005d0984a3a9f1903f8a4da6302d8 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -98,7 +98,7 @@ action_dict = dict() print("Start episode...") # Reset environment and get initial observations for all agents start_reset = time.time() -obs = env.reset() +obs, info = env.reset() end_reset = time.time() print(end_reset - start_reset) print(env.get_num_agents(), ) diff --git a/examples/training_example.py b/examples/training_example.py index 78c0299d4cee8ae588bcf8e9e7559ff1c8364c26..8b42586b06d90bef44fe94ca6817fb0a84824aac 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -70,7 +70,7 @@ print("Starting Training...") for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents - obs = env.reset() + obs, info = env.reset() for idx in range(env.get_num_agents()): tmp_agent = env.agents[idx] tmp_agent.speed_data["speed"] = 1 / (idx + 1) diff --git a/flatland/cli.py b/flatland/cli.py index 47c450dba803fac17bc13979663ef04e4c0db899..f544aabcb6d9e81ddf8703c59b8bf07324b3ce2c 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -33,7 +33,7 @@ def demo(args=None): env_renderer = RenderTool(env) while True: - obs = env.reset() + obs, info = env.reset() _done = False # Run a single episode here step = 0 diff --git a/flatland/core/env.py b/flatland/core/env.py index 32b688ca78e35b1e36aac85c0da4a4ee22246d1b..a49cfb40518f3474db10a7765853f16a23c7ba50 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -15,7 +15,7 @@ class Environment: Agents are identified by agent ids (handles). Examples: - >>> obs = env.reset() + >>> obs, info = env.reset() >>> print(obs) { "train_0": [2.4, 1.6], diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index 52b047244850a5b1c6b39dadd74568fc5f92deec..8f91088a205c92ec1d79fc9b3909a4c80ca72db5 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -27,7 +27,7 @@ def test_global_obs(): number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) - global_obs = env.reset() + global_obs, info = env.reset() # we have to take step for the agent to enter the grid. global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})