fancy-yak-23698
09/12/2024, 4:53 PMtensorflow=2.16.2
. This includes some breaking changes with how keras
works. It's no longer bundled as tensorflow.keras
and is instead it's own package. It also has some breaking changes about how models are saved and loaded. Details in this migration guide.
I'm getting an error in a flyte task that returns a trained model. I think this is caused by the TensorFlowModelTransformer using the keras 2 style. To make this work with keras 3, I think the fix would be something like this:
- local_path = ctx.file_access.get_random_local_path()
+ local_path = ctx.file_access.get_random_local_path() + ".keras"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
# save model in SavedModel format
- tf.keras.models.save_model(python_val, local_path)
+ keras.saving.save_model(python_val, local_path)
This is untested code BTW. And it would require similar changes in loading the model. It might require a new transformer for Keras 3 specifically.
I didn't see anything related to this in the github issues. I'm wondering if others have had this problem as well.freezing-airport-6809
freezing-airport-6809
high-accountant-32689
09/18/2024, 4:13 PM