Andrew Achkar
09/08/2022, 5:11 PMPythonTask
and PythonInstanceTask
. With the latter, I am having a difficult time writing a unit test that passes where I am trying to verify if my workflow serialized correctly. More in 🧵import typing as ty
from collections import OrderedDict
from dataclasses import dataclass
import pytest
from dataclasses_json import dataclass_json
from flytekit import PythonInstanceTask
from flytekit import task
from flytekit import workflow
from flytekit.configuration import Image
from flytekit.configuration import ImageConfig
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
@dataclass_json
@dataclass
class MyTaskConfig:
string: str
@dataclass_json
@dataclass
class RuntimeOutput:
model: FlyteFile["tar.gz"]
class MyTask(PythonTask[MyTaskConfig]):
_TASK_TYPE = "unittest_my_task"
def __init__(
self,
name: str,
task_config: MyTaskConfig,
inputs: ty.Dict[str, ty.Any],
**kwargs,
):
if not isinstance(task_config, MyTaskConfig):
raise TypeError("Invalid task_config type")
outputs = {"result": RuntimeOutput}
super().__init__(
name=name,
task_config=task_config,
task_type=self._TASK_TYPE,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)
def execute(self, **kwargs) -> RuntimeOutput:
print(kwargs)
return RuntimeOutput(model=FlyteFile("<s3://bucket/model.tar.gz>"))
def test_workflow_serializes():
@task
def get_data() -> ty.Tuple[FlyteFile, FlyteFile]:
return FlyteFile("<s3://bucket/train.csv>"), FlyteFile(
"<s3://s3_bucket/valid.csv>"
)
task_config = MyTaskConfig(
string="some string",
)
training_task = MyTask(
"unittest-mytask",
task_config,
inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
)
@workflow
def test_ml_pipeline() -> RuntimeOutput:
train, valid = get_data()
training_output = training_task(train=train, valid=valid)
return training_output
default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, test_ml_pipeline)
assert wf_spec is not None
assert wf_spec.template is not None
As-is with PythonTask there is no problem. If I inherit from PythonInstanceTask instead, I get this error:
logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}")
m = _importlib.import_module(self._instantiated_in)
for k in dir(m):
try:
if getattr(m, k) is self:
logger.debug(f"Found LHS for {self}, {k}")
self._lhs = k
return k
except ValueError as err:
# Empty pandas dataframes behave weirdly here such that calling `m.df` raises:
# ValueError: The truth value of a {type(self).__name__} is ambiguous. Use a.empty, a.bool(), a.item(),
# a.any() or a.all()
# Since dataframes aren't registrable entities to begin with we swallow any errors they raise and
# continue looping through m.
logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err))
pass
logger.error(f"Could not find LHS for {self} in {self._instantiated_in}")
> raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}")
E flytekit.exceptions.system.FlyteSystemException: Error looking for LHS in __main__
venv/lib/python3.8/site-packages/flytekit/core/tracker.py:96: FlyteSystemException
Geoff Salmon
09/08/2022, 6:23 PMlhs
method.
Does moving
training_task = MyTask(
"unittest-mytask",
task_config,
inputs={"train": FlyteFile["csv"], "valid": FlyteFile["csv"]},
)
out of the test method to the top level avoid the error?Andrew Achkar
09/08/2022, 6:44 PMtest_workflow_serializes()
test case block does enable this to work with the PythonInstanceTask
Ketan (kumare3)
x = SQLAlchemyTask(....)
PythonTask is the base for all this
PythonFunction is of type
@task
def foo(...):
...
Andrew Achkar
09/08/2022, 6:59 PM_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:581: in get_serializable
cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:226: in get_serializable_workflow
upstream_node_models = [
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:227: in <listcomp>
get_serializable(entity_mapping, settings, n, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:584: in get_serializable
cp_entity = get_serializable_node(entity_mapping, settings, entity, options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:412: in get_serializable_node
task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:578: in get_serializable
cp_entity = get_serializable_task(entity_mapping, settings, entity)
venv/lib/python3.8/site-packages/flytekit/tools/translator.py:173: in get_serializable_task
container = entity.get_container(settings)
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:166: in get_container
args=self.get_command(settings=settings),
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:156: in get_command
return self._get_command_fn(settings)
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:131: in get_default_command
*self.task_resolver.loader_args(settings, self),
venv/lib/python3.8/site-packages/flytekit/core/python_auto_container.py:204: in loader_args
_, m, t, _ = extract_task_module(task)
venv/lib/python3.8/site-packages/flytekit/core/tracker.py:229: in extract_task_module
name = f.lhs
venv/lib/python3.8/site-packages/flytekit/core/tracker.py:69: in lhs
return self.find_lhs()
Yee
Andrew Achkar
09/08/2022, 9:21 PMfor testing can you just declare it at the module level?Yes, that’s possible. Also, I can set
._lhs
manually which lets the get_serializable
complete. I’m not sure if that invalidates what I’m trying to test..
also you can take a look at these tasks, which arguably are more complicated.Hmm, so those don’t ever call
get_serializable
, but do use a serialize_to_model
that seems specific to PythonCustomizedContainerTask
I’d to happy to hear about what unit / integration testing strategies you think work best as plugin and workflow authors. I really am just trying to ensure that a workflow defined using my new custom task will serialize properly and catch that before I ship my plugin to be used by some user code elsewhere. I have also written a test that uses subprocess to invoke pyflyte package
which I think is a valid (though slightly heavier) way to verify this.Yee