diff --git a/examples/training_navigation.py b/examples/training_navigation.py index c1182074731725091630df0f3753b2c55d0280b2..9fc83242fb4880ea6dff712100f73edf2e1ec109 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -30,13 +30,13 @@ env = RailEnv(width=20, rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0), number_of_agents=5) -""" + env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( ['../env-data/tests/circle.npy']), number_of_agents=1) - +""" env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() @@ -109,7 +109,7 @@ for trials in range(1, n_trials + 1): for a in range(env.number_of_agents): if demo: eps = 0 - action = 2 #agent.act(np.array(obs[a]), eps=eps) + action = agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) #env.obs_builder.util_print_obs_subtree(tree=obs[a], num_features_per_node=5) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index de3ee93230c921a281f007a1d5af3aab7f80d76d..6c27afdc2282f084ae6e4ffa62bc48b1d28aeaa0 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -243,6 +243,7 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation + # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible. for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): new_cell = self._new_position(position, branch_direction) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index eff1c1085f6e8a14bf4563b40d66eceba2e50171..a236362d5a99c906828890e8252d14d96ea8c8f1 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -1040,18 +1040,21 @@ class RailEnv(Environment): nbits = 0 tmp = self.rail.get_transitions((pos[0], pos[1])) + possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction)) + # print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction)) + while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 movement = direction if action == 1: movement = direction - 1 - if nbits <= 2: + if nbits <= 2 or np.sum(possible_transitions) <= 1: transition_isValid = False elif action == 3: movement = direction + 1 - if nbits <= 2: + if nbits <= 2 or np.sum(possible_transitions) <= 1: transition_isValid = False if movement < 0: movement += 4 @@ -1081,12 +1084,14 @@ class RailEnv(Environment): direction = reverse_direction movement = reverse_direction is_deadend = True - if nbits == 2: + if np.sum(possible_transitions) == 1: # Checking for curves - - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) + curv_dir = np.argmax(possible_transitions) + #valid_transition = self.rail.get_transition( + # (pos[0], pos[1], direction), + # movement) + movement = curv_dir + """ reverse_direction = (direction + 2) % 4 curv_dir = (movement + 1) % 4 while not valid_transition: @@ -1097,7 +1102,7 @@ class RailEnv(Environment): if valid_transition: movement = curv_dir curv_dir = (curv_dir + 1) % 4 - + """ new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 1f5c317965a3101c6232709ebb311959d5a566ed..73b42c385867ba6bc93ff794ec3ecda0bf82125d 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -30,8 +30,8 @@ def checkFrozenImage(sFileImage): if bytesFrozenImage is None: bytesFrozenImage = bytesImage else: - assert(bytesFrozenImage.shape == bytesImage.shape) - assert((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3) + assert (bytesFrozenImage.shape == bytesImage.shape) + assert ((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3) def test_render_env():