Skip to content
Snippets Groups Projects
Commit 6afc2a09 authored by hagrid67's avatar hagrid67
Browse files

add_agent_static

parent 95362985
No related branches found
No related tags found
No related merge requests found
from attr import attrs, attrib from attr import attrs, attrib
from itertools import starmap from itertools import starmap, count
import numpy as np import numpy as np
@attrs @attrs
...@@ -15,11 +15,13 @@ class EnvAgentStatic(object): ...@@ -15,11 +15,13 @@ class EnvAgentStatic(object):
target = attrib() target = attrib()
handle = attrib() handle = attrib()
next_handle = 0
@classmethod @classmethod
def from_lists(positions, directions, targets): def from_lists(positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
""" """
return starmap(EnvAgentStatic, zip(positions, directions, targets)) return starmap(EnvAgentStatic, zip(positions, directions, targets, count()))
class EnvAgent(EnvAgentStatic): class EnvAgent(EnvAgentStatic):
...@@ -33,6 +35,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -33,6 +35,7 @@ class EnvAgent(EnvAgentStatic):
class EnvManager(object): class EnvManager(object):
def __init__(self, env=None): def __init__(self, env=None):
self.env = env self.env = env
def load_env(self, sFilename): def load_env(self, sFilename):
pass pass
...@@ -46,7 +49,28 @@ class EnvManager(object): ...@@ -46,7 +49,28 @@ class EnvManager(object):
def replace_agents(self): def replace_agents(self):
pass pass
def add_agent(self, rcPos=None, rcTarget=None, iDir=None): def add_agent_static(self, agent_static):
""" Add a new agent_static
"""
iAgent = self.number_of_agents
if iDir is None:
iDir = self.pick_agent_direction(rcPos, rcTarget)
if iDir is None:
print("Error picking agent direction at pos:", rcPos)
return None
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
self.agents_direction.append(iDir)
self.agents_target.append(rcPos) # set the target to the origin initially
self.number_of_agents += 1
self.check_agent_lists()
return iAgent
def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None):
""" Add a new agent at position rcPos with target rcTarget and """ Add a new agent at position rcPos with target rcTarget and
initial direction index iDir. initial direction index iDir.
Should also store this initial position etc as environment "meta-data" Should also store this initial position etc as environment "meta-data"
......
...@@ -116,12 +116,16 @@ class RailEnv(Environment): ...@@ -116,12 +116,16 @@ class RailEnv(Environment):
TODO: replace_agents is ignored at the moment; agents will always be replaced. TODO: replace_agents is ignored at the moment; agents will always be replaced.
""" """
if regen_rail or self.rail is None: if regen_rail or self.rail is None:
self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator( self.rail, agents_position, agents_direction, agents_target = self.rail_generator(
self.width, self.width,
self.height, self.height,
self.agents_handles, self.agents_handles,
self.num_resets) self.num_resets)
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
self.agents = copy(agents_static)
self.num_resets += 1 self.num_resets += 1
self.dones = {"__all__": False} self.dones = {"__all__": False}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment