Skip to content
Snippets Groups Projects
Commit 74ea79a4 authored by u214892's avatar u214892
Browse files

renamed action_required instead of entering and remove two if False conditions...

renamed action_required instead of entering and remove two if False conditions in sparse_rail_generator
parent 973ced69
No related branches found
No related tags found
No related merge requests found
...@@ -189,18 +189,18 @@ This action is then executed when a step to the next cell is valid. For example ...@@ -189,18 +189,18 @@ This action is then executed when a step to the next cell is valid. For example
- Agents can make observations at any time step. Make sure to discard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation. - Agents can make observations at any time step. Make sure to discard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation.
- The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell
In your controller, you can check whether an agent is entering by checking `info`: In your controller, you can check whether an agent requires an action by checking `info`:
``` ```
obs, rew, done, info = env.step(actions) obs, rew, done, info = env.step(actions)
... ...
action_dict = dict() action_dict = dict()
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
if info['entering'][a] && info['malfunction'][a] == 0 &&: if info['action_required'][a] and info['malfunction'][a] == 0:
action_dict.update({a: ...}) action_dict.update({a: ...})
``` ```
Notice that `info['entering'][a]` does not mean that the action will have an effect: Notice that `info['action_required'][a]` does not mean that the action will have an effect:
if the next cell is blocked or the agent is malfunctioning, the action cannot be performed. if the next cell is blocked or the agent breaks down, the action cannot be performed and an action will be required again in the next step.
## Rail Generators and Schedule Generators ## Rail Generators and Schedule Generators
The separation between rail generator and schedule generator reflects the organisational separation in the railway domain The separation between rail generator and schedule generator reflects the organisational separation in the railway domain
......
...@@ -314,7 +314,7 @@ class RailEnv(Environment): ...@@ -314,7 +314,7 @@ class RailEnv(Environment):
if self.dones["__all__"]: if self.dones["__all__"]:
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
info_dict = { info_dict = {
'entering': {i: False for i in range(self.get_num_agents())}, 'action_required': {i: False for i in range(self.get_num_agents())},
'malfunction': {i: 0 for i in range(self.get_num_agents())}, 'malfunction': {i: 0 for i in range(self.get_num_agents())},
'speed': {i: 0 for i in range(self.get_num_agents())} 'speed': {i: 0 for i in range(self.get_num_agents())}
} }
...@@ -457,7 +457,7 @@ class RailEnv(Environment): ...@@ -457,7 +457,7 @@ class RailEnv(Environment):
for k in self.dones.keys(): for k in self.dones.keys():
self.dones[k] = True self.dones[k] = True
entering_agents = { action_required_agents = {
i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents()) i: self.agents[i].speed_data['position_fraction'] <= epsilon for i in range(self.get_num_agents())
} }
malfunction_agents = { malfunction_agents = {
...@@ -466,7 +466,7 @@ class RailEnv(Environment): ...@@ -466,7 +466,7 @@ class RailEnv(Environment):
speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())} speed_agents = {i: self.agents[i].speed_data['speed'] for i in range(self.get_num_agents())}
info_dict = { info_dict = {
'entering': entering_agents, 'action_required': action_required_agents,
'malfunction': malfunction_agents, 'malfunction': malfunction_agents,
'speed': speed_agents 'speed': speed_agents
} }
......
...@@ -629,13 +629,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -629,13 +629,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
available_nodes_full = np.delete(available_nodes_full, delete_idx, 0) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
# Priority city to intersection connections # Priority city to intersection connections
if False and current_node < num_cities and len(available_intersections) > 0: if current_node < num_cities and len(available_intersections) > 0:
available_nodes = available_intersections available_nodes = available_intersections
delete_idx = np.where(available_cities == current_node) delete_idx = np.where(available_cities == current_node)
available_cities = np.delete(available_cities, delete_idx, 0) available_cities = np.delete(available_cities, delete_idx, 0)
# Priority intersection to city connections # Priority intersection to city connections
elif False and current_node >= num_cities and len(available_cities) > 0: elif current_node >= num_cities and len(available_cities) > 0:
available_nodes = available_cities available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node) delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0) available_intersections = np.delete(available_intersections, delete_idx, 0)
......
...@@ -29,7 +29,7 @@ def test_sparse_rail_generator(): ...@@ -29,7 +29,7 @@ def test_sparse_rail_generator():
# TODO test assertions! # TODO test assertions!
def test_rail_env_entering_info(): def test_rail_env_action_required_info():
np.random.seed(0) np.random.seed(0)
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train 1. / 2.: 0.25, # Fast freight train
...@@ -54,7 +54,7 @@ def test_rail_env_entering_info(): ...@@ -54,7 +54,7 @@ def test_rail_env_entering_info():
number_of_agents=10, number_of_agents=10,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
np.random.seed(0) np.random.seed(0)
env_only_if_entering = RailEnv(width=50, env_only_if_action_required = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=10, num_intersections=10,
...@@ -79,28 +79,28 @@ def test_rail_env_entering_info(): ...@@ -79,28 +79,28 @@ def test_rail_env_entering_info():
print("step {}".format(step)) print("step {}".format(step))
action_dict_always_action = dict() action_dict_always_action = dict()
action_dict_only_if_entering = dict() action_dict_only_if_action_required = dict()
# Chose an action for each agent in the environment # Chose an action for each agent in the environment
for a in range(env_always_action.get_num_agents()): for a in range(env_always_action.get_num_agents()):
action = np.random.choice(np.arange(4)) action = np.random.choice(np.arange(4))
action_dict_always_action.update({a: action}) action_dict_always_action.update({a: action})
if step == 0 or info_only_if_entering['entering'][a]: if step == 0 or info_only_if_action_required['action_required'][a]:
action_dict_only_if_entering.update({a: action}) action_dict_only_if_action_required.update({a: action})
else: else:
print("[{}] not entering {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data)) print("[{}] not action_required {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data))
obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step( obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
action_dict_always_action) action_dict_always_action)
obs_only_if_entering, rewards_only_if_entering, done_only_if_entering, info_only_if_entering = env_only_if_entering.step( obs_only_if_action_required, rewards_only_if_action_required, done_only_if_action_required, info_only_if_action_required = env_only_if_action_required.step(
action_dict_only_if_entering) action_dict_only_if_action_required)
for a in range(env_always_action.get_num_agents()): for a in range(env_always_action.get_num_agents()):
assert len(obs_always_action[a]) == len(obs_only_if_entering[a]) assert len(obs_always_action[a]) == len(obs_only_if_action_required[a])
for i in range(len(obs_always_action[a])): for i in range(len(obs_always_action[a])):
assert np.array_equal(obs_always_action[a][i], obs_only_if_entering[a][i]) assert np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i])
assert np.array_equal(rewards_always_action[a], rewards_only_if_entering[a]) assert np.array_equal(rewards_always_action[a], rewards_only_if_action_required[a])
assert np.array_equal(done_always_action[a], done_only_if_entering[a]) assert np.array_equal(done_always_action[a], done_only_if_action_required[a])
assert info_always_action['entering'][a] == info_only_if_entering['entering'][a] assert info_always_action['action_required'][a] == info_only_if_action_required['action_required'][a]
env_renderer.render_env(show=True, show_observations=False, show_predictions=False) env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment