Hello, as someone who is creating a new flyte task...
# ask-the-community
a
Hello, as someone who is creating a new flyte task in flytekit, I am a little bit confused about the difference between
PythonTask
and
PythonInstanceTask
. With the latter, I am having a difficult time writing a unit test that passes where I am trying to verify if my workflow serialized correctly. More in 🧵
Here a sample file, call it test_serializes.py and invoke it via pytest:
Copy code
import typing as ty
from collections import OrderedDict
from dataclasses import dataclass

import pytest
from dataclasses_json import dataclass_json
from flytekit import PythonInstanceTask
from flytekit import task
from flytekit import workflow
from flytekit.configuration import Image
from flytekit.configuration import ImageConfig
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile


@dataclass_json
@dataclass
class MyTaskConfig:
    string: str


@dataclass_json
@dataclass
class RuntimeOutput:
    model: FlyteFile["tar.gz"]


class MyTask(PythonTask[MyTaskConfig]):
    _TASK_TYPE = "unittest_my_task"

    def __init__(
        self,
        name: str,
        task_config: MyTaskConfig,
        inputs: ty.Dict[str, ty.Any],
        **kwargs,
    ):
        if not isinstance(task_config, MyTaskConfig):
            raise TypeError("Invalid task_config type")

        outputs = {"result": RuntimeOutput}
        super().__init__(
            name=name,
            task_config=task_config,
            task_type=self._TASK_TYPE,
            interface=Interface(inputs=inputs, outputs=outputs),
            **kwargs,
        )

    def execute(self, **kwargs) -> RuntimeOutput:
        print(kwargs)
        return RuntimeOutput(model=FlyteFile("<s3://bucket/model.tar.gz>"))


def test_workflow_serializes():
    @task
    def get_data() -> ty.Tuple[FlyteFile, FlyteFile]:
        return FlyteFile("<s3://bucket/train.csv>"), FlyteFile(
            "<s3://s3_bucket/valid.csv>"
        )

    task_config = MyTaskConfig(
        string="some string",
    )
    training_task = MyTask(
        "unittest-mytask",
        task_config,
        inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
    )

    @workflow
    def test_ml_pipeline() -> RuntimeOutput:
        train, valid = get_data()
        training_output = training_task(train=train, valid=valid)
        return training_output

    default_img = Image(name="default", fqn="test", tag="tag")
    serialization_settings = SerializationSettings(
        project="project",
        domain="domain",
        version="version",
        env=None,
        image_config=ImageConfig(default_image=default_img, images=[default_img]),
    )

    wf_spec = get_serializable(OrderedDict(), serialization_settings, test_ml_pipeline)
    assert wf_spec is not None
    assert wf_spec.template is not None
As-is with PythonTask there is no problem. If I inherit from PythonInstanceTask instead, I get this error:
Copy code
logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}")
        m = _importlib.import_module(self._instantiated_in)
        for k in dir(m):
            try:
                if getattr(m, k) is self:
                    logger.debug(f"Found LHS for {self}, {k}")
                    self._lhs = k
                    return k
            except ValueError as err:
                # Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
                # ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
                #   a.any() or a.all()
                # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
                # continue looping through m.
                logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
                pass
    
        logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
>       raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
E       flytekit.exceptions.system.FlyteSystemException: Error looking for LHS in __main__

venv/lib/python3.8/site-packages/flytekit/core/tracker.py:96: FlyteSystemException
I was basing my approach on examples in the flytekit repo, but AFAICT it seems like all cases where get_serializable() is used in unit tests, the tasks being used do not inherit from PythonInstanceTask.
g
What's the stacktrace of the error? I'm wondering what is calling the
lhs
method. Does moving
Copy code
training_task = MyTask(
        "unittest-mytask",
        task_config,
        inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
    )
out of the test method to the top level avoid the error?
a
Interesting, pulling this line out of the
test_workflow_serializes()
test case block does enable this to work with the
PythonInstanceTask
I guess that tracker logic expects to find tasks defined at the module scope not inside a function block
k
PythonInstanceTask is for tasks that are of the type
Copy code
x = SQLAlchemyTask(....)
PythonTask is the base for all this PythonFunction is of type
Copy code
@task
def foo(...):
...
cc @Yee
a
to above request for the stacktrace:
Copy code
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:581: in get_serializable
    cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:226: in get_serializable_workflow
    upstream_node_models = [
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:227: in <listcomp>
    get_serializable(entity_mapping, settings, n, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:584: in get_serializable
    cp_entity = get_serializable_node(entity_mapping, settings, entity, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:412: in get_serializable_node
    task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:578: in get_serializable
    cp_entity = get_serializable_task(entity_mapping, settings, entity)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:173: in get_serializable_task
    container = entity.get_container(settings)
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:166: in get_container
    args=self.get_command(settings=settings),
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:156: in get_command
    return self._get_command_fn(settings)
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:131: in get_default_command
    *self.task_resolver.loader_args(settings, self),
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:204: in loader_args
    _, m, t, _ = extract_task_module(task)
venv/lib/python3.8/site-packages/flytekit/core/tracker.py:229: in extract_task_module
    name = f.lhs
venv/lib/python3.8/site-packages/flytekit/core/tracker.py:69: in lhs
    return self.find_lhs()
y
yes @Andrew Achkar - that is true.
the instance tasks need to be defined at the module level for the most part. (i’m not sure how well tested this is but you can also try creating them in the workflow function in actual user code, but not in other functions).
the loader needs to be able to find them so they have to be at the top level
for testing can you just declare it at the module level?
also you can take a look at these tasks, which arguably are more complicated.
can describe this in more detail if you’d like too
a
for testing can you just declare it at the module level?
Yes, that’s possible. Also, I can set
._lhs
manually which lets the
get_serializable
complete. I’m not sure if that invalidates what I’m trying to test..
also you can take a look at these tasks, which arguably are more complicated.
Hmm, so those don’t ever call
get_serializable
, but do use a
serialize_to_model
that seems specific to
PythonCustomizedContainerTask
I’d to happy to hear about what unit / integration testing strategies you think work best as plugin and workflow authors. I really am just trying to ensure that a workflow defined using my new custom task will serialize properly and catch that before I ship my plugin to be used by some user code elsewhere. I have also written a test that uses subprocess to invoke
pyflyte package
which I think is a valid (though slightly heavier) way to verify this.
y
do you have 20mins to chat rn?
160 Views