Skip to content
Snippets Groups Projects
Commit 47a23b3c authored by spiglerg's avatar spiglerg
Browse files

fixed bad tests bug

parent 340e965b
No related branches found
No related tags found
No related merge requests found
...@@ -191,16 +191,10 @@ class RailEnv(Environment): ...@@ -191,16 +191,10 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)): # for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()): for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent] agent = self.agents[iAgent]
print(agent.speed_data['speed'])
if self.dones[iAgent]: # this agent has already completed... if self.dones[iAgent]: # this agent has already completed...
continue continue
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
if iAgent not in action_dict: # no action has been supplied for this agent if iAgent not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING action_dict[iAgent] = RailEnvActions.DO_NOTHING
...@@ -288,6 +282,11 @@ class RailEnv(Environment): ...@@ -288,6 +282,11 @@ class RailEnv(Environment):
agent.position = new_position agent.position = new_position
agent.direction = new_direction agent.direction = new_direction
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
# Check for end of episode + add global reward to all rewards! # Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True self.dones["__all__"] = True
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import numpy as np import numpy as np
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions from flatland.core.transitions import Grid4Transitions, RailEnvTransitions
from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator
...@@ -53,7 +53,7 @@ def test_rail_environment_single_agent(): ...@@ -53,7 +53,7 @@ def test_rail_environment_single_agent():
# | | | # | | |
# \_/\_/ # \_/\_/
transitions = Grid4Transitions([]) transitions = RailEnvTransitions()
vertical_line = cells[1] vertical_line = cells[1]
south_symmetrical_switch = cells[6] south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
...@@ -107,6 +107,7 @@ def test_rail_environment_single_agent(): ...@@ -107,6 +107,7 @@ def test_rail_environment_single_agent():
if prev_pos != pos: if prev_pos != pos:
valid_active_actions_done += 1 valid_active_actions_done += 1
# After 6 movements on this railway network, the train should be back # After 6 movements on this railway network, the train should be back
# to its original height on the map. # to its original height on the map.
assert (initial_pos[0] == agent.position[0]) assert (initial_pos[0] == agent.position[0])
...@@ -121,9 +122,9 @@ def test_rail_environment_single_agent(): ...@@ -121,9 +122,9 @@ def test_rail_environment_single_agent():
action = np.random.randint(4) action = np.random.randint(4)
_, _, dones, _ = rail_env.step({0: action}) _, _, dones, _ = rail_env.step({0: action})
done = dones['__all__'] done = dones['__all__']
test_rail_environment_single_agent()
def test_dead_end(): def test_dead_end():
transitions = Grid4Transitions([]) transitions = Grid4Transitions([])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment