melodic-magician-71351
03/14/2023, 1:31 PMnn.Module
with the pytorch plugin, but I want a downstream task to receive the filepath that the nn.Module
is saved to. Is there a way to do that?tall-lock-23197
melodic-magician-71351
03/14/2023, 1:41 PMtall-lock-23197
melodic-magician-71351
03/14/2023, 1:45 PMmelodic-magician-71351
03/14/2023, 1:45 PMmelodic-magician-71351
03/14/2023, 1:47 PMoutput = task_a(...)
report = reporter(literal=flytekit.as_literal(output))
tall-lock-23197
melodic-magician-71351
03/14/2023, 1:55 PMmelodic-magician-71351
03/14/2023, 1:57 PMfreezing-airport-6809
freezing-airport-6809
freezing-airport-6809
melodic-magician-71351
03/14/2023, 2:17 PMdef reporter(location: FlyteFile[nn.Module]):
melodic-magician-71351
03/14/2023, 2:17 PMfreezing-airport-6809
freezing-airport-6809
melodic-magician-71351
03/14/2023, 2:18 PMfreezing-airport-6809
freezing-airport-6809
freezing-airport-6809
melodic-magician-71351
03/14/2023, 2:20 PMtall-lock-23197
you are right, you can receive as a FlyteFile[format]Oh yeah. This has to work!
melodic-magician-71351
03/22/2023, 10:39 AMTypeError: No automatic conversion found from type <class 'transformers.models.esm.modeling_esm.EsmForMaskedLM'> to FlyteFile.Supported (os.PathLike, str, Flytefile)
melodic-magician-71351
03/22/2023, 10:40 AMFlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT]
melodic-magician-71351
03/22/2023, 10:42 AM1.4.2
)tall-lock-23197
melodic-magician-71351
03/23/2023, 11:29 AM@wandb_task
def register_model(model: FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT],
model_name: str,
tokenizer: Optional[TokenizerSpec]) -> str:
api = wandb.Api()
artifact_id = f'{commons.FLYTE_CONTEXT.project}/{commons.FLYTE_CONTEXT.workflow_name}/{model_name}:{commons.FLYTE_CONTEXT.version}'
artifact = api.artifact(artifact_id, type='model')
artifact.add_reference(uri=model.remote_path)
if tokenizer:
artifact.metadata['tokenizer'] = tokenizer.to_dict()
artifact.save()
return artifact_id
I want to be able to pass it a nn.Module
or a FlyteFile
that represents an nn.Module, and I thought what @freezing-airport-6809 was suggesting is I could automatically do conversion/don't deserialize by using an annotated FlyteFile. e.g.
@workflow
def my_workflow(model: nn.Module) -> str:
model = train_task(model=model)
# here the pipelines still sees "model" as a `nn.Module`
return register_model(model=model) # This task can take an `nn.Module` but will always receive it as a `FlyteFile`
And register_model would get the raw FlyteFile
. As you can see the goal for the "register" task is to simply update a remote server with a reference to the remote location of the artifact. It would be wasteful to download that and then upload it to another location, just to convert the type from nn.Module
to FlyteFile
when there is already a remote location.
I couldn't find anything in the logic for FlyteFile
that would indicate this is the case, so perhaps I misunderstood.
I'm wondering what the right solution to this pattern is (ideally one that doesn't involve redownloading the nn.Module
just to upload it to a different location where I have the URI....my other solution was to essentially rewrite the PyTorchModuleTransformer
to do this registration for me once it gets the remote_uri, but this feels a little wrong.melodic-magician-71351
03/23/2023, 1:07 PMfreezing-airport-6809
freezing-airport-6809