hey <@U04664Z7H37> I wanted to get your feedback o...
# contribute
b
hey @cool-lifeguard-49380 I wanted to get your feedback on this idea: context: I want to improve the UX of using the
Elastic plugin
. Currently, to use it with pytorch distributed data parallel, you need to manually specify a custom pod template like so:
❤️ 2
train.py
Namely:
Copy code
custom_pod_template = PodTemplate(
    primary_container_name="flytesnacks-pytorch-lightning",
    pod_spec=V1PodSpec(
        containers=[
            V1Container(
                name="flytesnacks-pytorch-lightning",
                volume_mounts=[V1VolumeMount(mount_path="/dev/shm", name="dshm")]
            )
        ],
        volumes=[
            V1Volume(
                name="dshm",
                empty_dir=V1EmptyDirVolumeSource(medium="", size_limit="200Gi")
            )
        ]
    ),
)
And then the task def is:
Copy code
@task(
    container_image=custom_image,
    task_config=Elastic(
        nnodes=NUM_NODES,
        nproc_per_node=NUM_DEVICES,
        rdzv_configs={"timeout": 36000, "join_timeout": 36000},
        max_restarts=3,
    ),
    accelerator=T4,
    requests=Resources(mem="32Gi", cpu="48", gpu="8", ephemeral_storage="100Gi"),
    pod_template=custom_pod_template,
)
def train_model() -> FlyteDirectory:
I’d like to get rid of the boilerplate and automatically add this to the
Elastic
task config, perhaps exposing some options like
dshm_mount_path
,
dshm_size_limit
, etc.
This would require adding
kubernetes
as a dependency to the
flytekitplugins-kfpytorch
plugin Also need to do this in a way that nicely handles
primary_container_name
in the
PodTemplate
definition and specifying
pod_template
in the
@task
definition would override the one provided by the
Elastic
task config
c
💯
Absolutely agree with the goal.
Let’s talk about details 🙂
👍 1
b
ok, will make a github issue.
👍 1
c
1. Shared memory: Agree that one can’t really do torch multi-worker training without increasing the shared memory. We do so in the default pod template in the flyte namespace as opposed to the task pod template. But since this is more difficult to do for practitioners, I agree that it would be nice if the elastic plugin provided a way. We don’t specify an amount and never ran into any issues with this though:
Copy code
volumeMounts:
        - mountPath: /dev/shm
          name: dshm
  volumes:
    - name: dshm
      emptyDir:
        medium: Memory
So I wonder whether a flag like
increase_shared_memory
is enough. Have you encountered a situation where specifying an amount was required? A Questions to answer for me would be whether we try to merge this into the pod template a user might have provided or whether this flag should only work if the user doesn’t provide a pod template.
2. primary container name > Also need to do this in a way that nicely handles
primary_container_name
in the
PodTemplate
definition In our pod template in the flyte namespace we just name the container
default
and flyteplugins/pytorch just renames the container to pytorch. What’s the use case for naming the container something different like
flytesnacks-pytorch-lightning
?
3. kubernetes dependency Flytekit was able to remove this dependency, right? I wouldn’t mind adding it to the elastic plugin though …
4. timeouts
Copy code
rdzv_configs={"timeout": 36000, "join_timeout": 36000},
I agree with setting default timeouts that we expect should work in a default case for most users. I personally feel an hour is quite a lot though - gpus might end up idling for a long time. For the join timeout I feel we should consider the scenario that some workers have a hot start (node is up and image is cached) while other workers have a cold start, i.e. node needs to be scaled up and image has to be pulled. How long would this take, maybe 15 minutes tops? I personally feel that if the join timeout needs to be longer than this, e.g. due to resource constraints/nodes not being able to scale up immediately, one should maybe consider gang scheduling so that all of them start at the same time (see “coscheduling” here). Do you happen to know whether the
timeout
is the same timeout as one specifies to
torch.distributed.init_process_group
? If yes, I personally find an hour also quite a lot because if the training deadlocked, it would be nice to bring it down more quickly than 1h. (For this use case, we set
NCCL_ASYNC_ERROR_HANDLING=1
as env var in the pod template. We could consider doing this in the plugin as well!)
b
great feedback! let’s discuss in contributor sync: https://github.com/flyteorg/flyte/issues/5339
🙏 1
1