Commit 7b49965c authored by u214892's avatar u214892
Browse files

observations only for active agents

parent ab22a0c6
Pipeline #2329 failed with stages
in 12 minutes and 38 seconds
......@@ -54,11 +54,9 @@ class ObservePredictions(ObservationBuilder):
pos_list.append(self.predictions[a][t][1:3])
# We transform (x,y) coodrinates to a single integer number for simpler comparison
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
observations = {}
# Collect all the different observation for all the agents
for h in handles:
observations[h] = self.get(h)
observations = super().get_many(handles)
return observations
def get(self, handle: int = 0) -> np.ndarray:
......
from enum import IntEnum
from itertools import starmap
from typing import Tuple
from typing import Tuple, Optional
import numpy as np
from attr import attrs, attrib, Factory
......@@ -9,9 +9,10 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0
ACTIVE = 1
DONE = 2
READY_TO_DEPART = 0 # -> observation
ACTIVE = 1 # -> observation
DONE = 2 # -> observation
DONE_REMOVED = 3 # -> no observation
@attrs
......@@ -21,11 +22,10 @@ class EnvAgentStatic(object):
rather than where it is at the moment.
The target should also be stored here.
"""
position = attrib(type=Tuple[int, int])
initial_position = attrib(type=Tuple[int, int])
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# position = attrib(default=None,type=Optional[Tuple[int, int]])
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
......@@ -42,6 +42,7 @@ class EnvAgentStatic(object):
'moving_before_malfunction': False})))
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
......@@ -75,7 +76,7 @@ class EnvAgentStatic(object):
# I can't find an expression which works on both tuples, lists and ndarrays
# which converts them all to a list of native python ints.
lPos = self.position
lPos = self.initial_position
if type(lPos) is np.ndarray:
lPos = lPos.tolist()
......
......@@ -72,14 +72,16 @@ class TreeObsForRailEnv(ObservationBuilder):
pos_list = []
dir_list = []
for a in handles:
if self.env.agents[a].status != RailAgentStatus.ACTIVE:
continue
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
observations = {}
for h in handles:
observations[h] = self.get(h)
observations = super().get_many(handles)
return observations
def get(self, handle: int = 0) -> Node:
......@@ -628,12 +630,7 @@ class LocalObsForRailEnv(ObservationBuilder):
in the `handles` list.
"""
observations = {}
if handles is None:
handles = []
for h in handles:
observations[h] = self.get(h)
return observations
return super().get_many(handles)
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
......
......@@ -5,6 +5,7 @@ Collection of environment-specific PredictionBuilder.
import numpy as np
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
......@@ -47,6 +48,9 @@ class DummyPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
......@@ -122,6 +126,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status != RailAgentStatus.ACTIVE:
continue
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_speed = agent.speed_data["speed"]
......
......@@ -224,12 +224,18 @@ class RailEnv(Environment):
self.agents_static.append(agent_static)
return len(self.agents_static) - 1
def set_agent_active(self, handle: int):
agent = self.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position
def restart_agents(self):
""" Reset the agents to their starting positions defined in agents_static
"""
self.agents = EnvAgent.list_from_static(self.agents_static)
def reset(self, regen_rail=True, replace_agents=True):
def reset(self, regen_rail=True, replace_agents=True, activate_agents=False):
""" if regen_rail then regenerate the rails.
if replace_agents then regenerate the agents static.
Relies on the rail_generator returning agent_static lists (pos, dir, target)
......@@ -265,8 +271,13 @@ class RailEnv(Environment):
*self.schedule_generator(self.rail, self.get_num_agents(), agents_hints))
self.restart_agents()
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
if activate_agents:
for i_agent in range(self.get_num_agents()):
self.set_agent_active(i_agent)
for i_agent, agent in enumerate(self.agents):
if agent.status != RailAgentStatus.ACTIVE:
continue
# A proportion of agent in the environment will receive a positive malfunction rate
if np.random.random() < self.proportion_malfunctioning_trains:
......@@ -375,8 +386,9 @@ class RailEnv(Environment):
self.dones[i] = True
info_dict = {
'action_required': {i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in
range(self.get_num_agents())},
'action_required': {
i: (agent.status == RailAgentStatus.ACTIVE and agent.speed_data['position_fraction'] == 0.0)
for i, agent in enumerate(self.agents)},
'malfunction': {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
},
......@@ -406,8 +418,9 @@ class RailEnv(Environment):
# agent gets active by a MOVE_* action and if c
if agent.status == RailAgentStatus.READY_TO_DEPART:
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD]: # and self.cell_free(agent.position):
RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position):
agent.status = RailAgentStatus.ACTIVE
agent.position = agent.initial_position
else:
return
......@@ -557,7 +570,10 @@ class RailEnv(Environment):
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def cell_free(self, position):
return not np.any(np.equal(position, [agent.position for agent in self.agents]).all(1))
agent_positions = [agent.position for agent in self.agents if agent.position is not None]
ret = len(agent_positions) == 0 or not np.any(np.equal(position, agent_positions).all(1))
return ret
def check_action(self, agent: EnvAgent, action: RailEnvActions):
"""
......
......@@ -7,6 +7,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
from flatland.utils.ordered_set import OrderedSet
......@@ -92,7 +93,17 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
shortest_paths = dict()
def _shortest_path_for_agent(agent):
position = agent.position
if agent.status == RailAgentStatus.READY_TO_DEPART:
position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
position = agent.position
elif agent.status == RailAgentStatus.DONE:
if agent.position is not None:
position = agent.target
else:
shortest_paths[agent.handle] = None
return
# todo is this correct? current position?
direction = agent.direction
shortest_paths[agent.handle] = []
distance = math.inf
......
......@@ -235,7 +235,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_position = [a.initial_position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if len(data['agents_static'][0]) > 5:
......
......@@ -28,46 +28,48 @@ def test_initial_status():
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=0,
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(3, 9), # east dead-end
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty!
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.ACTIVE,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 7),
......@@ -93,7 +95,7 @@ def test_initial_status():
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
action=None,
reward=env.global_reward, # already done
status=RailAgentStatus.DONE
),
......@@ -113,8 +115,10 @@ def test_initial_status():
)
],
initial_position=(3, 9), # east dead-end
initial_direction=Grid4TransitionsEnum.EAST,
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config])
run_replay_config(env, [test_config], activate_agents=False)
......@@ -28,274 +28,274 @@ def test_city_generator():
expected_grid_map = np.zeros((50, 50), dtype=env.rail.transitions.get_type())
expected_grid_map[8][16]=4
expected_grid_map[8][17]=5633
expected_grid_map[8][18]=1025
expected_grid_map[8][19]=1025
expected_grid_map[8][20]=17411
expected_grid_map[8][21]=1025
expected_grid_map[8][22]=1025
expected_grid_map[8][23]=1025
expected_grid_map[8][24]=1025
expected_grid_map[8][25]=1025
expected_grid_map[8][26]=4608
expected_grid_map[9][16]=16386
expected_grid_map[9][17]=50211
expected_grid_map[9][18]=1025
expected_grid_map[9][19]=1025
expected_grid_map[9][20]=3089
expected_grid_map[9][21]=1025
expected_grid_map[9][22]=256
expected_grid_map[9][26]=32800
expected_grid_map[10][6]=16386
expected_grid_map[10][7]=1025
expected_grid_map[10][8]=1025
expected_grid_map[10][9]=1025
expected_grid_map[10][10]=1025
expected_grid_map[10][11]=1025
expected_grid_map[10][12]=1025
expected_grid_map[10][13]=1025
expected_grid_map[10][14]=1025
expected_grid_map[10][15]=1025
expected_grid_map[10][16]=33825
expected_grid_map[10][17]=34864
expected_grid_map[10][26]=32800
expected_grid_map[11][6]=32800
expected_grid_map[11][16]=32800
expected_grid_map[11][17]=32800
expected_grid_map[11][26]=32800
expected_grid_map[12][6]=32800
expected_grid_map[12][16]=32800
expected_grid_map[12][17]=32800
expected_grid_map[12][26]=32800
expected_grid_map[13][6]=32800
expected_grid_map[13][16]=32800
expected_grid_map[13][17]=32800
expected_grid_map[13][26]=32800
expected_grid_map[14][6]=32800
expected_grid_map[14][16]=32800
expected_grid_map[14][17]=32800
expected_grid_map[14][26]=32800
expected_grid_map[15][6]=32800
expected_grid_map[15][16]=32800
expected_grid_map[15][17]=32800
expected_grid_map[15][26]=32800
expected_grid_map[16][6]=32800
expected_grid_map[16][16]=32800
expected_grid_map[16][17]=32800
expected_grid_map[16][26]=32800
expected_grid_map[17][6]=32800
expected_grid_map[17][16]=72
expected_grid_map[17][17]=1097
expected_grid_map[17][18]=1025
expected_grid_map[17][19]=1025
expected_grid_map[17][20]=1025
expected_grid_map[17][21]=1025
expected_grid_map[17][22]=1025
expected_grid_map[17][23]=1025
expected_grid_map[17][24]=1025
expected_grid_map[17][25]=1025
expected_grid_map[17][26]=33825
expected_grid_map[17][27]=4608
expected_grid_map[18][6]=32800
expected_grid_map[18][26]=72
expected_grid_map[18][27]=52275
expected_grid_map[18][28]=5633
expected_grid_map[18][29]=17411
expected_grid_map[18][30]=1025
expected_grid_map[18][31]=1025
expected_grid_map[18][32]=256
expected_grid_map[19][6]=32800
expected_grid_map[19][25]=16386
expected_grid_map[19][26]=1025
expected_grid_map[19][27]=2136
expected_grid_map[19][28]=1097
expected_grid_map[19][29]=1097
expected_grid_map[19][30]=5633
expected_grid_map[19][31]=1025
expected_grid_map[19][32]=256
expected_grid_map[20][6]=32800
expected_grid_map[20][25]=32800
expected_grid_map[20][26]=16386
expected_grid_map[20][27]=17411
expected_grid_map[20][28]=1025
expected_grid_map[20][29]=1025
expected_grid_map[20][30]=3089
expected_grid_map[20][31]=1025
expected_grid_map[20][32]=256
expected_grid_map[21][6]=32800
expected_grid_map[21][16]=16386
expected_grid_map[21][17]=1025
expected_grid_map[21][18]=1025
expected_grid_map[21][19]=1025
expected_grid_map[21][20]=1025
expected_grid_map[21][21]=1025
expected_grid_map[21][22]=1025
expected_grid_map[21][23]=1025
expected_grid_map[21][24]=1025
expected_grid_map[21][25]=33825
expected_grid_map[21][26]=33825
expected_grid_map[21][27]=2064
expected_grid_map[22][6]=32800
expected_grid_map[22][16]=32800
expected_grid_map[22][25]=32800
expected_grid_map[22][26]=32800
expected_grid_map[23][6]=32800
expected_grid_map[23][16]=32800
expected_grid_map[23][25]=32800
expected_grid_map[23][26]=32800
expected_grid_map[24][6]=32800
expected_grid_map[24][16]=32800
expected_grid_map[24][25]=32800
expected_grid_map[24][26]=32800
expected_grid_map[25][6]=32800
expected_grid_map[25][16]=32800
expected_grid_map[25][25]=32800
expected_grid_map[25][26]=32800
expected_grid_map[26][6]=32800
expected_grid_map[26][16]=32800
expected_grid_map[26][25]=32800
expected_grid_map[26][26]=32800
expected_grid_map[27][6]=72
expected_grid_map[27][7]=1025
expected_grid_map[27][8]=1025
expected_grid_map[27][9]=17411
expected_grid_map[27][10]=1025
expected_grid_map[27][11]=1025
expected_grid_map[27][12]=1025
expected_grid_map[27][13]=1025
expected_grid_map[27][14]=1025
expected_grid_map[27][15]=4608
expected_grid_map[27][16]=72
expected_grid_map[27][17]=17411
expected_grid_map[27][18]=5633
expected_grid_map[27][19]=1025
expected_grid_map[27][20]=1025
expected_grid_map[27][21]=1025
expected_grid_map[27][22]=1025
expected_grid_map[27][23]=1025
expected_grid_map[27][24]=1025
expected_grid_map[27][25]=33825
expected_grid_map[27][26]=2064
expected_grid_map[28][6]=4
expected_grid_map[28][7]=1025
expected_grid_map[28][8]=1025
expected_grid_map[28][9]=3089
expected_grid_map[28][10]=1025
expected_grid_map[28][11]=1025
expected_grid_map[28][12]=1025
expected_grid_map[28][13]=1025
expected_grid_map[28][14]=4608
expected_grid_map[28][15]=72
expected_grid_map[28][16]=1025
expected_grid_map[28][17]=2136
expected_grid_map[28][18]=1097
expected_grid_map[28][19]=5633
expected_grid_map[28][20]=5633
expected_grid_map[28][21]=1025
expected_grid_map[28][22]=256
expected_grid_map[28][25]=32800
expected_grid_map[29][6]=4
expected_grid_map[29][7]=5633
expected_grid_map[29][8]=20994
expected_grid_map[29][9]=5633
expected_grid_map[29][10]=1025
expected_grid_map[29][11]=1025
expected_grid_map[29][12]=1025
expected_grid_map[29][13]=1025
expected_grid_map[29][14]=1097
expected_grid_map[29][15]=5633
expected_grid_map[29][16]=1025
expected_grid_map[29][17]=17411
expected_grid_map[29][18]=5633
expected_grid_map[29][19]=1097
expected_grid_map[29][20]=3089
expected_grid_map[29][21]=20994
expected_grid_map[29][22]=1025
expected_grid_map[29][23]=1025
expected_grid_map[29][24]=1025
expected_grid_map[29][25]=2064
expected_grid_map[30][6]=16386
expected_grid_map[30][7]=38505
expected_grid_map[30][8]=3089
expected_grid_map[30][9]=1097
expected_grid_map[30][10]=1025
expected_grid_map[30][11]=1025
expected_grid_map[30][12]=256
expected_grid_map[30][15]=32800
expected_grid_map[30][16]=16386
expected_grid_map[30][17]=52275
expected_grid_map[30][18]=1097
expected_grid_map[30][19]=1025
expected_grid_map[30][20]=1025
expected_grid_map[30][21]=3089
expected_grid_map[30][22]=256
expected_grid_map[31][6]=32800
expected_grid_map[31][7]=32800
expected_grid_map[31][15]=72
expected_grid_map[31][16]=37408
expected_grid_map[31][17]=32800
expected_grid_map[32][6]=32800
expected_grid_map[32][7]=32800
expected_grid_map[32][16]=32800
expected_grid_map[32][17]=32800
expected_grid_map[33][6]=32800
expected_grid_map[33][7]=32800
expected_grid_map[33][16]=32800
expected_grid_map[33][17]=32800
expected_grid_map[34][6]=32800
expected_grid_map[34][7]=32800
expected_grid_map[34][16]=32800
expected_grid_map[34][17]=32800
expected_grid_map[35][6]=32800
expected_grid_map[35][7]=32800
expected_grid_map[35][16]=32800
expected_grid_map[35][17]=32800
expected_grid_map[36][6]=32800
expected_grid_map[36][7]=32800
expected_grid_map[36][16]=32800
expected_grid_map[36][17]=32800
expected_grid_map[37][6]=72
expected_grid_map[37][7]=1097
expected_grid_map[37][8]=1025
expected_grid_map[37][9]=1025
expected_grid_map[37][10]=1025
expected_grid_map[37][11]=1025
expected_grid_map[37][12]=1025
expected_grid_map[37][13]=1025
expected_grid_map[37][14]=1025
expected_grid_map[37][15]=1025
expected_grid_map[37][16]=33897
expected_grid_map[37][17]=37408
expected_grid_map[38][16]=72
expected_grid_map[38][17]=52275
expected_grid_map[38][18]=5633
expected_grid_map[38][19]=17411
expected_grid_map[38][20]=1025
expected_grid_map[38][21]=1025
expected_grid_map[38][22]=256
expected_grid_map[39][16]=4
expected_grid_map[39][17]=52275
expected_grid_map[39][18]=3089
expected_grid_map[39][19]=1097
expected_grid_map[39][20]=5633
expected_grid_map[39][21]=1025
expected_grid_map[39][22]=256
expected_grid_map[40][16]=4
expected_grid_map[40][17]=1097
expected_grid_map[40][18]=1025
expected_grid_map[40][19]=1025
expected_grid_map[40][20]=3089
expected_grid_map[40][21]=1025