Commit 9cfcb266 authored by dzorlu's avatar dzorlu

checkpoint state. fix discount bug

parent 33eaf8ed
......@@ -55,7 +55,7 @@ MINERL_TRAINING_MAX_INSTANCES = int(os.getenv('MINERL_TRAINING_MAX_INSTANCES', 5
# Round 2: Training timeout is 4 days
MINERL_TRAINING_TIMEOUT = int(os.getenv('MINERL_TRAINING_TIMEOUT_MINUTES', 4*24*60))
# The dataset is available in data/ directory from repository root.
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', '/data/minerl')
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT', '/hdd/minerl')
# Optional: You can view best effort status of your instances with the help of parser.py
# This will give you current state like number of steps completed, instances launched and so on. Make your you keep a tap on the numbers to avoid breaching any limits.
......@@ -100,10 +100,7 @@ def _nested_stack(sequence: List[Any]):
return tree.map_structure(lambda *x: np.stack(x), *sequence)
class DemonstrationRecorder:
"""Records demonstrations.
A demonstration is a (observation, action, reward, discount) tuple where
every element is a numpy array corresponding to a full episode.
"""Generate (TimeStep, action) tuples
"""
def __init__(self, environment: dm_env.Environment):
......@@ -129,8 +126,7 @@ class DemonstrationRecorder:
discrete_action = self.map_action(action)
self._prev_action = discrete_action
self._prev_reward = reward
self._episode.append((new_timestep.observation, discrete_action, reward,
np.array(new_timestep.discount or 0, np.float32)))
return (new_timestep, discrete_action)
def _augment_observation(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
ovar = OVAR(observation=timestep.observation['pov'].astype(np.float32),
......@@ -140,12 +136,10 @@ class DemonstrationRecorder:
return timestep._replace(observation=ovar)
def record_episode(self):
logger.info(f"episode length of {len(self._episode)}")
self._demos.append(_nested_stack(self._episode))
self._reset_episode()
def discard_episode(self):
self._reset_episode()
def _reset_episode(self):
self._episode = []
self._episode_reward = 0
......@@ -176,29 +170,27 @@ class DemonstrationRecorder:
ds = tf.data.Dataset.from_generator(lambda: self._demos, types, shapes)
return ds.repeat().shuffle(len(self._demos))
def build_demonstrations(env: dm_env.Environment,
def generate_demonstration(env: dm_env.Environment,
dat_loader: minerl.data.data_pipeline.DataPipeline,
nb_experts: int = 5):
nb_experts: int = 10):
# Build demonstrations.
recorder = DemonstrationRecorder(env)
recorder._reset_episode()
# replay trajectories
trajectories = dat_loader.get_trajectory_names()
for t, trajectory in enumerate(trajectories):
if t < nb_experts:
logger.info({str(t): trajectory})
for i, (state, a, r, _, done, meta) in enumerate(dat_loader.load_data(trajectory, include_metadata=True)):
if done:
step_type = dm_env.StepType(2)
elif i == 0:
step_type = dm_env.StepType(0)
else:
step_type = dm_env.StepType(1)
ts = dm_env.TimeStep(observation=state, reward=r, step_type=step_type, discount=0)
recorder.step(ts, a)
logger.info(f"recording {t} expert")
recorder.record_episode()
return recorder.make_tf_dataset()
t = 0
while t < nb_experts:
for t, trajectory in enumerate(trajectories):
logger.info({str(t): trajectory})
for i, (state, a, r, _, done, meta) in enumerate(dat_loader.load_data(trajectory, include_metadata=True)):
if done:
step_type = dm_env.StepType(2)
elif i == 0:
step_type = dm_env.StepType(0)
else:
step_type = dm_env.StepType(1)
ts = dm_env.TimeStep(observation=state, reward=r, step_type=step_type, discount=0)
yield recorder.step(ts, a)
def main():
"""
......@@ -231,7 +223,7 @@ def main():
# Build demonstrations
logger.info("building the demonstration dataset")
demonstration_dataset = build_demonstrations(environment, data)
generator = generate_demonstration(environment, data)
logger.info("demonstration dataset is built..")
# Construct the network.
......@@ -246,9 +238,9 @@ def main():
target_network=target_network,
demonstration_dataset=demonstration_dataset,
demonstration_ratio=0.5,
batch_size=10,
batch_size=8,
samples_per_insert=2,
min_replay_size=10,
min_replay_size=1000,
burn_in_length=burn_in_length,
trace_length=trace_length,
replay_period=40, # per R2D3 paper.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment