diff --git a/examples/training_navigation.py b/examples/training_navigation.py index e21afd9c7c9d51606c85db03e0f052a03d37ef1b..9fc83242fb4880ea6dff712100f73edf2e1ec109 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -112,7 +112,7 @@ for trials in range(1, n_trials + 1): action = agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) - env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) + #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.number_of_agents): diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 2d22b2439b153144f4c0223dc00f3ddf79be3899..aaf5177fc7fcc98adc071c3f6abb210bcd5b945d 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -340,7 +340,7 @@ class TreeObsForRailEnv(ObservationBuilder): elif num_transitions == 0: # Wrong cell type, but let's cover it and treat it as a dead-end, just in case - print("WRONG CELL TYPE detected in tree-search (0 transitions possible)") + print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell",position[0], position[1] ) last_isTerminal = True break diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 55e0e2ed83c40c2fae651b6fbd9791034360078b..878892d799defcf0d48e07b2695889e36a08d667 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -998,7 +998,7 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] - + transition_isValid = None if handle not in action_dict: continue @@ -1027,9 +1027,13 @@ class RailEnv(Environment): movement = direction if action == 1: movement = direction - 1 + if nbits <= 2: + transition_isValid == False + elif action == 3: movement = direction + 1 - + if nbits <= 2: + transition_isValid == False if movement < 0: movement += 4 if movement >= 4: @@ -1089,9 +1093,10 @@ class RailEnv(Environment): else: new_cell_isValid = False - transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend + if transition_isValid == None: + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend cell_isFree = True for j in range(self.number_of_agents):