diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 9cb7b95567987bc9c7e986b1db895b99b0c223ec..c29839e6c20972260c4ae9536a36d0ff8f58b0af 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -1,6 +1,6 @@ from attr import attrs, attrib -from itertools import starmap +from itertools import starmap, count import numpy as np @attrs @@ -15,11 +15,13 @@ class EnvAgentStatic(object): target = attrib() handle = attrib() + next_handle = 0 + @classmethod def from_lists(positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return starmap(EnvAgentStatic, zip(positions, directions, targets)) + return starmap(EnvAgentStatic, zip(positions, directions, targets, count())) class EnvAgent(EnvAgentStatic): @@ -33,6 +35,7 @@ class EnvAgent(EnvAgentStatic): class EnvManager(object): def __init__(self, env=None): self.env = env + def load_env(self, sFilename): pass @@ -46,7 +49,28 @@ class EnvManager(object): def replace_agents(self): pass - def add_agent(self, rcPos=None, rcTarget=None, iDir=None): + def add_agent_static(self, agent_static): + """ Add a new agent_static + """ + iAgent = self.number_of_agents + + if iDir is None: + iDir = self.pick_agent_direction(rcPos, rcTarget) + if iDir is None: + print("Error picking agent direction at pos:", rcPos) + return None + + self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list + self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 + self.agents_direction.append(iDir) + self.agents_target.append(rcPos) # set the target to the origin initially + self.number_of_agents += 1 + self.check_agent_lists() + return iAgent + + + + def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None): """ Add a new agent at position rcPos with target rcTarget and initial direction index iDir. Should also store this initial position etc as environment "meta-data" diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 77ec25ec592b9db7da4ff19a158420617d83a74a..ea8c3dcaf4ac58f0f241643bfebb4da43afe0c09 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -116,12 +116,16 @@ class RailEnv(Environment): TODO: replace_agents is ignored at the moment; agents will always be replaced. """ if regen_rail or self.rail is None: - self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator( + self.rail, agents_position, agents_direction, agents_target = self.rail_generator( self.width, self.height, self.agents_handles, self.num_resets) + if replace_agents: + self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target) + self.agents = copy(agents_static) + self.num_resets += 1 self.dones = {"__all__": False}