jolly-easter-30455
05/19/2023, 11:28 AM@task
def build_models() -> Tuple[LinearRegression, RandomForestRegressor, Any]:
lr = LinearRegression()
rf = RandomForestRegressor()
ann = HyperModel1()
return lr,rf,ann
@task
def pipline_(models: Tuple[LinearRegression, RandomForestRegressor, Any], (.....etc)) -> List[Dict]:
def itirate(models):
for i in models:
# Do STUFF
def wf() -> pd.DataFrame:
....
....
models = build_models()
s_summary = HyperSearch(models=models, ....)
return s_summary
billowy-winter-86593
05/19/2023, 12:11 PMList[str]
where you use importlib to import the ‘`sklearn.linear_model.LinearRegression.`
or you could have a dataclasses with more fields then the training task would have all the info it needs to run. Only would pickle the trained model!
Inputs(model=str, hyperparams=dict, …)
jolly-easter-30455
05/20/2023, 5:36 PMbillowy-winter-86593
05/20/2023, 6:07 PMimport importlib
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import pandas as pd
import sklearn.datasets
from dataclasses_json import dataclass_json
from flytekit import map_task, task, workflow
@dataclass_json
@dataclass
class ModelConfig:
model: str # "sklearn.linear_model.LinearRegression"
hyperparams: dict
def load_model(self):
module_name, class_name = self.model.rsplit(
".", 1
) # split the module and class
module = importlib.import_module(module_name) # import the module
model_class = getattr(module, class_name) # get the class from the module
return model_class(**self.hyperparams)
@task
def build_configs() -> List[ModelConfig]:
lr_config = ModelConfig(
model="sklearn.linear_model.LinearRegression", hyperparams={}
)
rf_config = ModelConfig(
model="sklearn.ensemble.RandomForestRegressor",
hyperparams={"max_features": "sqrt"},
)
return [lr_config, rf_config]
@task
def find_hyperparms(config: ModelConfig, X: pd.DataFrame, y: pd.DataFrame):
model = config.load_model()
model.fit(X, y.target)
print(model.predict(X).mean())
@task
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
return X, pd.DataFrame(y)
@workflow
def wf():
X, y = load_data()
configs = build_configs()
func = partial(find_hyperparms, X=X, y=y)
# runs in parallel
map_task(func)(config=configs)
wf()
billowy-winter-86593
05/20/2023, 6:07 PMdict
type transformer is broken for int types. It gets recast as a float...I will ping the OSS team and create an issue.billowy-winter-86593
05/20/2023, 6:37 PMimport importlib
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from functools import partial
from flytekit import task, workflow, map_task
from dataclasses import dataclass
from dataclasses_json import dataclass_json
import importlib
import sklearn.datasets
import numpy as np
from typing import Optional, List, Tuple
@dataclass_json
@dataclass
class ModelConfig:
model: str # "sklearn.linear_model.LinearRegression"
hyperparams: str
def load_model(self):
module_name, class_name = self.model.rsplit('.', 1) # split the module and class
module = importlib.import_module(module_name) # import the module
model_class = getattr(module, class_name) # get the class from the module
return model_class(**json.loads(self.hyperparams))
@task
def build_configs() -> List[ModelConfig]:
lr_config = ModelConfig(
model="sklearn.linear_model.LinearRegression",
hyperparams=json.dumps({}))
rf_config = ModelConfig(
model="sklearn.ensemble.RandomForestRegressor",
hyperparams=json.dumps({"max_features": "sqrt", "n_estimators": 10}))
return [lr_config, rf_config]
@task
def find_hyperparms(config: ModelConfig, X: pd.DataFrame, y: pd.DataFrame):
model = config.load_model()
model.fit(X, y.target)
print(model.predict(X).mean())
@task
def load_data() -> Tuple[pd.DataFrame, pd.DataFrame]:
X, y = sklearn.datasets.load_diabetes(return_X_y=True, as_frame=True)
return X, pd.DataFrame(y)
@workflow
def wf():
X, y = load_data()
configs = build_configs()
func = partial(find_hyperparms, X=X, y=y)
# runs in parallel
map_task(func)(config=configs)
wf()
jolly-easter-30455
05/22/2023, 9:11 PMbillowy-winter-86593
05/22/2023, 9:12 PMjolly-easter-30455
05/22/2023, 9:14 PMbillowy-winter-86593
05/22/2023, 11:43 PM