Skip to content
Snippets Groups Projects
Commit 43cf00ff authored by u214892's avatar u214892
Browse files

formatted everything with IntelliJ/PyCharm formatter, optimizing imports

parent f296c0e3
No related branches found
No related tags found
No related merge requests found
Showing with 102 additions and 104 deletions
......@@ -13,16 +13,17 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
import os
import sys
# If extensions (or modules to document with autodoc) are in another
# directory, add these directories to sys.path here. If the directory is
# relative to the documentation root, use os.path.abspath to make it
# absolute, like shown here.
#
import flatland
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
sys.path.insert(0, os.path.abspath('..'))
# -- General configuration ---------------------------------------------
......@@ -78,7 +79,6 @@ pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output -------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
......@@ -86,7 +86,6 @@ todo_include_todos = False
#
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a
# theme further. For a list of options available for each theme, see the
# documentation.
......@@ -98,13 +97,11 @@ html_theme = "sphinx_rtd_theme"
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# -- Options for HTMLHelp output ---------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'flatlanddoc'
# -- Options for LaTeX output ------------------------------------------
latex_elements = {
......@@ -134,7 +131,6 @@ latex_documents = [
u'S.P. Mohanty', 'manual'),
]
# -- Options for manual page output ------------------------------------
# One entry per manual page. List of tuples
......@@ -145,7 +141,6 @@ man_pages = [
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
......
import random
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.generators import random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.core.env_observation_builder import ObservationBuilder
import numpy as np
random.seed(100)
np.random.seed(100)
......@@ -18,7 +18,7 @@ class CustomObs(ObservationBuilder):
return
def get(self, handle):
observation = handle*np.ones((5,))
observation = handle * np.ones((5,))
return observation
......
......@@ -23,6 +23,7 @@ def custom_rail_generator():
agents_target = []
return grid_map, agents_positions, agents_direction, agents_target
return generator
......
from examples.play_model import Player
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
......@@ -26,7 +25,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"):
env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step,
action_dict=oPlayer.action_dict)
env_renderer.close_window()
env_renderer.close_window()
if __name__ == "__main__":
......
......@@ -2,6 +2,7 @@
"""Console script for flatland."""
import sys
import click
......
......@@ -297,8 +297,8 @@ class GridTransitionMap(TransitionMap):
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
self.grid[0:min(self.height, new_height),
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
def is_cell_valid(self, rcPos):
cell_transition = self.grid[tuple(rcPos)]
......@@ -336,8 +336,8 @@ class GridTransitionMap(TransitionMap):
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
# gDirIn = g2binTrans.any(axis=1) # inbound directions as boolean array (4)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
......
......@@ -319,7 +319,7 @@ class Grid4Transitions(Transitions):
value = self.set_transitions(value, i, block_tuple)
# Rotate the 4-bits blocks
value = ((value & (2**(rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4))
cell_transition = value
return cell_transition
......@@ -499,7 +499,7 @@ class Grid8Transitions(Transitions):
value = self.set_transitions(value, i, block_tuple)
# Rotate the 8bits blocks
value = ((value & (2**(rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8))
cell_transition = value
......@@ -587,9 +587,9 @@ class RailEnvTransitions(Grid4Transitions):
sRepr = " ".join([
"{}:{}".format(sDir, sbinTrans[i:(i + 4)])
for i, sDir in
zip(
range(0, len(sbinTrans), 4),
self.lsDirs)]) # NESW
zip(
range(0, len(sbinTrans), 4),
self.lsDirs)]) # NESW
return sRepr
if version == 1:
......
from attr import attrs, attrib
from itertools import starmap
import numpy as np
from attr import attrs, attrib
# from flatland.envs.rail_env import RailEnv
......@@ -16,7 +18,7 @@ class EnvDescription(object):
height = attrib()
width = attrib()
rail_generator = attrib()
obs_builder = attrib() # not sure if this should closer to the agent than the env
obs_builder = attrib() # not sure if this should closer to the agent than the env
@attrs
......@@ -41,7 +43,7 @@ class EnvAgentStatic(object):
def from_lists(cls, positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions))))
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions))))
def to_list(self):
......@@ -78,7 +80,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving]
@classmethod
......
import numpy as np
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
def empty_rail_generator():
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width, height, num_agents=0, num_resets=0):
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......@@ -21,6 +23,7 @@ def empty_rail_generator():
rail_array.fill(0)
return grid_map, [], [], []
return generator
......@@ -41,7 +44,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......
......@@ -7,9 +7,10 @@ a GridTransitionMap object.
# TODO: _ this is a global method --> utils or remove later
# from inspect import currentframe
from enum import IntEnum
import msgpack
import numpy as np
from enum import IntEnum
from flatland.core.env import Environment
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
......
# -*- coding: utf-8 -*-
"""Main module."""
import copy
import re
import svgutils
import re
import copy
from flatland.core.transitions import RailEnvTransitions
......@@ -60,7 +60,7 @@ class SVG(object):
sNewStyles = "\n"
for sKey, sValue in self.dStyles.items():
if sKey == style_name:
sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";"
sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0 * col))[2:4]) for col in color[0:3]]) + ";"
sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n"
sNewStyles += sNewStyle
......@@ -111,6 +111,7 @@ class Track(object):
The directions and images are also rotated by 90, 180 & 270 degrees.
(There is some redundancy in this process, given the images provided)
"""
def __init__(self):
dFiles = {
"": "Background_#9CCB89.svg",
......
#!/usr/bin/env python
import os
import webbrowser
import subprocess
import webbrowser
from urllib.request import pathname2url
def browser(pathname):
webbrowser.open("file:" + pathname2url(os.path.abspath(pathname)))
subprocess.call(['coverage', 'run', '--source', 'flatland', '-m', 'pytest'])
subprocess.call(['coverage', 'report', '-m'])
subprocess.call(['coverage', 'html'])
......
#!/usr/bin/env python
import os
import webbrowser
import subprocess
import webbrowser
from urllib.request import pathname2url
def browser(pathname):
webbrowser.open("file:" + pathname2url(os.path.abspath(pathname)))
def remove_exists(filename):
try:
os.remove(filename)
......
......@@ -3,13 +3,10 @@
"""The setup script."""
import os
from setuptools import setup, find_packages
import sys
import os
import platform
import sys
from setuptools import setup, find_packages
with open('README.rst') as readme_file:
readme = readme_file.read()
......@@ -17,10 +14,6 @@ with open('README.rst') as readme_file:
with open('HISTORY.rst') as history_file:
history = history_file.read()
# install pycairo on Windows
if os.name == 'nt':
p = platform.architecture()
......@@ -51,13 +44,14 @@ if os.name == 'nt':
import site
import ctypes.util
default_os_path = os.environ['PATH']
os.environ['PATH'] = ''
for s in site.getsitepackages():
os.environ['PATH'] = os.environ['PATH']+';' + s+'\\cairo'
os.environ['PATH'] = os.environ['PATH']+';' + default_os_path
os.environ['PATH'] = os.environ['PATH'] + ';' + s + '\\cairo'
os.environ['PATH'] = os.environ['PATH'] + ';' + default_os_path
print(os.environ['PATH'])
if ctypes.util.find_library('cairo')is not None:
if ctypes.util.find_library('cairo') is not None:
print("cairo installed: OK")
else:
try:
......@@ -69,7 +63,7 @@ else:
def get_all_svg_files(directory='./svg/'):
ret = []
for f in os.listdir(directory):
ret.append(directory+f)
ret.append(directory + f)
return ret
......
from flatland.envs.rail_env import RailEnv
# from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.rail_env import RailEnv
def test_load_env():
......@@ -11,5 +10,3 @@ def test_load_env():
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
......@@ -3,10 +3,10 @@
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
"""Tests for `flatland` package."""
......@@ -70,7 +70,7 @@ def test_global_obs():
# env_renderer.renderEnv(show=True)
# global_obs.reset()
assert(global_obs[0][0].shape == rail_map.shape + (16,))
assert (global_obs[0][0].shape == rail_map.shape + (16,))
rail_map_recons = np.zeros_like(rail_map)
for i in range(global_obs[0][0].shape[0]):
......@@ -78,11 +78,11 @@ def test_global_obs():
rail_map_recons[i, j] = int(
''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2)
assert(rail_map_recons.all() == rail_map.all())
assert (rail_map_recons.all() == rail_map.all())
# If this assertion is wrong, it means that the observation returned
# places the agent on an empty cell
assert(np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0)
def main():
......
......@@ -2,10 +2,11 @@
# -*- coding: utf-8 -*-
"""Tests for `flatland` package."""
import numpy as np
from flatland.core.transitions import RailEnvTransitions, Grid8Transitions
# from flatland.envs.rail_env import validate_new_transition
from flatland.envs.env_utils import validate_new_transition
import numpy as np
def test_is_valid_railenv_transitions():
......@@ -13,14 +14,14 @@ def test_is_valid_railenv_transitions():
transition_list = rail_env_trans.transitions
for t in transition_list:
assert(rail_env_trans.is_valid(t) is True)
assert (rail_env_trans.is_valid(t) is True)
for i in range(3):
rot_trans = rail_env_trans.rotate_transition(t, 90 * i)
assert(rail_env_trans.is_valid(rot_trans) is True)
assert (rail_env_trans.is_valid(rot_trans) is True)
assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False)
assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False)
assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False)
def test_adding_new_valid_transition():
......@@ -28,32 +29,32 @@ def test_adding_new_valid_transition():
rail_array = np.zeros(shape=(15, 15), dtype=np.uint16)
# adding straight
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True)
# adding valid right turn
assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True)
# adding valid left turn
assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5, 5)] = rail_trans.transitions[2]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# should create #4 -> valid
rail_array[(5, 5)] = rail_trans.transitions[3]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True)
# adding invalid turn
rail_array[(5, 5)] = rail_trans.transitions[7]
assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False)
# test path start condition
rail_array[(5, 5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True)
# test path end condition
rail_array[(5, 5)] = rail_trans.transitions[0]
assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True)
def test_valid_railenv_transitions():
......@@ -65,48 +66,48 @@ def test_valid_railenv_transitions():
# 'W': 3}
for i in range(2):
assert(rail_env_trans.get_transitions(
int('1100110000110011', 2), i) == (1, 1, 0, 0))
assert(rail_env_trans.get_transitions(
int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1))
assert (rail_env_trans.get_transitions(
int('1100110000110011', 2), i) == (1, 1, 0, 0))
assert (rail_env_trans.get_transitions(
int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1))
no_transition_cell = int('0000000000000000', 2)
for i in range(4):
assert(rail_env_trans.get_transitions(
no_transition_cell, i) == (0, 0, 0, 0))
assert (rail_env_trans.get_transitions(
no_transition_cell, i) == (0, 0, 0, 0))
# Facing south, going south
north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0))
assert(rail_env_trans.set_transition(
north_south_transition, 2, 2, 0) == no_transition_cell)
assert(rail_env_trans.get_transition(
north_south_transition, 2, 2))
assert (rail_env_trans.set_transition(
north_south_transition, 2, 2, 0) == no_transition_cell)
assert (rail_env_trans.get_transition(
north_south_transition, 2, 2))
# Facing north, going east
south_east_transition = \
rail_env_trans.set_transition(no_transition_cell, 0, 1, 1)
assert(rail_env_trans.get_transition(
south_east_transition, 0, 1))
assert (rail_env_trans.get_transition(
south_east_transition, 0, 1))
# The opposite transitions are not feasible
assert(not rail_env_trans.get_transition(
north_south_transition, 2, 0))
assert(not rail_env_trans.get_transition(
south_east_transition, 2, 1))
assert (not rail_env_trans.get_transition(
north_south_transition, 2, 0))
assert (not rail_env_trans.get_transition(
south_east_transition, 2, 1))
east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90)
north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180)
# Facing west, going west
assert(rail_env_trans.get_transition(
east_west_transition, 3, 3))
assert (rail_env_trans.get_transition(
east_west_transition, 3, 3))
# Facing south, going west
assert(rail_env_trans.get_transition(
north_west_transition, 2, 3))
assert (rail_env_trans.get_transition(
north_west_transition, 2, 3))
assert(south_east_transition == rail_env_trans.rotate_transition(
south_east_transition, 360))
assert (south_east_transition == rail_env_trans.rotate_transition(
south_east_transition, 360))
def test_diagonal_transitions():
......@@ -114,12 +115,12 @@ def test_diagonal_transitions():
# Facing north, going north-east
south_northeast_transition = int('01000000' + '0' * 8 * 7, 2)
assert(diagonal_trans_env.get_transitions(
south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
assert (diagonal_trans_env.get_transitions(
south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0))
# Allowing transition from north to southwest: Facing south, going SW
north_southwest_transition = \
diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0))
assert(diagonal_trans_env.rotate_transition(
south_northeast_transition, 180) == north_southwest_transition)
assert (diagonal_trans_env.rotate_transition(
south_northeast_transition, 180) == north_southwest_transition)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment