Skip to content
Snippets Groups Projects
Commit c8321d5d authored by Mattias Ljungstrom's avatar Mattias Ljungstrom
Browse files

fixes to flake8 errors

parent 9b13b14f
No related branches found
No related tags found
No related merge requests found
from flatland.envs.rail_env import RailEnv, random_rail_generator, complex_rail_generator
from flatland.envs.rail_env import RailEnv, complex_rail_generator
# from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import RenderTool
from flatland.baselines.dueling_double_dqn import Agent
......@@ -6,7 +6,6 @@ from collections import deque
import torch
import random
import numpy as np
#import matplotlib.pyplot as plt
import time
......@@ -34,7 +33,7 @@ class Player(object):
self.tStart = time.time()
# Reset environment
#self.obs = self.env.reset()
# self.obs = self.env.reset()
self.env.obs_builder.reset()
self.obs = self.env._get_observations()
for a in range(self.env.number_of_agents):
......@@ -86,7 +85,6 @@ def max_lt(seq, val):
return None
def main(render=True, delay=0.0):
random.seed(1)
......@@ -94,7 +92,7 @@ def main(render=True, delay=0.0):
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
#transition_probability = [0.5, # empty cell - Case 0
# transition_probability = [0.5, # empty cell - Case 0
# 1.0, # Case 1 - straight
# 1.0, # Case 2 - simple switch
# 0.3, # Case 3 - diamond crossing
......@@ -113,7 +111,7 @@ def main(render=True, delay=0.0):
# plt.figure(figsize=(5,5))
# fRedis = redis.Redis()
handle = env.get_agent_handles()
# handle = env.get_agent_handles()
state_size = 105
action_size = 4
......@@ -151,7 +149,7 @@ def main(render=True, delay=0.0):
obs = env.reset()
for a in range(env.number_of_agents):
norm = max(1, max_lt(obs[a],np.inf))
norm = max(1, max_lt(obs[a], np.inf))
obs[a] = np.clip(np.array(obs[a]) / norm, -1, 1)
# env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
......@@ -161,9 +159,9 @@ def main(render=True, delay=0.0):
# Run episode
for step in range(50):
#if trials > 114:
#env_renderer.renderEnv(show=True)
#print(step)
# if trials > 114:
# env_renderer.renderEnv(show=True)
# print(step)
# Action
for a in range(env.number_of_agents):
action = agent.act(np.array(obs[a]), eps=eps)
......@@ -187,7 +185,6 @@ def main(render=True, delay=0.0):
iFrame += 1
obs = next_obs.copy()
if done['__all__']:
env_done = 1
......@@ -201,23 +198,23 @@ def main(render=True, delay=0.0):
dones_list.append((np.mean(done_window)))
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob/np.sum(action_prob)),
'\tEpsilon: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, action_prob/np.sum(action_prob)),
end=" ")
if trials % 100 == 0:
tNow = time.time()
rFps = iFrame / (tNow - tStart)
print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' +
'\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
'\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format(
env.number_of_agents,
trials,
np.mean(scores_window),
100 * np.mean(done_window),
eps, rFps, action_prob / np.sum(action_prob)))
torch.save(agent.qnetwork_local.state_dict(),
'../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
action_prob = [1]*4
......
......@@ -542,10 +542,10 @@ class RailEnvTransitions(Grid4Transitions):
def print(self, cell_transition):
print(" NESW")
print("N", format(cell_transition>>(3*4) & 0xF, '04b'))
print("E", format(cell_transition>>(2*4) & 0xF, '04b'))
print("S", format(cell_transition>>(1*4) & 0xF, '04b'))
print("W", format(cell_transition>>(0*4) & 0xF, '04b'))
print("N", format(cell_transition >> (3*4) & 0xF, '04b'))
print("E", format(cell_transition >> (2*4) & 0xF, '04b'))
print("S", format(cell_transition >> (1*4) & 0xF, '04b'))
print("W", format(cell_transition >> (0*4) & 0xF, '04b'))
def is_valid(self, cell_transition):
"""
......
......@@ -281,7 +281,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
# - on failure goto step1 and retry with seed+1
# - [avoid crossing other start,goal positions] (optional)
#
# - [after X pairs]
# - [after X pairs]
# - find closest rail from start (Pa)
# - iterating outwards in a "circle" from start until an existing rail cell is hit
# - connect [start, Pa]
......@@ -314,6 +314,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
......
......@@ -12,14 +12,14 @@ def test_is_valid_railenv_transitions():
transition_list = rail_env_trans.transitions
for t in transition_list:
assert(rail_env_trans.is_valid(t) == True)
assert(rail_env_trans.is_valid(t) is True)
for i in range(3):
rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
assert(rail_env_trans.is_valid(rot_trans) == True)
assert(rail_env_trans.is_valid(rot_trans) is True)
assert(rail_env_trans.is_valid(int('1111111111110010', 2)) == False)
assert(rail_env_trans.is_valid(int('1001111111110010', 2)) == False)
assert(rail_env_trans.is_valid(int('1001111001110110', 2)) == False)
assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
def test_adding_new_valid_transition():
......@@ -27,32 +27,32 @@ def test_adding_new_valid_transition():
rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
# adding straight
assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (6,5), (10,10)) == True)
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
# adding valid right turn
assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (5,6), (10,10)) == True)
assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
# adding valid left turn
assert(validate_new_transition(rail_trans, rail_array, (5,6), (5,5), (5,6), (10,10)) == True)
assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5,5)] = rail_trans.transitions[2]
assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
rail_array[(5, 5)] = rail_trans.transitions[2]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# should create #4 -> valid
rail_array[(5,5)] = rail_trans.transitions[3]
assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == True)
rail_array[(5, 5)] = rail_trans.transitions[3]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5,5)] = rail_trans.transitions[7]
assert(validate_new_transition(rail_trans, rail_array, (4,5), (5,5), (5,6), (10,10)) == False)
rail_array[(5, 5)] = rail_trans.transitions[7]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# test path start condition
rail_array[(5,5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, None, (5,5), (5,6), (10,10)) == True)
rail_array[(5, 5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
# test path end condition
rail_array[(5,5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, (5,4), (5,5), (6,5), (6,5)) == True)
rail_array[(5, 5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
def test_valid_railenv_transitions():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment