From fcc1ee6afcca9fbfc5c430fc40b2124621703222 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 5 Oct 2019 09:13:16 -0400 Subject: [PATCH] minor bugfixes --- torch_training/training_navigation.py | 19 ++++++------------- utils/observation_utils.py | 2 +- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index bd221ae..ad512c6 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -59,19 +59,12 @@ def main(argv): env = RailEnv(width=x_dim, height=y_dim, - rail_generator=sparse_rail_generator(num_cities=5, + rail_generator=sparse_rail_generator(max_num_cities=3, # Number of cities in map (where train stations are) - num_intersections=4, - # Number of intersections (no start / target) - num_trainstations=10, # Number of possible start/targets on map - min_node_dist=3, # Minimal distance of nodes - node_radius=2, # Proximity of stations to city center - num_neighb=3, - # Number of connections to other cities/intersections - seed=15, # Random seed - grid_mode=True, - enhance_intersection=False - ), + seed=1, # Random seed + grid_mode=False, + max_rails_between_cities=2, + max_rails_in_city=2), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, stochastic_data=stochastic_data, # Malfunction data generator @@ -129,7 +122,7 @@ def main(argv): # Build agent specific observations for a in range(env.get_num_agents()): - agent_obs[a] = normalize_observation(obs[a], observation_radius=10) + agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10) agent_obs_buffer[a] = agent_obs[a].copy() # Reset score and done diff --git a/utils/observation_utils.py b/utils/observation_utils.py index f9661f5..e9eb3ed 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -104,7 +104,7 @@ def split_tree_into_feature_groups(tree: TreeObsForRailEnv.Node, max_tree_depth: """ data, distance, agent_data = _split_node_into_feature_groups(tree) - for direction in TreeObsForRailEnv.tree_explorted_actions_char: + for direction in TreeObsForRailEnv.tree_explored_actions_char: sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth) data = np.concatenate((data, sub_data)) distance = np.concatenate((distance, sub_distance)) -- GitLab