Commit bc0ee393 authored by u214892's avatar u214892
Browse files

introducing agent status

parent b82c5362
Pipeline #2323 passed with stages
in 42 minutes and 8 seconds
from enum import IntEnum
from itertools import starmap
from typing import Tuple
......@@ -7,6 +8,12 @@ from attr import attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
class RailAgentStatus(IntEnum):
READY_TO_DEPART = 0
ACTIVE = 1
DONE = 2
@attrs
class EnvAgentStatic(object):
""" EnvAgentStatic - Stores initial position, direction and target.
......@@ -18,6 +25,7 @@ class EnvAgentStatic(object):
direction = attrib(type=Grid4TransitionsEnum)
target = attrib(type=Tuple[int, int])
moving = attrib(default=False, type=bool)
# position = attrib(default=None,type=Optional[Tuple[int, int]])
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
......@@ -33,6 +41,8 @@ class EnvAgentStatic(object):
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False})))
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
......
......@@ -11,11 +11,11 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.utils.ordered_set import OrderedSet
class TreeObsForRailEnv(ObservationBuilder):
Node = collections.namedtuple('Node', 'dist_own_target_encountered '
'dist_other_target_encountered '
'dist_other_agent_encountered '
......@@ -296,7 +296,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
......@@ -307,7 +307,7 @@ class TreeObsForRailEnv(ObservationBuilder):
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
......@@ -318,7 +318,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
if self.env.agents[ca].status == RailAgentStatus.DONE and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
......@@ -621,7 +621,8 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
def get_many(self, handles: Optional[List[int]] = None) -> Dict[
int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......
......@@ -15,7 +15,7 @@ from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
......@@ -354,7 +354,8 @@ class RailEnv(Environment):
info_dict = {
'action_required': {i: False for i in range(self.get_num_agents())},
'malfunction': {i: 0 for i in range(self.get_num_agents())},
'speed': {i: 0 for i in range(self.get_num_agents())}
'speed': {i: 0 for i in range(self.get_num_agents())},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
}
return self._get_observations(), self.rewards_dict, self.dones, info_dict
......@@ -369,21 +370,18 @@ class RailEnv(Environment):
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for k in self.dones.keys():
self.dones[k] = True
action_required_agents = {
i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in range(self.get_num_agents())
}
malfunction_agents = {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
}
speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}
for i in range(self.get_num_agents()):
self.agents[i].status = RailAgentStatus.DONE
self.dones[i] = True
info_dict = {
'action_required': action_required_agents,
'malfunction': malfunction_agents,
'speed': speed_agents
'action_required': {i: self.agents[i].speed_data['position_fraction'] == 0.0 for i in
range(self.get_num_agents())},
'malfunction': {
i: self.agents[i].malfunction_data['malfunction'] for i in range(self.get_num_agents())
},
'speed': {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())},
'status': {i: agent.status for i, agent in enumerate(self.agents)}
}
return self._get_observations(), self.rewards_dict, self.dones, info_dict
......@@ -401,10 +399,18 @@ class RailEnv(Environment):
action_dict_ : Dict[int,RailEnvActions]
"""
if self.dones[i_agent]: # this agent has already completed...
agent = self.agents[i_agent]
if agent.status == RailAgentStatus.DONE: # this agent has already completed...
return
agent = self.agents[i_agent]
# agent gets active by a MOVE_* action and if c
if agent.status == RailAgentStatus.READY_TO_DEPART:
if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT,
RailEnvActions.MOVE_FORWARD]: # and self.cell_free(agent.position):
agent.status = RailAgentStatus.ACTIVE
else:
return
agent.old_direction = agent.direction
agent.old_position = agent.position
......@@ -497,6 +503,7 @@ class RailEnv(Environment):
# has the agent reached its target?
if np.equal(agent.position, agent.target).all():
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
agent.moving = False
else:
......@@ -543,9 +550,12 @@ class RailEnv(Environment):
# Check the new position is not the same as any of the existing agent positions
# (including itself, for simplicity, since it is moving)
cell_free = not np.any(np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))
cell_free = self.cell_free(new_position)
return cell_free, new_cell_valid, new_direction, new_position, transition_valid
def cell_free(self, position):
return not np.any(np.equal(position, [agent.position for agent in self.agents]).all(1))
def check_action(self, agent: EnvAgent, action: RailEnvActions):
"""
......@@ -591,7 +601,7 @@ class RailEnv(Environment):
return self.obs_dict
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row,col))
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
def get_full_state_msg(self):
grid_data = self.rail.grid.tolist()
......
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
np.random.seed(1)
def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
set_penalties_for_replay(env)
test_config = ReplayConfig(
replay=[
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
reward=0,
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(3, 9), # east dead-end
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5, # auto-correction left to forward without penalty!
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # done
status=RailAgentStatus.ACTIVE
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.global_reward, # already done
status=RailAgentStatus.DONE
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # already done
status=RailAgentStatus.DONE
),
Replay(
position=(3, 5),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # already done
status=RailAgentStatus.DONE
)
],
target=(3, 5),
speed=0.5
)
run_replay_config(env, [test_config])
......@@ -2,14 +2,15 @@ import random
from typing import Dict, List
import numpy as np
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
class SingleAgentNavigationObs(ObservationBuilder):
......@@ -271,17 +272,19 @@ def test_initial_malfunction_stop_moving():
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty when stopped
reward=env.step_penalty, # full step penalty when stopped
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty # full step penalty when stopped
reward=env.step_penalty, # full step penalty when stopped
status=RailAgentStatus.ACTIVE
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
......@@ -291,7 +294,8 @@ def test_initial_malfunction_stop_moving():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty # full step penalty while stopped
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
# we have stopped and do nothing --> should stand still
Replay(
......@@ -299,7 +303,8 @@ def test_initial_malfunction_stop_moving():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty # full step penalty while stopped
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
# we start to move forward --> should go to next cell now
Replay(
......@@ -307,14 +312,16 @@ def test_initial_malfunction_stop_moving():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # full step penalty while stopped
reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # full step penalty while stopped
reward=env.step_penalty * 1.0, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
)
],
speed=env.agents[0].speed_data['speed'],
......@@ -363,17 +370,19 @@ def test_initial_malfunction_do_nothing():
replay=[Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
action=RailEnvActions.MOVE_FORWARD,
set_malfunction=3,
malfunction=3,
reward=env.step_penalty # full step penalty while malfunctioning
reward=env.step_penalty, # full step penalty while malfunctioning
status=RailAgentStatus.READY_TO_DEPART
),
Replay(
position=(28, 5),
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty # full step penalty while malfunctioning
reward=env.step_penalty, # full step penalty while malfunctioning
status=RailAgentStatus.ACTIVE
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
......@@ -383,7 +392,8 @@ def test_initial_malfunction_do_nothing():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=1,
reward=env.step_penalty # full step penalty while stopped
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
# we haven't started moving yet --> stay here
Replay(
......@@ -391,7 +401,8 @@ def test_initial_malfunction_do_nothing():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty # full step penalty while stopped
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
),
# we start to move forward --> should go to next cell now
Replay(
......@@ -399,14 +410,16 @@ def test_initial_malfunction_do_nothing():
direction=Grid4TransitionsEnum.EAST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0 # start penalty + step penalty for speed 1.0
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
status=RailAgentStatus.ACTIVE
),
Replay(
position=(28, 4),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0 # step penalty for speed 1.0
reward=env.step_penalty * 1.0, # step penalty for speed 1.0
status=RailAgentStatus.ACTIVE
)
],
speed=env.agents[0].speed_data['speed'],
......
......@@ -5,7 +5,7 @@ import numpy as np
from attr import attrs, attrib
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.rail_env import RailEnvActions, RailEnv
from flatland.utils.rendertools import RenderTool
......@@ -18,6 +18,7 @@ class Replay(object):
malfunction = attrib(default=0, type=int)
set_malfunction = attrib(default=None, type=Optional[int])
reward = attrib(default=None, type=Optional[float])
status = attrib(default=None, type=Optional[RailAgentStatus])
@attrs
......@@ -47,10 +48,12 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
- position, direction before step are verified
- optionally, set_malfunction is applied
- malfunction is verified
- status is verified (optionally)
*After each step*
- reward is verified after step
Parameters
----------
env
......@@ -77,8 +80,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
def _assert(a, actual, expected, msg):
assert np.allclose(actual, expected), "[{}] agent {} {}: actual={}, expected={}".format(step, a, msg,
actual,
expected)
actual,
expected)
action_dict = {}
......@@ -88,6 +91,8 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
if replay.status is not None:
_assert(a, agent.status, replay.status, 'status')
if replay.action is not None:
assert info_dict['action_required'][a] == True, "[{}] agent {} expecting action_required={}".format(
......@@ -109,5 +114,3 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
for a, test_config in enumerate(test_configs):
replay = test_config.replay[step]
_assert(a, rewards_dict[a], replay.reward, 'reward')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment