Am I correct in understanding that Flyte does not ...
# ask-the-community
e
Am I correct in understanding that Flyte does not support custom encoders and decoders for dataclasses_json.
Copy code
class MyDataClass:
     foo: Foo = field(metadata=dict(encoder=encoder, decoder=decoder))
Is the only alternative to define and register a custom
TypeTransformer
? What's the logic behind requiring something comparatively cumbersome?
I ask because I find myself writing the following pattern a lot, which the above is essentially sugar for.
Copy code
class MyDataClass:
   foo_str: str

   def __post_init__(self):
       self.foo = decoder(self.foo_str)

   @classmethod
   def from_foo(cls, foo: Foo):
       return cls(foo_str=encoder(foo))
s
The field type, in this case
Foo
, has to be a valid Flyte type. If that isn't the case, I believe the data gets pickled. @Kevin Su, could you please confirm?
k
no, we can’t use pickle, because it’s not dataclass or python primitive type. I think we can support custom encoder/decoder here. just need to update the dataclass transformer here. mind create a ticket, and share your encoder/decoder code
[flyte-core]
e
@Kevin Su Done https://github.com/flyteorg/flyte/issues/3359 I think this is super straightforward
AFAIK the default encoder and decoder (without additional metadata), assumes a serialization to string. If I understand correctly, you can also specify
mm_field
to specify an alternate intermediate marshmallow type in the schema, which then needs to match up with the return/call signature of your encoder/decoder respectively.
But I haven't tried that
If you're asking what I personally use this for, I like to use it for serializing type and function objects. So for example:
Copy code
def import_from_str(obj_qual_name: str):
    module_str, obj_qual_name = obj_qual_name.rsplit(sep='.', maxsplit=1)
    module = importlib.import_module(module_str)
    return getattr(module, obj_qual_name)


def full_name_from_obj(obj) -> str:
    return f'{inspect.getmodule(obj).__name__}.{obj.__qualname__}'


TYPE_SERIALIZER = config(
    encoder=full_name_from_obj,
    decoder=import_from_str
)


@dataclass_json
@dataclass
class ModelSpec:
    model: Type[PreTrainedModel]
    preprocessor_factory: Type[PreprocessorFactory] = field(metadata=TYPE_SERIALIZER)
    training_args: Dict = field(default_factory=dict)
    model_weights: PyTorchCheckpoint
    model_args: Dict = field(default_factory=dict)
    weight_preprocessing: Callable = field(metadata=TYPE_SERIALIZER, default=noop)
k
Thanks for the clarification, will take a look. contributions are welcome too
e
Yeah, I may try to contribute if I can find time, since someone snagged the other feature I was looking for 🙂
k
Thanks! let me know anything I can help
151 Views