I have the following issue with FlyteFile + single...
# flyte-support
f
I have the following issue with FlyteFile + single node/multi-GPU PyTorch plugin: 1. When task is decorated only with
@task()
, the following code works just fine:
Copy code
@task()
def my_task(dataset: FlyteFile):
    path = dataset.download()
    # works as intended
    assert Path(path).is_file()

@workflow
def my_workflow():
    file = FlyteFile(
        path="<s3://path/to/my/file.csv>"
    )

    outputs = my_task(dataset=file)
2. When task is decorated with
@task(task_config=Elastic(nnodes=1, nproc_per_node=4)
, it breaks
Copy code
@task(task_config=Elastic(nnodes=1, nproc_per_node=4)
def my_task(dataset: FlyteFile):
    path = dataset.download()
    # .download() immediately returns and file is not there
    assert Path(path).is_file() # this will raise

@workflow
def my_workflow():
    file = FlyteFile(
        path="<s3://path/to/my/file.csv>"
    )

    outputs = my_task(dataset=file)
There's also a warning raised:
Copy code
.venv/lib/python3.12/site-packages/flytekit/types/file/file.py:356: RuntimeWarning: coroutine 'FileAccessProvider.async_get_data' was never awaited
It seems like the FlyteFile does not play well with underlying multiprocessing spawn This happens on
Flyte 1.15.3
f
Good catch. It should not stay open in the spawn
In v2 we have changed the architecture of flytefile, this should not happen, but let’s look into this for v1 too
f
It should not stay open in the spawn
WDYM?
Hello?
f
For the multi you train we spawn torchrun spawns many processes
f
I was asking about "It should not stay open in the spawn"
Can you provide a solution/workaround for that @freezing-airport-6809?
f
But this is just a warning
Should not impact anything
It’s a fortune handle that was shared
Not run
f
No, the warning is "additional" - the main reason I wrote this because it actually breaks
f
Ok I need to reproduce
🤞 1
i wont get to it for a bit, but will do. cc @echoing-account-76888 if you happen to look into this?
🙏 1
🫡 1
e
Sure! I'll look into this
🙏 1
Just a quick update, I can reproduce this on my side. The error occurs on both local and remote. I’m working on a fix. Thank you for your patience! 🙏
🤞 2
Fixed in https://github.com/flyteorg/flytekit/pull/3313 I think it's an edge case using
loop_manager.synced
+ spawn. I leave more context in the PR description, feel free to have a look and leave comments!
🔥 1
f
Cc @thankful-minister-83577 fyi for v2? Shouldn’t need it
f
Awesome @echoing-account-76888, thanks for fixing it! 🎖️ @freezing-airport-6809 will v2 still support PyTorch Elastic? (via decorator, same as in v1)?
f
Absolutely
We are working on it
You can contribute too
We already have Ray spark and dask
Do you have any improvement suggestions
f
I has quietly hoping for a decorator for PyTorch that wouldn't require installing 3rd party software in K8s besides Flyte itself 🤞
f
Wdym
For distributed training?
f
Yes - right now, if you want to do distributed training on > 1 node, you have to either install KF training operator or Ray on the cluster, which adds complexity. PyTorch is THE FRAMEWORK for neural networks right now, so having 1st party support for it in Flyte itself would be 🎖️
f
@flat-waiter-82487 we have started working on honest support
I don’t think it’s a lot of work, but we are working through how the experience should be
You are right, we concluded in Flyte 2 we will ship oob with jobset based multinode training
💜 1
It will available in Flyte 1 too
💜 1
f
Aw yeah! 🦜
f
will take some time
@flat-waiter-82487 what would you like to improve in it?
f
Functionally I would like the following things: 1. Ability to decorate a task to make it distributed (so 2 params in it at least: num_nodes, num_proc_per_node - same as now?) 2. Support for TorchElastic and fault tolerance during training with auto-recovery 3. No requirement of installing 3rd party extensions to the cluster besides Flyte and its CRDs (CRDs coming from Flyte itself would be fine) 4. Ability to specify num_nodes/num_proc_per_node dynamically (e.g. via task param or something) - right now as far as I know, it has to be "baked in" into the task definition = somehow hardcoded - this is a painpoint during development, because in development we usually use small machines with 2-4 GPUs and then we go full throttle on larger machines with 8 GPUs and we have to re-configure the code manually to do it instead of just swapping
--num_proc_per_node=8
or sth.