diff --git a/examples/play_model.py b/examples/play_model.py index 1cd125b12192236cb9b57ebbdef75e645068357c..6997b2c002582d35520362ac50555aae0641c9ec 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -28,24 +28,16 @@ class Player(object): self.action_prob = [0] * 4 # Removing refs to a real agent for now. - # self.agent = Agent(self.state_size, self.action_size, "FC", 0) - # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - # self.agent.qnetwork_local.load_state_dict(torch.load( - # '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) - self.iFrame = 0 self.tStart = time.time() # Reset environment - # self.obs = self.env.reset() self.env.obs_builder.reset() self.obs = self.env._get_observations() for envAgent in range(self.env.get_num_agents()): norm = max(1, max_lt(self.obs[envAgent], np.inf)) self.obs[envAgent] = np.clip(np.array(self.obs[envAgent]) / norm, -1, 1) - # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - self.score = 0 self.env_done = 0 @@ -58,13 +50,9 @@ class Player(object): # Pass the (stored) observation to the agent network and retrieve the action for handle in env.get_agent_handles(): - # Real Agent - # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) # Random actions - # action = random.randint(0, 3) action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0] # Numpy version uses single random sequence - # action = np.random.randint(0, 4, size=1) self.action_prob[action] += 1 self.action_dict.update({handle: action}) @@ -125,17 +113,11 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"): oPlayer.reset() env_renderer.set_new_rail() - # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) - - # score = 0 - # env_done = 0 - # Run episode for step in range(n_steps): oPlayer.step() if render: env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step) - # time.sleep(10) if delay > 0: time.sleep(delay) diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index ca442873be6a6254cd169b354e088d9a9e94d3e7..3055e3195247c987e1a7c9f9dbc0c24aee6ca2b8 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -10,12 +10,6 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)], [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] -# CURVED RAIL + DEAD-ENDS TEST -# specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)], -# [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)], -# [(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)], -# [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]] - env = RailEnv(width=6, height=4, rail_generator=rail_from_manual_specifications_generator(specs), diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index f1a6a7c75a500e8b8da76ab95f01cf55111826ab..8c40a34978764830accafcec12c1be750dd1457d 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -30,12 +30,6 @@ env = RailEnv(width=10, number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2)) -# env = RailEnv(width=10, -# height=10, -# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), -# number_of_agents=3, -# obs_builder_object=TreeObsForRailEnv(max_depth=2)) - env.reset() env_renderer = RenderTool(env, gl="PILSVG") diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py index 089e72b99e1cb73deef37f15038716089c13e85a..e20b81ea7d1f4b13275e2d22cc01b1485a2b0f56 100644 --- a/examples/simple_example_3.py +++ b/examples/simple_example_3.py @@ -16,10 +16,6 @@ env = RailEnv(width=7, number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2)) -# Print the distance map of each cell to the target of the first agent -# for i in range(4): -# print(env.obs_builder.distance_map[0, :, :, i]) - # Print the observation vector for agent 0 obs, all_rewards, done, _ = env.step({0: 0}) for i in range(env.get_num_agents()):