diff --git a/tests/test_utils.py b/tests/test_utils.py index 6347bd0f5048350c099ba2568dac7caba74baf2d..967615833c4790d67af80a1d75e35174e2ff5e5a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ """Test Utils.""" -from typing import List, Tuple +from typing import List, Tuple, Optional from attr import attrs, attrib @@ -13,6 +13,7 @@ class Replay(object): direction = attrib(type=Grid4TransitionsEnum) action = attrib(type=RailEnvActions) malfunction = attrib(default=0, type=int) + penalty = attrib(default=None, type=Optional[float]) @attrs