Skip to content
Snippets Groups Projects
Commit 13d1c8c3 authored by u214892's avatar u214892
Browse files

flake8 master

parent eb50515d
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
import numpy as np
np.random.seed(1)
......@@ -12,6 +13,7 @@ env = RailEnv(width=15,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
number_of_agents=5)
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here
......@@ -42,10 +44,11 @@ class RandomAgent:
# Store the current policy
return
def load(self,filename):
def load(self, filename):
# Load a policy
return
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
n_trials = 5
......@@ -80,5 +83,4 @@ for trials in range(1, n_trials + 1):
obs = next_obs.copy()
if done['__all__']:
break
print('Episode Nr. {}\t Score = {}'.format(trials,score))
print('Episode Nr. {}\t Score = {}'.format(trials, score))
......@@ -198,7 +198,6 @@ class RailEnv(Environment):
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
# Reset the step rewards
self.rewards_dict = dict()
# for handle in self.agents_handles:
......@@ -246,10 +245,9 @@ class RailEnv(Environment):
self.rewards_dict[iAgent] += stop_penalty
if not agent.moving and \
(action == RailEnvActions.MOVE_LEFT or
action == RailEnvActions.MOVE_FORWARD or
action == RailEnvActions.MOVE_RIGHT):
(action == RailEnvActions.MOVE_LEFT or
action == RailEnvActions.MOVE_FORWARD or
action == RailEnvActions.MOVE_RIGHT):
agent.moving = True
self.rewards_dict[iAgent] += start_penalty
......@@ -346,46 +344,6 @@ class RailEnv(Environment):
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def _check_action_on_agent(self, action, agent):
# pos = agent.position # self.agents_position[i]
# direction = agent.direction # self.agents_direction[i]
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_isValid = self.check_action(agent, action)
new_position = get_new_position(agent.position, new_direction)
# Is it a legal move?
# 1) transition allows the new_direction in the cell,
# 2) the new cell is not empty (case 0),
# 3) the cell is free, i.e., no agent is currently in that cell
# if (
# new_position[1] >= self.width or
# new_position[0] >= self.height or
# new_position[0] < 0 or new_position[1] < 0):
# new_cell_isValid = False
# if self.rail.get_transitions(new_position) == 0:
# new_cell_isValid = False
new_cell_isValid = (
np.array_equal( # Check the new position is still in the grid
new_position,
np.clip(new_position, [0, 0], [self.height - 1, self.width - 1]))
and # check the new position has some transitions (ie is not an empty cell)
self.rail.get_transitions(new_position) > 0)
# If transition validity hasn't been checked yet.
if transition_isValid is None:
transition_isValid = self.rail.get_transition(
(*agent.position, agent.direction),
new_direction)
# cell_isFree = True
# for j in range(self.number_of_agents):
# if self.agents_position[j] == new_position:
# cell_isFree = False
# break
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_isFree = not np.any(
np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid
def predict(self):
if not self.prediction_builder:
return {}
......
......@@ -499,7 +499,7 @@ class EditorModel(object):
def mod_path(self, bAddRemove):
# disabled functionality (no longer required)
if bAddRemove == False:
if bAddRemove is False:
return
# This elif means we wait until all the mouse events have been processed (black square drawn)
# before trying to draw rails. (We could change this behaviour)
......
......@@ -542,7 +542,6 @@ class RenderTool(object):
if type(self.gl) is PILGL:
self.gl.beginFrame()
# self.gl.clf()
# if oFigure is None:
# oFigure = self.gl.figure()
......@@ -582,7 +581,6 @@ class RenderTool(object):
# TODO: for MPL, we don't want to call clf (called by endframe)
# if not show:
if show and type(self.gl) is PILGL:
self.gl.show()
......
......@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.generators import complex_rail_generator
from flatland.core.transitions import Grid4Transitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.core.transitions import Grid4Transitions
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
"""Tests for `flatland` package."""
......@@ -26,19 +26,18 @@ def test_save_load():
agent_2_tar = env.agents_static[1].target
env.save("test_save.dat")
env.load("test_save.dat")
assert(env.width == 10)
assert(env.height == 10)
assert(len(env.agents) == 2)
assert(agent_1_pos == env.agents_static[0].position)
assert(agent_1_dir == env.agents_static[0].direction)
assert(agent_1_tar == env.agents_static[0].target)
assert(agent_2_pos == env.agents_static[1].position)
assert(agent_2_dir == env.agents_static[1].direction)
assert(agent_2_tar == env.agents_static[1].target)
assert (env.width == 10)
assert (env.height == 10)
assert (len(env.agents) == 2)
assert (agent_1_pos == env.agents_static[0].position)
assert (agent_1_dir == env.agents_static[0].direction)
assert (agent_1_tar == env.agents_static[0].target)
assert (agent_2_pos == env.agents_static[1].position)
assert (agent_2_dir == env.agents_static[1].direction)
assert (agent_2_tar == env.agents_static[1].target)
def test_rail_environment_single_agent():
cells = [int('0000000000000000', 2), # empty cell - Case 0
int('1000000000100000', 2), # Case 1 - straight
int('1001001000100000', 2), # Case 2 - simple switch
......@@ -91,7 +90,7 @@ def test_rail_environment_single_agent():
# Check that trains are always initialized at a consistent position
# or direction.
# They should always be able to go somewhere.
assert(transitions.get_transitions(
assert (transitions.get_transitions(
# rail_map[rail_env.agents_position[0]],
# rail_env.agents_direction[0]) != (0, 0, 0, 0))
rail_map[agent.position],
......@@ -114,7 +113,7 @@ def test_rail_environment_single_agent():
# After 6 movements on this railway network, the train should be back
# to its original height on the map.
assert(initial_pos[0] == agent.position[0])
assert (initial_pos[0] == agent.position[0])
# We check that the train always attains its target after some time
for _ in range(10):
......@@ -131,7 +130,6 @@ def test_rail_environment_single_agent():
def test_dead_end():
transitions = Grid4Transitions([])
straight_vertical = int('1000000000100000', 2) # Case 1 - straight
......@@ -180,7 +178,7 @@ def test_dead_end():
if i < 5:
assert (not dones[0] and not dones['__all__'])
else:
assert (dones[0] and dones['__all__'])
assert (dones[0] and dones['__all__'])
# We try the configuration in the 4 directions:
rail_env.reset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment