Commit ff83fdeb authored by u229589's avatar u229589
Browse files

Refactoring: add return types for ObservationBuilder.get(self, handle: int =...

Refactoring: add return types for ObservationBuilder.get(self, handle: int = 0) and ObservationBuilder.get_many(self, handles: Optional[List[int]] = None)
parent 15db8d20
Pipeline #2120 failed with stages
in 6 minutes and 9 seconds
......@@ -23,7 +23,7 @@ class SimpleObs(ObservationBuilder):
def reset(self):
return
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> np.ndarray:
observation = handle * np.ones((5,))
return observation
......
......@@ -2,6 +2,7 @@ import getopt
import random
import sys
import time
from typing import List
import numpy as np
......@@ -36,7 +37,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
......@@ -2,7 +2,7 @@ import getopt
import random
import sys
import time
from typing import Optional, List
from typing import Optional, List, Dict
import numpy as np
......@@ -41,7 +41,7 @@ class ObservePredictions(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, np.ndarray]:
'''
Because we do not want to call the predictor seperately for every agent we implement the get_many function
Here we can call the predictor just ones for all the agents and use the predictions to generate our observations
......@@ -69,7 +69,7 @@ class ObservePredictions(TreeObsForRailEnv):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> np.ndarray:
'''
Lets write a simple observation which just indicates whether or not the own predicted path
overlaps with other predicted paths at any time. This is useless for the task of navigation but might
......
import random
import time
from typing import List
import numpy as np
......@@ -33,7 +34,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
......@@ -97,8 +97,8 @@ class DummyObservationBuilder(ObservationBuilder):
def reset(self):
pass
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> bool:
return True
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> bool:
return True
......@@ -2,7 +2,7 @@
Collection of environment-specific ObservationBuilder.
"""
import pprint
from typing import Optional, List
from typing import Optional, List, Dict
import numpy as np
......@@ -46,7 +46,7 @@ class TreeObsForRailEnv(ObservationBuilder):
def reset(self):
self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, List[int]]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......@@ -75,7 +75,7 @@ class TreeObsForRailEnv(ObservationBuilder):
observations[h] = self.get(h)
return observations
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
"""
Computes the current observation for agent `handle` in env
......@@ -534,7 +534,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 4))
agents = self.env.agents
......@@ -600,7 +600,7 @@ class LocalObsForRailEnv(ObservationBuilder):
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
agents = self.env.agents
agent = agents[handle]
......@@ -640,7 +640,7 @@ class LocalObsForRailEnv(ObservationBuilder):
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles: Optional[List[int]] = None):
def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, (np.ndarray, np.ndarray, np.ndarray, np.ndarray)]:
"""
Called whenever an observation has to be computed for the `env` environment, for each agent with handle
in the `handles` list.
......
import random
from typing import Dict
from typing import Dict, List
import numpy as np
......@@ -31,7 +31,7 @@ class SingleAgentNavigationObs(TreeObsForRailEnv):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle: int = 0):
def get(self, handle: int = 0) -> List[int]:
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment