Skip to content
Snippets Groups Projects
Commit cc5fee4c authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

parents de0eff03 368b144e
No related branches found
No related tags found
No related merge requests found
This repository allows to run Rail Environment multi agent training with the RLLib Library. This repository allows to run Rail Environment multi agent training with the RLLib Library.
## Installation: ## Installation:
To run scripts of this repository, the deep learning library tensorflow should be installed, along with the following packages:
```sh ```sh
pip install ray pip install gym ray==0.7.0 gin-config opencv-python lz4 psutil
pip install gin-config
``` ```
To start a training with different parameters, you can create a folder containing a config.gin file (see example in `experiment_configs/config_example/config.gin`. To start a training with different parameters, you can create a folder containing a config.gin file (see example in `experiment_configs/config_example/config.gin`.
...@@ -57,7 +58,7 @@ More documentation on how to use gin-config can be found on the github repositor ...@@ -57,7 +58,7 @@ More documentation on how to use gin-config can be found on the github repositor
## Run an example: ## Run an example:
To start a training on a 20X20 map, with different numbers of agents initialized at each episode, on can run the train_experiment.py script: To start a training on a 20X20 map, with different numbers of agents initialized at each episode, on can run the train_experiment.py script:
``` ```
python baselines/RLLib_training/train_experiment.py python RLLib_training/train_experiment.py
``` ```
This will load the gin config file in the folder `experiment_configs/config_examples`. This will load the gin config file in the folder `experiment_configs/config_examples`.
......
...@@ -8,11 +8,11 @@ import numpy as np ...@@ -8,11 +8,11 @@ import numpy as np
np.random.seed(2) np.random.seed(2)
""" """
file_name = "./railway/complex_scene.pkl" file_name = "../torch_training/railway/complex_scene.pkl"
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, height=20,
rail_generator=rail_from_file(file_name), rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())) obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width x_dim = env.width
y_dim = env.height y_dim = env.height
...@@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP ...@@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles() handle = env.get_agent_handles()
n_trials = 1 n_trials = 1
max_steps = 3 * (env.height + env.width) max_steps = 100 * (env.height + env.width)
record_images = True record_images = False
agent = OrderedAgent() agent = OrderedAgent()
action_dict = dict() action_dict = dict()
...@@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1): ...@@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if done[a]: if done[a]:
acting_agent += 1 acting_agent += 1
print(acting_agent)
if a == acting_agent: if a == acting_agent:
action = agent.act(obs[a], eps=0) action = agent.act(obs[a], eps=0)
else: else:
......
...@@ -18,7 +18,7 @@ class OrderedAgent: ...@@ -18,7 +18,7 @@ class OrderedAgent:
min_dist = min_lt(distance, 0) min_dist = min_lt(distance, 0)
min_direction = np.where(distance == min_dist) min_direction = np.where(distance == min_dist)
if len(min_direction[0]) > 1: if len(min_direction[0]) > 1:
return min_direction[0][0] + 1 return min_direction[0][-1] + 1
return min_direction[0] + 1 return min_direction[0] + 1
def step(self, memories): def step(self, memories):
......
No preview for this file type
No preview for this file type
...@@ -16,7 +16,7 @@ from utils.observation_utils import normalize_observation ...@@ -16,7 +16,7 @@ from utils.observation_utils import normalize_observation
random.seed(3) random.seed(3)
np.random.seed(2) np.random.seed(2)
"""
file_name = "./railway/complex_scene.pkl" file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10, env = RailEnv(width=10,
height=20, height=20,
...@@ -27,9 +27,9 @@ y_dim = env.height ...@@ -27,9 +27,9 @@ y_dim = env.height
""" """
x_dim = 10 # np.random.randint(8, 20) x_dim = 18 # np.random.randint(8, 20)
y_dim = 10 # np.random.randint(8, 20) y_dim = 14 # np.random.randint(8, 20)
n_agents = 5 # np.random.randint(3, 8) n_agents = 7 # np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3) n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim)) min_dist = int(0.75 * min(x_dim, y_dim))
...@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, ...@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents) number_of_agents=n_agents)
env.reset(True, True) env.reset(True, True)
"""
tree_depth = 3 tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1): ...@@ -53,7 +53,7 @@ for i in range(tree_depth + 1):
state_size = num_features_per_node * nr_nodes state_size = num_features_per_node * nr_nodes
action_size = 5 action_size = 5
n_trials = 1 n_trials = 10
observation_radius = 10 observation_radius = 10
max_steps = int(3 * (env.height + env.width)) max_steps = int(3 * (env.height + env.width))
eps = 1. eps = 1.
...@@ -70,7 +70,7 @@ action_prob = [0] * action_size ...@@ -70,7 +70,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents() agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0) agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in: with path(torch_training.Nets, "avoid_checkpoint46200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in)) agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False record_images = False
...@@ -98,12 +98,11 @@ for trials in range(1, n_trials + 1): ...@@ -98,12 +98,11 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0) action = agent.act(agent_obs[a], eps=0)
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step # Environment step
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
if done['__all__']: if done['__all__']:
break break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment