jolly-florist-66728
11/06/2024, 4:46 PMtask
that preprocesses data and returns a list
of dict
using json.dumps
as FlyteFile
* Then we want to use map_task
to process each dict in list that's in the FlyteFile
@task
def ingest_data() -> pd.DataFrame:
return pd.DataFrame(...some Data)
@task
def preprocess_data(data: pd.DataFrame) -> FlyteFile:
rows = []
for _ , row in data.iterrows():
... some processing
rows.append(row)
out_path = Path(flytekit.current_context().working_directory) / "preprocessed.json"
with out_path.open(mode="w") as output_file:
pickle.dumps(input_rows, output_file)
return FlyteFile(path=str(out_path))
@task
def generate_ai_response(row: ?) -> ?:
ai_response = do_some_ai_stuff()
return ai_response
@workflow
def ai_workflow():
data = ingest_data()
preprocessed_data = preprocess_data(data)
predicted_responses: ? = map_task(generate_ai_response)(
row=preprocessed_data
)
We don't understand how to map_task
over a FlyteFile
or if that's even possible. Can we deserialise the preprocessed_data
inside the workflow
prior to map_task
? Or how else would we do this?
Thank you!
# Answer
To use map_task
over a FlyteFile
, you need to deserialize the file content into a list of dictionaries before passing it to the map_task
. You can achieve this by reading the FlyteFile
in a task and returning the deserialized list. Here's a step-by-step approach:
1. Deserialize FlyteFile: Create a task to read and deserialize the FlyteFile
into a list of dictionaries.
2. Map Task: Use map_task
on the deserialized list.
Here's an example:
from flytekit import task, workflow, map_task
from flytekit.types.file import FlyteFile
import pandas as pd
import json
from pathlib import Path
@task
def ingest_data() -> pd.DataFrame:
return pd.DataFrame(...some Data)
@task
def preprocess_data(data: pd.DataFrame) -> FlyteFile:
rows = []
for _, row in data.iterrows():
... some processing
rows.append(row.to_dict())
out_path = Path(flytekit.current_context().working_directory) / "preprocessed.json"
with out_path.open(mode="w") as output_file:
json.dump(rows, output_file)
return FlyteFile(path=str(out_path))
@task
def deserialize_file(file: FlyteFile) -> list:
with file.open("r") as f:
return json.load(f)
@task
def generate_ai_response(row: dict) -> str:
ai_response = do_some_ai_stuff(row)
return ai_response
@workflow
def ai_workflow():
data = ingest_data()
preprocessed_data = preprocess_data(data)
deserialized_data = deserialize_file(preprocessed_data)
predicted_responses = map_task(generate_ai_response)(row=deserialized_data)
jolly-florist-66728
11/06/2024, 4:46 PM>>Would you like more details on any specific part? Sources: - https://github.com/flyteorg/flyte/issues/5684 - https://github.com/flyteorg/flyte/issues/5473 Slack Link https://.slack.com/archives/CP2HDHKE1/p1730911589.393329 👍 0 button 👎 0 button 💡 Hint: Mention @runllm in the thread for followups.