Skip to content
Snippets Groups Projects
Commit d2c0ce1c authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

minor adjustments

parent a8d5dc4f
No related branches found
No related tags found
No related merge requests found
...@@ -155,6 +155,10 @@ class RailEnv(Environment): ...@@ -155,6 +155,10 @@ class RailEnv(Environment):
self.restart_agents() self.restart_agents()
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
agent.speed_data['position_fraction'] = 0.0
self.num_resets += 1 self.num_resets += 1
# TODO perhaps dones should be part of each agent. # TODO perhaps dones should be part of each agent.
...@@ -192,7 +196,7 @@ class RailEnv(Environment): ...@@ -192,7 +196,7 @@ class RailEnv(Environment):
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[iAgent]
if iAgent % 2 == 0: if iAgent % 2 == 0:
agent.speed_data["speed"] = 1./3. agent.speed_data["speed"] = 1./10.
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
continue continue
...@@ -211,7 +215,7 @@ class RailEnv(Environment): ...@@ -211,7 +215,7 @@ class RailEnv(Environment):
# Keep moving # Keep moving
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] < 0.01: if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
# Only allow halting an agent on entering new cells. # Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[iAgent] += stop_penalty
...@@ -233,12 +237,12 @@ class RailEnv(Environment): ...@@ -233,12 +237,12 @@ class RailEnv(Environment):
# If the agent can make an action # If the agent can make an action
action_selected = False action_selected = False
if agent.speed_data['position_fraction'] < 0.01: if agent.speed_data['position_fraction'] == 0.:
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(action, agent) self._check_action_on_agent(action, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]): if all([new_cell_isValid, transition_isValid]):
agent.speed_data['transition_action_on_cellexit'] = action agent.speed_data['transition_action_on_cellexit'] = action
action_selected = True action_selected = True
...@@ -249,7 +253,7 @@ class RailEnv(Environment): ...@@ -249,7 +253,7 @@ class RailEnv(Environment):
cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
if all([new_cell_isValid, transition_isValid, cell_isFree]): if all([new_cell_isValid, transition_isValid]):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
action_selected = True action_selected = True
...@@ -258,15 +262,17 @@ class RailEnv(Environment): ...@@ -258,15 +262,17 @@ class RailEnv(Environment):
self.rewards_dict[iAgent] += invalid_action_penalty self.rewards_dict[iAgent] += invalid_action_penalty
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[iAgent] += stop_penalty
continue continue
else: else:
# TODO: an invalid action was chosen after entering the cell. The agent cannot move. # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
self.rewards_dict[iAgent] += invalid_action_penalty self.rewards_dict[iAgent] += invalid_action_penalty
agent.moving = False agent.moving = False
self.rewards_dict[iAgent] += stop_penalty self.rewards_dict[iAgent] += stop_penalty
continue continue
if agent.moving and (action_selected or agent.speed_data['position_fraction'] >= 0.01): if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
agent.speed_data['position_fraction'] += agent.speed_data['speed'] agent.speed_data['position_fraction'] += agent.speed_data['speed']
if agent.speed_data['position_fraction'] >= 1.0: if agent.speed_data['position_fraction'] >= 1.0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment