When using the training operator and PyTroch elast...
# flyte-support
f
When using the training operator and PyTroch elastic with Flyte, how does Flyte decide which pod's error to propagate at the task level? Say I have a 2 node PyTorch elastic training. And say both pods have a failure reported by PyTorch elastic. Which one propagates at the Flyte task level?
f
pod with rank-0
f
That doesn’t match what I’m observing.
Or let me put it this way: I have a counter observation to that.
Also, shouldn’t it be the one that failed the first? (That’s PyTorch elastic agent’s approach for error propagation within a node and makes sense to apply similar across nodes)
but we could change that, even add a custom way of handling this
f
I took rank 0 to mean PyTorch rank 0. I think you mean task with index 0. With c10d rdvz, rank 0 and task 0 are not the same.
f
🤷‍♂️
f
@freezing-airport-6809 The code you pointed above is the happy case. I was asking about what happens when there are multiple nodes that throw an exception from elastic_launch.
f
but there should only ever be one rank-0 or 0 right?
maybe i am reading it wrong
cc @cool-lifeguard-49380 may know more
f
Here is the code that handles error cases:
Copy code
try:
            out = elastic_launch(
                config=config,
                entrypoint=launcher_target_func,
            )(*launcher_args)
        except ChildFailedError as e:
            _, first_failure = e.get_first_failure()
            if is_recoverable_worker_error(first_failure):
                raise FlyteRecoverableException(e.format_msg())
            else:
                raise RuntimeError(e.format_msg())
        except SignalException as e:
            logger.exception(f"Elastic launch agent process terminating: {e}")
            raise IgnoreOutputs()
It assumes that elastic launch either throws ChildFailedError or SignalException. Here is what happens: • We have 2 node training (0-7 ranks on one node, 8-15 ranks on another node)
• And we have 2 elastic_launches across 2 pods
• One of the ranks on one node throws a ChildFailedError, the other on another node throws a RendezvousClosedError • Flyte sometimes reports one and sometimes the other as the failure reason of the entire task.
@cool-lifeguard-49380 Would appreciate if you have any insight into how Flyte determines the final error for the elastic task. The code I see in the pytorch Python plugin cannot handle this properly IMO, especially if we want to report the first error across elastic_launches from different pods.
f
i do not think it reports the first error you re right Bugra
it returns the error collected on rank-0
returning first error is indeed hard, as you will need some sort of consensus
f
It doesn’t return the error collected on rank 0 based on my reading of the code, but I might be mistaken. SignalExceptions are ignored, and if there is only one ChildError across nodes, then all is well. If there are multiple exceptions across nodes, then the handling is somewhere else, on the golang side I assume. We have training runs with 128 nodes or more. It is not uncommon for elastic agent to throw various exceptions on multiple nodes. Reporting the first error based on timestamp is almost always reliable. Would be easy to support, but don’t know where to look yet, as the PyTorch plugin doesn’t seem to deal with the multi node error propagation.
c
@fierce-oil-47448 you are correct that we don't properly handle this, we discussed this back in this PR 😿 Letting the worker pods (as opposed to the backend) figure out the exception with the earliest time stamp would require opening a new distributed process group + communicating over torch dist. I feel this would be very brittle, especially since some workers might already be down and not able to join the rendezvous. Instead, I think we need to modify flytekit's pod entrypoint to allow choosing different error pb file names depending in this case on the group rank. (Maybe the entrypoint "can ask" the task whether it wants to customize the error file name?) Under the respective bucket uri there would then be multiple error files
error-0.pb
,
error-1.pb
, ... and the backend plugin would need to go through them and pick the one with the earliest time stamp. It would be nice if there was a general solution for plugins to be able to have multiple error files as torch is not the only plugin with this conceptional problem. Would you be willing to colaborate on this with me @fierce-oil-47448? @freezing-airport-6809 maybe we can discuss in the next contrib sync how exactly this should be done.
f
Hi @cool-lifeguard-49380. Yes, I would be happy to collaborate. And I agree with the technical direction. It would be very fragile to try to do this at the task pod level. I would love to understand how this works currently as well. Are different task pods overwriting the same error file today? In my minimal reproduce, we have 2 pods, each getting an exception from its elastic_launch. How does Flyte decide to report one or the other.
c
I never debugged this to the end but I would assume they overwrite each other and there is a race condition 😭
Do you have time to do an experiment? (Please let me know if you prefer to not be given tasks!! 🙂) If the PythonFunctionTask (or some lower base class) had a function called
get_error_file_name
which by default returned None but which the elastic task would implement to return a file name based on the group rank env var, would we be able to call that function and access the error file name in in pod entrypoint script? That would already solve the uploading + not overwriting part. We’d still need logic in the backend to read multiple files of course …
👍 1
f
Sure, will look into this.
🙏 2
🙏🏽 1
@cool-lifeguard-49380 See https://github.com/flyteorg/flytekit/pull/2607/files for an early attempt at making the changes you suggested. A few notes: • GROUP_RANK is not available at the level of the entry point. It is available, only in the PyTorch child processes launched by elastic_launch • The concept of a replica index exists at the Training Operator level, which we can use • We may need an update to Training Operator to pass in the replica index as an environment variable. Or perhaps we can do it with some kind of override in Flyte itself. What's needed is to customize the pod spec to set an environment variable that captures the value of
metadata.labels."<http://training.kubeflow.org/replica-index|training.kubeflow.org/replica-index>"
:
Copy code
env:
 - name: FLYTE_PYTORCH_TASK_REPLICA_INDEX
   valueFrom:
     fieldRef:
       fieldPath: metadata.labels['training.kubeflow.org/replica-index']
Any suggestions on how to do that?
Btw, what are the backend changes to deal with multiple error files?
c
Thank you for looking into this! I replied in the PR.
About backend changes: First of all a disclaimer, I don't have a complete picture yet and I think once we have an idea/working prototype, we'll need to run this via other maintainers, maybe in the form of an RFC but I can help with or entirely take over that process, as you prefer 🙂
I did look into potential ways to handle this in the backend:
First of all, I think we might need to add an optional timestamp field to the error proto message. Otherwise the backend can't sort by timestamp. I'm unsure how we can get the timestamp from the
ChildError
in the elastic task into this error proto message but maybe we can use a proxy timestamp measured in the entrypoint? Or we introduce a new exception type in flytekit which has a timestamp attribute which we can set and raise in the elastic task? And the entrypoint would then check whether this type of exception was raised and would fill the value in the error proto. (Reading this again, this is probably better than measuring a second timestamp in the entrypoint, don't want to create another race condition.)
Currently, here in the plugin manager, upon completion of a node execution, a new output reader is constructed. The
tCtx.OutputWriter()
which is passed to
outPaths
of
NewRemoteFileOutputReader
contains a reference to the bucket uri which contains the `error.pb`:
<protocol>://<bucket name>/metadata/propeller/<project>-<domain>-<execution id>/n0/data/0
Notably, the
outPaths
does not directly point to the
error.pb
but to the “directory” containing it.
This is the same bucket uri that is passed to
--output-prefix
to
pyflyte-execute
.
RemoteFileOutputReader
implements the `OutputReader` interface.
I think we'll need another implementation of that interface, e.g. called
MultiErrorFileRemoteFileOutputReader
(maybe "earliest timestamp" should make it into the name). This would not try to read the
<protocol>://<bucket name>/metadata/propeller/<project>-<domain>-<execution id>/n0/data/0/error.pb
but would search the
<protocol>://<bucket name>/metadata/propeller/<project>-<domain>-<execution id>/n0/data/0
output prefix for files called e.g.
error-*.pb
. It would then figure out which one has the earliest timestamp.
here in the plugin manager, we need to have a way to figure out which output reader the plugin requires. I think we can do sth along the following lines:
We could add a field
MultipleErrorFiles
to
PluginProperties
, see here. (Again, "earliest time stamp" should probably make it into that name ^^).
The PyTorch plugin, for instance, would then pass
true
for
MultipleErrorFiles
here. Or add an OutputReader here?
Currently, here in the plugin manager, where we call
NewRemoteFileOutputReader
, we do have access to
e.plugin
, and thus to
PluginProperties
and could make use of that information to instantiate another output reader.
I'm fully aware that all this might be super confusing @fierce-oil-47448, as I said, I also don't have a 100% clear picture yet, sorry 🙇
Super happy to discuss this in a call if you want. Also, do you have a setup that allows you to use a debugger/breakpoints in flytepropeller? Otherwise I think it will be difficult to dig into this. Happy to help you get a working setup if not.
@freezing-airport-6809 is there anyone from Union side who has an overview over this part of the code base and would have some time to sanity check the approach I proposed ☝️? Would be good to know I'm not overlooking a better fitting place to inject this logic. Can then drive this forward with @fierce-oil-47448.
f
@cool-lifeguard-49380 I think timestamp is needed in
ContainerError
(here) and not in
Error
, as the former is what entry point uses to capture the error. Everything you described makes a lot of sense. These changes look reasonably straightforward to me. I can help with the implementation, but I'd be slow (due to other priorities at work). Also, I'll need help with a set up where I can try backend changes (never done that).
c
Also, I'll need help with a set up where I can try backend changes (never done that).
Happy to help, in which time zone are you?
I can also take over some part of the implementation but I'd also be somewhat slow atm 🙈
🤷
f
I'm PST
c
I’ll dm you
f
The thing is many folks are out this week, they will be back next week
I can definitely get someone to overview
Cc @high-accountant-32689 ?
h
Sorry I missed this. Let me catch up.
c
The thing is many folks are out this week, they will be back next week
No time pressure
h
Let me summarize the discussion up to this point. The proposed solution handles error propagation in Flyte's multi-node PyTorch elastic training by using multiple error files and selecting the earliest error based on timestamps. The breakdown: 1. Add Timestamp to Error Messages: Introduce an optional timestamp field to the error proto message to enable sorting errors by their occurrence time. a. The
error.pb
protobuf message is an ErrorDocument b. where we add the timestamp? in
ContainerError
or
ErrorDocument
c. Do we have to worry about time drift? (different containers having different clocks) 2. Custom Error File Naming: Modify the Flytekit pod entrypoint to generate multiple error files (e.g., error-0.pb, error-1.pb) based on the group rank. This avoids overwriting errors from different pods. a. in https://github.com/flyteorg/flytekit/pull/2607/ we're generating random suffixes instead 3. Proposed backend Changes: a. Implement a new output reader (e.g., MultiErrorFileRemoteFileOutputReader) in Flyte's plugin manager to read these multiple error files. This reader will search for files matching the pattern (e.g., error-*.pb) and select the one with the earliest timestamp. i. Flyte turns
ErrorDocument
into an
ExecutionError
here and this is what ends up being reported downstream. We can leverage that in the implementation of
ReadError
in this new output reader. Not sure if we have to turn this into a policy (i.e. does it always make sense to pick the first error sorted by timestamp or should that be configurable?) b. Task Pod Modifications: Ensure the task pods write unique error files based on the group rank or replica index. i. If we want to encode the rank in the error message, the PyTorch operator exposes the rank as an environment variable (see https://github.com/kubeflow/training-operator/blob/master/pkg/controller.v1/pytorch/envvar.go#L89-L96). * This
PET_*
business seems to be a pytorch quirk introduced some time ago (I couldn't find what
PET
stands for) c. Additional updates to the Training Operator might be needed to ensure the replica index is available in the environment variables.
c
> The proposed solution handles error propagation in Flyte's multi-node PyTorch elastic training by using multiple error files and selecting the earliest error based on timestamps. Yes but if I'm not mistaken this is not only for torch elastic tasks but also for other distributed training plugins where each worker pod might write the error.pb. It's certainly also the case for non-elastic pytorch trainings and potentially for tensorflow, mpi, ...? (I don't know whether the latter come with a mechanism that propagates any errors to a dedicated pod which is the only one to write the error.pb) I think it would be good if we build a general solution for all plugins that rely on multiple pods - nothing too specific to elastic. > c. Do we have to worry about time drift? (different containers having different clocks) I think we can deal with this in a best-effort manner. If time drift might happen, causing us to report the wrong error, one can still combine the logs of all pods in stackdriver logs, filter by severity, and search for the first error(s). But it will save a lot of time, if we can get it right often enough ^^ (Or is there a low hanging fruit for how we can sync the clocks?) > generate multiple error files (e.g., error-0.pb, error-1.pb) based on the group rank. In the case of elastic, we unfortunately don't have easy access to the group rank in the entrypoint. It is not set as an env var, I ran a test yesterday. What you linked here, is only set for non-elastic pytorch tasks. I think it doesn't matter though, if Flyte would inject the pod name as an env var (or we use the downstream api), we could just use that - in the end the group rank was just a proposal to allow easy identification in the flyte console which pod died first. We can use the pod name for that - which would work for all distributed plugins regardless of what concepts they use for ranks etc.
👍 2
Not sure if we have to turn this into a policy (i.e. does it always make sense to pick the first error sorted by timestamp or should that be configurable?)
I also thought about making this a config or policy. I think we should do that so that when another dist plugin in the future needs a slightly different appriach, we don't have to write a new output reader but just a new policy.
Since this can be relevant for various plugins which rely on multiple pods being created, let's persist the discussion and outcomes in an RFC 🙂
🎉 3
🔥 2
f
Thanks @cool-lifeguard-49380 for putting up the RFC ❤️. I've added some comments (backward compatibility is an area that may need some work imo).
c
Thank you for the comments 🙏
f
@cool-lifeguard-49380 Here is what I'm thinking in terms of the implementation plan: • Set environment variables for FLYTE_INTERNAL_POD_NAME and FLYTE_INTERNAL_ERROR_PROPAGATION • Add entry point support for multi error file upload based on above • Add exception timestamps and timestamp handling in PyTorch plugin • Add MultiErrorFileRemoteFileOutputReader • Add configuration support for enabling MultiErrorFileRemoteFileOutputReader for PyTorch plugin (and other relevant ones) Let me know what you think.
c
Fully agree 👍
• Add entry point support for multi error file upload based on above
This requires a flyteidl change to add timestamp to error document.
f
Can I get an approval on this https://github.com/flyteorg/flyte/pull/5616 I think there was a small test file change on the spark support, that still needs someone to approve cc: @high-accountant-32689
h
@fierce-oil-47448, left one minor comment.
f
Thanks @high-accountant-32689, addressed
👍 1
h
@fierce-oil-47448, I think you committed a few changes by mistake in your last update.
f
Really sorry for that @high-accountant-32689. Let me fix. Didn't notice I had a dirty branch.
Fixed
h
go doesn't like python comments 🙂 Fixed in https://github.com/flyteorg/flyte/pull/5616/commits/63a1faaba0b9b0dc642938bdad484d1fe412a630. Should merge as soon as tests pass.
f
Haha, Tx. @high-accountant-32689 Would appreciate a look at https://github.com/flyteorg/flyte/pull/5674/files when you get a chance as well. This adds listing support to storage interfaces. I added it for stow storage. One specific feedback I'm looking for is to understand if this needs to be added for all storage kinds.
h
cool, I will take a look later today.
f
@high-accountant-32689 One more approval here https://github.com/flyteorg/flyte/pull/5616/files#diff-4798c57d1e223aea7c8724d4401e259fe395da096b0eed7766493e8ae342a336 Had to touch another test file. All tests look healthy now.
Interesting, I see a flytestdlib failure here: https://github.com/flyteorg/flyte/actions/runs/10478829680/job/29023346175?pr=5616 but completely unrelated to my changes. Will sync with upstream and try again.