From 8cec2335efd6f3f26434c05eda71a3d87ab2b13a Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 1 May 2019 08:34:49 +0200 Subject: [PATCH] updated action behavior at curve. deviation left and right not allowed at curve --- examples/training_navigation.py | 2 +- flatland/core/env_observation_builder.py | 2 +- flatland/envs/rail_env.py | 15 ++++++++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index e21afd9c..9fc83242 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -112,7 +112,7 @@ for trials in range(1, n_trials + 1): action = agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) - env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) + #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.number_of_agents): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 2d22b243..aaf5177f 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -340,7 +340,7 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible)") + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] ) last_isTerminal = True break diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 55e0e2ed..878892d7 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -998,7 +998,7 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] - + transition_isValid = None if handle not in action_dict: continue @@ -1027,9 +1027,13 @@ class RailEnv(Environment): movement = direction if action == 1: movement = direction - 1 + if nbits <= 2: + transition_isValid == False + elif action == 3: movement = direction + 1 - + if nbits <= 2: + transition_isValid == False if movement < 0: movement += 4 if movement >= 4: @@ -1089,9 +1093,10 @@ class RailEnv(Environment): else: new_cell_isValid = False - transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend + if transition_isValid == None: + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend cell_isFree = True for j in range(self.number_of_agents): -- GitLab