Commit e65e3d71 authored by u229589's avatar u229589
Browse files

initialize grid in global observation with -1 since 0 is a valid direction and add unit test

parent 9c91a903
......@@ -508,14 +508,14 @@ class GlobalObsForRailEnv(ObservationBuilder):
- transition map array with dimensions (env.height, env.width, 16),\
assuming 16 bits encoding of transitions.
- A 3D array (map_height, map_width, 4) with
- first channel containing the agents position and direction
- second channel containing the other agents positions and diretion
- third channel containing agent/other agent malfunctions
- fourth channel containing agent/other agent fractional speeds
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
target and the positions of the other agents targets.
- A 3D array (map_height, map_width, 4) wtih
- first channel containing the agents position and direction
- second channel containing the other agents positions and diretions
- third channel containing agent malfunctions
- fourth channel containing agent fractional speeds
"""
def __init__(self):
......@@ -535,22 +535,21 @@ class GlobalObsForRailEnv(ObservationBuilder):
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
agents = self.env.agents
agent = agents[handle]
obs_agents_state = np.zeros((self.env.height, self.env.width, 4)) - 1
agent_pos = agents[handle].position
obs_agents_state[agent_pos][0] = agents[handle].direction
agent = self.env.agents[handle]
obs_agents_state[agent.position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][1] = agent2.direction
obs_targets[agent2.target][1] = 1
obs_agents_state[agents[i].position][2] = agents[i].malfunction_data['malfunction']
obs_agents_state[agents[i].position][3] = agents[i].speed_data['speed']
for i in range(len(self.env.agents)):
other_agent = self.env.agents[i]
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_targets[other_agent.target][1] = 1
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
return self.rail_obs, obs_agents_state, obs_targets
......
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
def test_get_global_observation():
np.random.seed(1)
number_of_agents = 20
stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction
}
speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=25,
# Number of cities in map (where train stations are)
num_intersections=10,
# Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
num_neighb=4,
# Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
enhance_intersection=False
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
obs, all_rewards, done, _ = env.step({0: 0})
for i in range(len(env.agents)):
obs_agents_state = obs[i][1]
obs_targets = obs[i][2]
nr_agents = np.count_nonzero(obs_targets[:, :, 0])
nr_agents_other = np.count_nonzero(obs_targets[:, :, 1])
assert nr_agents == 1
assert nr_agents_other == (number_of_agents - 1)
# since the array is initialized with -1 add one in order to used np.count_nonzero
obs_agents_state += 1
obs_agents_state_0 = np.count_nonzero(obs_agents_state[:, :, 0])
obs_agents_state_1 = np.count_nonzero(obs_agents_state[:, :, 1])
obs_agents_state_2 = np.count_nonzero(obs_agents_state[:, :, 2])
obs_agents_state_3 = np.count_nonzero(obs_agents_state[:, :, 3])
assert obs_agents_state_0 == 1
assert obs_agents_state_1 == (number_of_agents - 1)
assert obs_agents_state_2 == number_of_agents
assert obs_agents_state_3 == number_of_agents
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment