Commit aa40c923 authored by u214892's avatar u214892
Browse files

#141 different agent classes

parent e5606f1e
Pipeline #1831 failed with stages
in 9 minutes and 6 seconds
......@@ -3,9 +3,9 @@ import random
import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
def run_benchmark():
......
......@@ -5,11 +5,11 @@ import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import random_rail_generator, complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator, complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.utils.rendertools import RenderTool
random.seed(100)
......
......@@ -5,9 +5,9 @@ import numpy as np
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_generators import AgentGenerator, AgentGeneratorProduct
from flatland.envs.generators import RailGenerator, RailGeneratorProduct
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct
from flatland.utils.rendertools import RenderTool
random.seed(100)
......@@ -29,8 +29,8 @@ def custom_rail_generator() -> RailGenerator:
return generator
def custom_agent_generator() -> AgentGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
def custom_agent_generator() -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
agents_positions = []
agents_direction = []
agents_target = []
......
......@@ -3,10 +3,10 @@ import time
import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.utils.rendertools import RenderTool
random.seed(1)
......
from flatland.envs.generators import rail_from_manual_specifications_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_manual_specifications_generator
from flatland.utils.rendertools import RenderTool
# Example generate a rail given a manual specification,
......
......@@ -2,8 +2,8 @@ import random
import numpy as np
from flatland.envs.generators import random_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import random_rail_generator
from flatland.utils.rendertools import RenderTool
random.seed(100)
......
......@@ -2,10 +2,10 @@ import random
import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.utils.rendertools import RenderTool
random.seed(1)
......
import numpy as np
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.utils.rendertools import RenderTool
np.random.seed(1)
......
......@@ -8,9 +8,9 @@ import click
import numpy as np
import redis
from flatland.envs.agent_generators import complex_rail_generator_agents_placer
from flatland.envs.generators import complex_rail_generator
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.schedule_generators import complex_rail_generator_agents_placer
from flatland.evaluators.service import FlatlandRemoteEvaluationService
from flatland.utils.rendertools import RenderTool
......
......@@ -11,10 +11,10 @@ import numpy as np
from flatland.core.env import Environment
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_generators import get_rnd_agents_pos_tgt_dir_on_rail, AgentGenerator
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator, RailGenerator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import get_rnd_agents_pos_tgt_dir_on_rail, ScheduleGenerator
m.patch()
......@@ -93,7 +93,7 @@ class RailEnv(Environment):
width,
height,
rail_generator: RailGenerator = random_rail_generator(),
agent_generator: AgentGenerator = get_rnd_agents_pos_tgt_dir_on_rail(),
agent_generator: ScheduleGenerator = get_rnd_agents_pos_tgt_dir_on_rail(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None,
......@@ -110,11 +110,11 @@ class RailEnv(Environment):
the env has been reset, and returns a GridTransitionMap object and a list of
starting positions, targets, and initial orientations for agent handle.
The rail_generator can pass a distance map in the hints or information for specific agent_generators.
Implementations can be found in flatland/envs/generators.py
Implementations can be found in flatland/envs/rail_generators.py
agent_generator : function
The agent_generator function is a function that takes the grid, the number of agents and optional hints
and returns a list of starting positions, targets, initial orientations and speed for all agent handles.
Implementations can be found in flatland/envs/agent_generators.py
Implementations can be found in flatland/envs/schedule_generators.py
width : int
The width of the rail map. Potentially in the future,
a range of widths to sample from.
......@@ -133,7 +133,7 @@ class RailEnv(Environment):
"""
self.rail_generator: RailGenerator = rail_generator
self.agent_generator: AgentGenerator = agent_generator
self.agent_generator: ScheduleGenerator = agent_generator
self.rail = None
self.width = width
self.height = height
......
"""Agent generators (railway undertaking, "EVU")."""
"""Schedule generators (railway undertaking, "EVU")."""
from typing import Tuple, List, Callable, Mapping, Optional, Any
import msgpack
......@@ -9,8 +9,8 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
AgentPosition = Tuple[int, int]
AgentGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
AgentGenerator = Callable[[GridTransitionMap, int, Optional[Any]], AgentGeneratorProduct]
ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None) -> List[float]:
......@@ -37,7 +37,7 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None):
start_goal = hints['start_goal']
start_dir = hints['start_dir']
......@@ -55,7 +55,7 @@ def complex_rail_generator_agents_placer(speed_ratio_map: Mapping[float, float]
return generator
def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> AgentGenerator:
def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] = None) -> ScheduleGenerator:
"""
Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).
......@@ -73,7 +73,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] =
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
def _path_exists(rail, start, direction, end):
# BFS - Check if a path exists between the 2 nodes
......@@ -151,7 +151,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(speed_ratio_map: Mapping[float, float] =
return generator
def agents_from_file(filename) -> AgentGenerator:
def agents_from_file(filename) -> ScheduleGenerator:
"""
Utility to load pickle file
......@@ -165,7 +165,7 @@ def agents_from_file(filename) -> AgentGenerator:
initial positions, directions, targets speeds
"""
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> AgentGeneratorProduct:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
......
import redis
import hashlib
import json
import logging
import os
import numpy as np
import random
import time
import msgpack
import msgpack_numpy as m
import hashlib
import random
from flatland.evaluators import messages
from flatland.envs.rail_env import RailEnv
from flatland.envs.generators import rail_from_file
import numpy as np
import redis
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
import time
import logging
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.evaluators import messages
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
m.patch()
......@@ -22,8 +25,8 @@ def are_dicts_equal(d1, d2):
""" return True if all keys and values are the same """
return all(k in d2 and d1[k] == d2[k]
for k in d1) \
and all(k in d1 and d1[k] == d2[k]
for k in d2)
and all(k in d1 and d1[k] == d2[k]
for k in d2)
class FlatlandRemoteClient(object):
......@@ -41,39 +44,40 @@ class FlatlandRemoteClient(object):
where `service_id` is either provided as an `env` variable or is
instantiated to "flatland_rl_redis_service_id"
"""
def __init__(self,
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
test_envs_root=None,
verbose=False):
def __init__(self,
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
test_envs_root=None,
verbose=False):
self.remote_host = remote_host
self.remote_port = remote_port
self.remote_db = remote_db
self.remote_password = remote_password
self.redis_pool = redis.ConnectionPool(
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
host=remote_host,
port=remote_port,
db=remote_db,
password=remote_password)
self.namespace = "flatland-rl"
self.service_id = os.getenv(
'FLATLAND_RL_SERVICE_ID',
'FLATLAND_RL_SERVICE_ID'
)
'FLATLAND_RL_SERVICE_ID',
'FLATLAND_RL_SERVICE_ID'
)
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
self.namespace,
self.service_id
)
if test_envs_root:
self.test_envs_root = test_envs_root
else:
self.test_envs_root = os.getenv(
'AICROWD_TESTS_FOLDER',
'/tmp/flatland_envs'
)
'AICROWD_TESTS_FOLDER',
'/tmp/flatland_envs'
)
self.verbose = verbose
......@@ -85,12 +89,12 @@ class FlatlandRemoteClient(object):
def _generate_response_channel(self):
random_hash = hashlib.md5(
"{}".format(
random.randint(0, 10**10)
).encode('utf-8')).hexdigest()
"{}".format(
random.randint(0, 10 ** 10)
).encode('utf-8')).hexdigest()
response_channel = "{}::{}::response::{}".format(self.namespace,
self.service_id,
random_hash)
self.service_id,
random_hash)
return response_channel
def _blocking_request(self, _request):
......@@ -124,9 +128,9 @@ class FlatlandRemoteClient(object):
if self.verbose:
print("Response : ", _response)
_response = msgpack.unpackb(
_response,
object_hook=m.decode,
encoding="utf8")
_response,
object_hook=m.decode,
encoding="utf8")
if _response['type'] == messages.FLATLAND_RL.ERROR:
raise Exception(str(_response["payload"]))
else:
......@@ -181,7 +185,7 @@ class FlatlandRemoteClient(object):
"Did you remember to set the AICROWD_TESTS_FOLDER environment variable "
"to point to the location of the Tests folder ? \n"
"We are currently looking at `{}` for the tests".format(self.test_envs_root)
)
)
print("Current env path : ", test_env_file_path)
self.env = RailEnv(
width=1,
......@@ -207,7 +211,7 @@ class FlatlandRemoteClient(object):
_request['payload']['action'] = action
_response = self._blocking_request(_request)
_payload = _response['payload']
# remote_observation = _payload['observation']
remote_reward = _payload['reward']
remote_done = _payload['done']
......@@ -216,14 +220,14 @@ class FlatlandRemoteClient(object):
# Replicate the action in the local env
local_observation, local_reward, local_done, local_info = \
self.env.step(action)
print(local_reward)
if not are_dicts_equal(remote_reward, local_reward):
raise Exception("local and remote `reward` are diverging")
print(remote_reward, local_reward)
if not are_dicts_equal(remote_done, local_done):
raise Exception("local and remote `done` are diverging")
# Return local_observation instead of remote_observation
# as the remote_observation is build using a dummy observation
# builder
......@@ -250,21 +254,23 @@ class FlatlandRemoteClient(object):
if __name__ == "__main__":
remote_client = FlatlandRemoteClient()
def my_controller(obs, _env):
_action = {}
for _idx, _ in enumerate(_env.agents):
_action[_idx] = np.random.randint(0, 5)
return _action
my_observation_builder = TreeObsForRailEnv(max_depth=3,
predictor=ShortestPathPredictorForRailEnv())
predictor=ShortestPathPredictorForRailEnv())
episode = 0
obs = True
while obs:
while obs:
obs = remote_client.env_create(
obs_builder_object=my_observation_builder
)
obs_builder_object=my_observation_builder
)
if not obs:
"""
The remote env returns False as the first obs
......@@ -285,7 +291,5 @@ if __name__ == "__main__":
print("Reward : ", sum(list(all_rewards.values())))
break
print("Evaluation Complete...")
print("Evaluation Complete...")
print(remote_client.submit())
#!/usr/bin/env python
from __future__ import print_function
import redis
from flatland.envs.generators import rail_from_file
from flatland.envs.rail_env import RailEnv
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.evaluators import messages
from flatland.evaluators import aicrowd_helpers
from flatland.utils.rendertools import RenderTool
import numpy as np
import msgpack
import msgpack_numpy as m
import os
import glob
import os
import random
import shutil
import time
import traceback
import crowdai_api
import msgpack
import msgpack_numpy as m
import numpy as np
import redis
import timeout_decorator
import random
from flatland.core.env_observation_builder import DummyObservationBuilder
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_file
from flatland.evaluators import aicrowd_helpers
from flatland.evaluators import messages
from flatland.utils.rendertools import RenderTool
use_signals_in_timeout = True
if os.name == 'nt':
......@@ -35,7 +37,7 @@ m.patch()
########################################################
# CONSTANTS
########################################################
PER_STEP_TIMEOUT = 10*60 # 5 minutes
PER_STEP_TIMEOUT = 10 * 60 # 5 minutes
class FlatlandRemoteEvaluationService:
......@@ -59,17 +61,18 @@ class FlatlandRemoteEvaluationService:
unpacked with `msgpack` (a patched version of msgpack which also supports
numpy arrays).
"""
def __init__(self,
test_env_folder="/tmp",
flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
visualize=False,
video_generation_envs=[],
report=None,
verbose=False):
test_env_folder="/tmp",
flatland_rl_service_id='FLATLAND_RL_SERVICE_ID',
remote_host='127.0.0.1',
remote_port=6379,
remote_db=0,
remote_password=None,
visualize=False,
video_generation_envs=[],
report=None,
verbose=False):
# Test Env folder Paths
self.test_env_folder = test_env_folder
......@@ -83,15 +86,15 @@ class FlatlandRemoteEvaluationService:
# Logging and Reporting related vars
self.verbose = verbose
self.report = report
# Communication Protocol Related vars
self.namespace = "flatland-rl"
self.service_id = flatland_rl_service_id
self.command_channel = "{}::{}::commands".format(
self.namespace,
self.service_id
)
self.namespace,
self.service_id
)
# Message Broker related vars
self.remote_host = remote_host
self.remote_port = remote_port
......@@ -114,7 +117,7 @@ class FlatlandRemoteEvaluationService:
"normalized_reward": 0.0
}
}
# RailEnv specific variables
self.env = False
self.env_renderer = False
......@@ -156,7 +159,7 @@ class FlatlandRemoteEvaluationService:
   ├── .......
   ├── .......
└── Level_99.pkl
"""
"""
env_paths = sorted(glob.glob(
os.path.join(
self.test_env_folder,
......@@ -179,16 +182,16 @@ class FlatlandRemoteEvaluationService:
"""
if self.verbose or self.report:
print("Attempting to connect to redis server at {}:{}/{}".format(
self.remote_host,
self.remote_port,
self.remote_db))
self.remote_host,
self.remote_port,
self.remote_db))
self.redis_pool = redis.ConnectionPool(
host=self.remote_host,
port=self.remote_port,
db=self.remote_db,
password=self.remote_password
)
host=self.remote_host,
port=self.remote_port,
db=self.remote_db,
password=self.remote_password
)
def get_redis_connection(self):
"""
......@@ -200,13 +203,13 @@ class FlatlandRemoteEvaluationService:
redis_conn.ping()
except Exception as e:
raise Exception(
"Unable to connect to redis server at {}:{} ."
"Are you sure there is a redis-server running at the "
"specified location ?".format(
self.remote_host,
self.remote_port
)
)
"Unable to connect to redis server at {}:{} ."
"Are you sure there is a redis-server running at the "
"specified location ?".format(
self.remote_host,
self.remote_port
)
)
return redis_conn
def _error_template(self, payload):
......@@ -220,8 +223,8 @@ class FlatlandRemoteEvaluationService:
return _response
@timeout_decorator.timeout(
PER_STEP_TIMEOUT,
use_signals=use_signals_in_timeout) # timeout for each command
PER_STEP_TIMEOUT,
use_signals=use_signals_in_timeout) # timeout for each command
def _get_next_command(self, _redis):
"""
A low level wrapper for obtaining the next command from a
......@@ -231,7 +234,7 @@ class FlatlandRemoteEvaluationService:
"""
command = _redis.brpop(self.command_channel)[1]
return command
def get_next_command(self):
"""
A helper function to obtain the next command, which transparently
......@@ -246,18 +249,18 @@ class FlatlandRemoteEvaluationService:
print("Command Service: ", command)
except timeout_decorator.timeout_decorator.TimeoutError:
raise Exception(
"Timeout in step {} of simulation {}".format(
self.current_step,
self.simulation_count
))