From 6afc2a09f64046755ad3ad4fdec7930aabf35447 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Thu, 9 May 2019 11:53:12 +0100
Subject: [PATCH] add_agent_static

---
 flatland/envs/agent_utils.py | 30 +++++++++++++++++++++++++++---
 flatland/envs/rail_env.py    |  6 +++++-
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 9cb7b955..c29839e6 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -1,6 +1,6 @@
 
 from attr import attrs, attrib
-from itertools import starmap
+from itertools import starmap, count
 import numpy as np
 
 @attrs
@@ -15,11 +15,13 @@ class EnvAgentStatic(object):
     target = attrib()
     handle = attrib()
 
+    next_handle = 0
+
     @classmethod
     def from_lists(positions, directions, targets):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
-        return starmap(EnvAgentStatic, zip(positions, directions, targets))
+        return starmap(EnvAgentStatic, zip(positions, directions, targets, count()))
         
 
 class EnvAgent(EnvAgentStatic):
@@ -33,6 +35,7 @@ class EnvAgent(EnvAgentStatic):
 class EnvManager(object):
     def __init__(self, env=None):
         self.env = env
+        
 
     def load_env(self, sFilename):
         pass
@@ -46,7 +49,28 @@ class EnvManager(object):
     def replace_agents(self):
         pass
 
-    def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
+    def add_agent_static(self, agent_static):
+        """ Add a new agent_static
+        """
+        iAgent = self.number_of_agents
+
+        if iDir is None:
+            iDir = self.pick_agent_direction(rcPos, rcTarget)
+        if iDir is None:
+            print("Error picking agent direction at pos:", rcPos)
+            return None
+
+        self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
+        self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
+        self.agents_direction.append(iDir)
+        self.agents_target.append(rcPos)  # set the target to the origin initially
+        self.number_of_agents += 1
+        self.check_agent_lists()
+        return iAgent
+
+
+
+    def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None):
         """ Add a new agent at position rcPos with target rcTarget and
             initial direction index iDir.
             Should also store this initial position etc as environment "meta-data"
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 77ec25ec..ea8c3dca 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -116,12 +116,16 @@ class RailEnv(Environment):
         TODO: replace_agents is ignored at the moment; agents will always be replaced.
         """
         if regen_rail or self.rail is None:
-            self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
+            self.rail, agents_position, agents_direction, agents_target = self.rail_generator(
                 self.width,
                 self.height,
                 self.agents_handles,
                 self.num_resets)
 
+        if replace_agents:
+            self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
+            self.agents = copy(agents_static)
+
         self.num_resets += 1
 
         self.dones = {"__all__": False}
-- 
GitLab