a

    Andrew Achkar

    2 weeks ago
    Hello, as someone who is creating a new flyte task in flytekit, I am a little bit confused about the difference between
    PythonTask
    and
    PythonInstanceTask
    . With the latter, I am having a difficult time writing a unit test that passes where I am trying to verify if my workflow serialized correctly. More in 🧵
    Here a sample file, call it test_serializes.py and invoke it via pytest:
    import typing as ty
    from collections import OrderedDict
    from dataclasses import dataclass
    
    import pytest
    from dataclasses_json import dataclass_json
    from flytekit import PythonInstanceTask
    from flytekit import task
    from flytekit import workflow
    from flytekit.configuration import Image
    from flytekit.configuration import ImageConfig
    from flytekit.configuration import SerializationSettings
    from flytekit.core.base_task import PythonTask
    from flytekit.core.interface import Interface
    from flytekit.tools.translator import get_serializable
    from flytekit.types.file import FlyteFile
    
    
    @dataclass_json
    @dataclass
    class MyTaskConfig:
        string: str
    
    
    @dataclass_json
    @dataclass
    class RuntimeOutput:
        model: FlyteFile["tar.gz"]
    
    
    class MyTask(PythonTask[MyTaskConfig]):
        _TASK_TYPE = "unittest_my_task"
    
        def __init__(
            self,
            name: str,
            task_config: MyTaskConfig,
            inputs: ty.Dict[str, ty.Any],
            **kwargs,
        ):
            if not isinstance(task_config, MyTaskConfig):
                raise TypeError("Invalid task_config type")
    
            outputs = {"result": RuntimeOutput}
            super().__init__(
                name=name,
                task_config=task_config,
                task_type=self._TASK_TYPE,
                interface=Interface(inputs=inputs, outputs=outputs),
                **kwargs,
            )
    
        def execute(self, **kwargs) -> RuntimeOutput:
            print(kwargs)
            return RuntimeOutput(model=FlyteFile("<s3://bucket/model.tar.gz>"))
    
    
    def test_workflow_serializes():
        @task
        def get_data() -> ty.Tuple[FlyteFile, FlyteFile]:
            return FlyteFile("<s3://bucket/train.csv>"), FlyteFile(
                "<s3://s3_bucket/valid.csv>"
            )
    
        task_config = MyTaskConfig(
            string="some string",
        )
        training_task = MyTask(
            "unittest-mytask",
            task_config,
            inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
        )
    
        @workflow
        def test_ml_pipeline() -> RuntimeOutput:
            train, valid = get_data()
            training_output = training_task(train=train, valid=valid)
            return training_output
    
        default_img = Image(name="default", fqn="test", tag="tag")
        serialization_settings = SerializationSettings(
            project="project",
            domain="domain",
            version="version",
            env=None,
            image_config=ImageConfig(default_image=default_img, images=[default_img]),
        )
    
        wf_spec = get_serializable(OrderedDict(), serialization_settings, test_ml_pipeline)
        assert wf_spec is not None
        assert wf_spec.template is not None
    As-is with PythonTask there is no problem. If I inherit from PythonInstanceTask instead, I get this error:
    logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}")
            m = _importlib.import_module(self._instantiated_in)
            for k in dir(m):
                try:
                    if getattr(m, k) is self:
                        logger.debug(f"Found LHS for {self}, {k}")
                        self._lhs = k
                        return k
                except ValueError as err:
                    # Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
                    # ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
                    #   a.any() or a.all()
                    # Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
                    # continue looping through m.
                    logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
                    pass
        
            logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
    >       raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
    E       flytekit.exceptions.system.FlyteSystemException: Error looking for LHS in __main__
    
    venv/lib/python3.8/site-packages/flytekit/core/tracker.py:96: FlyteSystemException
    I was basing my approach on examples in the flytekit repo, but AFAICT it seems like all cases where get_serializable() is used in unit tests, the tasks being used do not inherit from PythonInstanceTask.
    g

    Geoff Salmon

    2 weeks ago
    What's the stacktrace of the error? I'm wondering what is calling the
    lhs
    method. Does moving
    training_task = MyTask(
            "unittest-mytask",
            task_config,
            inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
        )
    out of the test method to the top level avoid the error?
    a

    Andrew Achkar

    2 weeks ago
    Interesting, pulling this line out of the
    test_workflow_serializes()
    test case block does enable this to work with the
    PythonInstanceTask
    I guess that tracker logic expects to find tasks defined at the module scope not inside a function block
    Ketan (kumare3)

    Ketan (kumare3)

    2 weeks ago
    PythonInstanceTask is for tasks that are of the type
    x = SQLAlchemyTask(....)
    PythonTask is the base for all this PythonFunction is of type
    @task
    def foo(...):
    ...
    cc @Yee
    a

    Andrew Achkar

    2 weeks ago
    to above request for the stacktrace:
    _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:581: in get_serializable
        cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options)
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:226: in get_serializable_workflow
        upstream_node_models = [
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:227: in <listcomp>
        get_serializable(entity_mapping, settings, n, options)
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:584: in get_serializable
        cp_entity = get_serializable_node(entity_mapping, settings, entity, options)
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:412: in get_serializable_node
        task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options)
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:578: in get_serializable
        cp_entity = get_serializable_task(entity_mapping, settings, entity)
    venv/lib/python3.8/site-packages/flytekit/tools/translator.py:173: in get_serializable_task
        container = entity.get_container(settings)
    venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:166: in get_container
        args=self.get_command(settings=settings),
    venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:156: in get_command
        return self._get_command_fn(settings)
    venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:131: in get_default_command
        *self.task_resolver.loader_args(settings, self),
    venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:204: in loader_args
        _, m, t, _ = extract_task_module(task)
    venv/lib/python3.8/site-packages/flytekit/core/tracker.py:229: in extract_task_module
        name = f.lhs
    venv/lib/python3.8/site-packages/flytekit/core/tracker.py:69: in lhs
        return self.find_lhs()
    Yee

    Yee

    2 weeks ago
    yes @Andrew Achkar - that is true.
    the instance tasks need to be defined at the module level for the most part. (i’m not sure how well tested this is but you can also try creating them in the workflow function in actual user code, but not in other functions).
    the loader needs to be able to find them so they have to be at the top level
    for testing can you just declare it at the module level?
    also you can take a look at these tasks, which arguably are more complicated.
    can describe this in more detail if you’d like too
    a

    Andrew Achkar

    2 weeks ago
    for testing can you just declare it at the module level?
    Yes, that’s possible. Also, I can set
    ._lhs
    manually which lets the
    get_serializable
    complete. I’m not sure if that invalidates what I’m trying to test..
    also you can take a look at these tasks, which arguably are more complicated.
    Hmm, so those don’t ever call
    get_serializable
    , but do use a
    serialize_to_model
    that seems specific to
    PythonCustomizedContainerTask
    I’d to happy to hear about what unit / integration testing strategies you think work best as plugin and workflow authors. I really am just trying to ensure that a workflow defined using my new custom task will serialize properly and catch that before I ship my plugin to be used by some user code elsewhere. I have also written a test that uses subprocess to invoke
    pyflyte package
    which I think is a valid (though slightly heavier) way to verify this.
    Yee

    Yee

    2 weeks ago
    do you have 20mins to chat rn?