Skip to content
Snippets Groups Projects
Commit b639031b authored by u214892's avatar u214892
Browse files

removed stale code in examples

parent 1c1ab6b9
No related branches found
No related tags found
No related merge requests found
......@@ -28,24 +28,16 @@ class Player(object):
self.action_prob = [0] * 4
# Removing refs to a real agent for now.
# self.agent = Agent(self.state_size, self.action_size, "FC", 0)
# self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth'))
# self.agent.qnetwork_local.load_state_dict(torch.load(
# '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth'))
self.iFrame = 0
self.tStart = time.time()
# Reset environment
# self.obs = self.env.reset()
self.env.obs_builder.reset()
self.obs = self.env._get_observations()
for envAgent in range(self.env.get_num_agents()):
norm = max(1, max_lt(self.obs[envAgent], np.inf))
self.obs[envAgent] = np.clip(np.array(self.obs[envAgent]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
self.score = 0
self.env_done = 0
......@@ -58,13 +50,9 @@ class Player(object):
# Pass the (stored) observation to the agent network and retrieve the action
for handle in env.get_agent_handles():
# Real Agent
# action = self.agent.act(np.array(self.obs[handle]), eps=self.eps)
# Random actions
# action = random.randint(0, 3)
action = np.random.choice([0, 1, 2, 3], 1, p=[0.2, 0.1, 0.6, 0.1])[0]
# Numpy version uses single random sequence
# action = np.random.randint(0, 4, size=1)
self.action_prob[action] += 1
self.action_dict.update({handle: action})
......@@ -125,17 +113,11 @@ def main(render=True, delay=0.0, n_trials=3, n_steps=50, sGL="PILSVG"):
oPlayer.reset()
env_renderer.set_new_rail()
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
# score = 0
# env_done = 0
# Run episode
for step in range(n_steps):
oPlayer.step()
if render:
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step)
# time.sleep(10)
if delay > 0:
time.sleep(delay)
......
......@@ -10,12 +10,6 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
# CURVED RAIL + DEAD-ENDS TEST
# specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
# [(7, 270), (1, 90), (1, 90), (8, 90), (0, 0), (0, 0)],
# [(0, 0), (7, 270),(1, 90), (8, 180), (0, 00), (0, 0)],
# [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
......
......@@ -30,12 +30,6 @@ env = RailEnv(width=10,
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
# env = RailEnv(width=10,
# height=10,
# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']),
# number_of_agents=3,
# obs_builder_object=TreeObsForRailEnv(max_depth=2))
env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
......
......@@ -16,10 +16,6 @@ env = RailEnv(width=7,
number_of_agents=2,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
# Print the distance map of each cell to the target of the first agent
# for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
# Print the observation vector for agent 0
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(env.get_num_agents()):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment