From f23796428a4f02faa5359f4878a7515806181ff9 Mon Sep 17 00:00:00 2001 From: hagrid67 <jdhwatson@gmail.com> Date: Thu, 9 May 2019 23:34:31 +0100 Subject: [PATCH] moved agent_* lists to a list of EnvAgents --- examples/play_model.py | 7 +- flatland/core/env_observation_builder.py | 81 ++++++----- flatland/envs/agent_utils.py | 176 +++++------------------ flatland/envs/rail_env.py | 109 ++++++++------ flatland/utils/rendertools.py | 17 +-- tests/test_environments.py | 52 ++++--- 6 files changed, 181 insertions(+), 261 deletions(-) diff --git a/examples/play_model.py b/examples/play_model.py index e69b312b..db4109d1 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -29,7 +29,8 @@ class Player(object): self.action_prob = [0]*4 self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) + self.agent.qnetwork_local.load_state_dict(torch.load( + '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) self.iFrame = 0 self.tStart = time.time() @@ -202,7 +203,7 @@ def main(render=True, delay=0.0): if trials % 100 == 0: tNow = time.time() rFps = iFrame / (tNow - tStart) - print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + + print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( env.number_of_agents, trials, @@ -215,4 +216,4 @@ def main(render=True, delay=0.0): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index a6fbae6d..43884a40 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -61,19 +61,23 @@ class TreeObsForRailEnv(ObservationBuilder): self.max_depth = max_depth def reset(self): - self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents, + agents = self.env.agents + nAgents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, self.env.height, self.env.width, 4)) - self.max_dist = np.zeros(self.env.number_of_agents) + self.max_dist = np.zeros(nAgents) - for i in range(self.env.number_of_agents): - self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + # for i in range(nAgents): + # self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] # Update local lookup table for all agents' target locations self.location_has_target = {} - for loc in self.env.agents_target: - self.location_has_target[(loc[0], loc[1])] = 1 + # for loc in self.env.agents_target: + # self.location_has_target[(loc[0], loc[1])] = 1 + self.location_has_target = {agent.position: 1 for agent in agents} def _distance_map_walker(self, position, target_nr): """ @@ -229,28 +233,33 @@ class TreeObsForRailEnv(ObservationBuilder): """ # Update local lookup table for all agents' positions - self.location_has_agent = {} - for loc in self.env.agents_position: - self.location_has_agent[(loc[0], loc[1])] = 1 - - position = self.env.agents_position[handle] - orientation = self.env.agents_direction[handle] - possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation)) + # self.location_has_agent = {} + # for loc in self.env.agents_position: + # self.location_has_agent[(loc[0], loc[1])] = 1 + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} + + agent = self.env.agents[handle] # TODO: handle being treated as index + # position = self.env.agents_position[handle] + # orientation = self.env.agents_direction[handle] + possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] root_observation = observation[:] # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. # TODO: Test if this works as desired! + orientation = agent.direction if num_transitions == 1: orientation == np.argmax(possible_transitions) - for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: + # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: + for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self._new_position(position, branch_direction) + new_cell = self._new_position(agent.position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation @@ -307,17 +316,18 @@ class TreeObsForRailEnv(ObservationBuilder): visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. - if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + if np.array_equal(position, self.env.agents[handle].target): last_isTarget = True break - cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + cell_transitions = self.env.rail.get_transitions((*position, direction)) num_transitions = np.count_nonzero(cell_transitions) exploring = False if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = 0 - tmp = self.env.rail.get_transitions((position[0], position[1])) + tmp = self.env.rail.get_transitions(tuple(position)) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 @@ -380,9 +390,9 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions - possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + possible_transitions = self.env.rail.get_transitions((*position, direction)) for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), + if last_isDeadEnd and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back @@ -471,20 +481,21 @@ class GlobalObsForRailEnv(ObservationBuilder): # self.targets[target_pos] += 1 def get(self, handle): - obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width)) - agent_pos = self.env.agents_position[handle] - obs_agents_targets_pos[0][agent_pos] += 1 - for i in range(len(self.env.agents_position)): - if i != handle: - obs_agents_targets_pos[3][self.env.agents_position[i]] += 1 - - agent_target_pos = self.env.agents_target[handle] - obs_agents_targets_pos[1][agent_target_pos] += 1 - for i in range(len(self.env.agents_target)): - if i != handle: - obs_agents_targets_pos[2][self.env.agents_target[i]] += 1 + obs = np.zeros((4, self.env.height, self.env.width)) + agents = self.env.agents + agent = agents[handle] + + agent_pos = agents[handle].position + obs[0][agent_pos] += 1 + obs[1][agent.target] += 1 + + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + obs[3][agent2.position] += 1 + obs[2][agent2.target] += 1 direction = np.zeros(4) - direction[self.env.agents_direction[handle]] = 1 + direction[agent.direction] = 1 - return self.rail_obs, obs_agents_targets_pos, direction + return self.rail_obs, obs, direction diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index c29839e6..da36fe73 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,7 +1,17 @@ from attr import attrs, attrib -from itertools import starmap, count -import numpy as np +from itertools import starmap +# from flatland.envs.rail_env import RailEnv + + +@attrs +class EnvDescription(object): + n_agents = attrib() + height = attrib() + width = attrib() + rail_generator = attrib() + obs_builder = attrib() + @attrs class EnvAgentStatic(object): @@ -13,157 +23,41 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - handle = attrib() - next_handle = 0 + next_handle = 0 # this is not properly implemented @classmethod - def from_lists(positions, directions, targets): + def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return starmap(EnvAgentStatic, zip(positions, directions, targets, count())) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) +@attrs class EnvAgent(EnvAgentStatic): - """ TODO: EnvAgent - replace separate agent lists with a single list + """ EnvAgent - replace separate agent_* lists with a single list of agent objects. The EnvAgent represent's the environment's view - of the dynamic agent state. So target is not part of it - target is - static. + of the dynamic agent state. + We are duplicating target in the EnvAgent, which seems simpler than + forcing the env to refer to it in the EnvAgentStatic """ + handle = attrib(default=None) - -class EnvManager(object): - def __init__(self, env=None): - self.env = env - - - def load_env(self, sFilename): - pass - - def save_env(self, sFilename): - pass - - def regen_rail(self): - pass - - def replace_agents(self): - pass - - def add_agent_static(self, agent_static): - """ Add a new agent_static - """ - iAgent = self.number_of_agents - - if iDir is None: - iDir = self.pick_agent_direction(rcPos, rcTarget) - if iDir is None: - print("Error picking agent direction at pos:", rcPos) - return None - - self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list - self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 - self.agents_direction.append(iDir) - self.agents_target.append(rcPos) # set the target to the origin initially - self.number_of_agents += 1 - self.check_agent_lists() - return iAgent - - - - def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None): - """ Add a new agent at position rcPos with target rcTarget and - initial direction index iDir. - Should also store this initial position etc as environment "meta-data" - but this does not yet exist. + @classmethod + def from_static(cls, oStatic): + """ Create an EnvAgent from the EnvAgentStatic, + copying all the fields, and adding handle with the default 0. """ - self.check_agent_lists() - - if rcPos is None: - rcPos = np.random.choice(len(self.valid_positions)) - - iAgent = self.number_of_agents - - if iDir is None: - iDir = self.pick_agent_direction(rcPos, rcTarget) - if iDir is None: - print("Error picking agent direction at pos:", rcPos) - return None - - self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list - self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 - self.agents_direction.append(iDir) - self.agents_target.append(rcPos) # set the target to the origin initially - self.number_of_agents += 1 - self.check_agent_lists() - return iAgent - - def fill_valid_positions(self): - ''' Populate the valid_positions list for the current TransitionMap. - TODO: put this elsewhere - ''' - self.env.valid_positions = valid_positions = [] - for r in range(self.env.height): - for c in range(self.env.width): - if self.env.rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) + return EnvAgent(*oStatic.__dict__, handle=0) - def check_agent_lists(self): - ''' Check that the agent_handles, position and direction lists are all of length - number_of_agents. - (Suggest this is replaced with a single list of Agent objects :) - ''' - for lAgents, name in zip( - [self.env.agents_handles, self.env.agents_position, self.env.agents_direction], - ["handles", "positions", "directions"]): - assert self.env.number_of_agents == len(lAgents), "Inconsistent agent list:" + name - - def check_agent_locdirpath(self, iAgent): - ''' Check that agent iAgent has a valid location and direction, - with a path to its target. - (Not currently used?) - ''' - valid_movements = [] - for direction in range(4): - position = self.env.agents_position[iAgent] - moves = self.env.rail.get_transitions((position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self.env._new_position(self.env.agents_position[iAgent], m[1]) - if m[0] not in valid_starting_directions and \ - self.env._path_exists(new_position, m[0], self.env.agents_target[iAgent]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - return False - else: - return True - - def pick_agent_direction(self, rcPos, rcTarget): - """ Pick and return a valid direction index (0..3) for an agent starting at - row,col rcPos with target rcTarget. - Return None if no path exists. - Picks random direction if more than one exists (uniformly). + @classmethod + def list_from_static(cls, lEnvAgentStatic, handles=None): + """ Create an EnvAgent from the EnvAgentStatic, + copying all the fields, and adding handle with the default 0. """ - valid_movements = [] - for direction in range(4): - moves = self.env.rail.get_transitions((*rcPos, direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - # print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self.env._new_position(rcPos, m[1]) - if m[0] not in valid_starting_directions and self.env._path_exists(new_position, m[0], rcTarget): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - return None - else: - return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + if handles is None: + handles = range(len(lEnvAgentStatic)) + + return [EnvAgent(**oEAS.__dict__, handle=handle) + for handle, oEAS in zip(handles, lEnvAgentStatic)] diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ea8c3dca..9767bba4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -10,7 +10,7 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator from flatland.envs.env_utils import get_new_position -from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, EnvManager +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions # from flatland.core.transition_map import GridTransitionMap @@ -124,10 +124,11 @@ class RailEnv(Environment): if replace_agents: self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target) - self.agents = copy(agents_static) + self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)]) self.num_resets += 1 + # perhaps dones should be part of each agent. self.dones = {"__all__": False} for handle in self.agents_handles: self.dones[handle] = False @@ -157,11 +158,12 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] transition_isValid = None + agent = self.agents[i] - if handle not in action_dict: + if handle not in action_dict: # no action has been supplied for this agent continue - if self.dones[handle]: + if self.dones[handle]: # this agent has already completed... continue action = action_dict[handle] @@ -171,31 +173,28 @@ class RailEnv(Environment): return if action > 0: - pos = self.agents_position[i] - direction = self.agents_direction[i] + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] # compute number of possible transitions in the current # cell used to check for invalid actions - possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction)) + possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) - movement = direction + movement = agent.direction # print(nbits,np.sum(possible_transitions)) if action == 1: - movement = direction - 1 + movement = agent.direction - 1 if num_transitions <= 1: transition_isValid = False elif action == 3: - movement = direction + 1 + movement = agent.direction + 1 if num_transitions <= 1: transition_isValid = False - if movement < 0: - movement += 4 - if movement >= 4: - movement -= 4 + movement %= 4 if action == 2: if num_transitions == 1: @@ -205,57 +204,72 @@ class RailEnv(Environment): movement = np.argmax(possible_transitions) transition_isValid = True - new_position = get_new_position(pos, movement) - # Is it a legal move? 1) transition allows the movement in the - # cell, 2) the new cell is not empty (case 0), 3) the cell is - # free, i.e., no agent is currently in that cell - if ( - new_position[1] >= self.width or - new_position[0] >= self.height or - new_position[0] < 0 or new_position[1] < 0): - new_cell_isValid = False - - elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: - new_cell_isValid = True - else: - new_cell_isValid = False + new_position = get_new_position(agent.position, movement) + # Is it a legal move? + # 1) transition allows the movement in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height-1, self.width-1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) # If transition validity hasn't been checked yet. if transition_isValid is None: transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), + (*agent.position, agent.direction), movement) - cell_isFree = True - for j in range(self.number_of_agents): - if self.agents_position[j] == new_position: - cell_isFree = False - break - - if new_cell_isValid and transition_isValid and cell_isFree: + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + + if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the movement that was # performed - self.agents_position[i] = new_position - self.agents_direction[i] = movement + # self.agents_position[i] = new_position + # self.agents_direction[i] = movement + agent.position = new_position + agent.direction = movement else: # the action was not valid, add penalty self.rewards_dict[handle] += invalid_action_penalty # if agent is not in target position, add step penalty - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + # if self.agents_position[i][0] == self.agents_target[i][0] and \ + # self.agents_position[i][1] == self.agents_target[i][1]: + # self.dones[handle] = True + if np.equal(agent.position, agent.target).all(): self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty # Check for end of episode + add global reward to all rewards! - num_agents_in_target_position = 0 - for i in range(self.number_of_agents): - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - num_agents_in_target_position += 1 - - if num_agents_in_target_position == self.number_of_agents: + # num_agents_in_target_position = 0 + # for i in range(self.number_of_agents): + # if self.agents_position[i][0] == self.agents_target[i][0] and \ + # self.agents_position[i][1] == self.agents_target[i][1]: + # num_agents_in_target_position += 1 + # if num_agents_in_target_position == self.number_of_agents: + if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True self.rewards_dict = [r + global_reward for r in self.rewards_dict] @@ -273,3 +287,4 @@ class RailEnv(Environment): def render(self): # TODO: pass + diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1f731c39..c2fcb73b 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -158,20 +158,9 @@ class RenderTool(object): def plotAgents(self, targets=True): cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1) - for iAgent in range(self.env.number_of_agents): + for iAgent, agent in enumerate(self.env.agents): oColor = cmap(iAgent) - - rcPos = self.env.agents_position[iAgent] - iDir = self.env.agents_direction[iAgent] # agent direction index - - if targets: - target = self.env.agents_target[iAgent] - else: - target = None - self.plotAgent(rcPos, iDir, oColor, target=target) - - # gTransRCAg = self.getTransRC(rcPos, iDir) - # self.plotTrans(rcPos, gTransRCAg) + self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None) def getTransRC(self, rcPos, iDir, bgiTrans=False): """ @@ -554,7 +543,7 @@ class RenderTool(object): if not bCellValid: # print("invalid:", r, c) - self.gl.scatter(*xyCentre, color="r", s=50) + self.gl.scatter(*xyCentre, color="r", s=30) for orientation in range(4): # ori is where we're heading from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE diff --git a/tests/test_environments.py b/tests/test_environments.py index a10fb061..fe788b7c 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -7,7 +7,7 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap from flatland.core.env_observation_builder import GlobalObsForRailEnv - +from flatland.envs.agent_utils import EnvAgent """Tests for `flatland` package.""" @@ -58,16 +58,21 @@ def test_rail_environment_single_agent(): _ = rail_env.reset() # We do not care about target for the moment - rail_env.agents_target[0] = [-1, -1] + # rail_env.agents_target[0] = [-1, -1] + agent = rail_env.agents[0] + # rail_env.agents[0].target = [-1, -1] + agent.target = [-1, -1] # Check that trains are always initialized at a consistent position # or direction. # They should always be able to go somewhere. assert(transitions.get_transitions( - rail_map[rail_env.agents_position[0]], - rail_env.agents_direction[0]) != (0, 0, 0, 0)) + # rail_map[rail_env.agents_position[0]], + # rail_env.agents_direction[0]) != (0, 0, 0, 0)) + rail_map[agent.position], + agent.direction) != (0, 0, 0, 0)) - initial_pos = rail_env.agents_position[0] + initial_pos = agent.position valid_active_actions_done = 0 pos = initial_pos @@ -78,13 +83,13 @@ def test_rail_environment_single_agent(): _, _, _, _ = rail_env.step({0: action}) prev_pos = pos - pos = rail_env.agents_position[0] + pos = agent.position # rail_env.agents_position[0] if prev_pos != pos: valid_active_actions_done += 1 # After 6 movements on this railway network, the train should be back # to its original height on the map. - assert(initial_pos[0] == rail_env.agents_position[0][0]) + assert(initial_pos[0] == agent.position[0]) # We check that the train always attains its target after some time for _ in range(10): @@ -135,13 +140,14 @@ def test_dead_end(): # We run step to check that trains do not move anymore # after being done. for i in range(7): - prev_pos = rail_env.agents_position[0] + # prev_pos = rail_env.agents_position[0] + prev_pos = rail_env.agents[0].position # The train cannot turn, so we check that when it tries, # it stays where it is. _ = rail_env.step({0: 1}) _ = rail_env.step({0: 3}) - assert (rail_env.agents_position[0] == prev_pos) + assert (rail_env.agents[0].position == prev_pos) _, _, dones, _ = rail_env.step({0: 2}) if i < 5: @@ -151,15 +157,17 @@ def test_dead_end(): # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents_target[0] = (0, 0) - rail_env.agents_position[0] = (0, 2) - rail_env.agents_direction[0] = 1 + # rail_env.agents_target[0] = (0, 0) + # rail_env.agents_position[0] = (0, 2) + # rail_env.agents_direction[0] = 1 + rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = (0, 4) - rail_env.agents_position[0] = (0, 2) - rail_env.agents_direction[0] = 3 + # rail_env.agents_target[0] = (0, 4) + # rail_env.agents_position[0] = (0, 2) + # rail_env.agents_direction[0] = 3 + rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] check_consistency(rail_env) # In the vertical configuration: @@ -181,13 +189,15 @@ def test_dead_end(): obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents_target[0] = (0, 0) - rail_env.agents_position[0] = (2, 0) - rail_env.agents_direction[0] = 2 + # rail_env.agents_target[0] = (0, 0) + # rail_env.agents_position[0] = (2, 0) + # rail_env.agents_direction[0] = 2 + rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))] check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = (4, 0) - rail_env.agents_position[0] = (2, 0) - rail_env.agents_direction[0] = 0 + # rail_env.agents_target[0] = (4, 0) + # rail_env.agents_position[0] = (2, 0) + # rail_env.agents_direction[0] = 0 + rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))] check_consistency(rail_env) -- GitLab