From 1169fcb628da3875ac432efe7495393579f583f4 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Mon, 13 May 2019 17:26:08 +0200 Subject: [PATCH] fixed formatting --- examples/training_navigation.py | 6 ++++-- flatland/core/env_observation_builder.py | 3 +-- flatland/envs/generators.py | 2 +- flatland/envs/rail_env.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index feb4796..ee360a1 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -115,6 +115,7 @@ for trials in range(1, n_trials + 1): # Reset environment obs = env.reset() final_obs = obs.copy() + final_obs_next = obs.copy() for a in range(env.get_num_agents()): data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) data = norm_obs_clip(data) @@ -150,7 +151,8 @@ for trials in range(1, n_trials + 1): # Update replay buffer and train agent for a in range(env.get_num_agents()): if done[a]: - final_obs[a] = obs[a] + final_obs[a] = obs[a].copy() + final_obs_next[a] = next_obs[a].copy() final_action_dict.update({a: action_dict[a]}) if not demo and not done[a]: agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) @@ -159,7 +161,7 @@ for trials in range(1, n_trials + 1): obs = next_obs.copy() if done['__all__']: env_done = 1 - agent.step(final_obs[a], final_action_dict[a], all_rewards[a], next_obs[a], done[a]) + agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a]) break # Epsilon decay eps = max(eps_end, eps_decay * eps) # decrease epsilon diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 547f349..8e7f2ae 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -139,8 +139,7 @@ class TreeObsForRailEnv(ObservationBuilder): for neigh_direction in possible_directions: new_cell = self._new_position(position, neigh_direction) - if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ - new_cell[1] >= 0 and new_cell[1] < self.env.width: + if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width: desired_movement_from_new_cell = (neigh_direction + 2) % 4 diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 35af894..e9e2d3d 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -143,7 +143,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99 if len(new_path) >= 2: nr_created += 1 - #print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections") + # print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections") # print(start_goal) agents_position = [sg[0] for sg in start_goal] diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 40e4ab7..33b9f96 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -300,7 +300,7 @@ class RailEnv(Environment): # if num_agents_in_target_position == self.number_of_agents: if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True - self.rewards_dict = [r + global_reward for r in self.rewards_dict] + self.rewards_dict = [0*r+global_reward for r in self.rewards_dict] # Reset the step actions (in case some agent doesn't 'register_action' # on the next step) -- GitLab