Skip to content
Snippets Groups Projects
Commit 6afc2a09 authored by hagrid67's avatar hagrid67
Browse files

add_agent_static

parent 95362985
No related branches found
No related tags found
No related merge requests found
from attr import attrs, attrib
from itertools import starmap
from itertools import starmap, count
import numpy as np
@attrs
......@@ -15,11 +15,13 @@ class EnvAgentStatic(object):
target = attrib()
handle = attrib()
next_handle = 0
@classmethod
def from_lists(positions, directions, targets):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
return starmap(EnvAgentStatic, zip(positions, directions, targets))
return starmap(EnvAgentStatic, zip(positions, directions, targets, count()))
class EnvAgent(EnvAgentStatic):
......@@ -33,6 +35,7 @@ class EnvAgent(EnvAgentStatic):
class EnvManager(object):
def __init__(self, env=None):
self.env = env
def load_env(self, sFilename):
pass
......@@ -46,7 +49,28 @@ class EnvManager(object):
def replace_agents(self):
pass
def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
def add_agent_static(self, agent_static):
""" Add a new agent_static
"""
iAgent = self.number_of_agents
if iDir is None:
iDir = self.pick_agent_direction(rcPos, rcTarget)
if iDir is None:
print("Error picking agent direction at pos:", rcPos)
return None
self.agents_position.append(tuple(rcPos)) # ensure it's a tuple not a list
self.agents_handles.append(max(self.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0
self.agents_direction.append(iDir)
self.agents_target.append(rcPos) # set the target to the origin initially
self.number_of_agents += 1
self.check_agent_lists()
return iAgent
def add_agent_old(self, rcPos=None, rcTarget=None, iDir=None):
""" Add a new agent at position rcPos with target rcTarget and
initial direction index iDir.
Should also store this initial position etc as environment "meta-data"
......
......@@ -116,12 +116,16 @@ class RailEnv(Environment):
TODO: replace_agents is ignored at the moment; agents will always be replaced.
"""
if regen_rail or self.rail is None:
self.rail, self.agents_position, self.agents_direction, self.agents_target = self.rail_generator(
self.rail, agents_position, agents_direction, agents_target = self.rail_generator(
self.width,
self.height,
self.agents_handles,
self.num_resets)
if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
self.agents = copy(agents_static)
self.num_resets += 1
self.dones = {"__all__": False}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment