Skip to content
Snippets Groups Projects
Commit 4cb884a1 authored by adrian_egli2's avatar adrian_egli2
Browse files

Fast methods improves overall performance

parent 344bd2c8
No related branches found
No related tags found
No related merge requests found
import cProfile
import pstats
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
def get_rail_env(nAgents=70, use_dummy_obs=False, width=60, height=60):
# Rail Generator:
num_cities = 5 # Number of cities to place on the map
seed = 1 # Random seed
max_rails_between_cities = 2 # Maximum number of rails connecting 2 cities
max_rail_pairs_in_cities = 2 # Maximum number of pairs of tracks within a city
# Even tracks are used as start points, odd tracks are used as endpoints)
rail_generator = sparse_rail_generator(
max_num_cities=num_cities,
seed=seed,
max_rails_between_cities=max_rails_between_cities,
max_rail_pairs_in_city=max_rail_pairs_in_cities,
)
# Line Generator
# sparse_line_generator accepts a dictionary which maps speeds to probabilities.
# Different agent types (trains) with different speeds.
speed_probability_map = {
1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25 # Slow freight train
}
line_generator = sparse_line_generator(speed_probability_map)
# Malfunction Generator:
stochastic_data = MalfunctionParameters(
malfunction_rate=1 / 10000, # Rate of malfunction occurence
min_duration=15, # Minimal duration of malfunction
max_duration=50 # Max duration of malfunction
)
malfunction_generator = ParamMalfunctionGen(stochastic_data)
# Observation Builder
# tree observation returns a tree of possible paths from the current position.
max_depth = 3 # Max depth of the tree
predictor = ShortestPathPredictorForRailEnv(
max_depth=50) # (Specific to Tree Observation - read code)
observation_builder = TreeObsForRailEnv(
max_depth=max_depth,
predictor=predictor
)
if use_dummy_obs:
observation_builder = DummyObservationBuilder()
number_of_agents = nAgents # Number of trains to create
seed = 1 # Random seed
env = RailEnv(
width=width,
height=height,
rail_generator=rail_generator,
line_generator=line_generator,
number_of_agents=number_of_agents,
random_seed=seed,
obs_builder_object=observation_builder,
malfunction_generator=malfunction_generator
)
return env
USE_PROFILER = True
PROFILE_CREATE = False
PROFILE_RESET = True
PROFILE_OBSERVATION = False
if __name__ == "__main__":
print("Start ...")
if USE_PROFILER:
profiler = cProfile.Profile()
print("Create env ... ")
if PROFILE_CREATE:
profiler.enable()
env_fast = get_rail_env(nAgents=70, use_dummy_obs=True)
if PROFILE_CREATE:
profiler.disable()
print("Reset env ... ")
if PROFILE_RESET:
profiler.enable()
env_fast.reset(random_seed=1)
if PROFILE_RESET:
profiler.disable()
print("Make actions ... ")
action_dict = {agent.handle: 0 for agent in env_fast.agents}
print("Step env ... ")
env_fast.step(action_dict)
if PROFILE_OBSERVATION:
profiler.enable()
print("get observation ... ")
env_fast._get_observations()
if PROFILE_OBSERVATION:
profiler.disable()
if USE_PROFILER:
print("---- tottime")
stats = pstats.Stats(profiler).sort_stats('tottime') # ncalls, 'cumtime'...
stats.print_stats(20)
print("---- cumtime")
stats = pstats.Stats(profiler).sort_stats('cumtime') # ncalls, 'cumtime'...
stats.print_stats(20)
print("---- ncalls")
stats = pstats.Stats(profiler).sort_stats('ncalls') # ncalls, 'cumtime'...
stats.print_stats(200)
print("... end ")
......@@ -7,8 +7,9 @@ from flatland.envs.malfunction_generators import malfunction_from_params, Malfun
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.line_generators import sparse_line_generator
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.agent_utils import TrainState
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_count_nonzero, fast_argmax
MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
......@@ -17,12 +18,12 @@ def get_shortest_path_action(env,handle):
distance_map = env.distance_map.get()
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
if agent.status in [TrainState.WAITING, TrainState.READY_TO_DEPART,
TrainState.MALFUNCTION_OFF_MAP]:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
elif agent.status in [TrainState.MALFUNCTION, TrainState.MOVING, TrainState.STOPPED]:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
elif agent.status == TrainState.DONE:
agent_virtual_position = agent.target
else:
return None
......@@ -34,8 +35,8 @@ def get_shortest_path_action(env,handle):
possible_transitions = env.rail.get_transitions(
*agent.initial_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
......@@ -43,7 +44,7 @@ def get_shortest_path_action(env,handle):
agent_virtual_position, direction)
min_distances.append(
distance_map[handle, new_position[0],
new_position[1], direction])
new_position[1], direction])
else:
min_distances.append(np.inf)
......@@ -54,7 +55,7 @@ def get_shortest_path_action(env,handle):
idx = np.argpartition(np.array(min_distances), 2)
observation = [0, 0, 0]
observation[idx[0]] = 1
return np.argmax(observation) + 1
return fast_argmax(observation) + 1
def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
......@@ -83,7 +84,7 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
line_generator = sparse_line_generator(speed_ratio_map)
malfunction_generator = no_malfunction_generator()
while width <= max_width and height <= max_height:
try:
env = RailEnv(width=width, height=height, rail_generator=rail_generator,
......@@ -104,7 +105,7 @@ def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
height += 5
logging.info("Try again with larger env: (w,h):", width, height)
logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
return None
return None
def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
......@@ -226,11 +227,11 @@ def perc_completion(env):
tasks_finished = 0
if hasattr(env, "agents_data"):
agent_data = env.agents_data
else:
else:
agent_data = env.agents
for current_agent in agent_data:
if current_agent.status == RailAgentStatus.DONE:
if current_agent.status == TrainState.DONE:
tasks_finished += 1
return 100 * np.mean(tasks_finished / max(
1, len(agent_data)))
\ No newline at end of file
1, len(agent_data)))
import numpy as np
from collections import defaultdict
from typing import Dict, Any, Optional, Set, List, Tuple
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from typing import Dict, Tuple
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.core.grid.grid4_utils import get_new_position
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.contrib.utils.deadlock_checker import Deadlock_Checker
from flatland.envs.step_utils.states import TrainState
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.state.is_on_map_state():
......@@ -42,8 +40,8 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?")
action = RailEnvActions.MOVE_FORWARD
action = RailEnvActions.MOVE_FORWARD
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
......@@ -68,7 +66,7 @@ class RailEnvWrapper:
# @property
# def number_of_agents(self):
# return self.env.number_of_agents
# @property
# def agents(self):
# return self.env.agents
......@@ -92,11 +90,11 @@ class RailEnvWrapper:
@property
def rail(self):
return self.env.rail
@property
def width(self):
return self.env.width
@property
def height(self):
return self.env.height
......@@ -123,7 +121,7 @@ class ShortestPathActionWrapper(RailEnvWrapper):
def __init__(self, env:RailEnv):
super().__init__(env)
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
# input: action dict with actions in [0, 1, 2].
......@@ -159,7 +157,7 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
# Check for switch: if there is more than one outgoing transition
for orientation in directions:
possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
if num_transitions > 1:
switches.append(pos)
is_switch = True
......@@ -177,7 +175,7 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# env can be a real RailEnv, or anything that shares the same interface
# e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
......@@ -208,7 +206,7 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
o, r, d, i = {}, {}, {}, {}
# need to initialize i["..."]
# as we will access i["..."][agent_id]
i["action_required"] = dict()
......@@ -225,11 +223,11 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
r[agent_id] = reward[agent_id]
d[agent_id] = done[agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["action_required"][agent_id] = info["action_required"][agent_id]
i["malfunction"][agent_id] = info["malfunction"][agent_id]
i["speed"][agent_id] = info["speed"][agent_id]
i["state"][agent_id] = info["state"][agent_id]
if self.accumulate_skipped_rewards:
discounted_skipped_reward = r[agent_id]
......@@ -248,7 +246,7 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# end of while-loop
return o, r, d, i
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
obs, info = self.env.reset(**kwargs)
......@@ -302,4 +300,4 @@ class DeadlockWrapper(RailEnvWrapper):
def reset(self, **kwargs) -> Tuple[Dict, Dict]:
self.deadlock_checker.reset() # sets all lists of deadlocked agents to empty list
obs, info = super().reset(**kwargs)
return obs, info
\ No newline at end of file
return obs, info
from typing import Tuple
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_clip(position: Tuple[int, int], min_value: Tuple[int, int], max_value: Tuple[int, int]) -> bool:
return (
max(min_value[0], min(position[0], max_value[0])),
max(min_value[1], min(position[1], max_value[1]))
)
def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None and pos_2 is None:
return True
if pos_1 is None or pos_2 is None:
return False
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def fast_count_nonzero(possible_transitions: (int, int, int, int)):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
......@@ -12,6 +12,7 @@ from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero, fast_position_equal
from flatland.envs.step_utils.states import TrainState
from flatland.utils.ordered_set import OrderedSet
......@@ -207,7 +208,7 @@ class TreeObsForRailEnv(ObservationBuilder):
return None
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Here information about the agent itself is stored
distance_map = self.env.distance_map.get()
......@@ -234,7 +235,7 @@ class TreeObsForRailEnv(ObservationBuilder):
orientation = agent.direction
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
orientation = fast_argmax(possible_transitions)
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
......@@ -381,7 +382,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if np.array_equal(position, self.env.agents[handle].target):
if fast_position_equal(position, self.env.agents[handle].target):
last_is_target = True
break
......@@ -389,7 +390,7 @@ class TreeObsForRailEnv(ObservationBuilder):
if crossing_found:
# Treat the crossing as a straight rail cell
total_transitions = 2
num_transitions = np.count_nonzero(cell_transitions)
num_transitions = fast_count_nonzero(cell_transitions)
exploring = False
......@@ -408,7 +409,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# Keep walking through the tree along `direction`
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
direction = fast_argmax(cell_transitions)
position = get_new_position(position, direction)
num_steps += 1
tot_dist += 1
......
......@@ -23,6 +23,7 @@ from flatland.envs import line_generators as line_gen
from flatland.envs.timetable_generators import timetable_generator
from flatland.envs import persistence
from flatland.envs import agent_chains as ac
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.observations import GlobalObsForRailEnv
......@@ -159,7 +160,7 @@ class RailEnv(Environment):
else:
self.malfunction_generator = mal_gen.NoMalfunctionGen()
self.malfunction_process_data = self.malfunction_generator.get_process_data()
self.number_of_agents = number_of_agents
if rail_generator is None:
......@@ -222,7 +223,7 @@ class RailEnv(Environment):
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
def get_num_agents(self) -> int:
return len(self.agents)
......@@ -312,7 +313,7 @@ class RailEnv(Environment):
if optionals and 'agents_hints' in optionals:
agents_hints = optionals['agents_hints']
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
line = self.line_generator(self.rail, self.number_of_agents, agents_hints,
self.num_resets, self.np_random)
self.agents = EnvAgent.from_line(line)
......@@ -320,17 +321,17 @@ class RailEnv(Environment):
self.distance_map.reset(self.agents, self.rail)
# NEW : Time Schedule Generation
timetable = timetable_generator(self.agents, self.distance_map,
timetable = timetable_generator(self.agents, self.distance_map,
agents_hints, self.np_random)
self._max_episode_steps = timetable.max_episode_steps
for agent_i, agent in enumerate(self.agents):
agent.earliest_departure = timetable.earliest_departures[agent_i]
agent.earliest_departure = timetable.earliest_departures[agent_i]
agent.latest_arrival = timetable.latest_arrivals[agent_i]
else:
self.distance_map.reset(self.agents, self.rail)
# Reset agents to initial states
self.reset_agents()
......@@ -365,11 +366,11 @@ class RailEnv(Environment):
self.agent_positions[agent.position] = agent.handle
if agent.old_position is not None:
self.agent_positions[agent.old_position] = -1
def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed):
""" Generate State Transitions Signals used in the state machine """
st_signals = StateTransitionSignals()
# Malfunction starts when in_malfunction is set to true
st_signals.in_malfunction = agent.malfunction_handler.in_malfunction
......@@ -386,7 +387,7 @@ class RailEnv(Environment):
st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() and movement_allowed
# Target Reached
st_signals.target_reached = env_utils.fast_position_equal(agent.position, agent.target)
st_signals.target_reached = fast_position_equal(agent.position, agent.target)
# Movement conflict - Multiple trains trying to move into same cell
# If speed counter is not in cell exit, the train can enter the cell
......@@ -419,7 +420,7 @@ class RailEnv(Environment):
# Departed but never reached
if (agent.state.is_on_map_state()):
reward = agent.get_current_delay(self._elapsed_steps, self.distance_map)
return reward
def preprocess_action(self, action, agent):
......@@ -436,7 +437,7 @@ class RailEnv(Environment):
current_position, current_direction = agent.position, agent.direction
if current_position is None: # Agent not added on map yet
current_position, current_direction = agent.initial_position, agent.initial_direction
action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
# Check transitions, bounts for executing the action in the given position and directon
......@@ -444,15 +445,15 @@ class RailEnv(Environment):
action = RailEnvActions.STOP_MOVING
return action
def clear_rewards_dict(self):
""" Reset the rewards dictionary """
self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))}
def get_info_dict(self):
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
"""
Returns dictionary of infos for all agents
dict_keys : action_required -
malfunction - Counter value for malfunction > 0 means train is in malfunction
speed - Speed of the train
state - State from the trains's state machine
......@@ -466,7 +467,7 @@ class RailEnv(Environment):
'state': {i: agent.state for i, agent in enumerate(self.agents)}
}
return info_dict
def update_step_rewards(self, i_agent):
"""
Update the rewards dict for agent id i_agent for every timestep
......@@ -474,7 +475,7 @@ class RailEnv(Environment):
pass
def end_of_episode_update(self, have_all_agents_ended):
"""
"""
Updates made when episode ends
Parameters: have_all_agents_ended - Indicates if all agents have reached done state
"""
......@@ -482,10 +483,10 @@ class RailEnv(Environment):
( (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)):
for i_agent, agent in enumerate(self.agents):
reward = self._handle_end_reward(agent)
self.rewards_dict[i_agent] += reward
self.dones[i_agent] = True
self.dones["__all__"] = True
......@@ -515,7 +516,7 @@ class RailEnv(Environment):
self.motionCheck = ac.MotionCheck() # reset the motion check
temp_transition_data = {}
for agent in self.agents:
i_agent = agent.handle
agent.old_position = agent.position
......@@ -534,7 +535,7 @@ class RailEnv(Environment):
# Train's next position can change if train is at cell's exit and train is not in malfunction
position_update_allowed = agent.speed_counter.is_cell_exit and \
not agent.malfunction_handler.malfunction_down_counter > 0 and \
not preprocessed_action == RailEnvActions.STOP_MOVING
not preprocessed_action == RailEnvActions.STOP_MOVING
# Calculate new position
# Keep agent in same place if already done
......@@ -548,24 +549,24 @@ class RailEnv(Environment):
elif agent.action_saver.is_action_saved and position_update_allowed:
saved_action = agent.action_saver.saved_action
# Apply action independent of other agents and get temporary new position and direction
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
new_position, new_direction = env_utils.apply_action_independent(saved_action,
self.rail,
agent.position,
agent.direction)
preprocessed_action = saved_action
else:
new_position, new_direction = agent.position, agent.direction
temp_transition_data[i_agent] = env_utils.AgentTransitionData(position=new_position,
direction=new_direction,
preprocessed_action=preprocessed_action)
# This is for storing and later checking for conflicts of agents trying to occupy same cell
# This is for storing and later checking for conflicts of agents trying to occupy same cell
self.motionCheck.addAgent(i_agent, agent.position, new_position)
# Find conflicts between trains trying to occupy same cell
self.motionCheck.find_conflicts()
for agent in self.agents:
i_agent = agent.handle
......@@ -574,7 +575,7 @@ class RailEnv(Environment):
movement_allowed = False
else:
movement_allowed = self.motionCheck.check_motion(i_agent, agent.position)
movement_inside_cell = agent.state == TrainState.STOPPED and not agent.speed_counter.is_cell_exit
movement_allowed = movement_allowed or movement_inside_cell
......@@ -606,7 +607,7 @@ class RailEnv(Environment):
# Handle done state actions, optionally remove agents
self.handle_done_state(agent)
have_all_agents_ended &= (agent.state == TrainState.DONE)
## Update rewards
......@@ -620,16 +621,16 @@ class RailEnv(Environment):
# Clear old action when starting in new cell
if agent.speed_counter.is_cell_entry and agent.position is not None:
agent.action_saver.clear_saved_action()
# Check if episode has ended and update rewards and dones
self.end_of_episode_update(have_all_agents_ended)
self._update_agent_positions_map()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
return self._get_observations(), self.rewards_dict, self.dones, self.get_info_dict()
def record_timestep(self, dActions):
"""
"""
Record the positions and orientations of all agents in memory, in the cur_episode
"""
list_agents_state = []
......@@ -643,8 +644,8 @@ class RailEnv(Environment):
pos = (int(agent.position[0]), int(agent.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append([
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
*pos, int(agent.direction),
agent.malfunction_handler.malfunction_down_counter,
int(agent.status),
int(agent.position in self.motionCheck.svDeadlocked)
])
......@@ -690,7 +691,7 @@ class RailEnv(Environment):
"""
return agent.malfunction_handler.in_malfunction
def save(self, filename):
print("DEPRECATED call to env.save() - pls call RailEnvPersister.save()")
......
......@@ -9,6 +9,7 @@ from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.step_utils.states import TrainState
from flatland.envs.distance_map import DistanceMap
from flatland.envs.fast_methods import fast_count_nonzero
from flatland.envs.rail_env_action import RailEnvActions, RailEnvNextAction
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.utils.ordered_set import OrderedSet
......@@ -38,7 +39,7 @@ def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
"""
valid_actions: Set[RailEnvNextAction] = OrderedSet()
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
......@@ -96,7 +97,7 @@ def get_new_position_for_action(
row, column, direction
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
......@@ -161,7 +162,7 @@ def get_action_for_move(
the action (if direct transition possible) or None.
"""
possible_transitions = rail.get_transitions(*agent_position, agent_direction)
num_transitions = np.count_nonzero(possible_transitions)
num_transitions = fast_count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
......
......@@ -13,16 +13,6 @@ class AgentTransitionData:
direction : Grid4Transitions
preprocessed_action : RailEnvActions
# Adrian Egli performance fix (the fast methods brings more than 50%)
def fast_isclose(a, b, rtol):
return (a < (b + rtol)) or (a < (b - rtol))
def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
if pos_1 is None:
return False
else:
return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
def apply_action_independent(action, rail, position, direction):
""" Apply the action on the train regardless of locations of other trains
Checks for valid cells to move and valid rail transitions
......
from typing import Tuple
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero
from flatland.envs.rail_env_action import RailEnvActions
......@@ -21,7 +22,7 @@ def check_action(action, position, direction, rail):
transition_valid = None
possible_transitions = rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = direction - 1
......@@ -81,18 +82,6 @@ def check_valid_action(action, rail, position, direction):
action_is_valid = new_cell_valid and transition_valid
return action_is_valid
def fast_argmax(possible_transitions: Tuple[int, int, int, int]) -> bool:
if possible_transitions[0] == 1:
return 0
if possible_transitions[1] == 1:
return 1
if possible_transitions[2] == 1:
return 2
return 3
def fast_count_nonzero(possible_transitions: Tuple[int, int, int, int]):
return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
def check_bounds(position, height, width):
return position[0] >= 0 and position[1] >= 0 and position[0] < height and position[1] < width
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment