I am experimenting with upgrading my workflow to u...
# flyte-support
f
I am experimenting with upgrading my workflow to use
tensorflow=2.16.2
. This includes some breaking changes with how
keras
works. It's no longer bundled as
tensorflow.keras
and is instead it's own package. It also has some breaking changes about how models are saved and loaded. Details in this migration guide. I'm getting an error in a flyte task that returns a trained model. I think this is caused by the TensorFlowModelTransformer using the keras 2 style. To make this work with keras 3, I think the fix would be something like this:
Copy code
-        local_path = ctx.file_access.get_random_local_path()
+        local_path = ctx.file_access.get_random_local_path() + ".keras"
        pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

        # save model in SavedModel format
-        tf.keras.models.save_model(python_val, local_path)
+        keras.saving.save_model(python_val, local_path)
This is untested code BTW. And it would require similar changes in loading the model. It might require a new transformer for Keras 3 specifically. I didn't see anything related to this in the github issues. I'm wondering if others have had this problem as well.
f
we would love to take this contribution. but @high-accountant-32689 may know why the upgrade is not done, maybe problem with protobuf version?
or maybe we just did not get to it
h
@fancy-yak-23698, please feel free to contribute. We should be able to support keras 2 and 3 with very little work. Can you open a gh issue to track this? I can help guide the flytekit changes needed besides the one you proposed above.