https://flyte.org logo
#ask-the-community
Title
# ask-the-community
k

Khalil Almazaideh

05/19/2023, 11:28 AM
Hello all I'm trying to return machine learning models to find the hyperparameters from a task, the functions in that task is build models, the models then should be passed to another task which will find the hyper parameters, the issue here is that the models are passed as Tuple, so the the task which will find the hyperparameter should expect a tuple of modules, yet it outputs an :Transformer for type <class 'tuple'> is restricted currently, The peace of code as follows
Copy code
@task
def build_models() -> Tuple[LinearRegression, RandomForestRegressor, Any]:
    lr = LinearRegression()
    rf = RandomForestRegressor()
    ann = HyperModel1()
    return lr,rf,ann
@task
def pipline_(models: Tuple[LinearRegression, RandomForestRegressor, Any], (.....etc)) -> List[Dict]:
       def itirate(models):
           for i in models:
              # Do STUFF
Copy code
def wf() -> pd.DataFrame:
   ....
....
    models = build_models()
    s_summary = HyperSearch(models=models, ....)
    return s_summary
e

Evan Sadler

05/19/2023, 12:11 PM
Hi Khalil! Tuple is restricted to only be an output, so you can’t pass a tuple from build models into hyper search. You can pass in a List to a map task and it will run in parallel! https://docs.flyte.org/projects/cookbook/en/latest/auto/core/control_flow/map_task.html Note that outputs from flyte tasks are serialized and saved in Flytes metadata store. Initializing and returning untrained models will pickle them, save, and reload. If you can use a
List[str]
where you use importlib to import the ‘`sklearn.linear_model.LinearRegression.` or you could have a dataclasses with more fields then the training task would have all the info it needs to run. Only would pickle the trained model!
Copy code
Inputs(model=str, hyperparams=dict, …)
k

Khalil Almazaideh

05/20/2023, 5:36 PM
Really appreciated, but can you provide a simple code implements what you have said, because im not sure if i got it right
e

Evan Sadler

05/20/2023, 6:07 PM
Something like this?
Copy code
import importlib
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple

import pandas as pd
import sklearn.datasets
from dataclasses_json import dataclass_json
from flytekit import map_task, task, workflow


@dataclass_json
@dataclass
class ModelConfig:
    model: str  # "sklearn.linear_model.LinearRegression"
    hyperparams: dict

    def load_model(self):
        module_name, class_name = self.model.rsplit(
            ".", 1
        )  # split the module and class
        module = importlib.import_module(module_name)  # import the module
        model_class = getattr(module, class_name)  # get the class from the module
        return model_class(**self.hyperparams)


@task
def build_configs() -> List[ModelConfig]:
    lr_config = ModelConfig(
        model="sklearn.linear_model.LinearRegression", hyperparams={}
    )

    rf_config = ModelConfig(
        model="sklearn.ensemble.RandomForestRegressor",
        hyperparams={"max_features": "sqrt"},
    )
    return [lr_config, rf_config]


@task
def find_hyperparms(config: ModelConfig, X: pd.DataFrame, y: pd.DataFrame):
    model = config.load_model()
    model.fit(X, y.target)

    print(model.predict(X).mean())


@task
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
    return X, pd.DataFrame(y)


@workflow
def wf():
    X, y = load_data()
    configs = build_configs()
    func = partial(find_hyperparms, X=X, y=y)

    # runs in parallel
    map_task(func)(config=configs)


wf()
It seems like
dict
type transformer is broken for int types. It gets recast as a float...I will ping the OSS team and create an issue.
So it isn't broken, but there is a limitation in google's protobuf library that doesn't distinguish between int and float. This is kind of hacky...but works.
Copy code
import importlib
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from functools import partial

from flytekit import task, workflow, map_task

from dataclasses import dataclass
from dataclasses_json import dataclass_json
import importlib
import sklearn.datasets
import numpy as np
from typing import Optional, List, Tuple

@dataclass_json
@dataclass
class ModelConfig:
    model: str  # "sklearn.linear_model.LinearRegression"
    hyperparams: str

    def load_model(self):
        module_name, class_name = self.model.rsplit('.', 1)  # split the module and class
        module = importlib.import_module(module_name)  # import the module
        model_class = getattr(module, class_name)  # get the class from the module
        return model_class(**json.loads(self.hyperparams))


@task
def build_configs() -> List[ModelConfig]:
    lr_config = ModelConfig(
        model="sklearn.linear_model.LinearRegression",
        hyperparams=json.dumps({}))

    rf_config = ModelConfig(
        model="sklearn.ensemble.RandomForestRegressor",
        hyperparams=json.dumps({"max_features": "sqrt", "n_estimators": 10}))
    return [lr_config, rf_config]

@task
def find_hyperparms(config: ModelConfig, X: pd.DataFrame, y: pd.DataFrame):
    model = config.load_model()
    model.fit(X, y.target)

    print(model.predict(X).mean())


@task
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
    X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
    return X, pd.DataFrame(y)

@workflow
def wf():
    X, y = load_data()
    configs = build_configs()
    func = partial(find_hyperparms, X=X, y=y)

    # runs in parallel
    map_task(func)(config=configs)

wf()
k

Khalil Almazaideh

05/22/2023, 9:11 PM
thank you, but i have question, did you use the partial function to provide more than one parameter to find_hyperparms function which at the end will be mapped ?
e

Evan Sadler

05/22/2023, 9:12 PM
Yeah you need to do that. It’s a limitation of map tasks!
k

Khalil Almazaideh

05/22/2023, 9:14 PM
Really appreciate your help man
e

Evan Sadler

05/22/2023, 11:43 PM
Anytime!!!
72 Views