diff --git a/utils/observation_utils.py b/utils/observation_utils.py index 5e01121fe84c24d8d5a46d92bb30578e1dcda2b0..4c4efa2405a01499d067e68cd1e305f40a6e11a7 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -31,7 +31,7 @@ def min_lt(seq, val): return min -def norm_obs_clip(obs, clip_min=-1, clip_max=1): +def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0): """ This function returns the difference between min and max value of an observation :param obs: Observation that should be normalized @@ -39,8 +39,12 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): :param clip_max: max value where observation will be clipped :return: returnes normalized and clipped observatoin """ - max_obs = max(1, max_lt(obs, 1000)) - min_obs = min(max_obs, min_lt(obs, 0)) + if fixed_radius > 0: + max_obs = fixed_radius + else: + max_obs = max(1, max_lt(obs, 1000)) + 1 + + min_obs = 0 # min(max_obs, min_lt(obs, 0)) if max_obs == min_obs: return np.clip(np.array(obs) / max_obs, clip_min, clip_max)