diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index de9ee2a45ebdeabec5202be4f12593a82b4e20e4..6f771eb4d5ba73d0f83949907215ca5280659c28 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -508,14 +508,14 @@ class GlobalObsForRailEnv(ObservationBuilder): - transition map array with dimensions (env.height, env.width, 16),\ assuming 16 bits encoding of transitions. + - A 3D array (map_height, map_width, 4) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and diretion + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ target and the positions of the other agents targets. - - - A 3D array (map_height, map_width, 4) wtih - - first channel containing the agents position and direction - - second channel containing the other agents positions and diretions - - third channel containing agent malfunctions - - fourth channel containing agent fractional speeds """ def __init__(self): @@ -535,22 +535,21 @@ class GlobalObsForRailEnv(ObservationBuilder): self.rail_obs[i, j] = np.array(bitlist) def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - agents = self.env.agents - agent = agents[handle] + obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1 - agent_pos = agents[handle].position - obs_agents_state[agent_pos][0] = agents[handle].direction + agent = self.env.agents[handle] + obs_agents_state[agent.position][0] = agent.direction obs_targets[agent.target][0] = 1 - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - obs_agents_state[agent2.position][1] = agent2.direction - obs_targets[agent2.target][1] = 1 - obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction'] - obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed'] + for i in range(len(self.env.agents)): + other_agent = self.env.agents[i] + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_targets[other_agent.target][1] = 1 + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] return self.rail_obs, obs_agents_state, obs_targets diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index d2663916a17a70597d10e489da7aead4f8932dc4..0d6d309765690b1f95c681d7d109a13071d7f86b 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -41,7 +41,9 @@ def test_global_obs(): # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert (np.sum(rail_map * global_obs[0][1][:, :, :4].sum(2)) > 0) + obs_agents_state = global_obs[0][1] + obs_agents_state = obs_agents_state + 1 + assert (np.sum(rail_map * obs_agents_state[:, :, :4].sum(2)) > 0) def _step_along_shortest_path(env, obs_builder, rail): diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py new file mode 100644 index 0000000000000000000000000000000000000000..7213560f9e9873ea4488b96d30223bab8128b37b --- /dev/null +++ b/tests/test_global_observation.py @@ -0,0 +1,64 @@ +import numpy as np + +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import sparse_rail_generator +from flatland.envs.schedule_generators import sparse_schedule_generator + + +def test_get_global_observation(): + np.random.seed(1) + number_of_agents = 20 + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 30, # Rate of malfunction occurence + 'min_duration': 3, # Minimal duration of malfunction + 'max_duration': 20 # Max duration of malfunction + } + + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=25, + # Number of cities in map (where train stations are) + num_intersections=10, + # Number of intersections (no start / target) + num_trainstations=50, # Number of possible start/targets on map + min_node_dist=3, # Minimal distance of nodes + node_radius=4, # Proximity of stations to city center + num_neighb=4, + # Number of connections to other cities/intersections + seed=15, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator + obs_builder_object=GlobalObsForRailEnv()) + + obs, all_rewards, done, _ = env.step({0: 0}) + + for i in range(len(env.agents)): + obs_agents_state = obs[i][1] + obs_targets = obs[i][2] + + nr_agents = np.count_nonzero(obs_targets[:, :, 0]) + nr_agents_other = np.count_nonzero(obs_targets[:, :, 1]) + assert nr_agents == 1 + assert nr_agents_other == (number_of_agents - 1) + + # since the array is initialized with -1 add one in order to used np.count_nonzero + obs_agents_state += 1 + obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0]) + obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1]) + obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2]) + obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3]) + assert obs_agents_state_0 == 1 + assert obs_agents_state_1 == (number_of_agents - 1) + assert obs_agents_state_2 == number_of_agents + assert obs_agents_state_3 == number_of_agents +