Commit 8be8160f authored by u214892's avatar u214892
Browse files

Trainrun and Waypoint instead of WayPoint and TrainRun for readability

parent 64c02dc1
Pipeline #2838 passed with stages
in 48 minutes and 29 seconds
......@@ -40,7 +40,8 @@ env = RailEnv(width=100,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=100,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), # Malfunction data generator
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
# Malfunction data generator
obs_builder_object=GlobalObsForRailEnv(),
remove_agents_at_target=True,
record_steps=True
......
......@@ -264,7 +264,7 @@ class ControllerFromTrainrunsReplayer():
"""Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction."""
@staticmethod
def replay_verify(max_episode_steps: int, ctl: ControllerFromTrainruns, env: RailEnv, rendering: bool):
def replay_verify(ctl: ControllerFromTrainruns, env: RailEnv, rendering: bool):
"""Replays this deterministic `ActionPlan` and verifies whether it is feasible."""
if rendering:
renderer = RenderTool(env, gl="PILSVG",
......@@ -275,7 +275,7 @@ class ControllerFromTrainrunsReplayer():
screen_width=1000)
renderer.render_env(show=True, show_observations=False, show_predictions=False)
i = 0
while not env.dones['__all__'] and i <= max_episode_steps:
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
way_point: Waypoint = ctl.get_way_point_before_or_at_step(agent_id, i)
assert agent.position == way_point.position, \
......
......@@ -31,7 +31,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
@attrs
class EnvAgent:
initial_position = attrib(type=Tuple[int, int])
initial_direction = attrib(type=Grid4TransitionsEnum)
direction = attrib(type=Grid4TransitionsEnum)
......
......@@ -63,7 +63,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
:param num_resets: How often the generator has been reset.
:return: Returns the generator to the rail constructor
"""
#Todo: Remove parameters and variables not used for next version, Issue: <https://gitlab.aicrowd.com/flatland/flatland/issues/305>
# Todo: Remove parameters and variables not used for next version, Issue: <https://gitlab.aicrowd.com/flatland/flatland/issues/305>
_runtime_seed = seed + num_resets
start_goal = hints['start_goal']
......
......@@ -16,9 +16,9 @@ import redis
import timeout_decorator
import flatland
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_file
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
......
......@@ -542,8 +542,8 @@ class PILSVG(PILGL):
if (col + row + col * row) % 3 == 0:
a = (a + (col + row + col * row)) % len(self.dBuildings)
pil_track = self.dBuildings[a]
elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or
((col ** 3 + row ** 2 + col * row) % 10 == 0)):
elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or
((col ** 3 + row ** 2 + col * row) % 10 == 0)):
a = int(self.background_grid[col][row]) - 4
a2 = (a + (col + row + col * row + col ** 3 + row ** 4))
if a2 % 64 > 11:
......
......@@ -564,9 +564,9 @@ class RenderTool(object):
if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, old_direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
else:
position = agent.position
direction = agent.direction
......@@ -578,15 +578,15 @@ class RenderTool(object):
# set_agent_at uses the agent index for the color
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction,
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
selected_agent == agent_idx, rail_grid=env.rail.grid,
show_debug=self.show_debug, clear_debug_text=self.clear_debug_text,
malfunction=is_malfunction)
# set_agent_at uses the agent index for the color
if self.agent_render_variant == AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
self.gl.set_agent_at(agent_idx, *position, agent.direction, direction, selected_agent == agent_idx,
rail_grid=env.rail.grid, malfunction=is_malfunction)
rail_grid=env.rail.grid, malfunction=is_malfunction)
if show_observations:
self.render_observation(range(env.get_num_agents()), env.dev_obs_dict)
......
......@@ -193,9 +193,10 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]:
rail_map = np.array(
[[empty] * 3 + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west]] +
[[empty] * 3 + [vertical_straight] + [empty] * 5 + [vertical_straight]] * 2 +
[[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
[[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] +
[[dead_end_from_east] + [horizontal_straight] * 2 + [simple_switch_left_east] + [horizontal_straight] * 2 + [
right_turn_from_west] + [empty] * 2 + [vertical_straight]] +
[[empty] * 6 + [simple_switch_north_right] + [horizontal_straight] * 2 + [right_turn_from_north]] +
[[empty] * 6 + [vertical_straight] + [empty] * 3] +
[[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
rail = GridTransitionMap(width=rail_map.shape[1],
height=rail_map.shape[0], transitions=transitions)
......
......@@ -42,7 +42,6 @@ if not os.path.exists(img_dest):
for image_file in glob.glob(r'./images/*.png'):
shutil.copy(image_file, img_dest)
subprocess.call(['sphinx-apidoc', '--force', '-a', '-e', '-o', 'docs/', 'flatland', '-H', 'API Reference', '--tocfile',
'05_apidoc'])
......
......@@ -80,9 +80,7 @@ def test_action_plan(rendering: bool = False):
]]
MAX_EPISODE_STEPS = 50
deterministic_controller = ControllerFromTrainruns(env, chosen_path_dict)
deterministic_controller.print_action_plan()
ControllerFromTrainruns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan)
ControllerFromTrainrunsReplayer.replay_verify(MAX_EPISODE_STEPS, deterministic_controller, env, rendering)
ControllerFromTrainrunsReplayer.replay_verify(deterministic_controller, env, rendering)
import numpy as np
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
......@@ -9,7 +6,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
def test_initial_status():
......
......@@ -8,7 +8,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import Waypoint
......
import random
import unittest
import warnings
......@@ -13,7 +12,6 @@ from flatland.utils.rendertools import RenderTool
def test_sparse_rail_generator():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
......@@ -594,7 +592,6 @@ def test_sparse_rail_generator():
def test_sparse_rail_generator_deterministic():
"""Check that sparse_rail_generator runs deterministic over different python versions!"""
speed_ration_map = {1.: 1., # Fast passenger train
1. / 2.: 0., # Fast freight train
1. / 3.: 0., # Slow commuter train
......@@ -1440,7 +1437,6 @@ def test_rail_env_action_required_info():
def test_rail_env_malfunction_speed_info():
env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
max_rails_between_cities=3,
seed=5,
......@@ -1475,7 +1471,6 @@ def test_rail_env_malfunction_speed_info():
def test_sparse_generator_with_too_man_cities_does_not_break_down():
RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
max_num_cities=100,
max_rails_between_cities=3,
......
......@@ -2,7 +2,6 @@ import random
from typing import Dict, List
import numpy as np
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
......@@ -13,6 +12,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail2
from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
class SingleAgentNavigationObs(ObservationBuilder):
......
import numpy as np
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.malfunction_generators import malfunction_from_params
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
......
import numpy as np
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
......@@ -8,6 +7,8 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment