<#1631 Improve task decorator type hints with over...
# flyte-github
a
#1631 Improve task decorator type hints with overload Pull request opened by ringohoffman TL;DR Functions decorated by
flytekit.task
are not hinted as
PythonFunctionTask
as a result of its return typing, so there are spurious type errors when attempting to register such tasks on
FlyteRemote
objects. Type ☑︎ Bug Fix ☐ Feature ☐ Plugin Are all requirements met? ☑︎ Code completed ☐ Smoke tested ☐ Unit tests added ☐ Code documentation added ☐ Any pending items have an associated Issue Complete description Here you can see an explanation of the problem and the fix:
Copy code
import flytekit
import flytekit.configuration as flyte_config
import flytekit.remote


@flytekit.task
def my_task1() -> int:
    return 0


@flytekit.task()
def my_task2() -> int:
    return 0


@flytekit.task(task_config=5)
def my_task3() -> int:
    return 0


my_task4 = flytekit.task(lambda: 0)

remote = flytekit.remote.FlyteRemote(...)  # type: ignore
serialization_settings = flyte_config.SerializationSettings(...)  # type: ignore

# before
reveal_type(my_task1)  # Type of "my_task1" is "((...) -> Any) | PythonFunctionTask"
remote.register_task(my_task1, serialization_settings)  # error: Argument of type "((...) -> Any) | PythonFunctionTask" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"

reveal_type(my_task2)  # Type of "my_task2" is "() -> int"
remote.register_task(my_task2, serialization_settings)  # error: Argument of type "() -> int" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"

reveal_type(my_task3)  # Type of "my_task3" is "() -> int"
remote.register_task(my_task3, serialization_settings)  # error: Argument of type "() -> int" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"

reveal_type(my_task4)  # Type of "my_task4" is "((...) -> Any) | PythonFunctionTask"
remote.register_task(my_task4, serialization_settings)  # Argument of type "((...) -> Any) | PythonFunctionTask" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"


# after
reveal_type(my_task1)  # Type of "my_task1" is "PythonFunctionTask"
remote.register_task(my_task1, serialization_settings)  # OK

reveal_type(my_task2)  # Type of "my_task2" is "PythonFunctionTask[T@task]"
remote.register_task(my_task2, serialization_settings)  # OK

reveal_type(my_task3)  # Type of "my_task3" is "PythonFunctionTask[int]"
remote.register_task(my_task3, serialization_settings)  # OK

reveal_type(my_task4)  # Type of "my_task4" is "PythonFunctionTask"
remote.register_task(my_task4, serialization_settings)  # OK
I accomplished this fix by using
typing.overload
to hint a different return type based on whether
_task_function
is passed to
flytekit.task
or not. If nothing is passed as in the case of
my_task2
or
my_task3
, we hint the return type as
Callable[[Callable[..., Any]], PythonFunctionTask[T]]
... that is a callable that accepts a single argument that is any function and returns a
PythonFunctionTask
. If something is passed for
_task_function
as in
my_task1
or
my_task4
, we hint the return type as
PythonFunctionTask[T]
. Tracking Issue N/A Follow-up issue N/A flyteorg/flytekit Codecov: 71.20% (-0.02%) compared to e44b802 Codecov: 11.11% of diff hit (target 71.22%) 28 other checks have passed 28/30 successful checks