Commit 011cc73a authored by dzorlu's avatar dzorlu

checkpointing. fix discount bug

parent 9cfcb266
......@@ -119,7 +119,7 @@ class DemonstrationRecorder:
return action
def step(self, timestep: dm_env.TimeStep, action: np.ndarray):
reward = np.array(timestep.reward or 0, np.float32)
reward = np.array(timestep.reward or 0., np.float32)
self._episode_reward += reward
# this imitates the enviroment step to create data in the same format.
new_timestep = self._augment_observation(timestep)
......@@ -130,7 +130,7 @@ class DemonstrationRecorder:
def _augment_observation(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
ovar = OVAR(observation=timestep.observation['pov'].astype(np.float32),
obs_vector=timestep.observation['vector'],
obs_vector=timestep.observation['vector'].astype(np.float32),
action=self._prev_action,
reward=self._prev_reward)
return timestep._replace(observation=ovar)
......@@ -172,15 +172,15 @@ class DemonstrationRecorder:
def generate_demonstration(env: dm_env.Environment,
dat_loader: minerl.data.data_pipeline.DataPipeline,
nb_experts: int = 10):
nb_experts: int = 20):
# Build demonstrations.
recorder = DemonstrationRecorder(env)
recorder._reset_episode()
# replay trajectories
trajectories = dat_loader.get_trajectory_names()
t = 0
while t < nb_experts:
for t, trajectory in enumerate(trajectories):
for t, trajectory in enumerate(trajectories):
if t < nb_experts:
logger.info({str(t): trajectory})
for i, (state, a, r, _, done, meta) in enumerate(dat_loader.load_data(trajectory, include_metadata=True)):
if done:
......@@ -189,7 +189,10 @@ def generate_demonstration(env: dm_env.Environment,
step_type = dm_env.StepType(0)
else:
step_type = dm_env.StepType(1)
ts = dm_env.TimeStep(observation=state, reward=r, step_type=step_type, discount=0)
ts = dm_env.TimeStep(observation=state,
reward=r,
step_type=step_type,
discount=np.array(1., dtype=np.float32))
yield recorder.step(ts, a)
def main():
......@@ -230,13 +233,14 @@ def main():
network = create_network()
target_network = create_network()
logger.info(f"model directory: {model_dir}")
# sequence_length = burn_in_length + trace_length
agent = r2d3.R2D3(
model_directory=model_dir,
environment_spec=spec,
network=network,
target_network=target_network,
demonstration_dataset=demonstration_dataset,
demonstration_generator=generator,
demonstration_ratio=0.5,
batch_size=8,
samples_per_insert=2,
......@@ -244,8 +248,8 @@ def main():
burn_in_length=burn_in_length,
trace_length=trace_length,
replay_period=40, # per R2D3 paper.
checkpoint=False,
logger=agent_logger
checkpoint=True,
#logger=agent_logger
)
# Run the env loop
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment