hi , i need help to write the databricks workflow, while the following code is working fine
from dataclasses import dataclass
from typing import Dict, Optional, Type
from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.models import task as _task_model
from flytekit.types.schema import FlyteSchema
_SERVER_HOSTNAME_FIELD = "server_hostname"
_HTTP_PATH = "http_path"
_ACCESS_TOKEN = "access_token"
_WAREHOUSE_FIELD = "warehouse"
@dataclass
class DatabricksConfig(object):
server_hostname: Optional[str] = None
http_path: Optional[str] = None
access_token: Optional[str] = None
warehouse: Optional[str] = None
class DatabricksTask(SQLTask[DatabricksConfig]):
_TASK_TYPE = "databricks"
def __init__(
self,
name: str,
query_template: str,
task_config: Optional[DatabricksConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
**kwargs,
):
outputs = {
"results": output_schema_type,
}
if task_config is None:
task_config = DatabricksConfig()
super().__init__(
name=name,
task_config=task_config,
query_template=query_template,
inputs=inputs,
outputs=outputs,
task_type=self._TASK_TYPE,
**kwargs,
)
self._output_schema_type = output_schema_type
def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
return {
_SERVER_HOSTNAME_FIELD: self.task_config.server_hostname,
_HTTP_PATH: self.task_config.http_path,
_ACCESS_TOKEN: self.task_config.access_token,
_WAREHOUSE_FIELD: self.task_config.warehouse,
}
def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
with sql.connect(server_hostname = "XXXXXXX",
http_path ="XXXXXX",
access_token = "XXXXXXX") as connection:
sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI)
return sql