Is there an easy way to specify a that a task shou...
# ask-the-community
e
Is there an easy way to specify a that a task should recieve the serialized version of an output from a previous task? Like say I have a train task that outputs a
nn.Module
with the pytorch plugin, but I want a downstream task to receive the filepath that the
nn.Module
is saved to. Is there a way to do that?
s
e
Say I want to report the location of the serialized object to a third party tracking system. How would you suggest I do that? I could subclass the serialize and reregister it with the typeengine, but that seems like overkill.
s
You can manually fetch the path. It should be available on the Flyte UI (in the inputs/outputs pane). But if you'd like to automate this process, I believe you'll need to re-register it.
e
So you'd suggest I have a separate task that manually serializes it as a flyte file and then fetches the remote path of that flyte file.
(this AFAICT duplicates the artifact in storage and is a fair bit slower if it's a large artifact which is a bit of a bummer)
Would be cool to have something like:
Copy code
output = task_a(...)
report = reporter(literal=flytekit.as_literal(output))
s
That isn't possible cause workflow is essentially a DSL and you can't perform any operations as such. The best way forward IMO is to re-register. Here's the existing type transformer code for your reference: https://github.com/flyteorg/flytekit/blob/152080fa27c6d07edda9fcc2a602ca3ca9f4739a/flytekit/extras/pytorch/native.py#L24-L64.
e
Yes I understand it's not possible now. It would essentially be a flag to the flytekit engine informing it not to deserialize in the consuming input task.
I think I'd actually prefer to subclass the typeengine and do the reporting in the serialization.
k
@Samhita Alla you do have to manually fetch, @Eli Bixby you are right, you can receive as a FlyteFile[format]
As long as the format matches the compiler will accept it
Why don’t you give it a try and let us know
e
Can you clarify? Do I annotate as:
Copy code
def reporter(location: FlyteFile[nn.Module]):
Not seeing anything about this in the docs.
k
Format=PyTorchModule
Ya no docs on this today
e
Awesome thanks!
k
You are in the far trenches
But please contribute
We should expose the format more easily
e
Ya no worries. I'm trying to integrate with a legacy system so that has me sprinting to the edges of the surface 😉
s
you are right, you can receive as a FlyteFile[format]
Oh yeah. This has to work!
e
Heyo, I'm trying this with a huggingface module and am getting:
TypeError: No automatic conversion found from type <class 'transformers.models.esm.modeling_esm.EsmForMaskedLM'> to FlyteFile.Supported (os.PathLike, str, Flytefile)
My type annotation looks like:
FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT]
(I'm on flytekit
1.4.2
)
s
Looks like you're trying to use two different types. Would you mind sharing your Flyte tasks?
e
Sure. Task in question is:
Copy code
@wandb_task
def register_model(model: FlyteFile[PyTorchModuleTransformer.PYTORCH_FORMAT],
                   model_name: str,
                   tokenizer: Optional[TokenizerSpec]) -> str:
    api = wandb.Api()
    artifact_id = f'{commons.FLYTE_CONTEXT.project}/{commons.FLYTE_CONTEXT.workflow_name}/{model_name}:{commons.FLYTE_CONTEXT.version}'
    artifact = api.artifact(artifact_id, type='model')
    artifact.add_reference(uri=model.remote_path)
    if tokenizer:
        artifact.metadata['tokenizer'] = tokenizer.to_dict()

    artifact.save()

    return artifact_id
I want to be able to pass it a
nn.Module
or a
FlyteFile
that represents an nn.Module, and I thought what @Ketan (kumare3) was suggesting is I could automatically do conversion/don't deserialize by using an annotated FlyteFile. e.g.
Copy code
@workflow
def my_workflow(model: nn.Module) -> str:
    model = train_task(model=model)
    # here the pipelines still sees "model" as a `nn.Module`
    return register_model(model=model) # This task can take an `nn.Module` but will always receive it as a `FlyteFile`
And register_model would get the raw
FlyteFile
. As you can see the goal for the "register" task is to simply update a remote server with a reference to the remote location of the artifact. It would be wasteful to download that and then upload it to another location, just to convert the type from
nn.Module
to
FlyteFile
when there is already a remote location. I couldn't find anything in the logic for
FlyteFile
that would indicate this is the case, so perhaps I misunderstood. I'm wondering what the right solution to this pattern is (ideally one that doesn't involve redownloading the
nn.Module
just to upload it to a different location where I have the URI....my other solution was to essentially rewrite the
PyTorchModuleTransformer
to do this registration for me once it gets the remote_uri, but this feels a little wrong.
Nevermind. I got this to work, not sure what I was doing wrong, potentially because I was trying to do a Union type that contained the FlyteFile type?
k
yes union type wont work
as union is not same as a univariate type
160 Views