Commit 8b186f8c authored by u229589's avatar u229589
Browse files

remove static agents

parent ae124063
Pipeline #2708 passed with stages
in 35 minutes and 15 seconds
......@@ -2,7 +2,6 @@ from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional
import numpy as np
from attr import attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
......@@ -17,13 +16,10 @@ class RailAgentStatus(IntEnum):
@attrs
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
This is like static data for the environment - it's where an agent starts,
rather than where it is at the moment.
The target should also be stored here.
"""
class EnvAgent:
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
......@@ -42,12 +38,31 @@ class EnvAgentStatic(object):
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False})))
handle = attrib(default=None)
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
def reset(self):
self.position = None
self.direction = self.initial_direction
self.status = RailAgentStatus.READY_TO_DEPART
self.old_position = None
self.old_direction = None
self.moving = False
def to_list(self):
return [self.initial_position, self.initial_direction, int(self.direction), self.target, int(self.moving),
self.speed_data, self.malfunction_data, self.handle, self.status, self.position, self.old_direction,
self.old_position]
@classmethod
def from_lists(cls, schedule: Schedule):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
def from_schedule(cls, schedule: Schedule):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
......@@ -56,9 +71,6 @@ class EnvAgentStatic(object):
'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0,
'transition_action_on_cellexit': 0})
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
# some as broken?
malfunction_datas = []
for i in range(len(schedule.agent_positions)):
malfunction_datas.append({'malfunction': 0,
......@@ -67,59 +79,11 @@ class EnvAgentStatic(object):
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
speed_datas,
malfunction_datas)))
def to_list(self):
# I can't find an expression which works on both tuples, lists and ndarrays
# which converts them all to a list of native python ints.
lPos = self.initial_position
if type(lPos) is np.ndarray:
lPos = lPos.tolist()
lTarget = self.target
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
@attrs
class EnvAgent(EnvAgentStatic):
""" EnvAgent - replace separate agent_* lists with a single list
of agent objects. The EnvAgent represent's the environment's view
of the dynamic agent state.
We are duplicating target in the EnvAgent, which seems simpler than
forcing the env to refer to it in the EnvAgentStatic
"""
handle = attrib(default=None)
old_direction = attrib(default=None)
old_position = attrib(default=None)
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
@classmethod
def from_static(cls, oStatic):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
"""
return EnvAgent(*oStatic.__dict__, handle=0)
@classmethod
def list_from_static(cls, lEnvAgentStatic, handles=None):
""" Create an EnvAgent from the EnvAgentStatic,
copying all the fields, and adding handle with the default 0.
"""
if handles is None:
handles = range(len(lEnvAgentStatic))
return [EnvAgent(**oEAS.__dict__, handle=handle)
for handle, oEAS in zip(handles, lEnvAgentStatic)]
return list(starmap(EnvAgent, zip(schedule.agent_positions,
schedule.agent_directions,
schedule.agent_directions,
schedule.agent_targets,
[False] * len(schedule.agent_positions),
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)))))
......@@ -157,8 +157,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
new_position = agent_virtual_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target or not moving, stop moving until max_depth is reached
if new_position == agent.target or not agent.moving or not shortest_path:
# if we're at the target, stop moving until max_depth is reached
if new_position == agent.target or not shortest_path:
prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
visited.add((*new_position, agent.direction))
continue
......
......@@ -17,7 +17,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
......@@ -182,8 +182,8 @@ class RailEnv(Environment):
self.dev_obs_dict = {}
self.dev_pred_dict = {}
self.agents: List[EnvAgent] = [None] * number_of_agents # live agents
self.agents_static: List[EnvAgentStatic] = [None] * number_of_agents # static agent information
self.agents: List[EnvAgent] = []
self.number_of_agents = number_of_agents
self.num_resets = 0
self.distance_map = DistanceMap(self.agents, self.height, self.width)
......@@ -227,18 +227,15 @@ class RailEnv(Environment):
def get_agent_handles(self):
return range(self.get_num_agents())
def get_num_agents(self, static=True):
if static:
return len(self.agents_static)
else:
return len(self.agents)
def get_num_agents(self) -> int:
return len(self.agents)
def add_agent_static(self, agent_static):
def add_agent(self, agent):
""" Add static info for a single agent.
Returns the index of the new agent.
"""
self.agents_static.append(agent_static)
return len(self.agents_static) - 1
self.agents.append(agent)
return len(self.agents) - 1
def set_agent_active(self, handle: int):
agent = self.agents[handle]
......@@ -247,9 +244,10 @@ class RailEnv(Environment):
self._set_agent_to_initial_position(agent, agent.initial_position)
def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static
""" Reset the agents to their starting positions
"""
self.agents = EnvAgent.list_from_static(self.agents_static)
for agent in self.agents:
agent.reset()
self.active_agents = [i for i in range(len(self.agents))]
@staticmethod
......@@ -327,7 +325,7 @@ class RailEnv(Environment):
optionals = {}
if regenerate_rail or self.rail is None:
rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
rail, optionals = self.rail_generator(self.width, self.height, self.number_of_agents, self.num_resets)
self.rail = rail
self.height, self.width = self.rail.grid.shape
......@@ -340,17 +338,13 @@ class RailEnv(Environment):
if optionals and 'distance_map' in optionals:
self.distance_map.set(optionals['distance_map'])
# todo change self.agents_static[0] with the refactoring for agents_static -> issue nr. 185
# https://gitlab.aicrowd.com/flatland/flatland/issues/185
if regenerate_schedule or regenerate_rail or self.agents_static[0] is None:
if regenerate_schedule or regenerate_rail or len(self.agents) == 0:
agents_hints = None
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
# TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
# why do we need static agents? could we it more elegantly?
schedule = self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets)
self.agents_static = EnvAgentStatic.from_lists(schedule)
schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets)
self.agents = EnvAgent.from_schedule(schedule)
if agents_hints and 'city_orientations' in agents_hints:
ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
......@@ -391,9 +385,9 @@ class RailEnv(Environment):
info_dict: Dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
},
'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
'speed': {i: agent.speed_data['speed'] for i, agent in enumerate(self.agents)},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
}
# Return the new observation vectors for each agent
......@@ -819,14 +813,11 @@ class RailEnv(Environment):
Returns state of environment in msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True)
......@@ -850,8 +841,7 @@ class RailEnv(Environment):
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -869,8 +859,7 @@ class RailEnv(Environment):
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
......@@ -884,16 +873,13 @@ class RailEnv(Environment):
Returns environment information with distance map information as msgpack object
"""
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_map": distance_map_data}
......
......@@ -7,7 +7,7 @@ import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.schedule_utils import Schedule
AgentPosition = Tuple[int, int]
......@@ -291,21 +291,15 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# agents are always reset as not moving
if len(data['agents_static'][0]) > 5:
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]]
else:
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]
agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11]) for d in data["agents"]]
# setup with loaded data
agents_position = [a.initial_position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if len(data['agents_static'][0]) > 5:
agents_speed = [a.speed_data['speed'] for a in agents_static]
else:
agents_speed = None
agents_position = [a.initial_position for a in agents]
agents_direction = [a.direction for a in agents]
agents_target = [a.target for a in agents]
agents_speed = [a.speed_data['speed'] for a in agents]
agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
......
......@@ -10,7 +10,7 @@ from numpy import array
import flatland.utils.rendertools as rt
from flatland.core.grid.grid4_utils import mirror
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator
from flatland.envs.rail_generators import complex_rail_generator, empty_rail_generator
......@@ -147,7 +147,7 @@ class View(object):
def redraw(self):
with self.output_generator:
self.oRT.set_new_rail()
self.model.env.agents = self.model.env.agents_static
self.model.env.restart_agents()
for a in self.model.env.agents:
if hasattr(a, 'old_position') is False:
a.old_position = a.position
......@@ -329,7 +329,7 @@ class Controller(object):
def rotate_agent(self, event):
self.log("Rotate Agent:", self.model.selected_agent)
if self.model.selected_agent is not None:
for agent_idx, agent in enumerate(self.model.env.agents_static):
for agent_idx, agent in enumerate(self.model.env.agents):
if agent is None:
continue
if agent_idx == self.model.selected_agent:
......@@ -339,13 +339,7 @@ class Controller(object):
def restart_agents(self, event):
self.log("Restart Agents - nAgents:", self.view.regen_n_agents.value)
if self.model.init_agents_static is not None:
self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in
self.model.init_agents_static]
self.model.env.agents = None
self.model.init_agents_static = None
self.model.env.restart_agents()
self.model.env.reset(False, False)
self.model.env.reset(False, False)
self.refresh(event)
def regenerate(self, event):
......@@ -399,7 +393,6 @@ class EditorModel(object):
self.env_filename = "temp.pkl"
self.set_env(env)
self.selected_agent = None
self.init_agents_static = None
self.thread = None
self.save_image_count = 0
......@@ -602,7 +595,6 @@ class EditorModel(object):
def clear(self):
self.env.rail.grid[:, :] = 0
self.env.agents = []
self.env.agents_static = []
self.redraw()
......@@ -616,7 +608,7 @@ class EditorModel(object):
self.redraw()
def restart_agents(self):
self.env.agents = EnvAgent.list_from_static(self.env.agents_static)
self.env.restart_agents()
self.redraw()
def set_filename(self, filename):
......@@ -634,7 +626,6 @@ class EditorModel(object):
self.env.restart_agents()
self.env.reset(False, False)
self.init_agents_static = None
self.view.oRT.update_background()
self.fix_env()
self.set_env(self.env)
......@@ -644,12 +635,7 @@ class EditorModel(object):
def save(self):
self.log("save to ", self.env_filename, " working dir: ", os.getcwd())
temp_store = self.env.agents
# clear agents before save , because we want the "init" position of the agent to expert
self.env.agents = []
self.env.save(self.env_filename)
# reset agents current (current position)
self.env.agents = temp_store
def save_image(self):
self.view.oRT.gl.save_image('frame_{:04d}.bmp'.format(self.save_image_count))
......@@ -689,7 +675,7 @@ class EditorModel(object):
self.regen_size_height = size
def find_agent_at(self, cell_row_col):
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if tuple(agent.position) == tuple(cell_row_col):
return agent_idx
return None
......@@ -709,15 +695,14 @@ class EditorModel(object):
# No
if self.selected_agent is None:
# Create a new agent and select it.
agent_static = EnvAgentStatic(position=cell_row_col, direction=0, target=cell_row_col, moving=False)
self.selected_agent = self.env.add_agent_static(agent_static)
agent = EnvAgent(position=cell_row_col, direction=0, target=cell_row_col, moving=False)
self.selected_agent = self.env.add_agent(agent)
self.view.oRT.update_background()
else:
# Move the selected agent to this cell
agent_static = self.env.agents_static[self.selected_agent]
agent_static.position = cell_row_col
agent_static.old_position = cell_row_col
self.env.agents = []
agent = self.env.agents[self.selected_agent]
agent.position = cell_row_col
agent.old_position = cell_row_col
else:
# Yes
# Have they clicked on the agent already selected?
......@@ -728,13 +713,11 @@ class EditorModel(object):
# No - select the agent
self.selected_agent = agent_idx
self.init_agents_static = None
self.redraw()
def add_target(self, rcCell):
if self.selected_agent is not None:
self.env.agents_static[self.selected_agent].target = rcCell
self.init_agents_static = None
self.env.agents[self.selected_agent].target = rcCell
self.view.oRT.update_background()
self.redraw()
......
......@@ -77,7 +77,7 @@ class RenderTool(object):
def update_background(self):
# create background map
targets = {}
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
targets[tuple(agent.target)] = agent_idx
......@@ -93,10 +93,9 @@ class RenderTool(object):
self.new_rail = True
def plot_agents(self, targets=True, selected_agent=None):
color_map = self.gl.get_cmap('hsv',
lut=max(len(self.env.agents), len(self.env.agents_static) + 1))
color_map = self.gl.get_cmap('hsv', lut=(len(self.env.agents) + 1))
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
color = color_map(agent_idx)
......@@ -515,7 +514,7 @@ class RenderTool(object):
# store the targets
targets = {}
selected = {}
for agent_idx, agent in enumerate(self.env.agents_static):
for agent_idx, agent in enumerate(self.env.agents):
if agent is None:
continue
targets[tuple(agent.target)] = agent_idx
......
......@@ -33,13 +33,12 @@ def test_walker():
obs_builder_object=TreeObsForRailEnv(max_depth=2,
predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
)
# reset to initialize agents_static
env.reset()
# set initial position and direction for testing...
env.agents_static[0].position = (0, 1)
env.agents_static[0].direction = 1
env.agents_static[0].target = (0, 0)
env.agents[0].position = (0, 1)
env.agents[0].direction = 1
env.agents[0].target = (0, 0)
# reset to set agents from agents_static
env.reset(False, False)
......
......@@ -53,13 +53,11 @@ def test_grid8_set_transitions():
def check_path(env, rail, position, direction, target, expected, rendering=False):
agent = env.agents_static[0]
agent = env.agents[0]
agent.position = position # south dead-end
agent.direction = direction # north
agent.target = target # east dead-end
agent.moving = True
# reset to set agents from agents_static
# env.reset(False, False)
if rendering:
renderer = RenderTool(env, gl="PILSVG")
renderer.render_env(show=True, show_observations=False)
......@@ -76,8 +74,6 @@ def test_path_exists(rendering=False):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
......@@ -142,8 +138,6 @@ def test_path_not_exists(rendering=False):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
# reset to initialize agents_static
env.reset()
check_path(
......
......@@ -103,26 +103,37 @@ def test_reward_function_conflict(rendering=False):
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()
# set the initial position
agent = env.agents_static[0]
agent = env.agents[0]
agent.position = (5, 6) # south dead-end
agent.initial_position = (5, 6) # south dead-end
agent.direction = 0 # north
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent = env.agents_static[1]
agent = env.agents[1]
agent.position = (3, 8) # east dead-end
agent.initial_position = (3, 8) # east dead-end
agent.direction = 3 # west
agent.initial_direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
# reset to set agents from agents_static
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].status = RailAgentStatus.ACTIVE
env.agents[0].position = (5, 6)
env.agents[1].position = (3, 8)
print("\n")
print(env.agents[0])
print(env.agents[1])
if rendering:
renderer = RenderTool(env, gl="PILSVG")
......@@ -185,28 +196,34 @@ def test_reward_function_waiting(rendering=False):
remove_agents_at_target=False
)
obs_builder: TreeObsForRailEnv = env.obs_builder
# initialize agents_static
env.reset()