https://flyte.org logo
#announcements
Title
# announcements
k

Ketan (kumare3)

03/18/2022, 9:30 PM
@Sören Brunk have you started working on KerasCheckpoint wrapper? cc @Niels Bantilan?
s

Sören Brunk

03/21/2022, 8:16 AM
The checkpointing I've done was was not for Keras but for the huggingface Trainer for PyTorch transformer models.
I'm not sure if it's worth it to create a dedicated plugin for this, or if we should just add it to the docs. It's essentially not much more than a callback for saving checkpoints:
Copy code
class SaveCheckpointCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        cp: Checkpoint = flytekit.current_context().checkpoint
        checkpoint_path = args.output_dir
        <http://logger.info|logger.info>("Saving checkpoint")
        cp.save(checkpoint_path)
And then some code for restoring the checkpoint:
Copy code
if enable_checkpointing:
        <http://logger.info|logger.info>("Checkpointing enabled. Trying to restore previous checkpoint")
        cp: Checkpoint = flytekit.current_context().checkpoint
        checkpoint_path = cp.restore(_training_args.output_dir)
        last_checkpoint_path = get_last_checkpoint(checkpoint_path)
        <http://logger.info|logger.info>(
            "Restored checkpoint" if last_checkpoint_path else "No checkpoint found")
    else:
        last_checkpoint_path = None

    trainer.train(resume_from_checkpoint=last_checkpoint_path)
33 Views