Commit a9a7b281 authored by dzorlu's avatar dzorlu

submit. change test script to keep env alive

parent 5efc909d
......@@ -264,8 +264,7 @@ def main():
#
environment = make_environment(num_actions=NUMBER_OF_DISCRETE_ACTIONS, k_means_path=model_dir)
spec = specs.make_environment_spec(environment)
actor = load_actor(spec) #initiate here to keep the state shared across threads
environment.close()
actor = load_actor(spec) #initiate here to keep the state shared across thread
agent = AGENT_TO_TEST()
assert isinstance(agent, MineRLAgentBase)
......@@ -275,11 +274,12 @@ def main():
assert EVALUATION_THREAD_COUNT > 0
# Create the parallel envs (sequentially to prevent issues!)
envs = list()
for _ in range(EVALUATION_THREAD_COUNT):
envs = [environment]
for _ in range(EVALUATION_THREAD_COUNT-1):
environment = make_environment(num_actions=NUMBER_OF_DISCRETE_ACTIONS,
k_means_path=model_dir)
envs.append(environment)
print(len(envs))
# Create the parallel envs (sequentially to prevent issues!)
#envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
......
......@@ -55,7 +55,7 @@ MINERL_TRAINING_MAX_INSTANCES = int(os.getenv('MINERL_TRAINING_MAX_INSTANCES', 5
# Round 2: Training timeout is 4 days
MINERL_TRAINING_TIMEOUT = int(os.getenv('MINERL_TRAINING_TIMEOUT_MINUTES', 4*24*60))
# The dataset is available in data/ directory from repository root.
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', '/hdd/minerl')
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', '/data')
# Optional: You can view best effort status of your instances with the help of parser.py
# This will give you current state like number of steps completed, instances launched and so on. Make your you keep a tap on the numbers to avoid breaching any limits.
......@@ -100,6 +100,7 @@ def make_environment(k_means_path: str,
num_actions=num_actions,
dat_loader=dat_loader,
k_means_path=k_means_path,
train=train
),
wrappers.SinglePrecisionWrapper,
])
......@@ -227,6 +228,7 @@ def main():
logger.info("creating environment")
environment = make_environment(num_actions=NUMBER_OF_DISCRETE_ACTIONS,
k_means_path=model_dir,
train=False,
dat_loader=data)
spec = specs.make_environment_spec(environment)
......@@ -251,7 +253,7 @@ def main():
network=network,
target_network=target_network,
demonstration_generator=generator,
demonstration_ratio=0.5,
demonstration_ratio=0.9,
batch_size=8,
samples_per_insert=2,
min_replay_size=1000,
......
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment