diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index b1a56d81839bff62f13a27753a935a19a8d05fe9..96a441299fd68b9d8f0e51e6d3e2b543ec15ba57 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -150,7 +150,7 @@ class ControllerFromTrainruns(): def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] - minimum_cell_time = agent.speed_counter.max_count + minimum_cell_time = agent.speed_counter.max_count + 1 for path_loop, trainrun_waypoint in enumerate(trainrun): trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint diff --git a/flatland/action_plan/action_plan_player.py b/flatland/action_plan/action_plan_player.py index 074e5590185ff601f9c038e9df4c23fd2f84c455..f9b82ba967392816319a8203b136524a1abba0fa 100644 --- a/flatland/action_plan/action_plan_player.py +++ b/flatland/action_plan/action_plan_player.py @@ -30,10 +30,7 @@ class ControllerFromTrainrunsReplayer(): assert agent.position == waypoint.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, waypoint.position) - if agent_id == 1: - print(env._elapsed_steps, agent.position, agent.state, agent.speed_counter) actions = ctl.act(i) - print("actions for {}: {}".format(i, actions)) obs, all_rewards, done, _ = env.step(actions) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 456d56a0c58fdbaefa5a2ff4c5e938b74618e1c1..0b5f2a845d525f36456ce3c770fe4453d2c8a0e5 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder): agent.direction)], num_agents_same_direction=0, num_agents_opposite_direction=0, num_agents_malfunctioning=agent.malfunction_data['malfunction'], - speed_min_fractional=agent.speed_counter.speed + speed_min_fractional=agent.speed_counter.speed, num_agents_ready_to_depart=0, childs={}) #print("root node type:", type(root_node_observation)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2915e9be2b1d9537631f0639a6f20a9f05955d17..6a766f35be9d26f3d40623a7ba9c314f410751b3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -366,9 +366,10 @@ class RailEnv(Environment): new_position = get_new_position(position, new_direction) else: new_position, new_direction = position, direction - return new_position, direction + return new_position, new_direction def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): + """ Generate State Transitions Signals used in the state machine """ st_signals = StateTransitionSignals() # Malfunction onset - Malfunction starts @@ -442,9 +443,8 @@ class RailEnv(Environment): return action def clear_rewards_dict(self): - """ Reset the step rewards """ - - self.rewards_dict = dict() + """ Reset the rewards dictionary """ + self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} def get_info_dict(self): # TODO Important : Update this info_dict = { @@ -456,6 +456,22 @@ class RailEnv(Environment): 'state': {i: agent.state for i, agent in enumerate(self.agents)} } return info_dict + + def update_step_rewards(self, i_agent): + pass + + def end_of_episode_update(self, have_all_agents_ended): + if have_all_agents_ended or \ + ( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)): + + for i_agent, agent in enumerate(self.agents): + + reward = self._handle_end_reward(agent) + self.rewards_dict[i_agent] += reward + + self.dones[i_agent] = True + + self.dones["__all__"] = True def step(self, action_dict_: Dict[int, RailEnvActions]): """ @@ -520,6 +536,8 @@ class RailEnv(Environment): i_agent = agent.handle agent_transition_data = temp_transition_data[i_agent] + old_position = agent.position + ## Update positions if agent.malfunction_handler.in_malfunction: movement_allowed = False @@ -544,30 +562,18 @@ class RailEnv(Environment): have_all_agents_ended &= (agent.state == TrainState.DONE) ## Update rewards - # self.update_rewards(i_agent, agent, rail) # TODO : Step Rewards + self.update_step_rewards(i_agent) ## Update counters (malfunction and speed) - agent.speed_counter.update_counter(agent.state) + agent.speed_counter.update_counter(agent.state, old_position) agent.malfunction_handler.update_counter() # Clear old action when starting in new cell if agent.speed_counter.is_cell_entry: agent.action_saver.clear_saved_action() - - - self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} - if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \ - or have_all_agents_ended : - - for i_agent, agent in enumerate(self.agents): - - reward = self._handle_end_reward(agent) - self.rewards_dict[i_agent] += reward - - self.dones[i_agent] = True - - self.dones["__all__"] = True + # Check if episode has ended and update rewards and dones + self.end_of_episode_update(have_all_agents_ended) return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict() diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py index 5bde9c20f98b1b7ed26ad4a8ba3d5791786bd84f..272087817439a659298fa12f71aaa7c982b91bf5 100644 --- a/flatland/envs/step_utils/speed_counter.py +++ b/flatland/envs/step_utils/speed_counter.py @@ -4,12 +4,13 @@ from flatland.envs.step_utils.states import TrainState class SpeedCounter: def __init__(self, speed): self.speed = speed - self.max_count = int(1/speed) + self.max_count = int(1/speed) - 1 - def update_counter(self, state): - if state == TrainState.MOVING: + def update_counter(self, state, old_position): + # When coming onto the map, do no update speed counter + if state == TrainState.MOVING and old_position is not None: self.counter += 1 - self.counter = self.counter % self.max_count + self.counter = self.counter % (self.max_count + 1) def __repr__(self): return f"speed: {self.speed} \ @@ -27,5 +28,5 @@ class SpeedCounter: @property def is_cell_exit(self): - return self.counter == self.max_count - 1 + return self.counter == self.max_count diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 2b062c4e5a892322bcf8c86e3be66e433254b346..9be4fdf6410b6f63455c6df58da8121012778b85 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -9,6 +9,7 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.line_generators import sparse_line_generator from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.step_utils.speed_counter import SpeedCounter def test_action_plan(rendering: bool = False): @@ -29,7 +30,7 @@ def test_action_plan(rendering: bool = False): env.agents[1].initial_position = (3, 8) env.agents[1].initial_direction = Grid4TransitionsEnum.WEST env.agents[1].target = (0, 3) - env.agents[1].speed_data['speed'] = 0.5 # two + env.agents[1].speed_counter = SpeedCounter(speed=0.5) env.reset(False, False) for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))