diff --git a/examples/training_example.py b/examples/training_example.py index ee97c7e4c90dc15018d6d87c5b2293455075c5f0..1342107767a599cb1440fd08f506903a5059492e 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,6 +1,7 @@ +import numpy as np + from flatland.envs.generators import complex_rail_generator from flatland.envs.rail_env import RailEnv -import numpy as np np.random.seed(1) @@ -12,6 +13,7 @@ env = RailEnv(width=15, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0), number_of_agents=5) + # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here @@ -42,10 +44,11 @@ class RandomAgent: # Store the current policy return - def load(self,filename): + def load(self, filename): # Load a policy return + # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) n_trials = 5 @@ -80,5 +83,4 @@ for trials in range(1, n_trials + 1): obs = next_obs.copy() if done['__all__']: break - print('Episode Nr. {}\t Score = {}'.format(trials,score)) - + print('Episode Nr. {}\t Score = {}'.format(trials, score)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 77ac6c06f6baaf8cce29444eda554ff9fca19842..dc5802765c39f337341b775e6c4ffc2591125709 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -198,7 +198,6 @@ class RailEnv(Environment): stop_penalty = 0 # penalty for stopping a moving agent start_penalty = 0 # penalty for starting a stopped agent - # Reset the step rewards self.rewards_dict = dict() # for handle in self.agents_handles: @@ -246,10 +245,9 @@ class RailEnv(Environment): self.rewards_dict[iAgent] += stop_penalty if not agent.moving and \ - (action == RailEnvActions.MOVE_LEFT or - action == RailEnvActions.MOVE_FORWARD or - action == RailEnvActions.MOVE_RIGHT): - + (action == RailEnvActions.MOVE_LEFT or + action == RailEnvActions.MOVE_FORWARD or + action == RailEnvActions.MOVE_RIGHT): agent.moving = True self.rewards_dict[iAgent] += start_penalty @@ -346,46 +344,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def _check_action_on_agent(self, action, agent): - # pos = agent.position # self.agents_position[i] - # direction = agent.direction # self.agents_direction[i] - # compute number of possible transitions in the current - # cell used to check for invalid actions - new_direction, transition_isValid = self.check_action(agent, action) - new_position = get_new_position(agent.position, new_direction) - # Is it a legal move? - # 1) transition allows the new_direction in the cell, - # 2) the new cell is not empty (case 0), - # 3) the cell is free, i.e., no agent is currently in that cell - # if ( - # new_position[1] >= self.width or - # new_position[0] >= self.height or - # new_position[0] < 0 or new_position[1] < 0): - # new_cell_isValid = False - # if self.rail.get_transitions(new_position) == 0: - # new_cell_isValid = False - new_cell_isValid = ( - np.array_equal( # Check the new position is still in the grid - new_position, - np.clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_transitions(new_position) > 0) - # If transition validity hasn't been checked yet. - if transition_isValid is None: - transition_isValid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - # cell_isFree = True - # for j in range(self.number_of_agents): - # if self.agents_position[j] == new_position: - # cell_isFree = False - # break - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_isFree = not np.any( - np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) - return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def predict(self): if not self.prediction_builder: return {} diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 52561369204a9dfc35209ff8fdb6c4805a8c6197..a23bf855a867949f2d15d8bff58d19283a71438f 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -499,7 +499,7 @@ class EditorModel(object): def mod_path(self, bAddRemove): # disabled functionality (no longer required) - if bAddRemove == False: + if bAddRemove is False: return # This elif means we wait until all the mouse events have been processed (black square drawn) # before trying to draw rails. (We could change this behaviour) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index c73278a9d19d717125503514162658730a357951..6e6043b511ff6d4054b66c2a9bb8008ae544715e 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -542,7 +542,6 @@ class RenderTool(object): if type(self.gl) is PILGL: self.gl.beginFrame() - # self.gl.clf() # if oFigure is None: # oFigure = self.gl.figure() @@ -582,7 +581,6 @@ class RenderTool(object): # TODO: for MPL, we don't want to call clf (called by endframe) # if not show: - if show and type(self.gl) is PILGL: self.gl.show() diff --git a/tests/test_environments.py b/tests/test_environments.py index 57e8b7b045eda1ace355102b1140a71c466c8633..3feb04b159d2a6950012fddd32902c8cdb9aae4e 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -2,13 +2,13 @@ # -*- coding: utf-8 -*- import numpy as np -from flatland.envs.rail_env import RailEnv -from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.generators import complex_rail_generator -from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.core.transitions import Grid4Transitions from flatland.envs.agent_utils import EnvAgent +from flatland.envs.generators import complex_rail_generator +from flatland.envs.generators import rail_from_GridTransitionMap_generator +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv """Tests for `flatland` package.""" @@ -26,19 +26,18 @@ def test_save_load(): agent_2_tar = env.agents_static[1].target env.save("test_save.dat") env.load("test_save.dat") - assert(env.width == 10) - assert(env.height == 10) - assert(len(env.agents) == 2) - assert(agent_1_pos == env.agents_static[0].position) - assert(agent_1_dir == env.agents_static[0].direction) - assert(agent_1_tar == env.agents_static[0].target) - assert(agent_2_pos == env.agents_static[1].position) - assert(agent_2_dir == env.agents_static[1].direction) - assert(agent_2_tar == env.agents_static[1].target) + assert (env.width == 10) + assert (env.height == 10) + assert (len(env.agents) == 2) + assert (agent_1_pos == env.agents_static[0].position) + assert (agent_1_dir == env.agents_static[0].direction) + assert (agent_1_tar == env.agents_static[0].target) + assert (agent_2_pos == env.agents_static[1].position) + assert (agent_2_dir == env.agents_static[1].direction) + assert (agent_2_tar == env.agents_static[1].target) def test_rail_environment_single_agent(): - cells = [int('0000000000000000', 2), # empty cell - Case 0 int('1000000000100000', 2), # Case 1 - straight int('1001001000100000', 2), # Case 2 - simple switch @@ -91,7 +90,7 @@ def test_rail_environment_single_agent(): # Check that trains are always initialized at a consistent position # or direction. # They should always be able to go somewhere. - assert(transitions.get_transitions( + assert (transitions.get_transitions( # rail_map[rail_env.agents_position[0]], # rail_env.agents_direction[0]) != (0, 0, 0, 0)) rail_map[agent.position], @@ -114,7 +113,7 @@ def test_rail_environment_single_agent(): # After 6 movements on this railway network, the train should be back # to its original height on the map. - assert(initial_pos[0] == agent.position[0]) + assert (initial_pos[0] == agent.position[0]) # We check that the train always attains its target after some time for _ in range(10): @@ -131,7 +130,6 @@ def test_rail_environment_single_agent(): def test_dead_end(): - transitions = Grid4Transitions([]) straight_vertical = int('1000000000100000', 2) # Case 1 - straight @@ -180,7 +178,7 @@ def test_dead_end(): if i < 5: assert (not dones[0] and not dones['__all__']) else: - assert (dones[0] and dones['__all__']) + assert (dones[0] and dones['__all__']) # We try the configuration in the 4 directions: rail_env.reset()