From e65e3d7111f4b05525ca426f460aad5abe2276d6 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 1 Oct 2019 10:59:06 +0200 Subject: [PATCH] initialize grid in global observation with -1 since 0 is a valid direction and add unit test --- flatland/envs/observations.py | 35 +++++++++-------- tests/test_global_observation.py | 64 ++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 tests/test_global_observation.py diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index de9ee2a4..6f771eb4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -508,14 +508,14 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. + - A 3D array (map_height, map_width, 4) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and diretion + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets. - - - A 3D array (map_height, map_width, 4) wtih - - first channel containing the agents position and direction - - second channel containing the other agents positions and diretions - - third channel containing agent malfunctions - - fourth channel containing agent fractional speeds """ def __init__(self): @@ -535,22 +535,21 @@ class GlobalObsForRailEnv(ObservationBuilder): self.rail_obs[i, j] = np.array(bitlist) def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - agents = self.env.agents - agent = agents[handle] + obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1 - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = agents[handle].direction + agent = self.env.agents[handle] + obs_agents_state[agent.position][0] = agent.direction obs_targets[agent.target][0] = 1 - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1] = agent2.direction - obs_targets[agent2.target][1] = 1 - obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction'] - obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed'] + for i in range(len(self.env.agents)): + other_agent = self.env.agents[i] + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_targets[other_agent.target][1] = 1 + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py new file mode 100644 index 00000000..7213560f --- /dev/null +++ b/tests/test_global_observation.py @@ -0,0 +1,64 @@ +import numpy as np + +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator + + +def test_get_global_observation(): + np.random.seed(1) + number_of_agents = 20 + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=25, + # Number of cities in map (where train stations are) + num_intersections=10, + # Number of intersections (no start / target) + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=4, # Proximity of stations to city center + num_neighb=4, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=GlobalObsForRailEnv()) + + obs, all_rewards, done, _ = env.step({0: 0}) + + for i in range(len(env.agents)): + obs_agents_state = obs[i][1] + obs_targets = obs[i][2] + + nr_agents = np.count_nonzero(obs_targets[:, :, 0]) + nr_agents_other = np.count_nonzero(obs_targets[:, :, 1]) + assert nr_agents == 1 + assert nr_agents_other == (number_of_agents - 1) + + # since the array is initialized with -1 add one in order to used np.count_nonzero + obs_agents_state += 1 + obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0]) + obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1]) + obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2]) + obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3]) + assert obs_agents_state_0 == 1 + assert obs_agents_state_1 == (number_of_agents - 1) + assert obs_agents_state_2 == number_of_agents + assert obs_agents_state_3 == number_of_agents + -- GitLab