https://flyte.org logo
Title
e

Eli Bixby

03/14/2023, 1:31 PM
Is there an easy way to specify a that a task should recieve the serialized version of an output from a previous task? Like say I have a train task that outputs a
nn.Module
with the pytorch plugin, but I want a downstream task to receive the filepath that the
nn.Module
is saved to. Is there a way to do that?
s

Samhita Alla

03/14/2023, 1:38 PM
e

Eli Bixby

03/14/2023, 1:41 PM
Say I want to report the location of the serialized object to a third party tracking system. How would you suggest I do that? I could subclass the serialize and reregister it with the typeengine, but that seems like overkill.
s

Samhita Alla

03/14/2023, 1:44 PM
You can manually fetch the path. It should be available on the Flyte UI (in the inputs/outputs pane). But if you'd like to automate this process, I believe you'll need to re-register it.
e

Eli Bixby

03/14/2023, 1:45 PM
So you'd suggest I have a separate task that manually serializes it as a flyte file and then fetches the remote path of that flyte file.
(this AFAICT duplicates the artifact in storage and is a fair bit slower if it's a large artifact which is a bit of a bummer)
Would be cool to have something like:
output = task_a(...)
report = reporter(literal=flytekit.as_literal(output))
s

Samhita Alla

03/14/2023, 1:54 PM
That isn't possible cause workflow is essentially a DSL and you can't perform any operations as such. The best way forward IMO is to re-register. Here's the existing type transformer code for your reference: https://github.com/flyteorg/flytekit/blob/152080fa27c6d07edda9fcc2a602ca3ca9f4739a/flytekit/extras/pytorch/native.py#L24-L64.
e

Eli Bixby

03/14/2023, 1:55 PM
Yes I understand it's not possible now. It would essentially be a flag to the flytekit engine informing it not to deserialize in the consuming input task.
I think I'd actually prefer to subclass the typeengine and do the reporting in the serialization.
k

Ketan (kumare3)

03/14/2023, 2:16 PM
@Samhita Alla you do have to manually fetch, @Eli Bixby you are right, you can receive as a FlyteFile[format]
As long as the format matches the compiler will accept it
Why don’t you give it a try and let us know
e

Eli Bixby

03/14/2023, 2:17 PM
Can you clarify? Do I annotate as:
def reporter(location: FlyteFile[nn.Module]):
Not seeing anything about this in the docs.
k

Ketan (kumare3)

03/14/2023, 2:18 PM
Format=PyTorchModule
Ya no docs on this today
e

Eli Bixby

03/14/2023, 2:18 PM
Awesome thanks!
k

Ketan (kumare3)

03/14/2023, 2:19 PM
You are in the far trenches
But please contribute
We should expose the format more easily
e

Eli Bixby

03/14/2023, 2:20 PM
Ya no worries. I'm trying to integrate with a legacy system so that has me sprinting to the edges of the surface 😉
s

Samhita Alla

03/14/2023, 3:19 PM
you are right, you can receive as a FlyteFile[format]
Oh yeah. This has to work!
e

Eli Bixby

03/22/2023, 10:39 AM
Heyo, I'm trying this with a huggingface module and am getting:
TypeError: No automatic conversion found from type <class 'transformers.models.esm.modeling_esm.EsmForMaskedLM'> to FlyteFile.Supported (os.PathLike, str, Flytefile)
My type annotation looks like:
FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT]
(I'm on flytekit
1.4.2
)
s

Samhita Alla

03/23/2023, 5:06 AM
Looks like you're trying to use two different types. Would you mind sharing your Flyte tasks?
e

Eli Bixby

03/23/2023, 11:29 AM
Sure. Task in question is:
@wandb_task
def register_model(model: FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT],
                   model_name: str,
                   tokenizer: Optional[TokenizerSpec]) -> str:
    api = wandb.Api()
    artifact_id = f'{commons.FLYTE_CONTEXT.project}/{commons.FLYTE_CONTEXT.workflow_name}/{model_name}:{commons.FLYTE_CONTEXT.version}'
    artifact = api.artifact(artifact_id, type='model')
    artifact.add_reference(uri=model.remote_path)
    if tokenizer:
        artifact.metadata['tokenizer'] = tokenizer.to_dict()

    artifact.save()

    return artifact_id
I want to be able to pass it a
nn.Module
or a
FlyteFile
that represents an nn.Module, and I thought what @Ketan (kumare3) was suggesting is I could automatically do conversion/don't deserialize by using an annotated FlyteFile. e.g.
@workflow
def my_workflow(model: nn.Module) -> str:
    model = train_task(model=model)
    # here the pipelines still sees "model" as a `nn.Module`
    return register_model(model=model) # This task can take an `nn.Module` but will always receive it as a `FlyteFile`
And register_model would get the raw
FlyteFile
. As you can see the goal for the "register" task is to simply update a remote server with a reference to the remote location of the artifact. It would be wasteful to download that and then upload it to another location, just to convert the type from
nn.Module
to
FlyteFile
when there is already a remote location. I couldn't find anything in the logic for
FlyteFile
that would indicate this is the case, so perhaps I misunderstood. I'm wondering what the right solution to this pattern is (ideally one that doesn't involve redownloading the
nn.Module
just to upload it to a different location where I have the URI....my other solution was to essentially rewrite the
PyTorchModuleTransformer
to do this registration for me once it gets the remote_uri, but this feels a little wrong.
Nevermind. I got this to work, not sure what I was doing wrong, potentially because I was trying to do a Union type that contained the FlyteFile type?
k

Ketan (kumare3)

03/24/2023, 4:58 AM
yes union type wont work
as union is not same as a univariate type