diff --git a/docs/conf.py b/docs/conf.py index 73a78330059a08e8f7118193b1181a6540524d24..f63d090bf3a9ecf7370bcf2fa5edc37afb019b79 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,16 +13,17 @@ # All configuration values have a default; values that are commented out # serve to show the default. +import os +import sys + # If extensions (or modules to document with autodoc) are in another # directory, add these directories to sys.path here. If the directory is # relative to the documentation root, use os.path.abspath to make it # absolute, like shown here. # import flatland -import os -import sys -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('..')) # -- General configuration --------------------------------------------- @@ -78,7 +79,6 @@ pygments_style = 'sphinx' # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False - # -- Options for HTML output ------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -86,7 +86,6 @@ todo_include_todos = False # html_theme = "sphinx_rtd_theme" - # Theme options are theme-specific and customize the look and feel of a # theme further. For a list of options available for each theme, see the # documentation. @@ -98,13 +97,11 @@ html_theme = "sphinx_rtd_theme" # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] - # -- Options for HTMLHelp output --------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'flatlanddoc' - # -- Options for LaTeX output ------------------------------------------ latex_elements = { @@ -134,7 +131,6 @@ latex_documents = [ u'S.P. Mohanty', 'manual'), ] - # -- Options for manual page output ------------------------------------ # One entry per manual page. List of tuples @@ -145,7 +141,6 @@ man_pages = [ [author], 1) ] - # -- Options for Texinfo output ---------------------------------------- # Grouping the document tree into Texinfo files. List of tuples diff --git a/examples/custom_observation_example.py b/examples/custom_observation_example.py index 3c4fc8194aec9385619dc917bef9b3dd22492d47..a491c4a4bae68d1d105a495d24c5b7d16daf72b3 100644 --- a/examples/custom_observation_example.py +++ b/examples/custom_observation_example.py @@ -1,10 +1,10 @@ import random +import numpy as np + +from flatland.core.env_observation_builder import ObservationBuilder from flatland.envs.generators import random_rail_generator from flatland.envs.rail_env import RailEnv -from flatland.core.env_observation_builder import ObservationBuilder - -import numpy as np random.seed(100) np.random.seed(100) @@ -18,7 +18,7 @@ class CustomObs(ObservationBuilder): return def get(self, handle): - observation = handle*np.ones((5,)) + observation = handle * np.ones((5,)) return observation diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py index 26beb61dbcff9820bd7979ae07218340a6aaacde..16ec480f4f97ca8d200dd5e80b7e5d7c7ece2218 100644 --- a/examples/custom_railmap_example.py +++ b/examples/custom_railmap_example.py @@ -23,6 +23,7 @@ def custom_rail_generator(): agents_target = [] return grid_map, agents_positions, agents_direction, agents_target + return generator diff --git a/examples/tkplay.py b/examples/tkplay.py index 9e37a26fd096e41a640c9daca2e40179f5f90cc9..7ab1c3234d31d5c083b068fec5aa71b47cbaa4e9 100644 --- a/examples/tkplay.py +++ b/examples/tkplay.py @@ -1,4 +1,3 @@ - from examples.play_model import Player from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv @@ -26,7 +25,7 @@ def tkmain(n_trials=2, n_steps=50, sGL="PIL"): env_renderer.renderEnv(show=True, frames=True, iEpisode=trials, iStep=step, action_dict=oPlayer.action_dict) - env_renderer.close_window() + env_renderer.close_window() if __name__ == "__main__": diff --git a/flatland/cli.py b/flatland/cli.py index 48c009e757b7c136616d520376774ca8cbf86260..f4d4317676d2bf3de0e2f6fe522d5a6e106741f0 100644 --- a/flatland/cli.py +++ b/flatland/cli.py @@ -2,6 +2,7 @@ """Console script for flatland.""" import sys + import click diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 2e34b4236943b36b167268828e3010e727108760..d5c68b47fcf1888a153d251331176de2b0f186e7 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -297,8 +297,8 @@ class GridTransitionMap(TransitionMap): self.grid = np.zeros((self.height, self.width), dtype=np.uint64) self.grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] + 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), + 0:min(self.width, new_width)] def is_cell_valid(self, rcPos): cell_transition = self.grid[tuple(rcPos)] @@ -336,8 +336,8 @@ class GridTransitionMap(TransitionMap): lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) # gDirIn = g2binTrans.any(axis=1) # inbound directions as boolean array (4) - gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) - giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int + gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) + giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int # loop over available outbound directions (indices) for rcPos for iDirOut in giDirOut: diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index add047b6c7895e391211258bb10561110e0f1a19..c3be1e76baaf54c40d0dfae369e373e2e978c41b 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -319,7 +319,7 @@ class Grid4Transitions(Transitions): value = self.set_transitions(value, i, block_tuple) # Rotate the 4-bits blocks - value = ((value & (2**(rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) + value = ((value & (2 ** (rotation * 4) - 1)) << ((4 - rotation) * 4)) | (value >> (rotation * 4)) cell_transition = value return cell_transition @@ -499,7 +499,7 @@ class Grid8Transitions(Transitions): value = self.set_transitions(value, i, block_tuple) # Rotate the 8bits blocks - value = ((value & (2**(rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) + value = ((value & (2 ** (rotation * 8) - 1)) << ((8 - rotation) * 8)) | (value >> (rotation * 8)) cell_transition = value @@ -587,9 +587,9 @@ class RailEnvTransitions(Grid4Transitions): sRepr = " ".join([ "{}:{}".format(sDir, sbinTrans[i:(i + 4)]) for i, sDir in - zip( - range(0, len(sbinTrans), 4), - self.lsDirs)]) # NESW + zip( + range(0, len(sbinTrans), 4), + self.lsDirs)]) # NESW return sRepr if version == 1: diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index a66a27bf0e6dce55dfa878687ac1328aee63a6ea..87b7955fb4a5ddd18f5d15ba3d16c5e910e77fdf 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,7 +1,9 @@ - -from attr import attrs, attrib from itertools import starmap + import numpy as np +from attr import attrs, attrib + + # from flatland.envs.rail_env import RailEnv @@ -16,7 +18,7 @@ class EnvDescription(object): height = attrib() width = attrib() rail_generator = attrib() - obs_builder = attrib() # not sure if this should closer to the agent than the env + obs_builder = attrib() # not sure if this should closer to the agent than the env @attrs @@ -41,7 +43,7 @@ class EnvAgentStatic(object): def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False]*len(positions)))) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions)))) def to_list(self): @@ -78,7 +80,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ - self.position, self.direction, self.target, self.handle, + self.position, self.direction, self.target, self.handle, self.old_direction, self.old_position, self.moving] @classmethod diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9f2dfee3e89b88009d8489faaa6fb0870e01204b..eec9a7668ecb190d6389f917532360f0c34b6eb5 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,19 +1,21 @@ import numpy as np -# from flatland.core.env import Environment -# from flatland.envs.observations import TreeObsForRailEnv - -from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail +# from flatland.core.env import Environment +# from flatland.envs.observations import TreeObsForRailEnv + + def empty_rail_generator(): """ Returns a generator which returns an empty rail mail with no agents. Primarily used by the editor """ + def generator(width, height, num_agents=0, num_resets=0): rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) @@ -21,6 +23,7 @@ def empty_rail_generator(): rail_array.fill(0) return grid_map, [], [], [] + return generator @@ -41,7 +44,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= def generator(width, height, num_agents, num_resets=0): if num_agents > nr_start_goal: - num_agents = nr_start_goal + num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index dc5802765c39f337341b775e6c4ffc2591125709..e5a86ca3318f0dbba90cf1998264f9609dee2566 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -7,9 +7,10 @@ a GridTransitionMap object. # TODO: _ this is a global method --> utils or remove later # from inspect import currentframe +from enum import IntEnum + import msgpack import numpy as np -from enum import IntEnum from flatland.core.env import Environment from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent diff --git a/flatland/flatland.py b/flatland/flatland.py index 6d61ba81dcba4b66584cfc43175e65a6485345e6..7fbbae4f9c58882c3754a89675312f3c1430ffd8 100644 --- a/flatland/flatland.py +++ b/flatland/flatland.py @@ -1,4 +1,3 @@ # -*- coding: utf-8 -*- """Main module.""" - diff --git a/flatland/utils/svg.py b/flatland/utils/svg.py index e7b2cebb606089f5937d37c8e4d6381ada9fc211..c3d6805566a98fa364f49d3c4372bcd74f89c58d 100644 --- a/flatland/utils/svg.py +++ b/flatland/utils/svg.py @@ -1,8 +1,8 @@ - +import copy +import re import svgutils -import re -import copy + from flatland.core.transitions import RailEnvTransitions @@ -60,7 +60,7 @@ class SVG(object): sNewStyles = "\n" for sKey, sValue in self.dStyles.items(): if sKey == style_name: - sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0*col))[2:4]) for col in color[0:3]]) + ";" + sValue = "fill:#" + "".join([('{:#04x}'.format(int(255.0 * col))[2:4]) for col in color[0:3]]) + ";" sNewStyle = "\t.st" + sKey + "{" + sValue + "}\n" sNewStyles += sNewStyle @@ -111,6 +111,7 @@ class Track(object): The directions and images are also rotated by 90, 180 & 270 degrees. (There is some redundancy in this process, given the images provided) """ + def __init__(self): dFiles = { "": "Background_#9CCB89.svg", diff --git a/make_coverage.py b/make_coverage.py index 3d93f39ebed5a50c877a9ef72ab4b1a4602f2f53..27da5058263c8a10d61f73c9949bb99aa0ea13fa 100644 --- a/make_coverage.py +++ b/make_coverage.py @@ -1,13 +1,15 @@ #!/usr/bin/env python import os -import webbrowser import subprocess +import webbrowser from urllib.request import pathname2url + def browser(pathname): webbrowser.open("file:" + pathname2url(os.path.abspath(pathname))) + subprocess.call(['coverage', 'run', '--source', 'flatland', '-m', 'pytest']) subprocess.call(['coverage', 'report', '-m']) subprocess.call(['coverage', 'html']) diff --git a/make_docs.py b/make_docs.py index 5d27230fc82a5eb82550d5bef44a7f6f0942d368..be36a7bb7caed714445399b73f958ab42172242c 100644 --- a/make_docs.py +++ b/make_docs.py @@ -1,13 +1,15 @@ #!/usr/bin/env python import os -import webbrowser import subprocess +import webbrowser from urllib.request import pathname2url + def browser(pathname): webbrowser.open("file:" + pathname2url(os.path.abspath(pathname))) + def remove_exists(filename): try: os.remove(filename) diff --git a/setup.py b/setup.py index 649a4ac69bdf11c8800a571bb067e1ed16141129..e1a84fd454a57a70f0db43f1445e6db5ea6dd934 100644 --- a/setup.py +++ b/setup.py @@ -3,13 +3,10 @@ """The setup script.""" import os -from setuptools import setup, find_packages -import sys -import os - import platform +import sys - +from setuptools import setup, find_packages with open('README.rst') as readme_file: readme = readme_file.read() @@ -17,10 +14,6 @@ with open('README.rst') as readme_file: with open('HISTORY.rst') as history_file: history = history_file.read() - - - - # install pycairo on Windows if os.name == 'nt': p = platform.architecture() @@ -51,13 +44,14 @@ if os.name == 'nt': import site import ctypes.util + default_os_path = os.environ['PATH'] os.environ['PATH'] = '' for s in site.getsitepackages(): - os.environ['PATH'] = os.environ['PATH']+';' + s+'\\cairo' - os.environ['PATH'] = os.environ['PATH']+';' + default_os_path + os.environ['PATH'] = os.environ['PATH'] + ';' + s + '\\cairo' + os.environ['PATH'] = os.environ['PATH'] + ';' + default_os_path print(os.environ['PATH']) - if ctypes.util.find_library('cairo')is not None: + if ctypes.util.find_library('cairo') is not None: print("cairo installed: OK") else: try: @@ -69,7 +63,7 @@ else: def get_all_svg_files(directory='./svg/'): ret = [] for f in os.listdir(directory): - ret.append(directory+f) + ret.append(directory + f) return ret diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index 57dad857546d21009b881f4bf26e085873eaa655..84531bfa5f96f114415c94d28b9c83c04b41d598 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -1,7 +1,6 @@ - -from flatland.envs.rail_env import RailEnv # from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgentStatic +from flatland.envs.rail_env import RailEnv def test_load_env(): @@ -11,5 +10,3 @@ def test_load_env(): agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) assert env.get_num_agents() == 1 - - diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index ad393634aa4ab156e3a305dd8654d1d588127805..2f0438c9702f631fdf82a828d9c00b73cf6b1469 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -3,10 +3,10 @@ import numpy as np -from flatland.envs.observations import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions -from flatland.envs.rail_env import RailEnv from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv """Tests for `flatland` package.""" @@ -70,7 +70,7 @@ def test_global_obs(): # env_renderer.renderEnv(show=True) # global_obs.reset() - assert(global_obs[0][0].shape == rail_map.shape + (16,)) + assert (global_obs[0][0].shape == rail_map.shape + (16,)) rail_map_recons = np.zeros_like(rail_map) for i in range(global_obs[0][0].shape[0]): @@ -78,11 +78,11 @@ def test_global_obs(): rail_map_recons[i, j] = int( ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) - assert(rail_map_recons.all() == rail_map.all()) + assert (rail_map_recons.all() == rail_map.all()) # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert(np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) + assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) def main(): diff --git a/tests/test_transitions.py b/tests/test_transitions.py index 69a7953c5b50bc5411c52b58c33448228c91620f..86f015f8c800c6d16b5ce4a0827562ad601996cd 100644 --- a/tests/test_transitions.py +++ b/tests/test_transitions.py @@ -2,10 +2,11 @@ # -*- coding: utf-8 -*- """Tests for `flatland` package.""" +import numpy as np + from flatland.core.transitions import RailEnvTransitions, Grid8Transitions # from flatland.envs.rail_env import validate_new_transition from flatland.envs.env_utils import validate_new_transition -import numpy as np def test_is_valid_railenv_transitions(): @@ -13,14 +14,14 @@ def test_is_valid_railenv_transitions(): transition_list = rail_env_trans.transitions for t in transition_list: - assert(rail_env_trans.is_valid(t) is True) + assert (rail_env_trans.is_valid(t) is True) for i in range(3): rot_trans = rail_env_trans.rotate_transition(t, 90 * i) - assert(rail_env_trans.is_valid(rot_trans) is True) + assert (rail_env_trans.is_valid(rot_trans) is True) - assert(rail_env_trans.is_valid(int('1111111111110010', 2)) is False) - assert(rail_env_trans.is_valid(int('1001111111110010', 2)) is False) - assert(rail_env_trans.is_valid(int('1001111001110110', 2)) is False) + assert (rail_env_trans.is_valid(int('1111111111110010', 2)) is False) + assert (rail_env_trans.is_valid(int('1001111111110010', 2)) is False) + assert (rail_env_trans.is_valid(int('1001111001110110', 2)) is False) def test_adding_new_valid_transition(): @@ -28,32 +29,32 @@ def test_adding_new_valid_transition(): rail_array = np.zeros(shape=(15, 15), dtype=np.uint16) # adding straight - assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) + assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (6, 5), (10, 10)) is True) # adding valid right turn - assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) + assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (5, 6), (10, 10)) is True) # adding valid left turn - assert(validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) + assert (validate_new_transition(rail_trans, rail_array, (5, 6), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn rail_array[(5, 5)] = rail_trans.transitions[2] - assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # should create #4 -> valid rail_array[(5, 5)] = rail_trans.transitions[3] - assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) + assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is True) # adding invalid turn rail_array[(5, 5)] = rail_trans.transitions[7] - assert(validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) + assert (validate_new_transition(rail_trans, rail_array, (4, 5), (5, 5), (5, 6), (10, 10)) is False) # test path start condition rail_array[(5, 5)] = rail_trans.transitions[0] - assert(validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) + assert (validate_new_transition(rail_trans, rail_array, None, (5, 5), (5, 6), (10, 10)) is True) # test path end condition rail_array[(5, 5)] = rail_trans.transitions[0] - assert(validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) + assert (validate_new_transition(rail_trans, rail_array, (5, 4), (5, 5), (6, 5), (6, 5)) is True) def test_valid_railenv_transitions(): @@ -65,48 +66,48 @@ def test_valid_railenv_transitions(): # 'W': 3} for i in range(2): - assert(rail_env_trans.get_transitions( - int('1100110000110011', 2), i) == (1, 1, 0, 0)) - assert(rail_env_trans.get_transitions( - int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1)) + assert (rail_env_trans.get_transitions( + int('1100110000110011', 2), i) == (1, 1, 0, 0)) + assert (rail_env_trans.get_transitions( + int('1100110000110011', 2), 2 + i) == (0, 0, 1, 1)) no_transition_cell = int('0000000000000000', 2) for i in range(4): - assert(rail_env_trans.get_transitions( - no_transition_cell, i) == (0, 0, 0, 0)) + assert (rail_env_trans.get_transitions( + no_transition_cell, i) == (0, 0, 0, 0)) # Facing south, going south north_south_transition = rail_env_trans.set_transitions(no_transition_cell, 2, (0, 0, 1, 0)) - assert(rail_env_trans.set_transition( - north_south_transition, 2, 2, 0) == no_transition_cell) - assert(rail_env_trans.get_transition( - north_south_transition, 2, 2)) + assert (rail_env_trans.set_transition( + north_south_transition, 2, 2, 0) == no_transition_cell) + assert (rail_env_trans.get_transition( + north_south_transition, 2, 2)) # Facing north, going east south_east_transition = \ rail_env_trans.set_transition(no_transition_cell, 0, 1, 1) - assert(rail_env_trans.get_transition( - south_east_transition, 0, 1)) + assert (rail_env_trans.get_transition( + south_east_transition, 0, 1)) # The opposite transitions are not feasible - assert(not rail_env_trans.get_transition( - north_south_transition, 2, 0)) - assert(not rail_env_trans.get_transition( - south_east_transition, 2, 1)) + assert (not rail_env_trans.get_transition( + north_south_transition, 2, 0)) + assert (not rail_env_trans.get_transition( + south_east_transition, 2, 1)) east_west_transition = rail_env_trans.rotate_transition(north_south_transition, 90) north_west_transition = rail_env_trans.rotate_transition(south_east_transition, 180) # Facing west, going west - assert(rail_env_trans.get_transition( - east_west_transition, 3, 3)) + assert (rail_env_trans.get_transition( + east_west_transition, 3, 3)) # Facing south, going west - assert(rail_env_trans.get_transition( - north_west_transition, 2, 3)) + assert (rail_env_trans.get_transition( + north_west_transition, 2, 3)) - assert(south_east_transition == rail_env_trans.rotate_transition( - south_east_transition, 360)) + assert (south_east_transition == rail_env_trans.rotate_transition( + south_east_transition, 360)) def test_diagonal_transitions(): @@ -114,12 +115,12 @@ def test_diagonal_transitions(): # Facing north, going north-east south_northeast_transition = int('01000000' + '0' * 8 * 7, 2) - assert(diagonal_trans_env.get_transitions( - south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0)) + assert (diagonal_trans_env.get_transitions( + south_northeast_transition, 0) == (0, 1, 0, 0, 0, 0, 0, 0)) # Allowing transition from north to southwest: Facing south, going SW north_southwest_transition = \ diagonal_trans_env.set_transitions(int('0' * 64, 2), 4, (0, 0, 0, 0, 0, 1, 0, 0)) - assert(diagonal_trans_env.rotate_transition( - south_northeast_transition, 180) == north_southwest_transition) + assert (diagonal_trans_env.rotate_transition( + south_northeast_transition, 180) == north_southwest_transition)