acoustic-carpenter-78188
01/06/2023, 10:52 PMfrom flytekit import task, workflow
from flytekitplugins.mlflow import mlflow_autolog
import mlflow
import tensorflow as tf
@task(disable_deck=False)
@mlflow_autolog(framework=mlflow.keras)
def train_model(epochs: int):
# Refer to <https://www.tensorflow.org/tutorials/keras/classification>
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (_, _) = fashion_mnist.load_data()
train_images = train_images / 255.0
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=epochs)
@workflow
def ml_pipeline(epochs: int):
train_model(epochs=epochs)
if __name__ == "__main__":
ml_pipeline()
Type
☐ Bug Fix
☑︎ Feature
☑︎ Plugin
Are all requirements met?
☑︎ Code completed
☑︎ Smoke tested
☑︎ Unit tests added
☐ Code documentation added
☐ Any pending items have an associated Issue
Complete description
image▾
image▾
image▾
image▾
image▾
image▾
acoustic-carpenter-78188
01/06/2023, 10:52 PM