Ketan (kumare3)
Sören Brunk
03/21/2022, 8:16 AMclass SaveCheckpointCallback(TrainerCallback):
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
cp: Checkpoint = flytekit.current_context().checkpoint
checkpoint_path = args.output_dir
<http://logger.info|logger.info>("Saving checkpoint")
cp.save(checkpoint_path)
if enable_checkpointing:
<http://logger.info|logger.info>("Checkpointing enabled. Trying to restore previous checkpoint")
cp: Checkpoint = flytekit.current_context().checkpoint
checkpoint_path = cp.restore(_training_args.output_dir)
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
<http://logger.info|logger.info>(
"Restored checkpoint" if last_checkpoint_path else "No checkpoint found")
else:
last_checkpoint_path = None
trainer.train(resume_from_checkpoint=last_checkpoint_path)