<#1576 Feat: Add torchrun plugin> Pull request ope...
# flyte-github
a
#1576 Feat: Add torchrun plugin Pull request opened by fg91 on <!date^1680530978^{date_short}|2023-04-03T14:09:38Z> TL;DR Work in progress This plugin allows running torch elastic (torchrun) distributed training with Flyte.
Copy code
from dataclasses import dataclass

import torch
from dataclasses_json import dataclass_json
from flytekit import dynamic, task, workflow
from flytekitplugins.kfpytorch import PyTorch

from .torch_elastic_task import Elastic


@dataclass_json
@dataclass
class Config:
    lr: float = 1e-5
    bs: int = 64
    name: str = "foo"


@task
def init_model() -> torch.nn.Module:
    model = torch.nn.Linear(11, 22)

    return model


"""
This doesn't start a kubelfow pytorch job yet but a single python task Pod which then
runs a local worker group in sub-processes.
The changes in the flyteidl protobuf definitions, the flytekit python api, and the
flytepropeller (operator) which we need to actually make this distributed on multiple nodes
are easy (see RFC document linked in PR description).
"""
@task(
    task_config=Elastic(
        min_replicas=1,
        max_replicas=1,
        start_method="spawn",
    )
)
def train(config: Config, model: torch.nn.Module) -> tuple[str, Config, torch.nn.Module]:
    import os

    import torch

    local_rank = os.environ["LOCAL_RANK"]

    out_model = torch.nn.Linear(1000, int(local_rank) * 2000 + 1)
    print(f"Training with config {config}")
    config.name = "modified"
    return f"result from local rank {local_rank}", config, out_model


@workflow
def wf(config: Config=Config()) -> tuple[str, Config, torch.nn.Module]:
    model = init_model()
    return train(config=config, model=model)


if __name__ == "__main__":
    print(wf(config=Config()))
Type ☐ Bug Fix ☐ Feature ☐ Plugin Are all requirements met? ☐ Code completed ☐ Smoke tested ☐ Unit tests added ☐ Code documentation added ☐ Any pending items have an associated Issue Complete description How did you fix the bug, make the feature etc. Link to any design docs etc Tracking Issue https://github.com/flyteorg/flyte/issues/ Follow-up issue NA OR https://github.com/flyteorg/flyte/issues/ flyteorg/flytekit All checks have passed 30/30 successful checks