Commit be854ec7 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

reward scheme implemented, treeobs fixed, untested changes

parent c764e893
import numpy as np
from enum import IntEnum
from itertools import starmap
from typing import Tuple, Optional, NamedTuple
......@@ -7,14 +9,12 @@ from attr import attr, attrs, attrib, Factory
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.schedule_utils import Schedule
class RailAgentStatus(IntEnum):
WAITING = 0
READY_TO_DEPART = 1 # not in grid yet (position is None) -> prediction as if it were at initial position
ACTIVE = 2 # in grid (position is not None), not done -> prediction is remaining path
DONE = 3 # in grid (position is not None), but done -> prediction is stay at target forever
DONE_REMOVED = 4 # removed from grid (position is None) -> prediction is None
CANCELLED = 5
Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
......@@ -29,6 +29,7 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('handle', int),
('status', RailAgentStatus),
('position', Tuple[int, int]),
('arrival_time', int),
('old_direction', Grid4TransitionsEnum),
('old_position', Tuple[int, int])])
......@@ -66,6 +67,9 @@ class EnvAgent:
status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]])
# NEW : EnvAgent Reward Handling
arrival_time = attrib(default=None, type=int)
# used in rendering
old_direction = attrib(default=None)
old_position = attrib(default=None)
......@@ -83,6 +87,8 @@ class EnvAgent:
else:
self.status = RailAgentStatus.WAITING
self.arrival_time = None
self.old_position = None
self.old_direction = None
self.moving = False
......@@ -96,12 +102,33 @@ class EnvAgent:
self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False
# NEW : Callables
def get_shortest_path(self, distance_map):
from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix
return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle]
def get_travel_time_on_shortest_path(self, distance_map):
distance = len(self.get_shortest_path(distance_map))
speed = self.speed_data['speed']
return int(np.ceil(distance / speed))
def get_time_remaining_until_latest_arrival(self, elapsed_steps):
return self.latest_arrival - elapsed_steps
def get_current_delay(self, elapsed_steps, distance_map):
'''
+ve if arrival time is projected before latest arrival
-ve if arrival time is projected after latest arrival
'''
return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \
self.get_travel_time_on_shortest_path(distance_map)
def to_agent(self) -> Agent:
return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction,
direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival, speed_data=self.speed_data,
malfunction_data=self.malfunction_data, handle=self.handle, status=self.status,
position=self.position, old_direction=self.old_direction, old_position=self.old_position)
latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data,
handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time,
old_direction=self.old_direction, old_position=self.old_position)
@classmethod
def from_schedule(cls, schedule: Schedule):
......
......@@ -101,10 +101,13 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
# [NIMISH] WHAT IS THIS
if _agent.status in [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.WAITING] and \
_agent.initial_position:
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
self.location_has_agent_ready_to_depart.setdefault(tuple(_agent.initial_position), 0)
self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] += 1
# self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
# self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
observations = super().get_many(handles)
......@@ -192,8 +195,10 @@ class TreeObsForRailEnv(ObservationBuilder):
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
......@@ -564,7 +569,9 @@ class GlobalObsForRailEnv(ObservationBuilder):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
......@@ -602,7 +609,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART:
if other_agent.status == RailAgentStatus.READY_TO_DEPART or other_agent.status == RailAgentStatus.WAITING:
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -126,8 +126,9 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
prediction_dict = {}
for agent in agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.status == RailAgentStatus.WAITING:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
......
......@@ -121,15 +121,18 @@ class RailEnv(Environment):
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
"""
alpha = 1.0
beta = 1.0
# Epsilon to avoid rounding errors
epsilon = 0.01
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
# NEW : REW: Sparse Reward
alpha = 0
beta = 0
step_penalty = -1 * alpha
global_reward = 1 * beta
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
stop_penalty = 0 # penalty for stopping a moving agent
start_penalty = 0 # penalty for starting a stopped agent
cancellation_factor = 1
cancellation_time_buffer = 0
def __init__(self,
width,
......@@ -367,21 +370,24 @@ class RailEnv(Environment):
# Look at the specific schedule generator used to see where this number comes from
self._max_episode_steps = schedule.max_episode_steps # NEW UPDATE THIS!
# Agent Positions Map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
# Reset distance map - basically initializing
self.distance_map.reset(self.agents, self.rail)
# NEW : Time Schedule Generation
# find agent speeds (needed for max_ep_steps recalculation)
if (type(self.schedule_generator.speed_ratio_map) is dict):
config_speeds = list(self.schedule_generator.speed_ratio_map.keys())
else:
config_speeds = [1.0]
# NEW : Time Schedule Generation
# find agent speeds (needed for max_ep_steps recalculation)
if (type(self.schedule_generator.speed_ratio_map) is dict):
config_speeds = list(self.schedule_generator.speed_ratio_map.keys())
else:
config_speeds = [1.0]
self._max_episode_steps = schedule_time_generator(self.agents, config_speeds, self.distance_map,
self._max_episode_steps, self.np_random, temp_info=optionals)
# Reset distance map - again (just in case if regen_schedule = False)
self.distance_map.reset(self.agents, self.rail)
self._max_episode_steps = schedule_time_generator(self.agents, config_speeds, self.distance_map,
self._max_episode_steps, self.np_random, temp_info=optionals)
# Agent Positions Map
self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
# Reset agents to initial states
self.reset_agents()
......@@ -488,21 +494,7 @@ class RailEnv(Environment):
# If we're done, set reward and info_dict and step() is done.
if self.dones["__all__"]:
self.rewards_dict = {}
info_dict = {
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
for i_agent, agent in enumerate(self.agents):
self.rewards_dict[i_agent] = self.global_reward
info_dict["action_required"][i_agent] = False
info_dict["malfunction"][i_agent] = 0
info_dict["speed"][i_agent] = 0
info_dict["status"][i_agent] = agent.status
return self._get_observations(), self.rewards_dict, self.dones, info_dict
raise Exception("Episode is done, cannot call step()")
# Reset the step rewards
self.rewards_dict = dict()
......@@ -570,28 +562,40 @@ class RailEnv(Environment):
# Fix agents that finished their malfunction such that they can perform an action in the next step
self._fix_agent_after_malfunction(agent)
# Check for end of episode + set global reward to all rewards!
if have_all_agents_ended:
self.dones["__all__"] = True
self.rewards_dict = {i: self.global_reward for i in range(self.get_num_agents())}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
# NEW : REW: (END)
if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
or have_all_agents_ended :
for i_agent, agent in enumerate(self.agents):
# NEW : STEP:REW: CANCELLED check / reward (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART):
agent.status = RailAgentStatus.CANCELLED
# NEGATIVE REWARD?
# NEW : STEP:REW: Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE):
pass
# NEGATIVE REWARD?
# agent done? (arrival_time is not None)
if (self.dones[i_agent]):
# if agent arrived earlier or on time = 0
# if agent arrived later = -ve reward based on how late
reward = min(agent.latest_arrival - agent.arrival_time, 0)
self.rewards_dict[i_agent] += reward
# Agents not done (arrival_time is None)
else:
# CANCELLED check (never departed)
if (agent.status == RailAgentStatus.READY_TO_DEPART):
reward = -1 * self.cancellation_factor * \
(agent.get_travel_time_on_shortest_path(self.distance_map) + 0) # 0 replaced with buffer
self.rewards_dict[i_agent] += reward
# Departed but never reached
if (agent.status == RailAgentStatus.ACTIVE):
reward = agent.get_current_delay(self.distance_map)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
if self.record_steps:
self.record_timestep(action_dict_)
......@@ -859,7 +863,7 @@ class RailEnv(Environment):
def _step_agent2_cf(self, i_agent):
agent = self.agents[i_agent]
# NEW : REW: no reward during WAITING...
# NEW : REW: (WAITING) no reward during WAITING...
if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED, RailAgentStatus.WAITING]:
return
......@@ -901,23 +905,16 @@ class RailEnv(Environment):
agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0
# NEW : REW: Check DONE before / after LA & Check if RUNNING before / after LA
# NEW : STEP: Check DONE before / after LA & Check if RUNNING before / after LA
# has the agent reached its target?
if np.equal(agent.position, agent.target).all():
# arrived before Latest Arrival
if (self._elapsed_steps <= agent.latest_arrival):
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
self.active_agents.remove(i_agent)
agent.moving = False
self._remove_agent_from_scene(agent)
else: # arrived after latest arrival
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
self.active_agents.remove(i_agent)
agent.moving = False
self._remove_agent_from_scene(agent)
# NEGATIVE REWARD?
# arrived before or after Latest Arrival
agent.status = RailAgentStatus.DONE
self.dones[i_agent] = True
self.active_agents.remove(i_agent)
agent.moving = False
agent.arrival_time = self._elapsed_steps
self._remove_agent_from_scene(agent)
else: # not reached its target and moving
# running before Latest Arrival
......@@ -930,8 +927,7 @@ class RailEnv(Environment):
if (self._elapsed_steps <= agent.latest_arrival):
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
else: # stopped (!move) after Latest Arrival
self.rewards_dict[i_agent] += self.step_penalty * \
agent.speed_data['speed'] # + # NEGATIVE REWARD? per step?
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # + # NEGATIVE REWARD? per step?
def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment