acoustic-carpenter-78188
05/11/2023, 7:32 AMflytekit.task
are not hinted as PythonFunctionTask
as a result of its return typing, so there are spurious type errors when attempting to register such tasks on FlyteRemote
objects.
Type
☑︎ Bug Fix
☐ Feature
☐ Plugin
Are all requirements met?
☑︎ Code completed
☐ Smoke tested
☐ Unit tests added
☐ Code documentation added
☐ Any pending items have an associated Issue
Complete description
Here you can see an explanation of the problem and the fix:
import flytekit
import flytekit.configuration as flyte_config
import flytekit.remote
@flytekit.task
def my_task1() -> int:
return 0
@flytekit.task()
def my_task2() -> int:
return 0
@flytekit.task(task_config=5)
def my_task3() -> int:
return 0
my_task4 = flytekit.task(lambda: 0)
remote = flytekit.remote.FlyteRemote(...) # type: ignore
serialization_settings = flyte_config.SerializationSettings(...) # type: ignore
# before
reveal_type(my_task1) # Type of "my_task1" is "((...) -> Any) | PythonFunctionTask"
remote.register_task(my_task1, serialization_settings) # error: Argument of type "((...) -> Any) | PythonFunctionTask" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"
reveal_type(my_task2) # Type of "my_task2" is "() -> int"
remote.register_task(my_task2, serialization_settings) # error: Argument of type "() -> int" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"
reveal_type(my_task3) # Type of "my_task3" is "() -> int"
remote.register_task(my_task3, serialization_settings) # error: Argument of type "() -> int" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"
reveal_type(my_task4) # Type of "my_task4" is "((...) -> Any) | PythonFunctionTask"
remote.register_task(my_task4, serialization_settings) # Argument of type "((...) -> Any) | PythonFunctionTask" cannot be assigned to parameter "entity" of type "PythonTask" in function "register_task"
# after
reveal_type(my_task1) # Type of "my_task1" is "PythonFunctionTask"
remote.register_task(my_task1, serialization_settings) # OK
reveal_type(my_task2) # Type of "my_task2" is "PythonFunctionTask[T@task]"
remote.register_task(my_task2, serialization_settings) # OK
reveal_type(my_task3) # Type of "my_task3" is "PythonFunctionTask[int]"
remote.register_task(my_task3, serialization_settings) # OK
reveal_type(my_task4) # Type of "my_task4" is "PythonFunctionTask"
remote.register_task(my_task4, serialization_settings) # OK
I accomplished this fix by using typing.overload
to hint a different return type based on whether _task_function
is passed to flytekit.task
or not.
If nothing is passed as in the case of my_task2
or my_task3
, we hint the return type as Callable[[Callable[..., Any]], PythonFunctionTask[T]]
... that is a callable that accepts a single argument that is any function and returns a PythonFunctionTask
.
If something is passed for _task_function
as in my_task1
or my_task4
, we hint the return type as PythonFunctionTask[T]
.
Tracking Issue
N/A
Follow-up issue
N/A
flyteorg/flytekit
Codecov: 71.20% (-0.02%) compared to e44b802
Codecov: 11.11% of diff hit (target 71.22%)
✅ 28 other checks have passed
28/30 successful checksacoustic-carpenter-78188
05/11/2023, 9:28 PM