Commit ab22a0c6 authored by u214892's avatar u214892
Browse files

observations only for active agents

parent bc0ee393
Pipeline #2324 passed with stages
in 43 minutes and 4 seconds
......@@ -91,3 +91,17 @@ class Environment:
function.
"""
raise NotImplementedError()
def is_active_handle(self,h):
"""
Is the agent active and thus observable?
Parameters
----------
h: int agent handle
Returns
-------
"""
return True
......@@ -24,7 +24,7 @@ class ObservationBuilder:
self.env = None
def set_env(self, env: Environment):
self.env = env
self.env: Environment = env
def reset(self):
"""
......@@ -52,7 +52,8 @@ class ObservationBuilder:
if handles is None:
handles = []
for h in handles:
observations[h] = self.get(h)
if self.env.is_active_handle(h):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
......
......@@ -512,6 +512,9 @@ class RailEnv(Environment):
# step penalty if not moving (stopped now or before)
self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed']
def is_active_handle(self, h):
return self.agents[h].status == RailAgentStatus.ACTIVE
def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent):
"""
......
......@@ -29,6 +29,9 @@ def test_global_obs():
global_obs = env.reset()
# we have to take step for the agent to enter the grid.
global_obs, _, _, _ = env.step({0: RailEnvActions.MOVE_FORWARD})
assert (global_obs[0][0].shape == rail_map.shape + (16,))
rail_map_recons = np.zeros_like(rail_map)
......
......@@ -3,6 +3,7 @@ import random
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
......@@ -1549,6 +1550,11 @@ def test_rail_env_action_required_info():
obs_builder_object=GlobalObsForRailEnv())
env_renderer = RenderTool(env_always_action, gl="PILSVG", )
for agent in env_always_action.agents:
agent.status=RailAgentStatus.ACTIVE
for agent in env_only_if_action_required.agents:
agent.status=RailAgentStatus.ACTIVE
for step in range(100):
print("step {}".format(step))
......
......@@ -75,12 +75,15 @@ def test_malfunction_process():
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
for agent in env.agents:
agent.status = RailAgentStatus.ACTIVE
agent_halts = 0
total_down_time = 0
agent_old_position = env.agents[0].position
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
......@@ -105,7 +108,8 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed
assert agent_halts == 20
......
import numpy as np
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
......@@ -40,7 +40,7 @@ def test_get_global_observation():
number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
obs, all_rewards, done, _ = env.step({0: 0})
obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
for i in range(len(env.agents)):
obs_agents_state = obs[i][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment