Frank Shen
03/17/2023, 6:12 PMKevin Su
03/17/2023, 6:15 PMFrank Shen
03/17/2023, 6:15 PMdef __init__(
self,
bucket: str,
dataset_type: typing.Type[StructuredDataset],
config: Optional[XGBParams] = None,
**kwargs,
):
"""
A task that trains a XGBoost model.
Args:
label_col: name of the label or target column
dataset_type: Type of the dataset, supported type is FlyteSchema.
config: Configuration for the task. Contains the params used in the model training
Returns:
model_loc: The trained model's location in string presentation.
evaluation_result: The evaluation result against the validation dataset.
"""
self._config = config
self._dataset_type = dataset_type
self._bucket = bucket
inputs = {
self._TEAM: str,
self._PROJECT: str,
self._VERSION: str,
self._TRAIN_ARG: dataset_type,
self._VALIDATION_ARG: dataset_type,
self._PARAMS_ARG: XGBParams,
self._LABEL_COL: str,
self._USE_RAY: bool,
self._NUM_CPUS: int
}
outputs = {
self._OUTPUT_MODEL: str,
self._OUTPUT_EVAL_RESULT: Dict[str, Dict[str, List[float]]],
}
super(XGBTrainTask, self).__init__(
name=f'{self._TASK_TYPE}',
task_type=self._TASK_TYPE,
task_config=config,
interface=Interface(inputs=inputs, outputs=outputs),
requests = Resources(cpu='4'),
**kwargs,
)
Kevin Su
03/17/2023, 6:24 PMFrank Shen
03/17/2023, 6:37 PMKevin Su
03/17/2023, 8:03 PMdef subwf():
t1()
t2()
def wf():
subwf().with_override()
Frank Shen
03/17/2023, 8:55 PMKevin Su
03/20/2023, 11:29 PMSlackbot
03/20/2023, 11:29 PMFrank Shen
03/23/2023, 7:32 PMKevin Su
03/23/2023, 7:34 PMFrank Shen
03/23/2023, 7:35 PM