Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Showing
with 2511 additions and 3 deletions
from flatland.envs.fast_methods import fast_position_equal
from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
from flatland.envs.step_utils import env_utils
from flatland.envs.fast_methods import fast_position_equal
class TrainStateMachine:
def __init__(self, initial_state=TrainState.WAITING):
self._initial_state = initial_state
self._state = initial_state
self.st_signals = StateTransitionSignals()
self.next_state = None
self.previous_state = None
def _handle_waiting(self):
"""" Waiting state goes to ready to depart when earliest departure is reached"""
# TODO: Important - The malfunction handling is not like proper state machine
# Both transition signals can happen at the same time
# Atleast mention it in the diagram
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.earliest_departure_reached:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
def _handle_ready_to_depart(self):
""" Can only go to MOVING if a valid action is provided """
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.READY_TO_DEPART
def _handle_malfunction_off_map(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.earliest_departure_reached:
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
elif self.st_signals.stop_action_given:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.READY_TO_DEPART
else:
self.next_state = TrainState.WAITING
else:
self.next_state = TrainState.MALFUNCTION_OFF_MAP
def _handle_moving(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.target_reached:
self.next_state = TrainState.DONE
elif self.st_signals.stop_action_given or self.st_signals.movement_conflict:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MOVING
def _handle_stopped(self):
if self.st_signals.in_malfunction:
self.next_state = TrainState.MALFUNCTION
elif self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
def _handle_malfunction(self):
if self.st_signals.malfunction_counter_complete:
if self.st_signals.valid_movement_action_given:
self.next_state = TrainState.MOVING
else:
self.next_state = TrainState.STOPPED
else:
self.next_state = TrainState.MALFUNCTION
def _handle_done(self):
"""" Done state is terminal """
self.next_state = TrainState.DONE
def calculate_next_state(self, current_state):
# _Handle the current state
if current_state == TrainState.WAITING:
self._handle_waiting()
elif current_state == TrainState.READY_TO_DEPART:
self._handle_ready_to_depart()
elif current_state == TrainState.MALFUNCTION_OFF_MAP:
self._handle_malfunction_off_map()
elif current_state == TrainState.MOVING:
self._handle_moving()
elif current_state == TrainState.STOPPED:
self._handle_stopped()
elif current_state == TrainState.MALFUNCTION:
self._handle_malfunction()
elif current_state == TrainState.DONE:
self._handle_done()
else:
raise ValueError(f"Got unexpected state {current_state}")
def step(self):
""" Steps the state machine to the next state """
current_state = self._state
# Clear next state
self.clear_next_state()
# Handle current state to get next_state
self.calculate_next_state(current_state)
# Set next state
self.set_state(self.next_state)
def clear_next_state(self):
self.next_state = None
def set_state(self, state):
if not TrainState.check_valid_state(state):
raise ValueError(f"Cannot set invalid state {state}")
self.previous_state = self._state
self._state = state
def reset(self):
self._state = self._initial_state
self.previous_state = None
self.st_signals = StateTransitionSignals()
self.clear_next_state()
def update_if_reached(self, position, target):
# Need to do this hacky fix for now, state machine needed speed related states for proper handling
self.st_signals.target_reached = fast_position_equal(position, target)
if self.st_signals.target_reached:
self.next_state = TrainState.DONE
self.set_state(self.next_state)
@property
def state(self):
return self._state
@property
def state_transition_signals(self):
return self.st_signals
def set_transition_signals(self, state_transition_signals):
self.st_signals = state_transition_signals
def __repr__(self):
return f"\n \
state: {str(self.state)} previous_state {str(self.previous_state)} \n \
st_signals: {self.st_signals}"
def to_dict(self):
return {"state": self._state,
"previous_state": self.previous_state}
def from_dict(self, load_dict):
self.set_state(load_dict['state'])
self.previous_state = load_dict['previous_state']
def __eq__(self, other):
return self._state == other._state and self.previous_state == other.previous_state
from enum import IntEnum
from dataclasses import dataclass
class TrainState(IntEnum):
WAITING = 0
READY_TO_DEPART = 1
MALFUNCTION_OFF_MAP = 2
MOVING = 3
STOPPED = 4
MALFUNCTION = 5
DONE = 6
@classmethod
def check_valid_state(cls, state):
return state in cls._value2member_map_
def is_malfunction_state(self):
return self.value in [self.MALFUNCTION, self.MALFUNCTION_OFF_MAP]
def is_off_map_state(self):
return self.value in [self.WAITING, self.READY_TO_DEPART, self.MALFUNCTION_OFF_MAP]
def is_on_map_state(self):
return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION]
@dataclass(repr=True)
class StateTransitionSignals:
in_malfunction : bool = False
malfunction_counter_complete : bool = False
earliest_departure_reached : bool = False
stop_action_given : bool = False
valid_movement_action_given : bool = False
target_reached : bool = False
movement_conflict : bool = False
from typing import Tuple
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.fast_methods import fast_argmax, fast_count_nonzero
from flatland.envs.rail_env_action import RailEnvActions
def check_action(action, position, direction, rail):
"""
Parameters
----------
agent : EnvAgent
action : RailEnvActions
Returns
-------
Tuple[Grid4TransitionsEnum,Tuple[int,int]]
"""
transition_valid = None
possible_transitions = rail.get_transitions(*position, direction)
num_transitions = fast_count_nonzero(possible_transitions)
new_direction = direction
if action == RailEnvActions.MOVE_LEFT:
new_direction = direction - 1
if num_transitions <= 1:
transition_valid = False
elif action == RailEnvActions.MOVE_RIGHT:
new_direction = direction + 1
if num_transitions <= 1:
transition_valid = False
new_direction %= 4 # Dipam : Why?
if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1:
# - dead-end, straight line or curved line;
# new_direction will be the only valid transition
# - take only available transition
new_direction = fast_argmax(possible_transitions)
transition_valid = True
return new_direction, transition_valid
def check_action_on_agent(action, rail, position, direction):
"""
Parameters
----------
action : RailEnvActions
agent : EnvAgent
Returns
-------
bool
Is it a legal move?
1) transition allows the new_direction in the cell,
2) the new cell is not empty (case 0),
3) the cell is free, i.e., no agent is currently in that cell
"""
# compute number of possible transitions in the current
# cell used to check for invalid actions
new_direction, transition_valid = check_action(action, position, direction, rail)
new_position = get_new_position(position, new_direction)
new_cell_valid = check_bounds(new_position, rail.height, rail.width) and \
rail.get_full_transitions(*new_position) > 0
# If transition validity hasn't been checked yet.
if transition_valid is None:
transition_valid = rail.get_transition( (*position, direction), new_direction)
return new_cell_valid, new_direction, new_position, transition_valid
def check_valid_action(action, rail, position, direction):
new_cell_valid, _, _, transition_valid = check_action_on_agent(action, rail, position, direction)
action_is_valid = new_cell_valid and transition_valid
return action_is_valid
def check_bounds(position, height, width):
return position[0] >= 0 and position[1] >= 0 and position[0] < height and position[1] < width
import os
import json
import itertools
import warnings
from typing import Tuple, List, Callable, Mapping, Optional, Any
from flatland.envs.timetable_utils import Timetable
import numpy as np
from numpy.random.mtrand import RandomState
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
def len_handle_none(v):
if v is not None:
return len(v)
else:
return 0
def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
agents_hints: dict, np_random: RandomState = None) -> Timetable:
"""
Calculates earliest departure and latest arrival times for the agents
This is the new addition in Flatland 3
Also calculates the max episodes steps based on the density of the timetable
inputs:
agents - List of all the agents rail_env.agents
distance_map - Distance map of positions to tagets of each agent in each direction
agent_hints - Uses the number of cities
np_random - RNG state for seeding
returns:
Timetable with the latest_arrivals, earliest_departures and max_episdode_steps
"""
# max_episode_steps calculation
if agents_hints:
city_positions = agents_hints['city_positions']
num_cities = len(city_positions)
else:
num_cities = 2
timedelay_factor = 4
alpha = 2
max_episode_steps = int(timedelay_factor * alpha * \
(distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities)))
# Multipliers
old_max_episode_steps_multiplier = 3.0
new_max_episode_steps_multiplier = 1.5
travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier
assert new_max_episode_steps_multiplier > travel_buffer_multiplier
end_buffer_multiplier = 0.05
mean_shortest_path_multiplier = 0.2
shortest_paths = get_shortest_paths(distance_map)
shortest_paths_lengths = [len_handle_none(v) for k,v in shortest_paths.items()]
# Find mean_shortest_path_time
agent_speeds = [agent.speed_counter.speed for agent in agents]
agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds)
mean_shortest_path_time = np.mean(agent_shortest_path_times)
# Deciding on a suitable max_episode_steps
longest_speed_normalized_time = np.max(agent_shortest_path_times)
mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier
max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay)
max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier)
max_episode_steps = min(max_episode_steps_new, max_episode_steps_old)
end_buffer = int(max_episode_steps * end_buffer_multiplier)
latest_arrival_max = max_episode_steps-end_buffer
# Useless unless needed by returning
earliest_departures = []
latest_arrivals = []
for agent in agents:
agent_shortest_path_time = agent_shortest_path_times[agent.handle]
agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay))
departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1)
earliest_departure = np_random.randint(0, departure_window_max)
latest_arrival = earliest_departure + agent_travel_time_max
earliest_departures.append(earliest_departure)
latest_arrivals.append(latest_arrival)
agent.earliest_departure = earliest_departure
agent.latest_arrival = latest_arrival
return Timetable(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals,
max_episode_steps=max_episode_steps)
from typing import List, NamedTuple
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import IntVector2DArray
Line = NamedTuple('Line', [('agent_positions', IntVector2DArray),
('agent_directions', List[Grid4TransitionsEnum]),
('agent_targets', IntVector2DArray),
('agent_speeds', List[float])])
Timetable = NamedTuple('Timetable', [('earliest_departures', List[int]),
('latest_arrivals', List[int]),
('max_episode_steps', int)])
# -*- coding: utf-8 -*-
"""Top-level package for flatland."""
__author__ = """S.P. Mohanty"""
__email__ = 'mohanty@aicrowd.com'
import glob
import os
import random
import subprocess
import uuid
import pathlib
###############################################################
# Expected Env Variables
###############################################################
# Default Values to be provided :
# AICROWD_IS_GRADING : true
# CROWDAI_IS_GRADING : true
# S3_BUCKET : aicrowd-production
# S3_UPLOAD_PATH_TEMPLATE : misc/flatland-rl-Media/{}
# AWS_ACCESS_KEY_ID
# AWS_SECRET_ACCESS_KEY
# http_proxy
# https_proxy
###############################################################
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", False)
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", False)
S3_UPLOAD_PATH_TEMPLATE = os.getenv("S3_UPLOAD_PATH_TEMPLATE", "misc/flatland-rl-Media/{}.mp4")
S3_BUCKET = os.getenv("S3_BUCKET", "aicrowd-production")
S3_BUCKET_ACL = "public-read" if S3_BUCKET == "aicrowd-production" else ""
def get_boto_client():
if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
raise Exception("AWS Credentials not provided..")
try:
import boto3 # type: ignore
except ImportError:
raise Exception(
"boto3 is not installed. Please manually install by : ",
" pip install -U boto3"
)
return boto3.client(
's3',
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY
)
def is_aws_configured():
if not AWS_ACCESS_KEY_ID or not AWS_SECRET_ACCESS_KEY:
return False
else:
return True
def is_grading():
return os.getenv("CROWDAI_IS_GRADING", False) or \
os.getenv("AICROWD_IS_GRADING", False)
def get_submission_id():
submission_id = os.getenv("AICROWD_SUBMISSION_ID", "T12345")
return submission_id
def upload_random_frame_to_s3(frames_folder):
all_frames = glob.glob(os.path.join(frames_folder, "*.png"))
random_frame = random.choice(all_frames)
s3 = get_boto_client()
if not S3_UPLOAD_PATH_TEMPLATE:
raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
image_target_key = (S3_UPLOAD_PATH_TEMPLATE + ".png").format(str(uuid.uuid4()))
s3.put_object(
ACL=S3_BUCKET_ACL,
Bucket=S3_BUCKET,
Key=image_target_key,
Body=open(random_frame, 'rb')
)
return image_target_key
def upload_to_s3(localpath):
s3 = get_boto_client()
if not S3_UPLOAD_PATH_TEMPLATE:
raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
file_suffix = str(pathlib.Path(localpath).suffix)
file_target_key = (S3_UPLOAD_PATH_TEMPLATE + file_suffix).format(
str(uuid.uuid4())
)
s3.put_object(
ACL=S3_BUCKET_ACL,
Bucket=S3_BUCKET,
Key=file_target_key,
Body=open(localpath, 'rb')
)
return file_target_key
def upload_folder_to_s3(folderpath):
s3 = get_boto_client()
if not S3_BUCKET:
raise Exception("S3_BUCKET not provided...")
for path, subdirs, files in os.walk(folderpath):
if len(files) != 0:
for file in files:
file_target_key = f'analysis_logs/{get_submission_id()}/{path[path.find(next(filter(str.isalpha, path))):]}/{file}'
localpath = os.path.join(path, file)
print(f"[INFO] SAVING: {localpath}")
s3.put_object(
ACL=S3_BUCKET_ACL,
Bucket=S3_BUCKET,
Key=file_target_key,
Body=open(localpath, 'rb')
)
def make_subprocess_call(command, shell=False):
result = subprocess.run(
command.split(),
shell=shell,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
stdout = result.stdout.decode('utf-8')
stderr = result.stderr.decode('utf-8')
return result.returncode, stdout, stderr
def generate_movie_from_frames(frames_folder):
"""
Expects the frames in the frames_folder folder
and then use ffmpeg to generate the video
which writes the output to the frames_folder
"""
# Generate Thumbnail Video
print("Generating Thumbnail...")
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " +
thumb_output_path
)
if return_code != 0:
raise Exception(output_err)
# Generate Normal Sized Video
print("Generating Normal Video...")
frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
output_path = os.path.join(frames_folder, "out.mp4")
return_code, output, output_err = make_subprocess_call(
"ffmpeg -r 7 -start_number 0 -i " +
frames_path +
" -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " +
output_path
)
if return_code != 0:
raise Exception(output_err)
return output_path, thumb_output_path
import hashlib
import json
import logging
import os
import random
import time
import msgpack
import msgpack_numpy as m
import pickle
import numpy as np
import redis
import flatland
from flatland.envs.malfunction_generators import FileMalfunctionGen
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.line_generators import line_from_file
from flatland.evaluators import messages
from flatland.core.env_observation_builder import DummyObservationBuilder
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
m.patch()
# CONSTANTS
FLATLAND_RL_SERVICE_ID = os.getenv(
'AICROWD_SUBMISSION_ID',
'T12345')
class TimeoutException(StopAsyncIteration):
""" Custom exception for evaluation timeouts. """
pass
class FlatlandRemoteClient(object):
"""
Redis client to interface with flatland-rl remote-evaluation-service
The Docker container hosts a redis-server inside the container.
This client connects to the same redis-server,
and communicates with the service.
The service eventually will reside outside the docker container,
and will communicate
with the client only via the redis-server of the docker container.
On the instantiation of the docker container, one service will be
instantiated parallely.
The service will accepts commands at "`service_id`::commands"
where `service_id` is either provided as an `env` variable or is
instantiated to "flatland_rl_redis_service_id"
"""
def __init__(self,
test_env_folder=None,
flatland_rl_service_id=FLATLAND_RL_SERVICE_ID,
remote_host=os.getenv("redis_ip", '127.0.0.1'),
remote_port=6379,
remote_db=0,
remote_password=None,
verbose=False,
use_pickle=False):
self.use_pickle = use_pickle
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.redis_pool = redis.ConnectionPool(
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
self.redis_conn = redis.Redis(connection_pool=self.redis_pool)
self.namespace = "flatland-rl"
self.service_id = flatland_rl_service_id
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
# for timeout messages sent out-of-band
self.error_channel = "{}::{}::errors".format(
self.namespace,
self.service_id
)
if test_env_folder:
self.test_envs_root = test_env_folder
else:
self.test_envs_root = os.getenv(
'AICROWD_TESTS_FOLDER',
'/tmp/flatland_envs'
)
self.current_env_path = None
self.verbose = verbose
self.env = None
self.ping_pong()
self.env_step_times = []
self.stats = {}
def update_running_stats(self, key, scalar):
"""
Computes the running mean for certain params
"""
mean_key = "{}_mean".format(key)
counter_key = "{}_counter".format(key)
min_key = "{}_min".format(key)
max_key = "{}_max".format(key)
try:
# Update Mean
self.stats[mean_key] = \
((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
# Update min
if scalar < self.stats[min_key]:
self.stats[min_key] = scalar
# Update max
if scalar > self.stats[max_key]:
self.stats[max_key] = scalar
self.stats[counter_key] += 1
except KeyError:
self.stats[mean_key] = scalar
self.stats[min_key] = scalar
self.stats[max_key] = scalar
self.stats[counter_key] = 1
def get_redis_connection(self):
return self.redis_conn
def _generate_response_channel(self):
random_hash = hashlib.md5(
"{}".format(
random.randint(0, 10 ** 10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id,
random_hash)
return response_channel
def _remote_request(self, _request, blocking=True):
"""
request:
-command_type
-payload
-response_channel
response: (on response_channel)
- RESULT
* Send the payload on command_channel (self.namespace+"::command")
** redis-left-push (LPUSH)
* Keep listening on response_channel (BLPOP)
"""
assert isinstance(_request, dict)
_request['response_channel'] = self._generate_response_channel()
_request['timestamp'] = time.time()
_redis = self.get_redis_connection()
"""
The client always pushes in the left
and the service always pushes in the right
"""
if self.verbose:
print("Request : ", _request)
# check for errors (essentially just timeouts, for now.)
error_bytes = _redis.rpop(self.error_channel)
if error_bytes is not None:
if self.use_pickle:
error_dict = pickle.loads(error_bytes)
else:
error_dict = msgpack.unpackb(
error_bytes,
object_hook=m.decode,
strict_map_key=False, # new for msgpack 1.0?
)
print("Error received: ", error_dict)
raise TimeoutException(error_dict["type"])
# Push request in command_channels
# Note: The patched msgpack supports numpy arrays
if self.use_pickle:
payload = pickle.dumps(_request)
else:
payload = msgpack.packb(_request, default=m.encode, use_bin_type=True)
_redis.lpush(self.command_channel, payload)
if blocking:
# Wait with a blocking pop for the response
_response = _redis.blpop(_request['response_channel'])[1]
if self.verbose:
print("Response : ", _response)
if self.use_pickle:
_response = pickle.loads(_response)
else:
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
strict_map_key=False, # new for msgpack 1.0?
)
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
return _response
def ping_pong(self):
"""
Official Handshake with the evaluation service
Send a PING
and wait for PONG
If not PONG, raise error
"""
_request = {}
_request['type'] = messages.FLATLAND_RL.PING
_request['payload'] = {
"version": flatland.__version__
}
_response = self._remote_request(_request)
if _response['type'] != messages.FLATLAND_RL.PONG:
raise Exception(
"Unable to perform handshake with the evaluation service. \
Expected PONG; received {}".format(json.dumps(_response)))
else:
return True
def env_create(self, obs_builder_object):
"""
Create a local env and remote env on which the
local agent can operate.
The observation builder is only used in the local env
and the remote env uses a DummyObservationBuilder
"""
time_start = time.time()
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_CREATE
_request['payload'] = {}
_response = self._remote_request(_request)
observation = _response['payload']['observation']
info = _response['payload']['info']
random_seed = _response['payload']['random_seed']
test_env_file_path = _response['payload']['env_file_path']
time_diff = time.time() - time_start
self.update_running_stats("env_creation_wait_time", time_diff)
if not observation:
# If the observation is False,
# then the evaluations are complete
# hence return false
return observation, info
if self.verbose:
print("Received Env : ", test_env_file_path)
test_env_file_path = os.path.join(
self.test_envs_root,
test_env_file_path
)
if not os.path.exists(test_env_file_path):
raise Exception(
"\nWe cannot seem to find the env file paths at the required location.\n"
"Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
"to point to the location of the Tests folder ? \n"
"We are currently looking at `{}` for the tests".format(self.test_envs_root)
)
if self.verbose:
print("Current env path : ", test_env_file_path)
self.current_env_path = test_env_file_path
self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
line_generator=line_from_file(test_env_file_path),
malfunction_generator=FileMalfunctionGen(filename=test_env_file_path),
obs_builder_object=obs_builder_object)
time_start = time.time()
# Use the local observation
# as the remote server uses a dummy observation builder
local_observation, info = self.env.reset(
regenerate_rail=True,
regenerate_schedule=True,
random_seed=random_seed
)
time_diff = time.time() - time_start
self.update_running_stats("internal_env_reset_time", time_diff)
# We use the last_env_step_time as an approximate measure of the inference time
self.last_env_step_time = time.time()
return local_observation, info
def env_step(self, action, render=False):
"""
Respond with [observation, reward, done, info]
"""
# We use the last_env_step_time as an approximate measure of the inference time
approximate_inference_time = time.time() - self.last_env_step_time
self.update_running_stats("inference_time(approx)", approximate_inference_time)
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_STEP
_request['payload'] = {}
_request['payload']['action'] = action
_request['payload']['inference_time'] = approximate_inference_time
# Relay the action in a non-blocking way to the server
# so that it can start doing an env.step on it in ~ parallel
# Note - this can throw a Timeout
self._remote_request(_request, blocking=False)
# Apply the action in the local env
time_start = time.time()
local_observation, local_reward, local_done, local_info = \
self.env.step(action)
time_diff = time.time() - time_start
# Compute a running mean of env step times
self.update_running_stats("internal_env_step_time", time_diff)
# We use the last_env_step_time as an approximate measure of the inference time
self.last_env_step_time = time.time()
return [local_observation, local_reward, local_done, local_info]
def submit(self):
_request = {}
_request['type'] = messages.FLATLAND_RL.ENV_SUBMIT
_request['payload'] = {}
_response = self._remote_request(_request)
######################################################################
# Print Local Stats
######################################################################
print("=" * 100)
print("=" * 100)
print("## Client Performance Stats")
print("=" * 100)
for _key in self.stats:
if _key.endswith("_mean"):
metric_name = _key.replace("_mean", "")
mean_key = "{}_mean".format(metric_name)
min_key = "{}_min".format(metric_name)
max_key = "{}_max".format(metric_name)
print("\t - {}\t => min: {} || mean: {} || max: {}".format(
metric_name,
self.stats[min_key],
self.stats[mean_key],
self.stats[max_key]))
print("=" * 100)
if os.getenv("AICROWD_BLOCKING_SUBMIT"):
"""
If the submission is supposed to happen as a blocking submit,
then wait indefinitely for the evaluator to decide what to
do with the container.
"""
while True:
time.sleep(10)
return _response['payload']
if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = DummyObservationBuilder()
episode = 0
obs = True
while obs:
obs, info = remote_client.env_create(
obs_builder_object=my_observation_builder
)
if not obs:
"""
The remote env returns False as the first obs
when it is done evaluating all the individual episodes
"""
break
print("Episode : {}".format(episode))
episode += 1
print(remote_client.env.dones['__all__'])
while True:
action = my_controller(obs, remote_client.env)
time_start = time.time()
try:
observation, all_rewards, done, info = remote_client.env_step(action)
time_diff = time.time() - time_start
print("Step Time : ", time_diff)
if done['__all__']:
print("Current Episode : ", episode)
print("Episode Done")
print("Reward : ", sum(list(all_rewards.values())))
break
except TimeoutException as err:
print("Timeout: ", err)
break
print("Evaluation Complete...")
print(remote_client.submit())
class FLATLAND_RL:
PING = "FLATLAND_RL.PING"
PONG = "FLATLAND_RL.PONG"
ENV_CREATE = "FLATLAND_RL.ENV_CREATE"
ENV_CREATE_RESPONSE = "FLATLAND_RL.ENV_CREATE_RESPONSE"
ENV_RESET = "FLATLAND_RL.ENV_RESET"
ENV_RESET_RESPONSE = "FLATLAND_RL.ENV_RESET_RESPONSE"
ENV_RESET_TIMEOUT = "FLATLAND_RL.ENV_RESET_TIMEOUT"
ENV_STEP = "FLATLAND_RL.ENV_STEP"
ENV_STEP_RESPONSE = "FLATLAND_RL.ENV_STEP_RESPONSE"
ENV_STEP_TIMEOUT = "FLATLAND_RL.ENV_STEP_TIMEOUT"
ENV_SUBMIT = "FLATLAND_RL.ENV_SUBMIT"
ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
ERROR = "FLATLAND_RL.ERROR"
#!/usr/bin/env python
from __future__ import print_function
import glob
import os
import random
import shutil
import time
import traceback
import json
import yaml
import itertools
import re
import crowdai_api
import msgpack
import msgpack_numpy as m
import pickle
import numpy as np
import pandas as pd
import redis
import timeout_decorator
import flatland
from flatland.envs.step_utils.states import TrainState
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
from flatland.envs.persistence import RailEnvPersister
use_signals_in_timeout = True
if os.name == 'nt':
"""
Windows doesnt support signals, hence
timeout_decorators usually fall apart.
Hence forcing them to not using signals
whenever using the timeout decorator.
"""
use_signals_in_timeout = False
m.patch()
########################################################
# CONSTANTS
########################################################
FLATLAND_RL_SERVICE_ID = os.getenv(
'AICROWD_SUBMISSION_ID',
'T12345')
# Don't proceed to next Test if the previous one didn't reach this mean completion percentage
TEST_MIN_PERCENTAGE_COMPLETE_MEAN = float(os.getenv("TEST_MIN_PERCENTAGE_COMPLETE_MEAN", 0.25))
# After this number of consecutive timeouts, kill the submission:
# this probably means the submission has crashed
MAX_SUCCESSIVE_TIMEOUTS = int(os.getenv("FLATLAND_MAX_SUCCESSIVE_TIMEOUTS", 10))
debug_mode = (os.getenv("AICROWD_DEBUG_SUBMISSION", 0) == 1)
if debug_mode:
print("=" * 20)
print("Submission in DEBUG MODE! will get limited time")
print("=" * 20)
# 8 hours (will get debug timeout from env variable if applicable)
OVERALL_TIMEOUT = int(os.getenv(
"FLATLAND_OVERALL_TIMEOUT",
2 * 60 * 60))
# 10 mins
INTIAL_PLANNING_TIMEOUT = int(os.getenv(
"FLATLAND_INITIAL_PLANNING_TIMEOUT",
10 * 60))
# 10 seconds
PER_STEP_TIMEOUT = int(os.getenv(
"FLATLAND_PER_STEP_TIMEOUT",
10))
# 5 min - applies to the rest of the commands
DEFAULT_COMMAND_TIMEOUT = int(os.getenv(
"FLATLAND_DEFAULT_COMMAND_TIMEOUT",
5 * 60))
RANDOM_SEED = int(os.getenv("FLATLAND_EVALUATION_RANDOM_SEED", 1001))
SUPPORTED_CLIENT_VERSIONS = \
[
flatland.__version__
]
class FlatlandRemoteEvaluationService:
"""
A remote evaluation service which exposes the following interfaces
of a RailEnv :
- env_create
- env_step
and an additional `env_submit` to cater to score computation and on-episode-complete post-processings.
This service is designed to be used in conjunction with
`FlatlandRemoteClient` and both the service and client maintain a
local instance of the RailEnv instance, and in case of any unexpected
divergences in the state of both the instances, the local RailEnv
instance of the `FlatlandRemoteEvaluationService` is supposed to act
as the single source of truth.
Both the client and remote service communicate with each other
via Redis as a message broker. The individual messages are packed and
unpacked with `msgpack` (a patched version of msgpack which also supports
numpy arrays).
"""
def __init__(
self,
test_env_folder="/tmp",
flatland_rl_service_id=FLATLAND_RL_SERVICE_ID,
remote_host=os.getenv("redis_ip", '127.0.0.1'),
remote_port=6379,
remote_db=0,
remote_password=None,
visualize=False,
video_generation_envs=[],
report=None,
verbose=False,
action_dir=None,
episode_dir=None,
analysis_data_dir=None,
merge_dir=None,
use_pickle=False,
shuffle=False,
missing_only=False,
result_output_path=None,
disable_timeouts=False
):
# Episode recording properties
self.action_dir = action_dir
if action_dir and not os.path.exists(self.action_dir):
os.makedirs(self.action_dir)
with open(os.path.join(self.action_dir, 'seed.yml'), 'w') as outfile:
yaml.dump({"RANDOM_SEED": RANDOM_SEED}, outfile, default_flow_style=False)
self.episode_dir = episode_dir
if episode_dir and not os.path.exists(self.episode_dir):
os.makedirs(self.episode_dir)
self.analysis_data_dir = analysis_data_dir
if analysis_data_dir and not os.path.exists(self.analysis_data_dir):
os.makedirs(self.analysis_data_dir)
self.merge_dir = merge_dir
if merge_dir and not os.path.exists(self.merge_dir):
os.makedirs(self.merge_dir)
self.use_pickle = use_pickle
self.missing_only = missing_only
self.episode_actions = []
self.disable_timeouts = disable_timeouts
if self.disable_timeouts:
print("=" * 20)
print("Timeout are DISABLED!")
print("=" * 20)
if shuffle:
print("=" * 20)
print("Env shuffling is ENABLED! not suitable for infinite wave")
print("=" * 20)
print("=" * 20)
print("Max pre-planning time:", INTIAL_PLANNING_TIMEOUT)
print("Max step time:", PER_STEP_TIMEOUT)
print("Max overall time:", OVERALL_TIMEOUT)
print("Max submission startup time:", DEFAULT_COMMAND_TIMEOUT)
print("Max consecutive timeouts:", MAX_SUCCESSIVE_TIMEOUTS)
print("=" * 20)
# Test Env folder Paths
self.test_env_folder = test_env_folder
self.video_generation_envs = video_generation_envs
self.env_file_paths = self.get_env_filepaths()
print(self.env_file_paths)
# Shuffle all the env_file_paths for more exciting videos
# and for more uniform time progression
if shuffle:
random.shuffle(self.env_file_paths)
print(self.env_file_paths)
self.instantiate_evaluation_metadata()
# Logging and Reporting related vars
self.verbose = verbose
self.report = report
# Use a state to swallow and ignore any steps after an env times out.
self.state_env_timed_out = False
# Count the number of successive timeouts (will kill after MAX_SUCCESSIVE_TIMEOUTS)
# This prevents a crashed submission to keep running forever
self.timeout_counter = 0
# Results are the metrics: percent done, rewards, timing...
self.result_output_path = result_output_path
# Communication Protocol Related vars
self.namespace = "flatland-rl"
self.service_id = flatland_rl_service_id
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
self.error_channel = "{}::{}::errors".format(
self.namespace,
self.service_id
)
# Message Broker related vars
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.instantiate_redis_connection_pool()
# AIcrowd evaluation specific vars
self.oracle_events = crowdai_api.events.CrowdAIEvents(with_oracle=True)
self.evaluation_state = {
"state": "PENDING",
"progress": 0.0,
"simulation_count": 0,
"total_simulation_count": len(self.env_file_paths),
"score": {
"score": 0.0,
"score_secondary": 0.0
},
"meta": {
"normalized_reward": 0.0
}
}
self.stats = {}
self.previous_command = {
"type": None
}
# RailEnv specific variables
self.env = False
self.env_renderer = False
self.reward = 0
self.simulation_done = True
self.simulation_count = -1
self.simulation_env_file_paths = []
self.simulation_rewards = []
self.simulation_rewards_normalized = []
self.simulation_percentage_complete = []
self.simulation_percentage_complete_per_test = {}
self.simulation_steps = []
self.simulation_times = []
self.env_step_times = []
self.nb_malfunctioning_trains = []
self.nb_deadlocked_trains = []
self.overall_start_time = 0
self.termination_cause = "No reported termination cause."
self.evaluation_done = False
self.begin_simulation = False
self.current_step = 0
self.current_test = -1
self.current_level = -1
self.visualize = visualize
self.vizualization_folder_name = "./.visualizations"
self.record_frame_step = 0
if self.visualize:
if os.path.exists(self.vizualization_folder_name):
print("[WARNING] Deleting already existing visualizations folder at : {}".format(
self.vizualization_folder_name
))
shutil.rmtree(self.vizualization_folder_name)
os.mkdir(self.vizualization_folder_name)
def update_running_stats(self, key, scalar):
"""
Computes the running min/mean/max for given param
"""
mean_key = "{}_mean".format(key)
counter_key = "{}_counter".format(key)
min_key = "{}_min".format(key)
max_key = "{}_max".format(key)
try:
# Update Mean
self.stats[mean_key] = \
((self.stats[mean_key] * self.stats[counter_key]) + scalar) / (self.stats[counter_key] + 1)
# Update min
if scalar < self.stats[min_key]:
self.stats[min_key] = scalar
# Update max
if scalar > self.stats[max_key]:
self.stats[max_key] = scalar
self.stats[counter_key] += 1
except KeyError:
self.stats[mean_key] = scalar
self.stats[min_key] = scalar
self.stats[max_key] = scalar
self.stats[counter_key] = 1
def delete_key_in_running_stats(self, key):
"""
This deletes a particular key in the running stats
dictionary, if it exists
"""
mean_key = "{}_mean".format(key)
counter_key = "{}_counter".format(key)
min_key = "{}_min".format(key)
max_key = "{}_max".format(key)
try:
del mean_key
del counter_key
del min_key
del max_key
except KeyError:
pass
def get_env_filepaths(self):
"""
Gathers a list of all available rail env files to be used
for evaluation. The folder structure expected at the `test_env_folder`
is similar to :
.
├── Test_0
│ ├── Level_1.pkl
│ ├── .......
│ ├── .......
│ └── Level_99.pkl
└── Test_1
├── Level_1.pkl
├── .......
├── .......
└── Level_99.pkl
"""
env_paths = glob.glob(
os.path.join(
self.test_env_folder,
"*/*.pkl"
)
)
# Remove the root folder name from the individual
# lists, so that we only have the path relative
# to the test root folder
env_paths = [os.path.relpath(x, self.test_env_folder) for x in env_paths]
# Sort in proper numerical order
def get_file_order(filename):
test_id, level_id = self.get_env_test_and_level(filename)
value = test_id * 1000 + level_id
return value
env_paths.sort(key=get_file_order)
# if requested, only generate actions for those envs which don't already have them
if self.merge_dir and self.missing_only:
existing_paths = (itertools.chain.from_iterable(
[glob.glob(os.path.join(self.merge_dir, f"envs/*.{ext}"))
for ext in ["pkl", "mpk"]]))
existing_paths = [os.path.relpath(sPath, self.merge_dir) for sPath in existing_paths]
env_paths = set(env_paths) - set(existing_paths)
return env_paths
def get_env_test_and_level(self, filename):
numbers = re.findall(r'\d+', os.path.relpath(filename))
if len(numbers) == 2:
test_id = int(numbers[0])
level_id = int(numbers[1])
else:
print(numbers)
raise ValueError("Unexpected file path, expects 'Test_<N>/Level_<M>.pkl', found", filename)
return test_id, level_id
def instantiate_evaluation_metadata(self):
"""
This instantiates a pandas dataframe to record
information specific to each of the individual env evaluations.
This loads the template CSV with pre-filled information from the
provided metadata.csv file, and fills it up with
evaluation runtime information.
"""
self.evaluation_metadata_df = None
metadata_file_path = os.path.join(
self.test_env_folder,
"metadata.csv"
)
if os.path.exists(metadata_file_path):
self.evaluation_metadata_df = pd.read_csv(metadata_file_path)
self.evaluation_metadata_df["filename"] = \
self.evaluation_metadata_df["test_id"] + \
"/" + self.evaluation_metadata_df["env_id"] + ".pkl"
self.evaluation_metadata_df = self.evaluation_metadata_df.set_index("filename")
# Add custom columns for evaluation specific metrics
self.evaluation_metadata_df["reward"] = np.nan
self.evaluation_metadata_df["normalized_reward"] = np.nan
self.evaluation_metadata_df["percentage_complete"] = np.nan
self.evaluation_metadata_df["steps"] = np.nan
self.evaluation_metadata_df["simulation_time"] = np.nan
self.evaluation_metadata_df["nb_malfunctioning_trains"] = np.nan
self.evaluation_metadata_df["nb_deadlocked_trains"] = np.nan
# Add client specific columns
# TODO: This needs refactoring
self.evaluation_metadata_df["controller_inference_time_min"] = np.nan
self.evaluation_metadata_df["controller_inference_time_mean"] = np.nan
self.evaluation_metadata_df["controller_inference_time_max"] = np.nan
else:
raise Exception("metadata.csv not found in tests folder ({}). Please use an updated version of the test set.".format(metadata_file_path))
def update_evaluation_metadata(self):
"""
This function is called when we move from one simulation to another
and it simply tries to update the simulation specific information
for the **previous** episode in the metadata_df if it exists.
"""
if self.evaluation_metadata_df is not None and len(self.simulation_env_file_paths) > 0:
last_simulation_env_file_path = self.simulation_env_file_paths[-1]
_row = self.evaluation_metadata_df.loc[
last_simulation_env_file_path
]
# Add controller_inference_time_metrics
# These metrics may be missing if no step was done before the episode finished
# generate the lists of names for the stats (input names and output names)
sPrefixIn = "current_episode_controller_inference_time_"
sPrefixOut = "controller_inference_time_"
lsStatIn = [sPrefixIn + sStat for sStat in ["min", "mean", "max"]]
lsStatOut = [sPrefixOut + sStat for sStat in ["min", "mean", "max"]]
if lsStatIn[0] in self.stats:
lrStats = [self.stats[sStat] for sStat in lsStatIn]
else:
lrStats = [0.0] * len(lsStatIn)
lsFields = ("reward, normalized_reward, percentage_complete, " + \
"steps, simulation_time, nb_malfunctioning_trains, nb_deadlocked_trains").split(", ") + \
lsStatOut
loValues = [self.simulation_rewards[-1],
self.simulation_rewards_normalized[-1],
self.simulation_percentage_complete[-1],
self.simulation_steps[-1],
self.simulation_times[-1],
self.nb_malfunctioning_trains[-1],
self.nb_deadlocked_trains[-1]
] + lrStats
# update the dataframe without the updating-a-copy warning
df = self.evaluation_metadata_df
df.loc[last_simulation_env_file_path, lsFields] = loValues
# _row.reward = self.simulation_rewards[-1]
# _row.normalized_reward = self.simulation_rewards_normalized[-1]
# _row.percentage_complete = self.simulation_percentage_complete[-1]
# _row.steps = self.simulation_steps[-1]
# _row.simulation_time = self.simulation_times[-1]
# _row.nb_malfunctioning_trains = self.nb_malfunctioning_trains[-1]
# _row.controller_inference_time_min = self.stats[
# "current_episode_controller_inference_time_min"
# ]
# _row.controller_inference_time_mean = self.stats[
# "current_episode_controller_inference_time_mean"
# ]
# _row.controller_inference_time_max = self.stats[
# "current_episode_controller_inference_time_max"
# ]
# else:
# _row.controller_inference_time_min = 0.0
# _row.controller_inference_time_mean = 0.0
# _row.controller_inference_time_max = 0.0
# self.evaluation_metadata_df.loc[
# last_simulation_env_file_path
# ] = _row
# Delete this key from the stats to ensure that it
# gets computed again from scratch in the next episode
self.delete_key_in_running_stats(
"current_episode_controller_inference_time")
if self.verbose:
print(self.evaluation_metadata_df)
def instantiate_redis_connection_pool(self):
"""
Instantiates a Redis connection pool which can be used to
communicate with the message broker
"""
if self.verbose or self.report:
print("Attempting to connect to redis server at {}:{}/{}".format(
self.remote_host,
self.remote_port,
self.remote_db))
self.redis_pool = redis.ConnectionPool(
host=self.remote_host,
port=self.remote_port,
db=self.remote_db,
password=self.remote_password
)
self.redis_conn = redis.Redis(connection_pool=self.redis_pool)
def get_redis_connection(self):
"""
Obtains a new redis connection from a previously instantiated
redis connection pool
"""
return self.redis_conn
def _error_template(self, payload):
"""
Simple helper function to pass a payload as a part of a
flatland comms error template.
"""
_response = {}
_response['type'] = messages.FLATLAND_RL.ERROR
_response['payload'] = payload
return _response
def get_next_command(self):
"""
A helper function to obtain the next command, which transparently
also deals with things like unpacking of the command from the
packed message, and consider the timeouts, etc when trying to
fetch a new command.
"""
COMMAND_TIMEOUT = DEFAULT_COMMAND_TIMEOUT
"""
Handle case specific timeouts :
- INTIAL_PLANNING_TIMEOUT
The timeout between an env_create call and the first env_step call
- PER_STEP_TIMEOUT
The timeout between two consecutive env_step calls
"""
if self.previous_command['type'] == messages.FLATLAND_RL.ENV_CREATE:
"""
In case the previous command is an env_create, then leave
a but more time for the intial planning
"""
COMMAND_TIMEOUT = INTIAL_PLANNING_TIMEOUT
elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
"""
Use the per_step_time for all timesteps between two env_step calls
# Corner Case :
- Are there any reasons why a call between the last env_step call
and the subsequent env_create call will take an excessively large
amount of time (>5s in this case)
"""
COMMAND_TIMEOUT = PER_STEP_TIMEOUT
elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
"""
If the user has already done an env_submit call, then the timeout
can be an arbitrarily large number.
"""
COMMAND_TIMEOUT = 10 ** 6
if self.disable_timeouts:
COMMAND_TIMEOUT = None
@timeout_decorator.timeout(COMMAND_TIMEOUT, use_signals=use_signals_in_timeout) # timeout for each command
def _get_next_command(command_channel, _redis):
"""
A low level wrapper for obtaining the next command from a
pre-agreed command channel.
At the momment, the communication protocol uses lpush for pushing
in commands, and brpop for reading out commands.
"""
command = _redis.brpop(command_channel)[1]
return command
# try:
if True:
_redis = self.get_redis_connection()
command = _get_next_command(self.command_channel, _redis)
if self.verbose or self.report:
print("Command Service: ", command)
if self.use_pickle:
command = pickle.loads(command)
else:
command = msgpack.unpackb(
command,
object_hook=m.decode,
strict_map_key=False, # msgpack 1.0
)
if self.verbose:
print("Received Request : ", command)
message_queue_latency = time.time() - command["timestamp"]
self.update_running_stats("message_queue_latency", message_queue_latency)
return command
def send_response(self, _command_response, command, suppress_logs=False):
_redis = self.get_redis_connection()
command_response_channel = command['response_channel']
if self.verbose and not suppress_logs:
print("Responding with : ", _command_response)
if self.use_pickle:
sResponse = pickle.dumps(_command_response)
else:
sResponse = msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
_redis.rpush(command_response_channel, sResponse)
def send_error(self, error_dict, suppress_logs=False):
""" For out-of-band errors like timeouts,
where we do not have a command, so we have no response channel!
"""
_redis = self.get_redis_connection()
print("Sending error : ", error_dict)
if self.use_pickle:
sResponse = pickle.dumps(error_dict)
else:
sResponse = msgpack.packb(
error_dict,
default=m.encode,
use_bin_type=True)
_redis.rpush(self.error_channel, sResponse)
def handle_ping(self, command):
"""
Handles PING command from the client.
"""
service_version = flatland.__version__
if "version" in command["payload"].keys():
client_version = command["payload"]["version"]
else:
# 2.1.4 -> when the version mismatch check was added
client_version = "2.1.4"
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.PONG
_command_response['payload'] = {}
if client_version not in SUPPORTED_CLIENT_VERSIONS:
_command_response['type'] = messages.FLATLAND_RL.ERROR
_command_response['payload']['message'] = \
"Client-Server Version Mismatch => " + \
"[ Client Version : {} ] ".format(client_version) + \
"[ Server Version : {} ] ".format(service_version)
self.send_response(_command_response, command)
raise Exception(_command_response['payload']['message'])
self.send_response(_command_response, command)
def handle_env_create(self, command):
"""
Handles a ENV_CREATE command from the client
"""
print(" -- [DEBUG] [env_create] EVAL DONE: ",self.evaluation_done)
# Check if the previous episode was finished
if not self.simulation_done and not self.evaluation_done:
_command_response = self._error_template("CAN'T CREATE NEW ENV BEFORE PREVIOUS IS DONE")
self.send_response(_command_response, command)
raise Exception(_command_response['payload'])
self.simulation_count += 1
self.simulation_done = False
if self.simulation_count == 0:
# Very first episode: start the overall timer
self.overall_start_time = time.time()
# reset the timeout flag / state.
self.state_env_timed_out = False
# Check if we have finished all the available envs
print(" -- [DEBUG] [env_create] SIM COUNT: ", self.simulation_count + 1, len(self.env_file_paths))
if self.simulation_count >= len(self.env_file_paths):
self.evaluation_done = True
# Hack - just ensure these are set
test_env_file_path = self.env_file_paths[self.simulation_count - 1]
env_test, env_level = self.get_env_test_and_level(test_env_file_path)
else:
test_env_file_path = self.env_file_paths[self.simulation_count]
env_test, env_level = self.get_env_test_and_level(test_env_file_path)
# Did we just finish a test, and if yes did it reach high enough mean percentage done?
if self.current_test != env_test and env_test != 0:
if self.current_test not in self.simulation_percentage_complete_per_test:
print("No environment was finished at all during test {}!".format(self.current_test))
mean_test_complete_percentage = 0.0
else:
mean_test_complete_percentage = np.mean(self.simulation_percentage_complete_per_test[self.current_test])
if mean_test_complete_percentage < TEST_MIN_PERCENTAGE_COMPLETE_MEAN:
print("=" * 15)
msg = "The mean percentage of done agents during the last Test ({} environments) was too low: {:.3f} < {}".format(
len(self.simulation_percentage_complete_per_test[self.current_test]),
mean_test_complete_percentage,
TEST_MIN_PERCENTAGE_COMPLETE_MEAN
)
print(msg, "Evaluation will stop.")
self.termination_cause = msg
self.evaluation_done = True
if self.simulation_count < len(self.env_file_paths) and not self.evaluation_done:
"""
There are still test envs left that are yet to be evaluated
"""
print("=" * 15)
print("Evaluating {} ({}/{})".format(test_env_file_path, self.simulation_count+1, len(self.env_file_paths)))
test_env_file_path = os.path.join(
self.test_env_folder,
test_env_file_path
)
self.current_test = env_test
self.current_level = env_level
del self.env
self.env, _env_dict = RailEnvPersister.load_new(test_env_file_path)
# distance map here?
self.begin_simulation = time.time()
# Update evaluation metadata for the previous episode
self.update_evaluation_metadata()
# Start adding placeholders for the new episode
self.simulation_env_file_paths.append(
os.path.relpath(
test_env_file_path,
self.test_env_folder
)) # relative path
self.simulation_rewards.append(0)
self.simulation_rewards_normalized.append(0)
self.simulation_percentage_complete.append(0)
self.simulation_times.append(0)
self.simulation_steps.append(0)
self.nb_malfunctioning_trains.append(0)
self.current_step = 0
_observation, _info = self.env.reset(
regenerate_rail=True,
regenerate_schedule=True,
random_seed=RANDOM_SEED
)
if self.visualize:
current_env_path = self.env_file_paths[self.simulation_count]
if current_env_path in self.video_generation_envs:
self.env_renderer = RenderTool(self.env, gl="PILSVG", )
elif self.env_renderer:
self.env_renderer = False
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = _observation
_command_response['payload']['env_file_path'] = self.env_file_paths[self.simulation_count]
_command_response['payload']['info'] = _info
_command_response['payload']['random_seed'] = RANDOM_SEED
else:
print(" -- [DEBUG] [env_create] return obs = False (END)")
"""
All test env evaluations are complete
"""
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_CREATE_RESPONSE
_command_response['payload'] = {}
_command_response['payload']['observation'] = False
_command_response['payload']['env_file_path'] = False
_command_response['payload']['info'] = False
_command_response['payload']['random_seed'] = False
self.send_response(_command_response, command)
#####################################################################
# Update evaluation state
#####################################################################
elapsed = time.time() - self.overall_start_time
progress = np.clip(
elapsed / OVERALL_TIMEOUT,
0, 1)
mean_reward, mean_normalized_reward, sum_normalized_reward, mean_percentage_complete = self.compute_mean_scores()
self.evaluation_state["state"] = "IN_PROGRESS"
self.evaluation_state["progress"] = progress
self.evaluation_state["simulation_count"] = self.simulation_count
self.evaluation_state["score"]["score"] = sum_normalized_reward
self.evaluation_state["score"]["score_secondary"] = mean_percentage_complete
self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
self.evaluation_state["meta"]["termination_cause"] = self.termination_cause
self.handle_aicrowd_info_event(self.evaluation_state)
self.episode_actions = []
def handle_env_step(self, command):
"""
Handles a ENV_STEP command from the client
TODO: Add a high level summary of everything thats happening here.
"""
if self.state_env_timed_out or self.evaluation_done:
print("Ignoring step command after timeout.")
return
_payload = command['payload']
if not self.env:
raise Exception("env_client.step called before env_client.env_create() call")
if self.env.dones['__all__']:
raise Exception(
"Client attempted to perform an action on an Env which \
has done['__all__']==True")
overall_elapsed = (time.time() - self.overall_start_time)
if overall_elapsed > OVERALL_TIMEOUT:
msg = "Reached overall time limit: took {:.2f}s, limit is {:.2f}s.".format(
overall_elapsed, OVERALL_TIMEOUT
)
self.termination_cause = msg
self.evaluation_done = True
print("=" * 15)
print(msg, "Evaluation will stop.")
return
# else:
# print("="*15)
# print("{}s left!".format(OVERALL_TIMEOUT - overall_elapsed))
action = _payload['action']
inference_time = _payload['inference_time']
# We record this metric in two keys:
# - One for the current episode
# - One global
self.update_running_stats("current_episode_controller_inference_time", inference_time)
self.update_running_stats("controller_inference_time", inference_time)
# Perform the step
time_start = time.time()
_observation, all_rewards, done, info = self.env.step(action)
time_diff = time.time() - time_start
self.update_running_stats("internal_env_step_time", time_diff)
self.current_step += 1
cumulative_reward = sum(all_rewards.values())
self.simulation_rewards[-1] += cumulative_reward
self.simulation_steps[-1] += 1
"""
The normalized rewards normalize the reward for an
episode by dividing the whole reward by max-time-steps
allowed in that episode, and the number of agents present in
that episode
"""
self.simulation_rewards_normalized[-1] += \
(cumulative_reward / (
self.env._max_episode_steps *
self.env.get_num_agents()
))
# We count the number of agents that malfunctioned by checking how many have 1 more steps left before recovery
num_malfunctioning = sum(agent.malfunction_handler._malfunction_down_counter == 1 for agent in self.env.agents)
if self.verbose and num_malfunctioning > 0:
print("Step {}: {} agents have malfunctioned and will recover next step".format(self.current_step, num_malfunctioning))
self.nb_malfunctioning_trains[-1] += num_malfunctioning
# record the actions before checking for done
if self.action_dir is not None:
action = {key: int(val) for key, val in action.items()}
self.episode_actions.append(action)
# Is the episode over?
if done["__all__"]:
self.simulation_done = True
if self.begin_simulation:
# If begin simulation has already been initialized at least once
# This adds the simulation time for the previous episode
self.simulation_times[-1] = time.time() - self.begin_simulation
# Compute percentage complete
complete = 0
for i_agent in range(self.env.get_num_agents()):
agent = self.env.agents[i_agent]
if agent.state == TrainState.DONE:
complete += 1
percentage_complete = complete * 1.0 / self.env.get_num_agents()
self.simulation_percentage_complete[-1] = percentage_complete
# adds 1.0 so we can add them up
self.simulation_rewards_normalized[-1] += 1.0
if self.current_test not in self.simulation_percentage_complete_per_test:
self.simulation_percentage_complete_per_test[self.current_test] = []
self.simulation_percentage_complete_per_test[self.current_test].append(percentage_complete)
print("Percentage for test {}, level {}: {}".format(self.current_test, self.current_level, percentage_complete))
if len(self.env.cur_episode) > 0:
g3Ep = np.array(self.env.cur_episode)
self.nb_deadlocked_trains.append(np.sum(g3Ep[-1, :, 5]))
else:
self.nb_deadlocked_trains.append(np.nan)
print(
"Evaluation finished in {} timesteps, {:.3f} seconds. Percentage agents done: {:.3f}. Normalized reward: {:.3f}. Number of malfunctions: {}.".format(
self.simulation_steps[-1],
self.simulation_times[-1],
self.simulation_percentage_complete[-1],
self.simulation_rewards_normalized[-1],
self.nb_malfunctioning_trains[-1],
self.nb_deadlocked_trains[-1]
))
print("Total normalized reward so far: {:.3f}".format(sum(self.simulation_rewards_normalized)))
# Write intermediate results
if self.result_output_path:
self.evaluation_metadata_df.to_csv(self.result_output_path)
print("Wrote intermediate output results to : {}".format(self.result_output_path))
if self.action_dir is not None:
self.save_actions()
if self.episode_dir is not None:
self.save_episode()
if self.analysis_data_dir is not None:
self.collect_analysis_data()
self.save_analysis_data()
if self.merge_dir is not None:
self.save_merged_env()
# Record Frame
if self.visualize:
"""
Only generate and save the frames for environments which are separately provided
in video_generation_indices param
"""
current_env_path = self.env_file_paths[self.simulation_count]
if current_env_path in self.video_generation_envs:
self.env_renderer.render_env(
show=False,
show_observations=False,
show_predictions=False,
show_rowcols=False
)
self.env_renderer.gl.save_image(
os.path.join(
self.vizualization_folder_name,
"flatland_frame_{:04d}.png".format(self.record_frame_step)
))
self.record_frame_step += 1
def save_actions(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfActions = os.path.join(self.action_dir, sfEnv.replace(".pkl", ".json"))
print("env path: ", sfEnv, " actions path:", sfActions)
if not os.path.exists(os.path.dirname(sfActions)):
os.makedirs(os.path.dirname(sfActions))
with open(sfActions, "w") as fOut:
json.dump(self.episode_actions, fOut)
self.episode_actions = []
def save_episode(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfEpisode = self.episode_dir + "/" + sfEnv
print("env path: ", sfEnv, " sfEpisode:", sfEpisode)
RailEnvPersister.save_episode(self.env, sfEpisode)
# self.env.save_episode(sfEpisode)
def collect_analysis_data(self):
'''
Collect data at the END of an episode.
Data to be saved in a json file corresponding to the episode.
'''
self.analysis_data = {}
agent_speeds = []
agent_states = []
agent_earliest_departures = []
agent_latest_arrivals = []
agent_arrival_times = []
agent_shortest_paths = [] # only for nor arrived trains
agent_current_delays = [] # only for not arrived trains
agent_rewards = list(self.env.rewards_dict.values())
for i_agent in range(self.env.get_num_agents()):
agent = self.env.agents[i_agent]
agent_speeds.append(agent.speed_counter.speed)
agent_states.append(agent.state)
agent_earliest_departures.append(agent.earliest_departure)
agent_latest_arrivals.append(agent.latest_arrival)
agent_arrival_times.append(agent.arrival_time)
if (agent.state != TrainState.DONE):
sp = agent.get_shortest_path(self.env.distance_map)
len_sp = len(sp) if sp is not None else -1
agent_shortest_paths.append(len_sp)
agent_current_delays.append(agent.get_current_delay(self.env._elapsed_steps, self.env.distance_map))
else:
agent_shortest_paths.append(None)
agent_current_delays.append(None)
self.analysis_data['agent_speeds'] = agent_speeds
self.analysis_data['agent_states'] = agent_states
self.analysis_data['agent_earliest_departures'] = agent_earliest_departures
self.analysis_data['agent_latest_arrivals'] = agent_latest_arrivals
self.analysis_data['agent_arrival_times'] = agent_arrival_times
self.analysis_data['agent_shortest_paths'] = agent_shortest_paths
self.analysis_data['agent_current_delays'] = agent_current_delays
self.analysis_data['agent_rewards'] = agent_rewards
def save_analysis_data(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfData = os.path.join(self.analysis_data_dir, sfEnv.replace(".pkl", ".json"))
print("env path: ", sfEnv, " data path:", sfData)
if not os.path.exists(os.path.dirname(sfData)):
os.makedirs(os.path.dirname(sfData))
with open(sfData, "w") as fOut:
json.dump(self.analysis_data, fOut)
self.analysis_data = {}
def save_merged_env(self):
sfEnv = self.env_file_paths[self.simulation_count]
sfMergeEnv = self.merge_dir + "/" + sfEnv
if not os.path.exists(os.path.dirname(sfMergeEnv)):
os.makedirs(os.path.dirname(sfMergeEnv))
print("Input env path: ", sfEnv, " Merge File:", sfMergeEnv)
RailEnvPersister.save_episode(self.env, sfMergeEnv)
# self.env.save_episode(sfMergeEnv)
def handle_env_submit(self, command):
"""
Handles a ENV_SUBMIT command from the client
TODO: Add a high level summary of everything thats happening here.
"""
_payload = command['payload']
######################################################################
# Print Local Stats
######################################################################
print("=" * 100)
print("=" * 100)
print("## Server Performance Stats")
print("=" * 100)
for _key in self.stats:
if _key.endswith("_mean"):
metric_name = _key.replace("_mean", "")
mean_key = "{}_mean".format(metric_name)
min_key = "{}_min".format(metric_name)
max_key = "{}_max".format(metric_name)
print("\t - {}\t => min: {} || mean: {} || max: {}".format(
metric_name,
self.stats[min_key],
self.stats[mean_key],
self.stats[max_key]))
print("=" * 100)
# Register simulation time of the last episode
self.simulation_times.append(time.time() - self.begin_simulation)
# Compute the evaluation metadata for the last episode
self.update_evaluation_metadata()
if len(self.simulation_rewards) != len(self.env_file_paths) and not self.evaluation_done:
raise Exception(
"""env.submit called before the agent had the chance
to operate on all the test environments.
"""
)
mean_reward, mean_normalized_reward, sum_normalized_reward, mean_percentage_complete = self.compute_mean_scores()
if self.visualize and len(os.listdir(self.vizualization_folder_name)) > 0:
# Generate the video
#
# Note, if you had depdency issues due to ffmpeg, you can
# install it by :
#
# conda install -c conda-forge x264 ffmpeg
print("Generating Video from thumbnails...")
video_output_path, video_thumb_output_path = \
aicrowd_helpers.generate_movie_from_frames(
self.vizualization_folder_name
)
print("Videos : ", video_output_path, video_thumb_output_path)
# Upload to S3 if configuration is available
if aicrowd_helpers.is_grading() and aicrowd_helpers.is_aws_configured() and self.visualize:
video_s3_key = aicrowd_helpers.upload_to_s3(
video_output_path
)
video_thumb_s3_key = aicrowd_helpers.upload_to_s3(
video_thumb_output_path
)
static_thumbnail_s3_key = aicrowd_helpers.upload_random_frame_to_s3(
self.vizualization_folder_name
)
self.evaluation_state["score"]["media_content_type"] = "video/mp4"
self.evaluation_state["score"]["media_large"] = video_s3_key
self.evaluation_state["score"]["media_thumbnail"] = video_thumb_s3_key
self.evaluation_state["meta"]["static_media_frame"] = static_thumbnail_s3_key
else:
print("[WARNING] Ignoring uploading of video to S3")
#####################################################################
# Save `data` and `action` directories
#####################################################################
if self.action_dir:
if aicrowd_helpers.is_grading() and aicrowd_helpers.is_aws_configured():
aicrowd_helpers.upload_folder_to_s3(self.action_dir)
else:
print("[WARNING] Ignoring uploading action_dir to S3")
if self.analysis_data_dir:
if aicrowd_helpers.is_grading() and aicrowd_helpers.is_aws_configured():
aicrowd_helpers.upload_folder_to_s3(self.analysis_data_dir)
else:
print("[WARNING] Ignoring uploading analysis_data_dir to S3")
#####################################################################
# Write Results to a file (if applicable)
#####################################################################
if self.result_output_path:
self.evaluation_metadata_df.to_csv(self.result_output_path)
print("Wrote output results to : {}".format(self.result_output_path))
# Upload the metadata file to S3
if aicrowd_helpers.is_grading() and aicrowd_helpers.is_aws_configured():
metadata_s3_key = aicrowd_helpers.upload_to_s3(
self.result_output_path
)
self.evaluation_state["meta"]["private_metadata_s3_key"] = metadata_s3_key
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE
_payload = {}
_payload['mean_reward'] = mean_reward
_payload['sum_normalized_reward'] = sum_normalized_reward
_payload['mean_percentage_complete'] = mean_percentage_complete
_payload['mean_normalized_reward'] = mean_normalized_reward
_command_response['payload'] = _payload
self.send_response(_command_response, command)
#####################################################################
# Update evaluation state
#####################################################################
self.evaluation_state["state"] = "FINISHED"
self.evaluation_state["progress"] = 1.0
self.evaluation_state["simulation_count"] = self.simulation_count
self.evaluation_state["score"]["score"] = sum_normalized_reward
self.evaluation_state["score"]["score_secondary"] = mean_percentage_complete
self.evaluation_state["meta"]["normalized_reward"] = mean_normalized_reward
self.evaluation_state["meta"]["reward"] = mean_reward
self.evaluation_state["meta"]["percentage_complete"] = mean_percentage_complete
self.evaluation_state["meta"]["termination_cause"] = self.termination_cause
self.handle_aicrowd_success_event(self.evaluation_state)
print("#" * 100)
print("EVALUATION COMPLETE !!")
print("#" * 100)
print("# Mean Reward : {}".format(mean_reward))
print("# Sum Normalized Reward : {} (primary score)".format(sum_normalized_reward))
print("# Mean Percentage Complete : {} (secondary score)".format(mean_percentage_complete))
print("# Mean Normalized Reward : {}".format(mean_normalized_reward))
print("#" * 100)
print("#" * 100)
return _command_response
def compute_mean_scores(self):
#################################################################################
#################################################################################
# Compute the mean rewards, mean normalized_reward and mean_percentage_complete
# we group all the results by the test_ids
# so we first compute the mean in each of the test_id groups,
# and then we compute the mean across each of the test_id groups
#################################################################################
#################################################################################
source_df = self.evaluation_metadata_df
# grouped_df = source_df.groupby(['test_id']).mean()
mean_reward = source_df["reward"].mean()
mean_normalized_reward = source_df["normalized_reward"].mean()
sum_normalized_reward = source_df["normalized_reward"].sum()
mean_percentage_complete = source_df["percentage_complete"].mean()
# Round off the reward values
mean_reward = round(mean_reward, 2)
mean_normalized_reward = round(mean_normalized_reward, 5)
mean_percentage_complete = round(mean_percentage_complete, 3)
return mean_reward, mean_normalized_reward, sum_normalized_reward, mean_percentage_complete
def report_error(self, error_message, command_response_channel):
"""
A helper function used to report error back to the client
"""
_redis = self.get_redis_connection()
_command_response = {}
_command_response['type'] = messages.FLATLAND_RL.ERROR
_command_response['payload'] = error_message
if self.use_pickle:
bytes_error = pickle.dumps(_command_response)
else:
bytes_error = msgpack.packb(
_command_response,
default=m.encode,
use_bin_type=True)
_redis.rpush(command_response_channel, bytes_error)
self.evaluation_state["state"] = "ERROR"
self.evaluation_state["error"] = error_message
self.evaluation_state["meta"]["termination_cause"] = "An error occured."
self.handle_aicrowd_error_event(self.evaluation_state)
def handle_aicrowd_info_event(self, payload):
self.oracle_events.register_event(
event_type=self.oracle_events.CROWDAI_EVENT_INFO,
payload=payload
)
def handle_aicrowd_success_event(self, payload):
self.oracle_events.register_event(
event_type=self.oracle_events.CROWDAI_EVENT_SUCCESS,
payload=payload
)
def handle_aicrowd_error_event(self, payload):
self.oracle_events.register_event(
event_type=self.oracle_events.CROWDAI_EVENT_ERROR,
payload=payload
)
def run(self):
"""
Main runner function which waits for commands from the client
and acts accordingly.
"""
print("Listening at : ", self.command_channel)
MESSAGE_QUEUE_LATENCY = []
while True:
try:
command = self.get_next_command()
except timeout_decorator.timeout_decorator.TimeoutError:
# a timeout occurred: send an error, and give -1.0 normalized score for this episode
if self.previous_command['type'] == messages.FLATLAND_RL.ENV_STEP:
self.send_error({"type": messages.FLATLAND_RL.ENV_STEP_TIMEOUT})
timeout_details = "step time limit of {}s".format(PER_STEP_TIMEOUT)
elif self.previous_command['type'] == messages.FLATLAND_RL.ENV_CREATE:
self.send_error({"type": messages.FLATLAND_RL.ENV_RESET_TIMEOUT})
timeout_details = "pre-planning time limit of {}s".format(INTIAL_PLANNING_TIMEOUT)
self.simulation_steps[-1] += 1
self.simulation_rewards[-1] = self.env._max_episode_steps * self.env.get_num_agents()
self.simulation_rewards_normalized[-1] = 0.0
print(
"Evaluation of this episode TIMED OUT after {} timesteps (exceeded {}), won't get any reward. {} consecutive timeouts. "
"Percentage agents done: {:.3f}. Normalized reward: {:.3f}. Number of malfunctions: {}.".format(
self.simulation_steps[-1],
timeout_details,
self.timeout_counter,
self.simulation_percentage_complete[-1],
self.simulation_rewards_normalized[-1],
self.nb_malfunctioning_trains[-1],
))
self.timeout_counter += 1
self.state_env_timed_out = True
self.simulation_done = True
if self.timeout_counter >= MAX_SUCCESSIVE_TIMEOUTS:
print("=" * 15)
msg = "Submissions had {} consecutive timeouts.".format(self.timeout_counter)
print(msg, "Evaluation will stop.")
self.termination_cause = msg
self.evaluation_done = True
# JW - change the command to a submit
print("Creating fake submit message after excessive timeouts.")
command = {
"type": messages.FLATLAND_RL.ENV_SUBMIT,
"payload": {},
"response_channel": self.previous_command.get("response_channel")}
return self.handle_env_submit(command)
continue
self.timeout_counter = 0
if "timestamp" in command.keys():
latency = time.time() - command["timestamp"]
MESSAGE_QUEUE_LATENCY.append(latency)
if self.verbose:
print("Self.Reward : ", self.reward)
print("Current Simulation : ", self.simulation_count)
if self.env_file_paths and \
self.simulation_count < len(self.env_file_paths):
print("Current Env Path : ",
self.env_file_paths[self.simulation_count])
try:
if command['type'] == messages.FLATLAND_RL.PING:
"""
INITIAL HANDSHAKE : Respond with PONG
"""
self.handle_ping(command)
elif command['type'] == messages.FLATLAND_RL.ENV_CREATE:
"""
ENV_CREATE
Respond with an internal _env object
"""
self.handle_env_create(command)
elif command['type'] == messages.FLATLAND_RL.ENV_STEP:
"""
ENV_STEP
Request : Action dict
Respond with updated [observation,reward,done,info] after step
"""
self.handle_env_step(command)
elif command['type'] == messages.FLATLAND_RL.ENV_SUBMIT:
"""
ENV_SUBMIT
Submit the final cumulative reward
"""
print("Overall Message Queue Latency : ", np.array(MESSAGE_QUEUE_LATENCY).mean())
return self.handle_env_submit(command)
else:
_error = self._error_template(
"UNKNOWN_REQUEST:{}".format(
str(command)))
if self.verbose:
print("Responding with : ", _error)
if "response_channel" in command:
self.report_error(
_error,
command['response_channel'])
return _error
###########################################
# We keep a record of the previous command
# to be able to have different behaviors
# between different "command transitions"
#
# An example use case, is when we want to
# have a different timeout for the
# first step in every environment
# to account for some initial planning time
self.previous_command = command
except Exception as e:
print("Error : ", str(e))
print(traceback.format_exc())
if ("response_channel" in command):
self.report_error(
self._error_template(str(e)),
command['response_channel'])
return self._error_template(str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Submit the result to AIcrowd')
parser.add_argument('--service_id',
dest='service_id',
default=FLATLAND_RL_SERVICE_ID,
required=False)
parser.add_argument('--test_folder',
dest='test_folder',
default="../../../submission-scoring/Envs-Small",
help="Folder containing the files for the test envs",
required=False)
parser.add_argument('--actionDir',
dest='actionDir',
default=None,
help="deprecated - use mergeDir. Folder containing the files for the test envs",
required=False)
parser.add_argument('--episodeDir',
dest='episodeDir',
default=None,
help="deprecated - use mergeDir. Folder containing the files for the test envs",
required=False)
parser.add_argument('--mergeDir',
dest='mergeDir',
default=None,
help="Folder to store merged envs, actions, episodes.",
required=False)
parser.add_argument('--pickle',
default=False,
action="store_true",
help="use pickle instead of msgpack",
required=False)
parser.add_argument('--shuffle',
default=False,
action="store_true",
help="Shuffle the environments",
required=False)
parser.add_argument('--disableTimeouts',
default=False,
action="store_true",
help="Disable all timeouts.",
required=False)
parser.add_argument('--missingOnly',
default=False,
action="store_true",
help="only request the envs/actions which are missing",
required=False)
parser.add_argument('--resultsDir',
default="/tmp/output.csv",
help="Results CSV path",
required=False)
parser.add_argument('--verbose',
default=False,
action="store_true",
help="verbose debug messages",
required=False)
args = parser.parse_args()
test_folder = args.test_folder
grader = FlatlandRemoteEvaluationService(
test_env_folder=test_folder,
flatland_rl_service_id=args.service_id,
verbose=args.verbose,
visualize=True,
video_generation_envs=["Test_0/Level_100.pkl"],
result_output_path=args.resultsDir,
action_dir=args.actionDir,
episode_dir=args.episodeDir,
merge_dir=args.mergeDir,
use_pickle=args.pickle,
shuffle=args.shuffle,
missing_only=args.missingOnly,
disable_timeouts=args.disableTimeouts
)
result = grader.run()
if result['type'] == messages.FLATLAND_RL.ENV_SUBMIT_RESPONSE:
cumulative_results = result['payload']
elif result['type'] == messages.FLATLAND_RL.ERROR:
error = result['payload']
raise Exception("Evaluation Failed : {}".format(str(error)))
else:
# Evaluation failed
print("Evaluation Failed : ", result['payload'])
# -*- coding: utf-8 -*-
"""Main module."""
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.