diff --git a/examples/play_model.py b/examples/play_model.py index 62726c24c96be0e5dae2f4840e18da452163b7ac..e502bd28c419cbd8ac473b8c736e7a424000c229 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -28,8 +28,8 @@ class Player(object): self.action_prob = [0]*4 self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - self.agent.qnetwork_local.load_state_dict(torch.load( - '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) + #self.agent.qnetwork_local.load_state_dict(torch.load( + # '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) self.iFrame = 0 self.tStart = time.time() @@ -47,20 +47,26 @@ class Player(object): self.score = 0 self.env_done = 0 + def reset(self): + self.obs = self.env.reset() + return self.obs + def step(self): env = self.env # Pass the (stored) observation to the agent network and retrieve the action - #for handle in env.get_agent_handles(): for handle in env.get_agent_handles(): - action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) + # action = self.agent.act(np.array(self.obs[handle]), eps=self.eps) + action = random.randint(0, 3) + # action = np.random.randint(0, 4, size=1) self.action_prob[action] += 1 self.action_dict.update({handle: action}) # Environment step - pass the agent actions to the environment, # retrieve the response - observations, rewards, dones next_obs, all_rewards, done, _ = self.env.step(self.action_dict) - + next_obs = next_obs + for handle in env.get_agent_handles(): norm = max(1, max_lt(next_obs[handle], np.inf)) next_obs[handle] = np.clip(np.array(next_obs[handle]) / norm, -1, 1) @@ -93,7 +99,49 @@ def max_lt(seq, val): return None -def main(render=True, delay=0.0): +def main(render=True, delay=0.0, n_trials=3, n_steps=50): + random.seed(1) + np.random.seed(1) + + # Example generate a random rail + env = RailEnv(width=15, height=15, + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), + number_of_agents=5) + + if render: + # env_renderer = RenderTool(env, gl="QTSVG") + env_renderer = RenderTool(env, gl="QT") + + oPlayer = Player(env) + + for trials in range(1, n_trials + 1): + + # Reset environment8 + oPlayer.reset() + env_renderer.set_new_rail() + + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + # score = 0 + # env_done = 0 + + # Run episode + for step in range(n_steps): + oPlayer.step() + if render: + env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, + action_dict=oPlayer.action_dict) + # time.sleep(10) + if delay > 0: + time.sleep(delay) + + +def main_old(render=True, delay=0.0): + ''' DEPRECATED main which drives agent directly + Please use the new main() which creates a Player object which is also used by the Editor. + Please fix any bugs in main() and Player rather than here. + Will delete this one shortly. + ''' random.seed(1) np.random.seed(1) @@ -139,10 +187,12 @@ def main(render=True, delay=0.0): tStart = time.time() for trials in range(1, n_trials + 1): - # Reset environment + # Reset environment8 obs = env.reset() env_renderer.set_new_rail() + #obs = obs[0] + for a in range(env.get_num_agents()): norm = max(1, max_lt(obs[a], np.inf)) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) @@ -159,7 +209,7 @@ def main(render=True, delay=0.0): # print(step) # Action for a in range(env.get_num_agents()): - action = agent.act(np.array(obs[a]), eps=eps) + action = random.randint(0,3) # agent.act(np.array(obs[a]), eps=eps) action_prob[action] += 1 action_dict.update({a: action}) @@ -173,6 +223,8 @@ def main(render=True, delay=0.0): # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) + #next_obs = next_obs[0] + for a in range(env.get_num_agents()): norm = max(1, max_lt(next_obs[a], np.inf)) next_obs[a] = np.clip(np.array(next_obs[a]) / norm, -1, 1) @@ -181,7 +233,6 @@ def main(render=True, delay=0.0): agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) score += all_rewards[a] - obs = next_obs.copy() if done['__all__']: env_done = 1 diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 0fdec364ecc3d8415b27e2101c339d6f49e89022..651d83520ee1f492ea71ba6ddf82cfa5f9093964 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -219,8 +219,8 @@ class TreeObsForRailEnv(ObservationBuilder): if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = self._explore_branch(handle, new_cell, branch_direction, root_observation, - 1) + branch_observation, branch_visited = \ + self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation visited = visited.union(branch_visited) else: diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 34f3e9fa6857e86f4d99d211784d983a2e2a1e75..e9a94ec1fd18635b69d5f4bc47ce1d4f43daffcd 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -474,8 +474,8 @@ class RenderTool(object): def renderObs(self, agent_handles, observation_dict): """ - Render the extent of the observation of each agent. All cells that appear in the agent obsrevation will be - highlighted. + Render the extent of the observation of each agent. All cells that appear in the agent + observation will be highlighted. :param agent_handles: List of agent indices to adapt color and get correct observation :param observation_dict: dictionary containing sets of cells of the agent observation diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 245f2f327524653b3cf03bf921f6db6b0d4b51fb..3259ed387f1a09b7a6a5e73fe7976cc78f02f32f 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -6,10 +6,6 @@ Tests for `flatland` package. from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np -#<<<<<<< HEAD -#======= -# import os -#>>>>>>> dc2fa1ee0244b15c76d89ab768c5e1bbd2716147 import sys import matplotlib.pyplot as plt