From c8321d5d28c1ed8271bc09182db7a9d330e6a4b1 Mon Sep 17 00:00:00 2001
From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com>
Date: Sun, 28 Apr 2019 18:23:16 +0200
Subject: [PATCH] fixes to flake8 errors

---
 examples/play_model.py       | 43 +++++++++++++++++-------------------
 flatland/core/transitions.py |  8 +++----
 flatland/envs/rail_env.py    |  3 ++-
 tests/test_transitions.py    | 36 +++++++++++++++---------------
 4 files changed, 44 insertions(+), 46 deletions(-)

diff --git a/examples/play_model.py b/examples/play_model.py
index 8f0df7cd..82458e33 100644
--- a/examples/play_model.py
+++ b/examples/play_model.py
@@ -1,4 +1,4 @@
-from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator
+from flatland.envs.rail_env import RailEnv, complex_rail_generator
 # from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import RenderTool
 from flatland.baselines.dueling_double_dqn import Agent
@@ -6,7 +6,6 @@ from collections import deque
 import torch
 import random
 import numpy as np
-#import matplotlib.pyplot as plt
 import time
 
 
@@ -34,7 +33,7 @@ class Player(object):
         self.tStart = time.time()
         
         # Reset environment
-        #self.obs = self.env.reset()
+        # self.obs = self.env.reset()
         self.env.obs_builder.reset()
         self.obs = self.env._get_observations()
         for a in range(self.env.number_of_agents):
@@ -86,7 +85,6 @@ def max_lt(seq, val):
     return None
 
 
-
 def main(render=True, delay=0.0):
 
     random.seed(1)
@@ -94,7 +92,7 @@ def main(render=True, delay=0.0):
 
     # Example generate a rail given a manual specification,
     # a map of tuples (cell_type, rotation)
-    #transition_probability = [0.5,  # empty cell - Case 0
+    # transition_probability = [0.5,  # empty cell - Case 0
     #                        1.0,  # Case 1 - straight
     #                        1.0,  # Case 2 - simple switch
     #                        0.3,  # Case 3 - diamond crossing
@@ -113,7 +111,7 @@ def main(render=True, delay=0.0):
     # plt.figure(figsize=(5,5))
     # fRedis = redis.Redis()
 
-    handle = env.get_agent_handles()
+    # handle = env.get_agent_handles()
 
     state_size = 105
     action_size = 4
@@ -151,7 +149,7 @@ def main(render=True, delay=0.0):
         obs = env.reset()
 
         for a in range(env.number_of_agents):
-            norm = max(1, max_lt(obs[a],np.inf))
+            norm = max(1, max_lt(obs[a], np.inf))
             obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
 
         # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
@@ -161,9 +159,9 @@ def main(render=True, delay=0.0):
 
         # Run episode
         for step in range(50):
-            #if trials > 114:
-            #env_renderer.renderEnv(show=True)
-            #print(step)
+            # if trials > 114:
+            # env_renderer.renderEnv(show=True)
+            # print(step)
             # Action
             for a in range(env.number_of_agents):
                 action = agent.act(np.array(obs[a]), eps=eps)
@@ -187,7 +185,6 @@ def main(render=True, delay=0.0):
 
             iFrame += 1
 
-
             obs = next_obs.copy()
             if done['__all__']:
                 env_done = 1
@@ -201,23 +198,23 @@ def main(render=True, delay=0.0):
         dones_list.append((np.mean(done_window)))
 
         print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
-                '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
-                env.number_of_agents,
-                trials,
-                np.mean(scores_window),
-                100 * np.mean(done_window),
-                eps, action_prob/np.sum(action_prob)),
+               '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
+               env.number_of_agents,
+               trials,
+               np.mean(scores_window),
+               100 * np.mean(done_window),
+               eps, action_prob/np.sum(action_prob)),
             end=" ")
         if trials % 100 == 0:
             tNow = time.time()
             rFps = iFrame / (tNow - tStart)
             print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + 
-                    '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
-                    env.number_of_agents,
-                    trials,
-                    np.mean(scores_window),
-                    100 * np.mean(done_window),
-                    eps, rFps, action_prob / np.sum(action_prob)))
+                   '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
+                   env.number_of_agents,
+                   trials,
+                   np.mean(scores_window),
+                   100 * np.mean(done_window),
+                   eps, rFps, action_prob / np.sum(action_prob)))
             torch.save(agent.qnetwork_local.state_dict(),
                     '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
             action_prob = [1]*4
diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py
index ec9586ed..92620898 100644
--- a/flatland/core/transitions.py
+++ b/flatland/core/transitions.py
@@ -542,10 +542,10 @@ class RailEnvTransitions(Grid4Transitions):
 
     def print(self, cell_transition):
         print("  NESW")
-        print("N", format(cell_transition>>(3*4) & 0xF, '04b'))
-        print("E", format(cell_transition>>(2*4) & 0xF, '04b'))
-        print("S", format(cell_transition>>(1*4) & 0xF, '04b'))
-        print("W", format(cell_transition>>(0*4) & 0xF, '04b'))
+        print("N", format(cell_transition >> (3*4) & 0xF, '04b'))
+        print("E", format(cell_transition >> (2*4) & 0xF, '04b'))
+        print("S", format(cell_transition >> (1*4) & 0xF, '04b'))
+        print("W", format(cell_transition >> (0*4) & 0xF, '04b'))
 
     def is_valid(self, cell_transition):
         """
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 3672caf4..a7af40e4 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -281,7 +281,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
         #         - on failure goto step1 and retry with seed+1
         #     - [avoid crossing other start,goal positions] (optional)
         #
-        #   - [after X pairs] 
+        #   - [after X pairs]
         #     - find closest rail from start (Pa)
         #       - iterating outwards in a "circle" from start until an existing rail cell is hit
         #     - connect [start, Pa]
@@ -314,6 +314,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
                     continue
                 # check distance to existing points
                 sg_new = [start, goal]
+
                 def check_all_dist(sg_new):
                     for sg in start_goal:
                         for i in range(2):
diff --git a/tests/test_transitions.py b/tests/test_transitions.py
index 1d6ea966..2ebfc462 100644
--- a/tests/test_transitions.py
+++ b/tests/test_transitions.py
@@ -12,14 +12,14 @@ def test_is_valid_railenv_transitions():
     transition_list = rail_env_trans.transitions
 
     for t in transition_list:
-        assert(rail_env_trans.is_valid(t) == True)
+        assert(rail_env_trans.is_valid(t) is True)
         for i in range(3):
             rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
-            assert(rail_env_trans.is_valid(rot_trans) == True)
+            assert(rail_env_trans.is_valid(rot_trans) is True)
 
-    assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False)
-    assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False)
-    assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False)
+    assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
+    assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
+    assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
 
 
 def test_adding_new_valid_transition():
@@ -27,32 +27,32 @@ def test_adding_new_valid_transition():
     rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
 
     # adding straight
-    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True)
+    assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
 
     # adding valid right turn
-    assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True)
+    assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
     # adding valid left turn
-    assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True)
+    assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
-    rail_array[(5,5)] = rail_trans.transitions[2]
-    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
+    rail_array[(5, 5)] = rail_trans.transitions[2]
+    assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # should create #4 -> valid
-    rail_array[(5,5)] = rail_trans.transitions[3]
-    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True)
+    rail_array[(5, 5)] = rail_trans.transitions[3]
+    assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
 
     # adding invalid turn
-    rail_array[(5,5)] = rail_trans.transitions[7]
-    assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
+    rail_array[(5, 5)] = rail_trans.transitions[7]
+    assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
 
     # test path start condition
-    rail_array[(5,5)] = rail_trans.transitions[0]
-    assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True)
+    rail_array[(5, 5)] = rail_trans.transitions[0]
+    assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
 
     # test path end condition
-    rail_array[(5,5)] = rail_trans.transitions[0]
-    assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True)
+    rail_array[(5, 5)] = rail_trans.transitions[0]
+    assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
 
 
 def test_valid_railenv_transitions():
-- 
GitLab