Commit 4169a0f1 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

fixes to env.step() direction update

parent e4399082
Pipeline #8454 failed with stages
in 5 minutes and 5 seconds
......@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = agent.speed_counter.max_count
minimum_cell_time = agent.speed_counter.max_count + 1
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
......
......@@ -30,10 +30,7 @@ class ControllerFromTrainrunsReplayer():
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
waypoint.position)
if agent_id == 1:
print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
obs, all_rewards, done, _ = env.step(actions)
......
......@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
speed_min_fractional=agent.speed_counter.speed
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
#print("root node type:", type(root_node_observation))
......
......@@ -366,9 +366,10 @@ class RailEnv(Environment):
new_position = get_new_position(position, new_direction)
else:
new_position, new_direction = position, direction
return new_position, direction
return new_position, new_direction
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
# Malfunction onset - Malfunction starts
......@@ -442,9 +443,8 @@ class RailEnv(Environment):
return action
def clear_rewards_dict(self):
""" Reset the step rewards """
self.rewards_dict = dict()
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
def get_info_dict(self): # TODO Important : Update this
info_dict = {
......@@ -456,6 +456,22 @@ class RailEnv(Environment):
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
def update_step_rewards(self, i_agent):
pass
def end_of_episode_update(self, have_all_agents_ended):
if have_all_agents_ended or \
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
......@@ -520,6 +536,8 @@ class RailEnv(Environment):
i_agent = agent.handle
agent_transition_data = temp_transition_data[i_agent]
old_position = agent.position
## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
......@@ -544,30 +562,18 @@ class RailEnv(Environment):
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
# self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards
self.update_step_rewards(i_agent)
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state)
agent.speed_counter.update_counter(agent.state, old_position)
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry:
agent.action_saver.clear_saved_action()
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \
or have_all_agents_ended :
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
......
......@@ -4,12 +4,13 @@ from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self.speed = speed
self.max_count = int(1/speed)
self.max_count = int(1/speed) - 1
def update_counter(self, state):
if state == TrainState.MOVING:
def update_counter(self, state, old_position):
# When coming onto the map, do no update speed counter
if state == TrainState.MOVING and old_position is not None:
self.counter += 1
self.counter = self.counter % self.max_count
self.counter = self.counter % (self.max_count + 1)
def __repr__(self):
return f"speed: {self.speed} \
......@@ -27,5 +28,5 @@ class SpeedCounter:
@property
def is_cell_exit(self):
return self.counter == self.max_count - 1
return self.counter == self.max_count
......@@ -9,6 +9,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.speed_counter import SpeedCounter
def test_action_plan(rendering: bool = False):
......@@ -29,7 +30,7 @@ def test_action_plan(rendering: bool = False):
env.agents[1].initial_position = (3, 8)
env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
env.agents[1].target = (0, 3)
env.agents[1].speed_data['speed'] = 0.5 # two
env.agents[1].speed_counter = SpeedCounter(speed=0.5)
env.reset(False, False)
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment