From d2c0ce1c30a5462d299c414ec3ea5351f8e34d0a Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 22 Jun 2019 12:20:14 -0500
Subject: [PATCH] minor adjustments

---
 flatland/envs/rail_env.py | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index fccf6b89..d6a7cfac 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -155,6 +155,10 @@ class RailEnv(Environment):
 
         self.restart_agents()
 
+        for iAgent in range(self.get_num_agents()):
+            agent = self.agents[iAgent]
+            agent.speed_data['position_fraction'] = 0.0
+
         self.num_resets += 1
 
         # TODO perhaps dones should be part of each agent.
@@ -192,7 +196,7 @@ class RailEnv(Environment):
         for iAgent in range(self.get_num_agents()):
             agent = self.agents[iAgent]
             if iAgent % 2 == 0:
-                agent.speed_data["speed"] = 1./3.
+                agent.speed_data["speed"] = 1./10.
             if self.dones[iAgent]:  # this agent has already completed...
                 continue
 
@@ -211,7 +215,7 @@ class RailEnv(Environment):
                 # Keep moving
                 action = RailEnvActions.MOVE_FORWARD
 
-            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] < 0.01:
+            if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.:
                 # Only allow halting an agent on entering new cells.
                 agent.moving = False
                 self.rewards_dict[iAgent] += stop_penalty
@@ -233,12 +237,12 @@ class RailEnv(Environment):
 
             # If the agent can make an action
             action_selected = False
-            if agent.speed_data['position_fraction'] < 0.01:
+            if agent.speed_data['position_fraction'] == 0.:
                 if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
                     cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                         self._check_action_on_agent(action, agent)
 
-                    if all([new_cell_isValid, transition_isValid, cell_isFree]):
+                    if all([new_cell_isValid, transition_isValid]):
                         agent.speed_data['transition_action_on_cellexit'] = action
                         action_selected = True
 
@@ -249,7 +253,7 @@ class RailEnv(Environment):
                             cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                                 self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent)
 
-                            if all([new_cell_isValid, transition_isValid, cell_isFree]):
+                            if all([new_cell_isValid, transition_isValid]):
                                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD
                                 action_selected = True
 
@@ -258,15 +262,17 @@ class RailEnv(Environment):
                                 self.rewards_dict[iAgent] += invalid_action_penalty
                                 agent.moving = False
                                 self.rewards_dict[iAgent] += stop_penalty
+
                                 continue
                         else:
                             # TODO: an invalid action was chosen after entering the cell. The agent cannot move.
                             self.rewards_dict[iAgent] += invalid_action_penalty
                             agent.moving = False
                             self.rewards_dict[iAgent] += stop_penalty
+
                             continue
 
-            if agent.moving and (action_selected or agent.speed_data['position_fraction'] >= 0.01):
+            if agent.moving and (action_selected or agent.speed_data['position_fraction'] > 0.0):
                 agent.speed_data['position_fraction'] += agent.speed_data['speed']
 
             if agent.speed_data['position_fraction'] >= 1.0:
-- 
GitLab