From 1169fcb628da3875ac432efe7495393579f583f4 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Mon, 13 May 2019 17:26:08 +0200
Subject: [PATCH] fixed formatting

---
 examples/training_navigation.py          | 6 ++++--
 flatland/core/env_observation_builder.py | 3 +--
 flatland/envs/generators.py              | 2 +-
 flatland/envs/rail_env.py                | 2 +-
 4 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index feb4796..ee360a1 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -115,6 +115,7 @@ for trials in range(1, n_trials + 1):
     # Reset environment
     obs = env.reset()
     final_obs = obs.copy()
+    final_obs_next =  obs.copy()
     for a in range(env.get_num_agents()):
         data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
         data = norm_obs_clip(data)
@@ -150,7 +151,8 @@ for trials in range(1, n_trials + 1):
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             if done[a]:
-                final_obs[a] = obs[a]
+                final_obs[a] = obs[a].copy()
+                final_obs_next[a] = next_obs[a].copy()
                 final_action_dict.update({a: action_dict[a]})
             if not demo and not done[a]:
                 agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
@@ -159,7 +161,7 @@ for trials in range(1, n_trials + 1):
         obs = next_obs.copy()
         if done['__all__']:
             env_done = 1
-            agent.step(final_obs[a], final_action_dict[a], all_rewards[a], next_obs[a], done[a])
+            agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
             break
     # Epsilon decay
     eps = max(eps_end, eps_decay * eps)  # decrease epsilon
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 547f349..8e7f2ae 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -139,8 +139,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         for neigh_direction in possible_directions:
             new_cell = self._new_position(position, neigh_direction)
 
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
-                new_cell[1] >= 0 and new_cell[1] < self.env.width:
+            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
 
                 desired_movement_from_new_cell = (neigh_direction + 2) % 4
 
diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 35af894..e9e2d3d 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -143,7 +143,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99
             if len(new_path) >= 2:
                 nr_created += 1
 
-        #print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
+        # print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections")
         # print(start_goal)
 
         agents_position = [sg[0] for sg in start_goal]
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 40e4ab7..33b9f96 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -300,7 +300,7 @@ class RailEnv(Environment):
         # if num_agents_in_target_position == self.number_of_agents:
         if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
             self.dones["__all__"] = True
-            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
+            self.rewards_dict = [0*r+global_reward for r in self.rewards_dict]
 
         # Reset the step actions (in case some agent doesn't 'register_action'
         # on the next step)
-- 
GitLab