Commit f4dc1668 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

change railenvstatus and speed data in tests

parent 4169a0f1
Pipeline #8455 failed with stages
in 6 minutes and 27 seconds
......@@ -8,8 +8,6 @@ import time
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
class CustomObservationBuilder(ObservationBuilder):
......
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
......@@ -7,7 +6,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
def test_initial_status():
"""Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
......@@ -30,7 +29,7 @@ def test_initial_status():
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
......@@ -38,35 +37,35 @@ def test_initial_status():
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_LEFT,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
......@@ -76,28 +75,28 @@ def test_initial_status():
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, #
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, #
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
# Replay(
# position=(3, 5),
......@@ -122,7 +121,7 @@ def test_initial_status():
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
assert env.agents[0].status == RailAgentStatus.DONE
assert env.agents[0].state == TrainState.DONE
def test_status_done_remove():
......@@ -146,7 +145,7 @@ def test_status_done_remove():
Replay(
position=None, # not entered grid yet
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.DO_NOTHING,
reward=env.step_penalty * 0.5,
......@@ -154,35 +153,35 @@ def test_status_done_remove():
Replay(
position=None, # not entered grid yet before step
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.READY_TO_DEPART,
state=TrainState.READY_TO_DEPART,
action=RailEnvActions.MOVE_LEFT,
reward=env.step_penalty * 0.5, # auto-correction left to forward without penalty!
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.start_penalty + env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 9),
direction=Grid4TransitionsEnum.EAST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # running at speed 0.5
),
Replay(
position=(3, 8),
direction=Grid4TransitionsEnum.WEST,
status=RailAgentStatus.ACTIVE,
state=TrainState.MOVING,
action=None,
reward=env.step_penalty * 0.5, # running at speed 0.5
......@@ -192,28 +191,28 @@ def test_status_done_remove():
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_RIGHT,
reward=env.step_penalty * 0.5, # running at speed 0.5
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 7),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.step_penalty * 0.5, # wrong action is corrected to forward without penalty!
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=RailEnvActions.MOVE_FORWARD,
reward=env.step_penalty * 0.5, # done
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
Replay(
position=(3, 6),
direction=Grid4TransitionsEnum.WEST,
action=None,
reward=env.global_reward, # already done
status=RailAgentStatus.ACTIVE
state=TrainState.MOVING
),
# Replay(
# position=None,
......@@ -238,4 +237,4 @@ def test_status_done_remove():
)
run_replay_config(env, [test_config], activate_agents=False, skip_reward_check=True)
assert env.agents[0].status == RailAgentStatus.DONE_REMOVED
assert env.agents[0].state == TrainState.DONE
......@@ -5,7 +5,6 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
......@@ -13,6 +12,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.step_utils.states import TrainState
"""Tests for `flatland` package."""
......@@ -106,7 +106,7 @@ def test_reward_function_conflict(rendering=False):
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent._set_state(TrainState.MOVING)
agent = env.agents[1]
agent.position = (3, 8) # east dead-end
......@@ -115,13 +115,13 @@ def test_reward_function_conflict(rendering=False):
agent.initial_direction = 3 # west
agent.target = (6, 6) # south dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].status = RailAgentStatus.ACTIVE
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (5, 6)
env.agents[1].position = (3, 8)
print("\n")
......@@ -195,7 +195,7 @@ def test_reward_function_waiting(rendering=False):
agent.initial_direction = 3 # west
agent.target = (3, 1) # west dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent._set_state(TrainState.MOVING)
agent = env.agents[1]
agent.initial_position = (5, 6) # south dead-end
......@@ -204,13 +204,13 @@ def test_reward_function_waiting(rendering=False):
agent.initial_direction = 0 # north
agent.target = (3, 8) # east dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.agents[0].moving = True
env.agents[1].moving = True
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].status = RailAgentStatus.ACTIVE
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
env.agents[0].position = (3, 8)
env.agents[1].position = (5, 6)
......
......@@ -5,7 +5,6 @@ import pprint
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv, Node
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
......@@ -16,6 +15,7 @@ from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
from flatland.envs.rail_env_action import RailEnvActions
from flatland.envs.step_utils.states import TrainState
"""Test predictions for `flatland` package."""
......@@ -135,7 +135,7 @@ def test_shortest_path_predictor(rendering=False):
agent.initial_direction = 0 # north
agent.target = (3, 9) # east dead-end
agent.moving = True
agent.status = RailAgentStatus.ACTIVE
agent._set_state(TrainState.MOVING)
env.reset(False, False)
env.distance_map._compute(env.agents, env.rail)
......@@ -269,7 +269,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env.agents[0].initial_direction = 0 # north
env.agents[0].target = (3, 9) # east dead-end
env.agents[0].moving = True
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1].initial_position = (3, 8) # east dead-end
env.agents[1].position = (3, 8) # east dead-end
......@@ -277,7 +277,7 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env.agents[1].initial_direction = 3 # west
env.agents[1].target = (6, 6) # south dead-end
env.agents[1].moving = True
env.agents[1].status = RailAgentStatus.ACTIVE
env.agents[1]._set_state(TrainState.MOVING)
observations, info = env.reset(False, False)
......@@ -285,8 +285,8 @@ def test_shortest_path_predictor_conflicts(rendering=False):
env.agent_positions[env.agents[0].position] = 0
env.agents[1].position = (3, 8) # east dead-end
env.agent_positions[env.agents[1].position] = 1
env.agents[0].status = RailAgentStatus.ACTIVE
env.agents[1].status = RailAgentStatus.ACTIVE
env.agents[0]._set_state(TrainState.MOVING)
env.agents[1]._set_state(TrainState.MOVING)
observations = env._get_observations()
......
......@@ -1315,8 +1315,8 @@ def test_rail_env_action_required_info():
if step == 0 or info_only_if_action_required['action_required'][a]:
action_dict_only_if_action_required.update({a: action})
else:
print("[{}] not action_required {}, speed_data={}".format(step, a,
env_always_action.agents[a].speed_data))
print("[{}] not action_required {}, speed_counter={}".format(step, a,
env_always_action.agents[a].speed_counter))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action)
......@@ -1375,7 +1375,7 @@ def test_rail_env_malfunction_speed_info():
for a in range(env.get_num_agents()):
assert info['malfunction'][a] >= 0
assert info['speed'][a] >= 0 and info['speed'][a] <= 1
assert info['speed'][a] == env.agents[a].speed_data['speed']
assert info['speed'][a] == env.agents[a].sspeed_counter.speed
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
......
......@@ -6,14 +6,14 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.line_generators import sparse_line_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.envs.step_utils.states import TrainState
from flatland.envs.step_utils.speed_counter import SpeedCounter
class SingleAgentNavigationObs(ObservationBuilder):
"""
......@@ -32,11 +32,11 @@ class SingleAgentNavigationObs(ObservationBuilder):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.state.is_off_map_state():
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.state.is_on_map_state():
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.state == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -85,7 +85,7 @@ def test_malfunction_process():
obs, info = env.reset(False, False, random_seed=10)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
env.agents[a_idx].state = TrainState.MOVING
agent_halts = 0
total_down_time = 0
......@@ -297,7 +297,7 @@ def test_initial_malfunction():
reward=env.step_penalty # running at speed 1.0
)
],
speed=env.agents[0].speed_data['speed'],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
......@@ -315,7 +315,7 @@ def test_initial_malfunction_stop_moving():
env._max_episode_steps = 1000
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].state)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
......@@ -327,7 +327,7 @@ def test_initial_malfunction_stop_moving():
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty when stopped
status=RailAgentStatus.READY_TO_DEPART
state=TrainState.READY_TO_DEPART
),
Replay(
position=(3, 2),
......@@ -335,7 +335,7 @@ def test_initial_malfunction_stop_moving():
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty, # full step penalty when stopped
status=RailAgentStatus.ACTIVE
state=TrainState.READY_TO_DEPART
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action STOP_MOVING, agent should restart without moving
......@@ -346,7 +346,7 @@ def test_initial_malfunction_stop_moving():
action=RailEnvActions.STOP_MOVING,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.STOPPED
),
# we have stopped and do nothing --> should stand still
Replay(
......@@ -355,7 +355,7 @@ def test_initial_malfunction_stop_moving():
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.STOPPED
),
# we start to move forward --> should go to next cell now
Replay(
......@@ -364,7 +364,7 @@ def test_initial_malfunction_stop_moving():
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.STOPPED
),
Replay(
position=(3, 3),
......@@ -372,10 +372,10 @@ def test_initial_malfunction_stop_moving():
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.STOPPED
)
],
speed=env.agents[0].speed_data['speed'],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
......@@ -412,7 +412,7 @@ def test_initial_malfunction_do_nothing():
set_malfunction=3,
malfunction=3,
reward=env.step_penalty, # full step penalty while malfunctioning
status=RailAgentStatus.READY_TO_DEPART
state=TrainState.READY_TO_DEPART
),
Replay(
position=(3, 2),
......@@ -420,7 +420,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING,
malfunction=2,
reward=env.step_penalty, # full step penalty while malfunctioning
status=RailAgentStatus.ACTIVE
state=TrainState.ACTIVE
),
# malfunction stops in the next step and we're still at the beginning of the cell
# --> if we take action DO_NOTHING, agent should restart without moving
......@@ -431,7 +431,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING,
malfunction=1,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.ACTIVE
),
# we haven't started moving yet --> stay here
Replay(
......@@ -440,7 +440,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.DO_NOTHING,
malfunction=0,
reward=env.step_penalty, # full step penalty while stopped
status=RailAgentStatus.ACTIVE
state=TrainState.ACTIVE
),
Replay(
......@@ -449,7 +449,7 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.start_penalty + env.step_penalty * 1.0, # start penalty + step penalty for speed 1.0
status=RailAgentStatus.ACTIVE
state=TrainState.ACTIVE
), # we start to move forward --> should go to next cell now
Replay(
position=(3, 3),
......@@ -457,10 +457,10 @@ def test_initial_malfunction_do_nothing():
action=RailEnvActions.MOVE_FORWARD,
malfunction=0,
reward=env.step_penalty * 1.0, # step penalty for speed 1.0
status=RailAgentStatus.ACTIVE
state=TrainState.ACTIVE
)
],
speed=env.agents[0].speed_data['speed'],
speed=env.agents[0].speed_counter.speed,
target=env.agents[0].target,
initial_position=(3, 2),
initial_direction=Grid4TransitionsEnum.EAST,
......@@ -475,7 +475,7 @@ def tests_random_interference_from_outside():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
env_data = []
......@@ -501,7 +501,7 @@ def tests_random_interference_from_outside():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 0.33
env.agents[0].speed_counter = SpeedCounter(speed=0.33)
env.reset(False, False, random_seed=10)
dummy_list = [1, 2, 6, 7, 8, 9, 4, 5, 4]
......@@ -536,7 +536,7 @@ def test_last_malfunction_step():
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1)
env.reset()
env.agents[0].speed_data['speed'] = 1. / 3.
env.agents[0].speed_counter = SpeedCounter(speed=1./3.)
env.agents[0].initial_position = (6, 6)
env.agents[0].initial_direction = 2
env.agents[0].target = (0, 3)
......@@ -546,7 +546,7 @@ def test_last_malfunction_step():
env.reset(False, False)
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
env.agents[a_idx].state = TrainState.ACTIVE
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_data['malfunction'] = 0
......@@ -565,13 +565,13 @@ def test_last_malfunction_step():
if env.agents[0].malfunction_data['malfunction'] < 1:
agent_can_move = True
# Store the position before and after the step
pre_position = env.agents[0].speed_data['position_fraction']
pre_position = env.agents[0].speed_counter.counter
_, reward, _, _ = env.step(action_dict)
# Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_can_move = False
post_position = env.agents[0].speed_data['position_fraction']
post_position = env.agents[0].speed_counter.counter
# Assert that the agent moved while it was still allowed
if agent_can_move:
assert pre_position != post_position
......
......@@ -10,7 +10,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr
from flatland.envs.line_generators import sparse_line_generator, line_from_file
from flatland.utils.simple_rail import make_simple_rail
from flatland.envs.persistence import RailEnvPersister
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.step_utils.states import TrainState
def test_empty_rail_generator():
......@@ -35,7 +35,7 @@ def test_rail_from_grid_transition_map():
for a_idx in range(len(env.agents)):
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].status = RailAgentStatus.ACTIVE
env.agents[a_idx]._set_state(TrainState.MOVING)
nr_rail_elements = np.count_nonzero(env.rail.grid)
......
import numpy as np
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.step_utils.states import TrainState