diff --git a/examples/play_model.py b/examples/play_model.py index e69b312b1ceb2f450256d247f4b63c14a728acb5..9c67b0bce315ecf028fe898c510e4503e67a8cf4 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -29,7 +29,8 @@ class Player(object): self.action_prob = [0]*4 self.agent = Agent(self.state_size, self.action_size, "FC", 0) # self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint9900.pth')) - self.agent.qnetwork_local.load_state_dict(torch.load('../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) + self.agent.qnetwork_local.load_state_dict(torch.load( + '../flatland/flatland/baselines/Nets/avoid_checkpoint15000.pth')) self.iFrame = 0 self.tStart = time.time() @@ -97,7 +98,7 @@ def main(render=True, delay=0.0): # Example generate a random rail env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=12), + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), number_of_agents=5) if render: @@ -202,7 +203,7 @@ def main(render=True, delay=0.0): if trials % 100 == 0: tNow = time.time() rFps = iFrame / (tNow - tStart) - print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + + print(('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%' + '\tEpsilon: {:.2f} fps: {:.2f} \t Action Probabilities: \t {}').format( env.number_of_agents, trials, @@ -215,4 +216,4 @@ def main(render=True, delay=0.0): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index a6fbae6d0d271f47e98d08262c7fbc2801b7142d..43884a40e1e6cbd0cbfb10d69e569900bfffa72e 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -61,19 +61,23 @@ class TreeObsForRailEnv(ObservationBuilder): self.max_depth = max_depth def reset(self): - self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents, + agents = self.env.agents + nAgents = len(agents) + self.distance_map = np.inf * np.ones(shape=(nAgents, # self.env.number_of_agents, self.env.height, self.env.width, 4)) - self.max_dist = np.zeros(self.env.number_of_agents) + self.max_dist = np.zeros(nAgents) - for i in range(self.env.number_of_agents): - self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + # for i in range(nAgents): + # self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)] # Update local lookup table for all agents' target locations self.location_has_target = {} - for loc in self.env.agents_target: - self.location_has_target[(loc[0], loc[1])] = 1 + # for loc in self.env.agents_target: + # self.location_has_target[(loc[0], loc[1])] = 1 + self.location_has_target = {agent.position: 1 for agent in agents} def _distance_map_walker(self, position, target_nr): """ @@ -229,28 +233,33 @@ class TreeObsForRailEnv(ObservationBuilder): """ # Update local lookup table for all agents' positions - self.location_has_agent = {} - for loc in self.env.agents_position: - self.location_has_agent[(loc[0], loc[1])] = 1 - - position = self.env.agents_position[handle] - orientation = self.env.agents_direction[handle] - possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation)) + # self.location_has_agent = {} + # for loc in self.env.agents_position: + # self.location_has_agent[(loc[0], loc[1])] = 1 + self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents} + + agent = self.env.agents[handle] # TODO: handle being treated as index + # position = self.env.agents_position[handle] + # orientation = self.env.agents_direction[handle] + possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) # Root node - current position - observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + # observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)]] root_observation = observation[:] # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # If only one transition is possible, the tree is oriented with this transition as the forward branch. # TODO: Test if this works as desired! + orientation = agent.direction if num_transitions == 1: orientation == np.argmax(possible_transitions) - for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: + # for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]: + for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: - new_cell = self._new_position(position, branch_direction) + new_cell = self._new_position(agent.position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation @@ -307,17 +316,18 @@ class TreeObsForRailEnv(ObservationBuilder): visited.add((position[0], position[1], direction)) # If the target node is encountered, pick that as node. Also, no further branching is possible. - if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + # if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + if np.array_equal(position, self.env.agents[handle].target): last_isTarget = True break - cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + cell_transitions = self.env.rail.get_transitions((*position, direction)) num_transitions = np.count_nonzero(cell_transitions) exploring = False if num_transitions == 1: # Check if dead-end, or if we can go forward along direction nbits = 0 - tmp = self.env.rail.get_transitions((position[0], position[1])) + tmp = self.env.rail.get_transitions(tuple(position)) while tmp > 0: nbits += (tmp & 1) tmp = tmp >> 1 @@ -380,9 +390,9 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation # Get the possible transitions - possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + possible_transitions = self.env.rail.get_transitions((*position, direction)) for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]: - if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), + if last_isDeadEnd and self.env.rail.get_transition((*position, direction), (branch_direction + 2) % 4): # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes # it back @@ -471,20 +481,21 @@ class GlobalObsForRailEnv(ObservationBuilder): # self.targets[target_pos] += 1 def get(self, handle): - obs_agents_targets_pos = np.zeros((4, self.env.height, self.env.width)) - agent_pos = self.env.agents_position[handle] - obs_agents_targets_pos[0][agent_pos] += 1 - for i in range(len(self.env.agents_position)): - if i != handle: - obs_agents_targets_pos[3][self.env.agents_position[i]] += 1 - - agent_target_pos = self.env.agents_target[handle] - obs_agents_targets_pos[1][agent_target_pos] += 1 - for i in range(len(self.env.agents_target)): - if i != handle: - obs_agents_targets_pos[2][self.env.agents_target[i]] += 1 + obs = np.zeros((4, self.env.height, self.env.width)) + agents = self.env.agents + agent = agents[handle] + + agent_pos = agents[handle].position + obs[0][agent_pos] += 1 + obs[1][agent.target] += 1 + + for i in range(len(agents)): + if i != handle: # TODO: handle used as index...? + agent2 = agents[i] + obs[3][agent2.position] += 1 + obs[2][agent2.target] += 1 direction = np.zeros(4) - direction[self.env.agents_direction[handle]] = 1 + direction[agent.direction] = 1 - return self.rail_obs, obs_agents_targets_pos, direction + return self.rail_obs, obs, direction diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..da36fe73e867b7e2ec7f04a5564c73dd3e23a9a5 --- /dev/null +++ b/flatland/envs/agent_utils.py @@ -0,0 +1,63 @@ + +from attr import attrs, attrib +from itertools import starmap +# from flatland.envs.rail_env import RailEnv + + +@attrs +class EnvDescription(object): + n_agents = attrib() + height = attrib() + width = attrib() + rail_generator = attrib() + obs_builder = attrib() + + +@attrs +class EnvAgentStatic(object): + """ TODO: EnvAgentStatic - To store initial position, direction and target. + This is like static data for the environment - it's where an agent starts, + rather than where it is at the moment. + The target should also be stored here. + """ + position = attrib() + direction = attrib() + target = attrib() + + next_handle = 0 # this is not properly implemented + + @classmethod + def from_lists(cls, positions, directions, targets): + """ Create a list of EnvAgentStatics from lists of positions, directions and targets + """ + return list(starmap(EnvAgentStatic, zip(positions, directions, targets))) + + +@attrs +class EnvAgent(EnvAgentStatic): + """ EnvAgent - replace separate agent_* lists with a single list + of agent objects. The EnvAgent represent's the environment's view + of the dynamic agent state. + We are duplicating target in the EnvAgent, which seems simpler than + forcing the env to refer to it in the EnvAgentStatic + """ + handle = attrib(default=None) + + @classmethod + def from_static(cls, oStatic): + """ Create an EnvAgent from the EnvAgentStatic, + copying all the fields, and adding handle with the default 0. + """ + return EnvAgent(*oStatic.__dict__, handle=0) + + @classmethod + def list_from_static(cls, lEnvAgentStatic, handles=None): + """ Create an EnvAgent from the EnvAgentStatic, + copying all the fields, and adding handle with the default 0. + """ + if handles is None: + handles = range(len(lEnvAgentStatic)) + + return [EnvAgent(**oEAS.__dict__, handle=handle) + for handle, oEAS in zip(handles, lEnvAgentStatic)] + diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index fe971e6b24b90e31dadd797359247537078ad5f6..7452d325530bb189084182f0bbc4bf26369e5881 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail -def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0): """ Parameters ------- @@ -123,7 +123,27 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # print("failed...") created_sanity += 1 - #print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs") + # add extra connections between existing rail + created_sanity = 0 + nr_created = 0 + while nr_created < nr_extra and created_sanity < sanity_max: + all_ok = False + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check to make sure start,goal pos are not empty + if rail_array[goal] == 0 or rail_array[start] == 0: + continue + else: + all_ok = True + break + if not all_ok: + break + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >= 2: + nr_created += 1 + + print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections") # print(start_goal) agents_position = [sg[0] for sg in start_goal] diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 98abf81f469e1ea329db39e86c3cfe0a7756df28..9767bba42e2b3acf3ef4aa34b14154077dac77bd 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -10,34 +10,12 @@ from flatland.core.env import Environment from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.envs.generators import random_rail_generator from flatland.envs.env_utils import get_new_position +from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent # from flatland.core.transitions import Grid8Transitions, RailEnvTransitions # from flatland.core.transition_map import GridTransitionMap -class EnvAgentStatic(object): - """ TODO: EnvAgentStatic - To store initial position, direction and target. - This is like static data for the environment - it's where an agent starts, - rather than where it is at the moment. - The target should also be stored here. - """ - def __init__(self, rcPos, iDir, rcTarget): - self.rcPos = rcPos - self.iDir = iDir - self.rcTarget = rcTarget - - -class EnvAgent(object): - """ TODO: EnvAgent - replace separate agent lists with a single list - of agent objects. The EnvAgent represent's the environment's view - of the dynamic agent state. So target is not part of it - target is - static. - """ - def __init__(self, rcPos, iDir): - self.rcPos = rcPos - self.iDir = iDir - - class RailEnv(Environment): """ RailEnv environment class. @@ -123,6 +101,7 @@ class RailEnv(Environment): # self.agents_position = [] # self.agents_target = [] # self.agents_direction = [] + self.agents = [] self.num_resets = 0 self.reset() self.num_resets = 0 @@ -137,14 +116,19 @@ class RailEnv(Environment): TODO: replace_agents is ignored at the moment; agents will always be replaced. """ if regen_rail or self.rail is None: - self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator( + self.rail, agents_position, agents_direction, agents_target = self.rail_generator( self.width, self.height, self.agents_handles, self.num_resets) + if replace_agents: + self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target) + self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)]) + self.num_resets += 1 + # perhaps dones should be part of each agent. self.dones = {"__all__": False} for handle in self.agents_handles: self.dones[handle] = False @@ -174,11 +158,12 @@ class RailEnv(Environment): for i in range(len(self.agents_handles)): handle = self.agents_handles[i] transition_isValid = None + agent = self.agents[i] - if handle not in action_dict: + if handle not in action_dict: # no action has been supplied for this agent continue - if self.dones[handle]: + if self.dones[handle]: # this agent has already completed... continue action = action_dict[handle] @@ -188,31 +173,28 @@ class RailEnv(Environment): return if action > 0: - pos = self.agents_position[i] - direction = self.agents_direction[i] + # pos = agent.position # self.agents_position[i] + # direction = agent.direction # self.agents_direction[i] # compute number of possible transitions in the current # cell used to check for invalid actions - possible_transitions = self.rail.get_transitions((pos[0], pos[1], direction)) + possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) num_transitions = np.count_nonzero(possible_transitions) - movement = direction + movement = agent.direction # print(nbits,np.sum(possible_transitions)) if action == 1: - movement = direction - 1 + movement = agent.direction - 1 if num_transitions <= 1: transition_isValid = False elif action == 3: - movement = direction + 1 + movement = agent.direction + 1 if num_transitions <= 1: transition_isValid = False - if movement < 0: - movement += 4 - if movement >= 4: - movement -= 4 + movement %= 4 if action == 2: if num_transitions == 1: @@ -222,57 +204,72 @@ class RailEnv(Environment): movement = np.argmax(possible_transitions) transition_isValid = True - new_position = get_new_position(pos, movement) - # Is it a legal move? 1) transition allows the movement in the - # cell, 2) the new cell is not empty (case 0), 3) the cell is - # free, i.e., no agent is currently in that cell - if ( - new_position[1] >= self.width or - new_position[0] >= self.height or - new_position[0] < 0 or new_position[1] < 0): - new_cell_isValid = False - - elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: - new_cell_isValid = True - else: - new_cell_isValid = False + new_position = get_new_position(agent.position, movement) + # Is it a legal move? + # 1) transition allows the movement in the cell, + # 2) the new cell is not empty (case 0), + # 3) the cell is free, i.e., no agent is currently in that cell + + # if ( + # new_position[1] >= self.width or + # new_position[0] >= self.height or + # new_position[0] < 0 or new_position[1] < 0): + # new_cell_isValid = False + + # if self.rail.get_transitions(new_position) == 0: + # new_cell_isValid = False + + new_cell_isValid = ( + np.array_equal( # Check the new position is still in the grid + new_position, + np.clip(new_position, [0, 0], [self.height-1, self.width-1])) + and # check the new position has some transitions (ie is not an empty cell) + self.rail.get_transitions(new_position) > 0) # If transition validity hasn't been checked yet. if transition_isValid is None: transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), + (*agent.position, agent.direction), movement) - cell_isFree = True - for j in range(self.number_of_agents): - if self.agents_position[j] == new_position: - cell_isFree = False - break - - if new_cell_isValid and transition_isValid and cell_isFree: + # cell_isFree = True + # for j in range(self.number_of_agents): + # if self.agents_position[j] == new_position: + # cell_isFree = False + # break + # Check the new position is not the same as any of the existing agent positions + # (including itself, for simplicity, since it is moving) + cell_isFree = not np.any( + np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) + + if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the movement that was # performed - self.agents_position[i] = new_position - self.agents_direction[i] = movement + # self.agents_position[i] = new_position + # self.agents_direction[i] = movement + agent.position = new_position + agent.direction = movement else: # the action was not valid, add penalty self.rewards_dict[handle] += invalid_action_penalty # if agent is not in target position, add step penalty - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: + # if self.agents_position[i][0] == self.agents_target[i][0] and \ + # self.agents_position[i][1] == self.agents_target[i][1]: + # self.dones[handle] = True + if np.equal(agent.position, agent.target).all(): self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty # Check for end of episode + add global reward to all rewards! - num_agents_in_target_position = 0 - for i in range(self.number_of_agents): - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - num_agents_in_target_position += 1 - - if num_agents_in_target_position == self.number_of_agents: + # num_agents_in_target_position = 0 + # for i in range(self.number_of_agents): + # if self.agents_position[i][0] == self.agents_target[i][0] and \ + # self.agents_position[i][1] == self.agents_target[i][1]: + # num_agents_in_target_position += 1 + # if num_agents_in_target_position == self.number_of_agents: + if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True self.rewards_dict = [r + global_reward for r in self.rewards_dict] @@ -290,3 +287,4 @@ class RailEnv(Environment): def render(self): # TODO: pass + diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1f731c39f024f4cf8d10a5ad70171ba0b60b260d..c2fcb73b06fb2a2470187c10c90bee4ffc468148 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -158,20 +158,9 @@ class RenderTool(object): def plotAgents(self, targets=True): cmap = self.gl.get_cmap('hsv', lut=self.env.number_of_agents + 1) - for iAgent in range(self.env.number_of_agents): + for iAgent, agent in enumerate(self.env.agents): oColor = cmap(iAgent) - - rcPos = self.env.agents_position[iAgent] - iDir = self.env.agents_direction[iAgent] # agent direction index - - if targets: - target = self.env.agents_target[iAgent] - else: - target = None - self.plotAgent(rcPos, iDir, oColor, target=target) - - # gTransRCAg = self.getTransRC(rcPos, iDir) - # self.plotTrans(rcPos, gTransRCAg) + self.plotAgent(agent.position, agent.direction, oColor, target=agent.target if targets else None) def getTransRC(self, rcPos, iDir, bgiTrans=False): """ @@ -554,7 +543,7 @@ class RenderTool(object): if not bCellValid: # print("invalid:", r, c) - self.gl.scatter(*xyCentre, color="r", s=50) + self.gl.scatter(*xyCentre, color="r", s=30) for orientation in range(4): # ori is where we're heading from_ori = (orientation + 2) % 4 # 0123=NESW -> 2301=SWNE diff --git a/tests/test_environments.py b/tests/test_environments.py index a10fb0619eae4d27867c9008c27618fe059d52d2..fe788b7c72fbcab358ac2120b0069c7cf64b1801 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -7,7 +7,7 @@ from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap from flatland.core.env_observation_builder import GlobalObsForRailEnv - +from flatland.envs.agent_utils import EnvAgent """Tests for `flatland` package.""" @@ -58,16 +58,21 @@ def test_rail_environment_single_agent(): _ = rail_env.reset() # We do not care about target for the moment - rail_env.agents_target[0] = [-1, -1] + # rail_env.agents_target[0] = [-1, -1] + agent = rail_env.agents[0] + # rail_env.agents[0].target = [-1, -1] + agent.target = [-1, -1] # Check that trains are always initialized at a consistent position # or direction. # They should always be able to go somewhere. assert(transitions.get_transitions( - rail_map[rail_env.agents_position[0]], - rail_env.agents_direction[0]) != (0, 0, 0, 0)) + # rail_map[rail_env.agents_position[0]], + # rail_env.agents_direction[0]) != (0, 0, 0, 0)) + rail_map[agent.position], + agent.direction) != (0, 0, 0, 0)) - initial_pos = rail_env.agents_position[0] + initial_pos = agent.position valid_active_actions_done = 0 pos = initial_pos @@ -78,13 +83,13 @@ def test_rail_environment_single_agent(): _, _, _, _ = rail_env.step({0: action}) prev_pos = pos - pos = rail_env.agents_position[0] + pos = agent.position # rail_env.agents_position[0] if prev_pos != pos: valid_active_actions_done += 1 # After 6 movements on this railway network, the train should be back # to its original height on the map. - assert(initial_pos[0] == rail_env.agents_position[0][0]) + assert(initial_pos[0] == agent.position[0]) # We check that the train always attains its target after some time for _ in range(10): @@ -135,13 +140,14 @@ def test_dead_end(): # We run step to check that trains do not move anymore # after being done. for i in range(7): - prev_pos = rail_env.agents_position[0] + # prev_pos = rail_env.agents_position[0] + prev_pos = rail_env.agents[0].position # The train cannot turn, so we check that when it tries, # it stays where it is. _ = rail_env.step({0: 1}) _ = rail_env.step({0: 3}) - assert (rail_env.agents_position[0] == prev_pos) + assert (rail_env.agents[0].position == prev_pos) _, _, dones, _ = rail_env.step({0: 2}) if i < 5: @@ -151,15 +157,17 @@ def test_dead_end(): # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents_target[0] = (0, 0) - rail_env.agents_position[0] = (0, 2) - rail_env.agents_direction[0] = 1 + # rail_env.agents_target[0] = (0, 0) + # rail_env.agents_position[0] = (0, 2) + # rail_env.agents_direction[0] = 1 + rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = (0, 4) - rail_env.agents_position[0] = (0, 2) - rail_env.agents_direction[0] = 3 + # rail_env.agents_target[0] = (0, 4) + # rail_env.agents_position[0] = (0, 2) + # rail_env.agents_direction[0] = 3 + rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] check_consistency(rail_env) # In the vertical configuration: @@ -181,13 +189,15 @@ def test_dead_end(): obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents_target[0] = (0, 0) - rail_env.agents_position[0] = (2, 0) - rail_env.agents_direction[0] = 2 + # rail_env.agents_target[0] = (0, 0) + # rail_env.agents_position[0] = (2, 0) + # rail_env.agents_direction[0] = 2 + rail_env.agents = [EnvAgent(position=(2, 0), direction=2, target=(0, 0))] check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = (4, 0) - rail_env.agents_position[0] = (2, 0) - rail_env.agents_direction[0] = 0 + # rail_env.agents_target[0] = (4, 0) + # rail_env.agents_position[0] = (2, 0) + # rail_env.agents_direction[0] = 0 + rail_env.agents = [EnvAgent(position=(2, 0), direction=0, target=(4, 0))] check_consistency(rail_env)