From 4169a0f1e6f52cefd5574f31e31e5fb9ace9d4cc Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Fri, 10 Sep 2021 01:34:21 +0530 Subject: [PATCH] fixes to env.step() direction update --- flatland/action_plan/action_plan.py | 2 +- flatland/action_plan/action_plan_player.py | 3 -- flatland/envs/observations.py | 2 +- flatland/envs/rail_env.py | 46 ++++++++++++---------- flatland/envs/step_utils/speed_counter.py | 11 +++--- tests/test_action_plan.py | 3 +- 6 files changed, 36 insertions(+), 31 deletions(-) diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index b1a56d81..96a44129 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -150,7 +150,7 @@ class ControllerFromTrainruns(): def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] - minimum_cell_time = agent.speed_counter.max_count + minimum_cell_time = agent.speed_counter.max_count + 1 for path_loop, trainrun_waypoint in enumerate(trainrun): trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py index 074e5590..f9b82ba9 100644 --- a/flatland/action_plan/action_plan_player.py +++ b/flatland/action_plan/action_plan_player.py @@ -30,10 +30,7 @@ class ControllerFromTrainrunsReplayer(): assert agent.position == waypoint.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, waypoint.position) - if agent_id == 1: - print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter) actions = ctl.act(i) - print("actions for {}: {}".format(i, actions)) obs, all_rewards, done, _ = env.step(actions) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 456d56a0..0b5f2a84 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder): agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], - speed_min_fractional=agent.speed_counter.speed + speed_min_fractional=agent.speed_counter.speed, num_agents_ready_to_depart=0, childs={}) #print("root node type:", type(root_node_observation)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2915e9be..6a766f35 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -366,9 +366,10 @@ class RailEnv(Environment): new_position = get_new_position(position, new_direction) else: new_position, new_direction = position, direction - return new_position, direction + return new_position, new_direction def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): + """ Generate State Transitions Signals used in the state machine """ st_signals = StateTransitionSignals() # Malfunction onset - Malfunction starts @@ -442,9 +443,8 @@ class RailEnv(Environment): return action def clear_rewards_dict(self): - """ Reset the step rewards """ - - self.rewards_dict = dict() + """ Reset the rewards dictionary """ + self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} def get_info_dict(self): # TODO Important : Update this info_dict = { @@ -456,6 +456,22 @@ class RailEnv(Environment): 'state': {i: agent.state for i, agent in enumerate(self.agents)} } return info_dict + + def update_step_rewards(self, i_agent): + pass + + def end_of_episode_update(self, have_all_agents_ended): + if have_all_agents_ended or \ + ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)): + + for i_agent, agent in enumerate(self.agents): + + reward = self._handle_end_reward(agent) + self.rewards_dict[i_agent] += reward + + self.dones[i_agent] = True + + self.dones["__all__"] = True def step(self, action_dict_: Dict[int, RailEnvActions]): """ @@ -520,6 +536,8 @@ class RailEnv(Environment): i_agent = agent.handle agent_transition_data = temp_transition_data[i_agent] + old_position = agent.position + ## Update positions if agent.malfunction_handler.in_malfunction: movement_allowed = False @@ -544,30 +562,18 @@ class RailEnv(Environment): have_all_agents_ended &= (agent.state == TrainState.DONE) ## Update rewards - # self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards + self.update_step_rewards(i_agent) ## Update counters (malfunction and speed) - agent.speed_counter.update_counter(agent.state) + agent.speed_counter.update_counter(agent.state, old_position) agent.malfunction_handler.update_counter() # Clear old action when starting in new cell if agent.speed_counter.is_cell_entry: agent.action_saver.clear_saved_action() - - - self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} - if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \ - or have_all_agents_ended : - - for i_agent, agent in enumerate(self.agents): - - reward = self._handle_end_reward(agent) - self.rewards_dict[i_agent] += reward - - self.dones[i_agent] = True - - self.dones["__all__"] = True + # Check if episode has ended and update rewards and dones + self.end_of_episode_update(have_all_agents_ended) return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py index 5bde9c20..27208781 100644 --- a/flatland/envs/step_utils/speed_counter.py +++ b/flatland/envs/step_utils/speed_counter.py @@ -4,12 +4,13 @@ from flatland.envs.step_utils.states import TrainState class SpeedCounter: def __init__(self, speed): self.speed = speed - self.max_count = int(1/speed) + self.max_count = int(1/speed) - 1 - def update_counter(self, state): - if state == TrainState.MOVING: + def update_counter(self, state, old_position): + # When coming onto the map, do no update speed counter + if state == TrainState.MOVING and old_position is not None: self.counter += 1 - self.counter = self.counter % self.max_count + self.counter = self.counter % (self.max_count + 1) def __repr__(self): return f"speed: {self.speed} \ @@ -27,5 +28,5 @@ class SpeedCounter: @property def is_cell_exit(self): - return self.counter == self.max_count - 1 + return self.counter == self.max_count diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 2b062c4e..9be4fdf6 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -9,6 +9,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.step_utils.speed_counter import SpeedCounter def test_action_plan(rendering: bool = False): @@ -29,7 +30,7 @@ def test_action_plan(rendering: bool = False): env.agents[1].initial_position = (3, 8) env.agents[1].initial_direction = Grid4TransitionsEnum.WEST env.agents[1].target = (0, 3) - env.agents[1].speed_data['speed'] = 0.5 # two + env.agents[1].speed_counter = SpeedCounter(speed=0.5) env.reset(False, False) for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) -- GitLab