From c8321d5d28c1ed8271bc09182db7a9d330e6a4b1 Mon Sep 17 00:00:00 2001 From: Mattias Ljungstrom <mattias.ljungstrom@gmail.com> Date: Sun, 28 Apr 2019 18:23:16 +0200 Subject: [PATCH] fixes to flake8 errors --- examples/play_model.py | 43 +++++++++++++++++------------------- flatland/core/transitions.py | 8 +++---- flatland/envs/rail_env.py | 3 ++- tests/test_transitions.py | 36 +++++++++++++++--------------- 4 files changed, 44 insertions(+), 46 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index 8f0df7cd..82458e33 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -1,4 +1,4 @@ -from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator +from flatland.envs.rail_env import RailEnv, complex_rail_generator # from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import RenderTool from flatland.baselines.dueling_double_dqn import Agent @@ -6,7 +6,6 @@ from collections import deque import torch import random import numpy as np -#import matplotlib.pyplot as plt import time @@ -34,7 +33,7 @@ class Player(object): self.tStart = time.time() # Reset environment - #self.obs = self.env.reset() + # self.obs = self.env.reset() self.env.obs_builder.reset() self.obs = self.env._get_observations() for a in range(self.env.number_of_agents): @@ -86,7 +85,6 @@ def max_lt(seq, val): return None - def main(render=True, delay=0.0): random.seed(1) @@ -94,7 +92,7 @@ def main(render=True, delay=0.0): # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) - #transition_probability = [0.5, # empty cell - Case 0 + # transition_probability = [0.5, # empty cell - Case 0 # 1.0, # Case 1 - straight # 1.0, # Case 2 - simple switch # 0.3, # Case 3 - diamond crossing @@ -113,7 +111,7 @@ def main(render=True, delay=0.0): # plt.figure(figsize=(5,5)) # fRedis = redis.Redis() - handle = env.get_agent_handles() + # handle = env.get_agent_handles() state_size = 105 action_size = 4 @@ -151,7 +149,7 @@ def main(render=True, delay=0.0): obs = env.reset() for a in range(env.number_of_agents): - norm = max(1, max_lt(obs[a],np.inf)) + norm = max(1, max_lt(obs[a], np.inf)) obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -161,9 +159,9 @@ def main(render=True, delay=0.0): # Run episode for step in range(50): - #if trials > 114: - #env_renderer.renderEnv(show=True) - #print(step) + # if trials > 114: + # env_renderer.renderEnv(show=True) + # print(step) # Action for a in range(env.number_of_agents): action = agent.act(np.array(obs[a]), eps=eps) @@ -187,7 +185,6 @@ def main(render=True, delay=0.0): iFrame += 1 - obs = next_obs.copy() if done['__all__']: env_done = 1 @@ -201,23 +198,23 @@ def main(render=True, delay=0.0): dones_list.append((np.mean(done_window))) print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( - env.number_of_agents, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, action_prob/np.sum(action_prob)), + '\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format( + env.number_of_agents, + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, action_prob/np.sum(action_prob)), end=" ") if trials % 100 == 0: tNow = time.time() rFps = iFrame / (tNow - tStart) print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + - '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( - env.number_of_agents, - trials, - np.mean(scores_window), - 100 * np.mean(done_window), - eps, rFps, action_prob / np.sum(action_prob))) + '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( + env.number_of_agents, + trials, + np.mean(scores_window), + 100 * np.mean(done_window), + eps, rFps, action_prob / np.sum(action_prob))) torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') action_prob = [1]*4 diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index ec9586ed..92620898 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -542,10 +542,10 @@ class RailEnvTransitions(Grid4Transitions): def print(self, cell_transition): print(" NESW") - print("N", format(cell_transition>>(3*4) & 0xF, '04b')) - print("E", format(cell_transition>>(2*4) & 0xF, '04b')) - print("S", format(cell_transition>>(1*4) & 0xF, '04b')) - print("W", format(cell_transition>>(0*4) & 0xF, '04b')) + print("N", format(cell_transition >> (3*4) & 0xF, '04b')) + print("E", format(cell_transition >> (2*4) & 0xF, '04b')) + print("S", format(cell_transition >> (1*4) & 0xF, '04b')) + print("W", format(cell_transition >> (0*4) & 0xF, '04b')) def is_valid(self, cell_transition): """ diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3672caf4..a7af40e4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -281,7 +281,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # - on failure goto step1 and retry with seed+1 # - [avoid crossing other start,goal positions] (optional) # - # - [after X pairs] + # - [after X pairs] # - find closest rail from start (Pa) # - iterating outwards in a "circle" from start until an existing rail cell is hit # - connect [start, Pa] @@ -314,6 +314,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): continue # check distance to existing points sg_new = [start, goal] + def check_all_dist(sg_new): for sg in start_goal: for i in range(2): diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 1d6ea966..2ebfc462 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -12,14 +12,14 @@ def test_is_valid_railenv_transitions(): transition_list = rail_env_trans.transitions for t in transition_list: - assert(rail_env_trans.is_valid(t) == True) + assert(rail_env_trans.is_valid(t) is True) for i in range(3): rot_trans = rail_env_trans.rotate_transition(t, 90 * i) - assert(rail_env_trans.is_valid(rot_trans) == True) + assert(rail_env_trans.is_valid(rot_trans) is True) - assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False) - assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False) - assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False) + assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False) + assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False) + assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False) def test_adding_new_valid_transition(): @@ -27,32 +27,32 @@ def test_adding_new_valid_transition(): rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) # adding straight - assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True) + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn - assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True) + assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn - assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True) + assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5,5)] = rail_trans.transitions[2] - assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + rail_array[(5, 5)] = rail_trans.transitions[2] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid - rail_array[(5,5)] = rail_trans.transitions[3] - assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True) + rail_array[(5, 5)] = rail_trans.transitions[3] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn - rail_array[(5,5)] = rail_trans.transitions[7] - assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False) + rail_array[(5, 5)] = rail_trans.transitions[7] + assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition - rail_array[(5,5)] = rail_trans.transitions[0] - assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True) + rail_array[(5, 5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition - rail_array[(5,5)] = rail_trans.transitions[0] - assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True) + rail_array[(5, 5)] = rail_trans.transitions[0] + assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions(): -- GitLab