Commit 7838b8ae authored by nilabha's avatar nilabha

add explore to apx runs and update results


(cherry picked from commit 3b3f0c1d)
parent 8abfda27
Pipeline #5829 passed with stage
in 4 minutes and 12 seconds
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
......@@ -8,6 +8,7 @@ import os
import pickle
import shelve
from pathlib import Path
import random
import gym
import numpy as np
......@@ -47,6 +48,11 @@ load_algorithms(CUSTOM_ALGORITHMS) # Load algorithms
from collections.abc import Mapping
from copy import deepcopy
# Default terminal state epsilon
# https://github.com/ray-project/ray/blob/master/rllib/agents/dqn/dqn.py
final_epsilon = 0.02
random.seed(1)
def val_replace(mapping):
obj = deepcopy(mapping)
if isinstance(mapping, Mapping):
......@@ -287,6 +293,7 @@ def run(args, parser):
if args.eager:
from tensorflow.python.framework.ops import enable_eager_execution
enable_eager_execution()
config['eager'] = True
cls = get_trainable_cls(args.run)
......@@ -422,6 +429,13 @@ def rollout(agent,
policy_id=policy_id)
a_action = flatten_to_single_ndarray(a_action) # ray 0.8.5
# a_action = _flatten_action(a_action) # tuple actions # ray 0.8.4
# Epsilon-greedy action selection for APEX
if hasattr(agent, '_name'):
if agent._name == "APEX":
if random.random() <= final_epsilon:
a_action = random.choice(np.arange(env.action_space.n))
action_dict[agent_id] = a_action
prev_actions[agent_id] = a_action
action = action_dict
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment