worried-airplane-87065
01/09/2025, 7:03 PMfrom dataclasses import dataclass
import torch
from flytekit import Resources, task, workflow
@dataclass
class RuntimeGPU:
available: bool
device_count: int
# Need to set limits on task.
# <https://github.com/flyteorg/flytekit/pull/2151>
@task(
requests=Resources(cpu="4", mem="32Gi", gpu="2"),
limits=Resources(cpu="4", mem="32Gi", gpu="2"),
)
def cuda_available() -> RuntimeGPU:
return RuntimeGPU(
available=torch.cuda.is_available(), device_count=torch.cuda.device_count()
)
@task
def compute_resources(num_gpus: int) -> Resources:
return Resources(cpu="4", mem="32Gi", gpu=str(num_gpus))
@workflow
def runtime_gpu(num_gpus: int) -> RuntimeGPU:
resources = compute_resources(num_gpus)
gpu_info = cuda_available().with_overrides(
requests=resources,
limits=resources,
)
return gpu_info
freezing-airport-6809
freezing-airport-6809
@dynamic
worried-airplane-87065
01/09/2025, 7:49 PM# change to dynamic
@dynamic
def runtime_gpu(num_gpus: int) -> RuntimeGPU:
resources = compute_resources(num_gpus)
gpu_info = cuda_available().with_overrides(
requests=resources,
limits=resources,
)
return gpu_info
worried-airplane-87065
01/09/2025, 7:50 PMworried-airplane-87065
01/09/2025, 7:53 PMresources:{2 items
limits:[0 items
]
requests:[0 items
]
}
worried-airplane-87065
01/09/2025, 8:02 PMfrom dataclasses import dataclass
import torch
from flytekit import Resources, dynamic, task, workflow
@dataclass
class RuntimeGPU:
available: bool
device_count: int
# Need to set limits on task.
# <https://github.com/flyteorg/flytekit/pull/2151>
@task(
requests=Resources(cpu="4", mem="32Gi", gpu="2"),
limits=Resources(cpu="4", mem="32Gi", gpu="2"),
)
def cuda_available() -> RuntimeGPU:
return RuntimeGPU(
available=torch.cuda.is_available(), device_count=torch.cuda.device_count()
)
@dynamic
def dynamic_gpu_wf(num_gpus: str) -> RuntimeGPU:
gpu_info = cuda_available().with_overrides(
requests=Resources(cpu="4", mem="32Gi", gpu=num_gpus),
limits=Resources(cpu="4", mem="32Gi", gpu=num_gpus),
)
return gpu_info
@workflow
def runtime_gpu(num_gpus: str) -> RuntimeGPU:
return dynamic_gpu_wf(num_gpus=num_gpus)