Skip to content
Snippets Groups Projects
Commit cc5fee4c authored by Guillaume Mollard's avatar Guillaume Mollard
Browse files

Merge branch 'master' of gitlab.aicrowd.com:flatland/baselines

parents de0eff03 368b144e
No related branches found
No related tags found
No related merge requests found
This repository allows to run Rail Environment multi agent training with the RLLib Library.
## Installation:
To run scripts of this repository, the deep learning library tensorflow should be installed, along with the following packages:
```sh
pip install ray
pip install gin-config
pip install gym ray==0.7.0 gin-config opencv-python lz4 psutil
```
To start a training with different parameters, you can create a folder containing a config.gin file (see example in `experiment_configs/config_example/config.gin`.
......@@ -57,7 +58,7 @@ More documentation on how to use gin-config can be found on the github repositor
## Run an example:
To start a training on a 20X20 map, with different numbers of agents initialized at each episode, on can run the train_experiment.py script:
```
python baselines/RLLib_training/train_experiment.py
python RLLib_training/train_experiment.py
```
This will load the gin config file in the folder `experiment_configs/config_examples`.
......
......@@ -8,11 +8,11 @@ import numpy as np
np.random.seed(2)
"""
file_name = "./railway/complex_scene.pkl"
file_name = "../torch_training/railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
rail_generator=rail_from_file(file_name),
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()))
x_dim = env.width
y_dim = env.height
......@@ -38,8 +38,8 @@ observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestP
env_renderer = RenderTool(env, gl="PILSVG", )
handle = env.get_agent_handles()
n_trials = 1
max_steps = 3 * (env.height + env.width)
record_images = True
max_steps = 100 * (env.height + env.width)
record_images = False
agent = OrderedAgent()
action_dict = dict()
......@@ -63,6 +63,7 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
if done[a]:
acting_agent += 1
print(acting_agent)
if a == acting_agent:
action = agent.act(obs[a], eps=0)
else:
......
......@@ -18,7 +18,7 @@ class OrderedAgent:
min_dist = min_lt(distance, 0)
min_direction = np.where(distance == min_dist)
if len(min_direction[0]) > 1:
return min_direction[0][0] + 1
return min_direction[0][-1] + 1
return min_direction[0] + 1
def step(self, memories):
......
No preview for this file type
No preview for this file type
......@@ -16,7 +16,7 @@ from utils.observation_utils import normalize_observation
random.seed(3)
np.random.seed(2)
"""
file_name = "./railway/complex_scene.pkl"
env = RailEnv(width=10,
height=20,
......@@ -27,9 +27,9 @@ y_dim = env.height
"""
x_dim = 10 # np.random.randint(8, 20)
y_dim = 10 # np.random.randint(8, 20)
n_agents = 5 # np.random.randint(3, 8)
x_dim = 18 # np.random.randint(8, 20)
y_dim = 14 # np.random.randint(8, 20)
n_agents = 7 # np.random.randint(3, 8)
n_goals = n_agents + np.random.randint(0, 3)
min_dist = int(0.75 * min(x_dim, y_dim))
......@@ -41,7 +41,7 @@ env = RailEnv(width=x_dim,
obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
number_of_agents=n_agents)
env.reset(True, True)
"""
tree_depth = 3
observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -53,7 +53,7 @@ for i in range(tree_depth + 1):
state_size = num_features_per_node * nr_nodes
action_size = 5
n_trials = 1
n_trials = 10
observation_radius = 10
max_steps = int(3 * (env.height + env.width))
eps = 1.
......@@ -70,7 +70,7 @@ action_prob = [0] * action_size
agent_obs = [None] * env.get_num_agents()
agent_next_obs = [None] * env.get_num_agents()
agent = Agent(state_size, action_size, "FC", 0)
with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in:
with path(torch_training.Nets, "avoid_checkpoint46200.pth") as file_in:
agent.qnetwork_local.load_state_dict(torch.load(file_in))
record_images = False
......@@ -98,12 +98,11 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()):
action = agent.act(agent_obs[a], eps=0)
action_dict.update({a: action})
# Environment step
next_obs, all_rewards, done, _ = env.step(action_dict)
for a in range(env.get_num_agents()):
agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10)
if done['__all__']:
break
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment