curved-easter-24577
02/26/2024, 3:31 PMconfig.json
. In the remote s3 folder we can find the config.json
. Do we have to wait to download all the files or anything similar?
Thank you!.curved-easter-24577
02/26/2024, 3:32 PM@task(requests=Resources(cpu="1", mem="2Gi"))
def data_step() -> Tuple[FlyteDirectory, FlyteDirectory, Dict, str, FlyteDirectory]:
"""Queries the two splits of the dataset and tokenizes it for training.
Args:
Returns:
Tuple[FlyteDirectory, FlyteDirectory, Dict, str, FlyteDirectory]
"""
model_name = "distilbert-base-multilingual-cased"
df_train, df_val = query.main()
train_dataset, val_dataset, cat_index, tokenizer = tokenize.main(df_train, df_val, model_name=model_name)
train_dir = save_dataset(train_dataset, "train")
val_dir = save_dataset(val_dataset, "val")
config = AutoConfig.from_pretrained(model_name)
tokenizer_dir = save_tokenizer(tokenizer, config)
return train_dir, val_dir, cat_index, model_name, tokenizer_dir
def save_tokenizer(tokenizer: AutoTokenizer, config: AutoConfig) -> FlyteDirectory:
"""Saves a tokenizer to disk and return its path.
Args:
tokenizer (transformers.AutoTokenizer): dataset model to save.
config (transformers.AutoConfig): configuration for model to save.
Returns:
FlyteFile: path to the tokenizer file.
"""
ctx = flytekit.current_context()
tokenizer_dir = os.path.join(ctx.working_directory, "tokenizer")
os.makedirs(tokenizer_dir, exist_ok=True)
tokenizer.save_pretrained(tokenizer_dir)
config.save_pretrained(tokenizer_dir)
LOGGER.debug("Tokenizer saved at %s", tokenizer_dir)
return FlyteDirectory(path=tokenizer_dir)
@task(requests=Resources(cpu="1", gpu="1", mem="2Gi", ephemeral_storage="20Gi"))
def train_step(
train_dataset_dir: FlyteDirectory, val_dataset_dir: FlyteDirectory, model_name: str, tokenizer_dir: FlyteDirectory
) -> List[Any]:
"""Trains the model with the given data, and evaluates it on train-distributed validation data.
Args:
train_dataset_dir (FlyteDirectory): directory of tokenized train split data.
val_dataset_dir (FlyteDirectory):directory of tokenized validation split data.
model_name (str): name of pretrained model to finetune.
tokenizer_dir: (FlyteDirectory): directory of model tokenizer.
Returns:
Tuple[FlyteDirectory, Dict]
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, local_files_only=True) # does not find config.json
freezing-airport-6809
curved-easter-24577
02/26/2024, 4:31 PMfreezing-airport-6809
curved-easter-24577
02/26/2024, 4:35 PMfreezing-airport-6809
glamorous-carpet-83516
02/26/2024, 10:40 PMbut fails becaouse it does not find a concret file in the directory calleddid you download the file? like. In the remote s3 folder we can find theconfig.json
config.json
FlyteDirectory.download()
glamorous-carpet-83516
02/26/2024, 10:41 PMfreezing-airport-6809
download
button 🙂curved-easter-24577
02/27/2024, 9:45 AMFlyteDirectory.download()
as you say and worked perfectly, thank youu for your time ✨freezing-airport-6809
freezing-airport-6809