Commit ad51b49f authored by u214892's avatar u214892
Browse files

#188 approaching...

parent 7ac889fb
Pipeline #2334 failed with stages
in 60 minutes
......@@ -574,11 +574,16 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
# TODO can we do this more elegantly?
for r in range(self.env.height):
for c in range(self.env.width):
obs_agents_state[(r, c)][4] = 0
obs_agents_state[_agent_initial_position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
other_agent:EnvAgent = self.env.agents[i]
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
......@@ -586,12 +591,16 @@ class GlobalObsForRailEnv(ObservationBuilder):
obs_targets[other_agent.target][1] = 1
# third to fifth channel only if different agent and in the grid
if i != handle and other_agent.position is not None:
obs_agents_state[other_agent.position][1] = other_agent.direction
# second to fourth channel only if in the grid
if other_agent.position is not None:
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
# fifth channel: all ready to depart on this position
if other_agent.status == RailAgentStatus.READY_TO_DEPART:
obs_agents_state[other_agent.initial_position][4] += 1
return self.rail_obs, obs_agents_state, obs_targets
......
......@@ -77,3 +77,53 @@ def test_get_global_observation():
assert np.isclose(obs_agents_state[(r, c)][0], -1), \
"agent {} in status {} at {} expected contain -1 found {}" \
.format(i, agent.status, (r, c), obs_agents_state[(r, c)][0])
# test second channel of obs_agents_state: direction at other agents position
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if i == other_i:
continue
if other_agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and (
r, c) == other_agent.position:
assert np.isclose(obs_agents_state[(r, c)][1], other_agent.direction), \
"agent {} in status {} at {} should see other agent with direction {}, found = {}" \
.format(i, agent.status, (r, c), other_agent.direction, obs_agents_state[(r, c)][1])
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][1], -1), \
"agent {} in status {} at {} should see no other agent direction (-1), found = {}" \
.format(i, agent.status, (r, c), obs_agents_state[(r, c)][1])
# test third and fourth channel of obs_agents_state: malfunction and speed of own or other agent in the grid
for r in range(env.height):
for c in range(env.width):
has_agent = False
for other_i, other_agent in enumerate(env.agents):
if other_agent.status in [RailAgentStatus.ACTIVE,
RailAgentStatus.DONE] and other_agent.position == (r, c):
assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \
"agent {} in status {} at {} should see agent malfunction {}, found = {}" \
.format(i, agent.status, (r, c), other_agent.malfunction_data['malfunction'],
obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_data['speed'])
has_agent = True
if not has_agent:
assert np.isclose(obs_agents_state[(r, c)][2], -1), \
"agent {} in status {} at {} should see no agent malfunction (-1), found = {}" \
.format(i, agent.status, (r, c), obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], -1), \
"agent {} in status {} at {} should see no agent speed (-1), found = {}" \
.format(i, agent.status, (r, c), obs_agents_state[(r, c)][3])
# test fifth channel of obs_agents_state: number of agents ready to depart in to this cell
for r in range(env.height):
for c in range(env.width):
count = 0
for other_i, other_agent in enumerate(env.agents):
if other_agent.status == RailAgentStatus.READY_TO_DEPART and other_agent.initial_position == (r, c):
count += 1
assert np.isclose(obs_agents_state[(r, c)][4], count), \
"agent {} in status {} at {} should see {} agents ready to depart, found{}" \
.format(i, agent.status, (r, c), count, obs_agents_state[(r, c)][4])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment