Commit e3c821e5 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

agent_positions and docstrings and other cleanups

parent 08a30dbb
......@@ -7,13 +7,11 @@ from typing import List, Optional, Dict, Tuple
import numpy as np
from gym.utils import seeding
from dataclasses import dataclass
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
......@@ -30,8 +28,8 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import transition_utils
from flatland.envs.step_utils import action_preprocessing
from flatland.envs.step_utils import env_utils
class RailEnv(Environment):
"""
......@@ -110,7 +108,6 @@ class RailEnv(Environment):
remove_agents_at_target=True,
random_seed=1,
record_steps=False,
close_following=True
):
"""
Environment init.
......@@ -178,16 +175,12 @@ class RailEnv(Environment):
self.remove_agents_at_target = remove_agents_at_target
self.rewards = [0] * number_of_agents
self.done = False
self.obs_builder = obs_builder_object
self.obs_builder.set_env(self)
self._max_episode_steps: Optional[int] = None
self._elapsed_steps = 0
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self.obs_dict = {}
self.rewards_dict = {}
self.dev_obs_dict = {}
......@@ -205,10 +198,7 @@ class RailEnv(Environment):
if self.random_seed:
self._seed(seed=random_seed)
self.valid_positions = None
# global numpy array of agents position, True means that there is an agent at that cell
self.agent_positions: np.ndarray = np.full((height, width), False)
self.agent_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.record_steps = record_steps # whether to save timesteps
......@@ -216,11 +206,8 @@ class RailEnv(Environment):
self.cur_episode = []
self.list_actions = [] # save actions in here
self.close_following = close_following # use close following logic
self.motionCheck = ac.MotionCheck()
self.agent_helpers = {}
def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
......@@ -229,7 +216,7 @@ class RailEnv(Environment):
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_num_agents(self) -> int:
return len(self.agents)
......@@ -337,9 +324,6 @@ class RailEnv(Environment):
agent.latest_arrival = timetable.latest_arrivals[agent_i]
else:
self.distance_map.reset(self.agents, self.rail)
# Agent Positions Map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
# Reset agents to initial states
self.reset_agents()
......@@ -347,7 +331,10 @@ class RailEnv(Environment):
self.num_resets += 1
self._elapsed_steps = 0
# TODO perhaps dones should be part of each agent.
# Agent positions map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
self._update_agent_positions_map(ignore_old_positions=False)
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
# Reset the state of the observation builder with the new environment
......@@ -362,14 +349,16 @@ class RailEnv(Environment):
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def apply_action_independent(self, action, rail, position, direction):
if action.is_moving_action():
new_direction, _ = transition_utils.check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, new_direction
def _update_agent_positions_map(self, ignore_old_positions=True):
""" Update the agent_positions array for agents that changed positions """
for agent in self.agents:
if not ignore_old_positions or agent.old_position != agent.position:
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
......@@ -391,7 +380,7 @@ class RailEnv(Environment):
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
st_signals.target_reached = env_utils.fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
......@@ -449,11 +438,18 @@ class RailEnv(Environment):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
def get_info_dict(self): # TODO Important : Update this
def get_info_dict(self):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
"""
info_dict = {
'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)},
'malfunction': {
i: agent.malfunction_data['malfunction'] for i, agent in enumerate(self.agents)
i: agent.malfunction_handler.malfunction_down_counter for i, agent in enumerate(self.agents)
},
'speed': {i: agent.speed_counter.speed for i, agent in enumerate(self.agents)},
'state': {i: agent.state for i, agent in enumerate(self.agents)}
......@@ -461,9 +457,16 @@ class RailEnv(Environment):
return info_dict
def update_step_rewards(self, i_agent):
"""
Update the rewards dict for agent id i_agent for every timestep
"""
pass
def end_of_episode_update(self, have_all_agents_ended):
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
if have_all_agents_ended or \
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
......@@ -477,6 +480,7 @@ class RailEnv(Environment):
self.dones["__all__"] = True
def handle_done_state(self, agent):
""" Any updates to agent to be made in Done state """
if agent.state == TrainState.DONE:
agent.arrival_time = self._elapsed_steps
if self.remove_agents_at_target:
......@@ -528,7 +532,7 @@ class RailEnv(Environment):
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = self.apply_action_independent(saved_action,
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
......@@ -536,7 +540,7 @@ class RailEnv(Environment):
else:
new_position, new_direction = agent.position, agent.direction
temp_transition_data[i_agent] = AgentTransitionData(position=new_position,
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
......@@ -571,7 +575,7 @@ class RailEnv(Environment):
agent.state_machine.step()
# Off map or on map state and position should match
state_position_sync_check(agent.state, agent.position, agent.handle)
env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
......@@ -593,11 +597,14 @@ class RailEnv(Environment):
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
''' Record the positions and orientations of all agents in memory, in the cur_episode
'''
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
......@@ -610,7 +617,7 @@ class RailEnv(Environment):
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_data["malfunction"],
agent.malfunction_handler.malfunction_down_counter,
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
......@@ -620,11 +627,7 @@ class RailEnv(Environment):
def _get_observations(self):
"""
Utility which returns the observations for an agent with respect to environment
Returns
------
Dict object
Utility which returns the dictionary of observations for an agent with respect to environment
"""
# print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}")
self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents())))
......@@ -633,15 +636,6 @@ class RailEnv(Environment):
def get_valid_directions_on_grid(self, row: int, col: int) -> List[int]:
"""
Returns directions in which the agent can move
Parameters:
---------
row : int
col : int
Returns:
-------
List[int]
"""
return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col))
......@@ -669,9 +663,10 @@ class RailEnv(Environment):
"""
return agent.malfunction_handler.in_malfunction
def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
......@@ -747,31 +742,4 @@ class RailEnv(Environment):
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
@dataclass(repr=True)
class AgentTransitionData:
""" Class for keeping track of temporary agent data for position update """
position : Tuple[int, int]
direction : Grid4Transitions
preprocessed_action : RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def state_position_sync_check(state, position, i_agent):
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
self.renderer = None
\ No newline at end of file
from dataclasses import dataclass
from typing import Tuple
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.step_utils import transition_utils
from flatland.envs.rail_env_action import RailEnvActions
from flatland.core.grid.grid4 import Grid4Transitions
@dataclass(repr=True)
class AgentTransitionData:
""" Class for keeping track of temporary agent data for position update """
position : Tuple[int, int]
direction : Grid4Transitions
preprocessed_action : RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None:
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def apply_action_independent(action, rail, position, direction):
""" Apply the action on the train regardless of locations of other trains
Checks for valid cells to move and valid rail transitions
---------------------------------------------------------------------
Parameters: action - Action to execute
rail - Flatland env.rail object
position - current position of the train
direction - current direction of the train
---------------------------------------------------------------------
Returns: new_position - New position after applying the action
new_direction - New direction after applying the action
"""
if action.is_moving_action():
new_direction, _ = transition_utils.check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, new_direction
def state_position_sync_check(state, position, i_agent):
""" Check for whether on map and off map states are matching with position """
if state.is_on_map_state() and position is None:
raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format(
i_agent, str(state), str(position) ))
elif state.is_off_map_state() and position is not None:
raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format(
i_agent, str(state), str(position) ))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment