From c528e652f2fa570b12d4ed9c3fed2f72f4f803bb Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 23 Apr 2019 08:54:19 +0200 Subject: [PATCH] minor bugfixes in training script --- examples/training_navigation.py | 40 +++++++++------------------------ 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 3797d1cf..7b35cb21 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -3,7 +3,7 @@ from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * from flatland.baselines.dueling_double_dqn import Agent from collections import deque -import torch +import torch,random random.seed(1) np.random.seed(1) @@ -18,22 +18,22 @@ transition_probability = [1.0, # empty cell - Case 0 1.0, # Case 6 - symmetrical 1.0] # Case 7 - dead end """ +# Example generate a rail given a manual specification, +# a map of tuples (cell_type, rotation) transition_probability = [1.0, # empty cell - Case 0 1.0, # Case 1 - straight - 0.5, # Case 2 - simple switch - 0.2, # Case 3 - diamond drossing + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing 0.5, # Case 4 - single slip - 0.1, # Case 5 - double slip + 0.5, # Case 5 - double slip 0.2, # Case 6 - symmetrical - 0.01] # Case 7 - dead end + 0.0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=20, - height=20, +env = RailEnv(width=7, + height=7, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=10) -env.reset() - + number_of_agents=1) env_renderer = RenderTool(env) handle = env.get_agent_handles() @@ -51,28 +51,10 @@ dones_list = [] agent = Agent(state_size, action_size, "FC", 0) -# Example generate a rail given a manual specification, -# a map of tuples (cell_type, rotation) -specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], - [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] - -env = RailEnv(width=6, - height=2, - rail_generator=rail_from_manual_specifications_generator(specs), - number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2)) - -env.agents_position[0] = [1, 4] -env.agents_target[0] = [1, 1] -env.agents_direction[0] = 1 -# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! -env.obs_builder.reset() - - for trials in range(1, n_trials + 1): # Reset environment - obs, all_rewards, done, _ = env.step({0: 0}) + obs = env.reset() # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) -- GitLab