diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 0c97b186a9331f185cf1a1d3f99685581cb551f7..fda5b530fb4f473915b0f043d901d3e8cb4fe727 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -29,7 +29,7 @@ def min_lt(seq, val): return min -def norm_obs_clip(obs, clip_min=-1, clip_max=1): +def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0): """ This function returns the difference between min and max value of an observation :param obs: Observation that should be normalized @@ -37,8 +37,12 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): :param clip_max: max value where observation will be clipped :return: returnes normalized and clipped observatoin """ - max_obs = max(1, max_lt(obs, 1000)) - min_obs = min(max_obs, min_lt(obs, 0)) + if fixed_radius > 0: + max_obs = fixed_radius + else: + max_obs = max(1, max_lt(obs, 1000)) + + min_obs = 0 # min(max_obs, min_lt(obs, 0)) if max_obs == min_obs: return np.clip(np.array(obs) / max_obs, clip_min, clip_max)