Eli Bixby
03/14/2023, 1:31 PMnn.Module
with the pytorch plugin, but I want a downstream task to receive the filepath that the nn.Module
is saved to. Is there a way to do that?Samhita Alla
03/14/2023, 1:38 PMEli Bixby
03/14/2023, 1:41 PMSamhita Alla
03/14/2023, 1:44 PMEli Bixby
03/14/2023, 1:45 PMoutput = task_a(...)
report = reporter(literal=flytekit.as_literal(output))
Samhita Alla
03/14/2023, 1:54 PMEli Bixby
03/14/2023, 1:55 PMKetan (kumare3)
03/14/2023, 2:16 PMEli Bixby
03/14/2023, 2:17 PMdef reporter(location: FlyteFile[nn.Module]):
Ketan (kumare3)
03/14/2023, 2:18 PMEli Bixby
03/14/2023, 2:18 PMKetan (kumare3)
03/14/2023, 2:19 PMEli Bixby
03/14/2023, 2:20 PMSamhita Alla
03/14/2023, 3:19 PMyou are right, you can receive as a FlyteFile[format]Oh yeah. This has to work!
Eli Bixby
03/22/2023, 10:39 AMTypeError: No automatic conversion found from type <class 'transformers.models.esm.modeling_esm.EsmForMaskedLM'> to FlyteFile.Supported (os.PathLike, str, Flytefile)
FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT]
1.4.2
)Samhita Alla
03/23/2023, 5:06 AMEli Bixby
03/23/2023, 11:29 AM@wandb_task
def register_model(model: FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT],
model_name: str,
tokenizer: Optional[TokenizerSpec]) -> str:
api = wandb.Api()
artifact_id = f'{commons.FLYTE_CONTEXT.project}/{commons.FLYTE_CONTEXT.workflow_name}/{model_name}:{commons.FLYTE_CONTEXT.version}'
artifact = api.artifact(artifact_id, type='model')
artifact.add_reference(uri=model.remote_path)
if tokenizer:
artifact.metadata['tokenizer'] = tokenizer.to_dict()
artifact.save()
return artifact_id
I want to be able to pass it a nn.Module
or a FlyteFile
that represents an nn.Module, and I thought what @Ketan (kumare3) was suggesting is I could automatically do conversion/don't deserialize by using an annotated FlyteFile. e.g.
@workflow
def my_workflow(model: nn.Module) -> str:
model = train_task(model=model)
# here the pipelines still sees "model" as a `nn.Module`
return register_model(model=model) # This task can take an `nn.Module` but will always receive it as a `FlyteFile`
And register_model would get the raw FlyteFile
. As you can see the goal for the "register" task is to simply update a remote server with a reference to the remote location of the artifact. It would be wasteful to download that and then upload it to another location, just to convert the type from nn.Module
to FlyteFile
when there is already a remote location.
I couldn't find anything in the logic for FlyteFile
that would indicate this is the case, so perhaps I misunderstood.
I'm wondering what the right solution to this pattern is (ideally one that doesn't involve redownloading the nn.Module
just to upload it to a different location where I have the URI....my other solution was to essentially rewrite the PyTorchModuleTransformer
to do this registration for me once it gets the remote_uri, but this feels a little wrong.Ketan (kumare3)
03/24/2023, 4:58 AM