diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 47487086a3858dc1fd71006b934dc353300498bc..efc89c566f7779e011e5426b83a761626286222a 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -35,13 +35,13 @@ env = RailEnv(width=20, rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=10, max_dist=99999, seed=0), number_of_agents=5) - +""" env = RailEnv(width=20, height=20, rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - ['../env-data/tests/circle.npy']), + ['../notebooks/testing_11.npy']), number_of_agents=1) -""" + env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() @@ -61,7 +61,7 @@ action_prob = [0] * 4 agent = Agent(state_size, action_size, "FC", 0) agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth')) -demo = False +demo = True def max_lt(seq, val): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d8f1af4724c33772dd2b12b28f4d9e392a87b728..9a48915bce8386f0ffe0237c6bf9512b0609d82a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -366,13 +366,9 @@ class RailEnv(Environment): is_deadend = True if np.sum(possible_transitions) == 1: - # Checking for curves - curv_dir = np.argmax(possible_transitions) - # valid_transition = self.rail.get_transition( - # (pos[0], pos[1], direction), - # movement) + # Take only available transition + movement = np.argmax(possible_transitions) - movement = curv_dir new_position = self._new_position(pos, movement) # Is it a legal move? 1) transition allows the movement in the # cell, 2) the new cell is not empty (case 0), 3) the cell is diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index e6cc4cccc5d0cb219709cb7b805e440dfd954b52..09ac48997bcb1c420be146ee011042c45b6b597f 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -208,7 +208,7 @@ class RenderTool(object): xyDir = np.matmul(rcDir, rt.grc2xy) # agent direction in xy xyPos = np.matmul(rcPos - rcDir / 2, rt.grc2xy) + rt.xyHalf - print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) + #print("Agent:", rcPos, iDir, rcDir, xyDir, xyPos) self.gl.scatter(*xyPos, color=color, marker="o", s=100) # agent location xyDirLine = array([xyPos, xyPos + xyDir/2]).T # line for agent orient.