Commit 6db05ca9 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

WIP test fixes

parent 215595bb
Pipeline #8458 failed with stages
in 3 minutes and 3 seconds
......@@ -216,15 +216,12 @@ class EnvAgent:
agents.append(agent)
return agents
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.direction} target: {self.target} \n \
old_position: {self.old_position} old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_data: {self.malfunction_data} \n \
......@@ -235,6 +232,14 @@ class EnvAgent:
def state(self):
return self.state_machine.state
@state.setter
def state(self, state):
self._set_state(state)
def _set_state(self, state):
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
......@@ -261,7 +261,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return agent.state == TrainState.READY_TO_DEPART or \
(agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
( agent.state.is_on_map_state() and agent.speed_counter.is_cell_entry )
def reset(self, regenerate_rail: bool = True, regenerate_schedule: bool = True, *,
random_seed: bool = None) -> Tuple[Dict, Dict]:
......@@ -385,13 +385,14 @@ class RailEnv(Environment):
st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING)
# Valid Movement action Given
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action()
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
# Target Reached
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information
# If speed counter is not in cell exit, the train can enter the cell
st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit
return st_signals
......@@ -499,6 +500,8 @@ class RailEnv(Environment):
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
agent.old_direction = agent.direction
# Generate malfunction
agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random)
......@@ -542,8 +545,6 @@ class RailEnv(Environment):
i_agent = agent.handle
agent_transition_data = temp_transition_data[i_agent]
old_position = agent.position
## Update positions
if agent.malfunction_handler.in_malfunction:
movement_allowed = False
......@@ -561,6 +562,9 @@ class RailEnv(Environment):
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
if agent.state.is_on_map_state() and agent.position is None:
import pdb; pdb.set_trace()
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
......@@ -570,7 +574,7 @@ class RailEnv(Environment):
self.update_step_rewards(i_agent)
## Update counters (malfunction and speed)
agent.speed_counter.update_counter(agent.state, old_position)
agent.speed_counter.update_counter(agent.state, agent.old_position)
agent.malfunction_handler.update_counter()
# Clear old action when starting in new cell
......
......@@ -28,5 +28,8 @@ class ActionSaver:
def from_dict(self, load_dict):
self.saved_action = load_dict['saved_action']
def __eq__(self, other):
return self.saved_action == other.saved_action
......@@ -46,6 +46,8 @@ class MalfunctionHandler:
def from_dict(self, load_dict):
self._malfunction_down_counter = load_dict['malfunction_down_counter']
def __eq__(self, other):
return self._malfunction_down_counter == other._malfunction_down_counter
......
......@@ -4,6 +4,8 @@ from flatland.envs.step_utils.states import TrainState
class SpeedCounter:
def __init__(self, speed):
self._speed = speed
self.counter = None
self.reset_counter()
def update_counter(self, state, old_position):
# When coming onto the map, do no update speed counter
......@@ -38,8 +40,13 @@ class SpeedCounter:
return int(1/self._speed) - 1
def to_dict(self):
return {"speed": self._speed}
return {"speed": self._speed,
"counter": self.counter}
def from_dict(self, load_dict):
self._speed = load_dict['speed']
self.counter = load_dict['counter']
def __eq__(self, other):
return self._speed == other._speed and self.counter == other.counter
......@@ -6,6 +6,7 @@ class TrainStateMachine:
self._state = initial_state
self.st_signals = StateTransitionSignals()
self.next_state = None
self.previous_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
......@@ -117,10 +118,12 @@ class TrainStateMachine:
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self.previous_state = self._state
self._state = state
def reset(self):
self._state = self._initial_state
self.previous_state = None
self.st_signals = StateTransitionSignals()
self.clear_next_state()
......@@ -137,15 +140,19 @@ class TrainStateMachine:
def __repr__(self):
return f"\n \
state: {str(self.state)} \n \
state: {str(self.state)} previous_state {str(self.previous_state)} \n \
st_signals: {self.st_signals}"
def to_dict(self):
return {"state": self._state}
return {"state": self._state,
"previous_state": self.previous_state}
def from_dict(self, load_dict):
self.set_state(load_dict['state'])
self.previous_state = load_dict['previous_state']
def __eq__(self, other):
return self._state == other._state and self.previous_state == other.previous_state
......@@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info():
for a in range(env.get_num_agents()):
assert info['malfunction'][a] >= 0
assert info['speed'][a] >= 0 and info['speed'][a] <= 1
assert info['speed'][a] == env.agents[a].sspeed_counter.speed
assert info['speed'][a] == env.agents[a].speed_counter.speed
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
......
......@@ -9,6 +9,7 @@ from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
# Use the sparse_rail_generator to generate feasible network configurations with corresponding tasks
......@@ -71,7 +72,7 @@ def test_multi_speed_init():
# See training navigation example in the baseline repository
old_pos = []
for i_agent in range(env.get_num_agents()):
env.agents[i_agent].speed_counter.speed = 1. / (i_agent + 1)
env.agents[i_agent].speed_counter = SpeedCounter(speed = 1. / (i_agent + 1))
old_pos.append(env.agents[i_agent].position)
print(env.agents[i_agent].position)
# Run episode
......
from test_env_step_utils import get_small_two_agent_env
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
from flatland.envs.malfunction_generators import Malfunction
class NoMalfunctionGenerator:
def generate(self, np_random):
return Malfunction(0)
class AlwaysThreeStepMalfunction:
def generate(self, np_random):
return Malfunction(3)
def test_waiting_no_transition():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed-1):
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.WAITING
def test_waiting_to_ready_to_depart():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.READY_TO_DEPART
def test_ready_to_depart_to_moving():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.MOVING
def test_moving_to_stopped():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
env.step({i_agent: RailEnvActions.STOP_MOVING})
assert env.agents[i_agent].state == TrainState.STOPPED
def test_stopped_to_moving():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 0
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
env.step({i_agent: RailEnvActions.STOP_MOVING})
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.MOVING
def test_moving_to_done():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING})
for _ in range(50):
env.step({i_agent: RailEnvActions.MOVE_FORWARD})
assert env.agents[i_agent].state == TrainState.DONE
def test_waiting_to_malfunction():
env = get_small_two_agent_env()
env.malfunction_generator = AlwaysThreeStepMalfunction()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
def test_ready_to_depart_to_malfunction_off_map():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
env.malfunction_generator = AlwaysThreeStepMalfunction()
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
def test_malfunction_off_map_to_waiting():
env = get_small_two_agent_env()
env.malfunction_generator = NoMalfunctionGenerator()
i_agent = 1
env.step({i_agent: RailEnvActions.DO_NOTHING})
ed = env.agents[i_agent].earliest_departure
for _ in range(ed):
env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart
env.malfunction_generator = AlwaysThreeStepMalfunction()
env.step({i_agent: RailEnvActions.DO_NOTHING})
assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP
\ No newline at end of file
......@@ -108,8 +108,10 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
agent: EnvAgent = env.agents[a]
replay = test_config.replay[step]
_assert(a, agent.position, replay.position, 'position')
_assert(a, agent.direction, replay.direction, 'direction')
print(agent.position, replay.position, agent.state, agent.speed_counter)
# import pdb; pdb.set_trace()
# _assert(a, agent.position, replay.position, 'position')
# _assert(a, agent.direction, replay.direction, 'direction')
if replay.state is not None:
_assert(a, agent.state, replay.state, 'state')
......@@ -130,7 +132,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
agent.malfunction_data['malfunction'] = replay.set_malfunction
agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['fixed'] = False
_assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
# _assert(a, agent.malfunction_data['malfunction'], replay.malfunction, 'malfunction')
print(step)
_, rewards_dict, _, info_dict = env.step(action_dict)
if rendering:
......@@ -141,6 +143,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
if not skip_reward_check:
_assert(a, rewards_dict[a], replay.reward, 'reward')
assert False
def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_generator: RailGenerator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment