[ray::ImplicitFunc.train()] No module named 'evaluate_modules' when using `evaluate` metrics in training script ray.tune
jamnicki opened this issue · comments
What happened + What you expected to happen
I was following this tutorial in order to customize it for my needs. It raises following error after adding evaluation
metrics to training script.
2024-05-05 14:34:09,577 ERROR tune_controller.py:1331 -- Trial task failed for trial train_model_3cb2d_00009
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
result = ray.get(future)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2623, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 861, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RaySystemError): ray::ImplicitFunc.train() (pid=13793, ip=172.28.0.12, actor_id=65ba4ddb7f85fd31d27ad23301000000, repr=train_model)
File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 331, in train
raise skipped from exception_cause(skipped)
File "/usr/local/lib/python3.10/dist-packages/ray/air/_internal/util.py", line 98, in run
self._ret = self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/function_trainable.py", line 45, in <lambda>
training_func=lambda: self._trainable_func(self.config),
File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/function_trainable.py", line 248, in _trainable_func
output = fn()
File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/util.py", line 129, in inner
fn_kwargs[k] = parameter_registry.get(prefix + k)
File "/usr/local/lib/python3.10/dist-packages/ray/tune/registry.py", line 300, in get
return ray.get(self.references[k])
ray.exceptions.RaySystemError: System error: No module named 'evaluate_modules'
traceback: Traceback (most recent call last):
ModuleNotFoundError: No module named 'evaluate_modules'
Related issue: huggingface/transformers#22408
Versions / Dependencies
Environment: google colab T4
ray: 2.20.0
evaluate: 0.4.2
transformers: 4.40.1
torch: 2.2.1+cu121
Reproduction script
def train_model(config,
device=None,
train_ds=None,
val_ds=None,
meteor_model=None,
bertscore_model=None,
ter_model=None,
sacrebleu=None,
max_sequence_length=None,
vocab_size=None,
epochs=None,
):
transformer = Transformer(
d_model=config["d_model"],
ffn_hidden=config["ffn_hidden"],
num_heads=config["num_heads"],
drop_prob=DROP_PROB,
num_layers=config["num_layers"],
max_sequence_length=max_sequence_length,
vocab_size=vocab_size,
).to(device)
tokenizer = AutoTokenizer.from_pretrained("allegro/herbert-base-cased")
tokenizer.pad_token_id = 1
tokenizer.bos_token_id = 0
tokenizer.eos_token_id = 2
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction="none")
optim = torch.optim.Adam(transformer.parameters(), lr=config["lr"])
checkpoint = get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
data_path = Path(checkpoint_dir) / "data.pkl"
with open(data_path, "rb") as fp:
checkpoint_state = cloudpickle.load(fp)
start_epoch = checkpoint_state["epoch"]
transformer.load_state_dict(checkpoint_state["net_state_dict"])
optim.load_state_dict(checkpoint_state["optimizer_state_dict"])
else:
start_epoch = 0
train_dl = DataLoader(train_ds, batch_size=config["batch_size"], pin_memory=True, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=config["batch_size"], pin_memory=True, shuffle=True)
n_train_batches = len(train_dl)
n_val_batches = len(val_dl)
for epoch in range(start_epoch, epochs):
epoch_history = {
"Epoch": epoch + 1,
"TrainLoss": 0,
"ValLoss": 0,
"ValMETEOR": 0,
"ValBERTScoreAvgF1": 0,
"ValBERTScoreAvgPrecision": 0,
"ValBERTScoreAvgRecall": 0,
"ValTER": 0,
"ValBLEU": 0,
# "ValPredictions": [],
}
# Epoch Training
for batch_num, (train_batch, val_batch) in enumerate(zip(train_dl, val_dl)):
x_train, y_train = train_batch["in_landmarks"].to(device), train_batch["out_polish_token_ids"].to(device)
train_loss = train_step(x_train, y_train, transformer, tokenizer, criterion, optim)
epoch_history["TrainLoss"] += train_loss / n_train_batches
# Epoch Validation
for batch_num, val_batch in enumerate(val_dl):
x_val, y_val = val_batch["in_landmarks"].to(device), val_batch["out_polish_token_ids"].to(device)
val_pred, val_loss = eval_step(x_val, y_val, transformer, tokenizer, criterion)
y_batch_decoded = tokenizer.batch_decode(y_val, skip_special_tokens=True)
val_pred_max = torch.max(val_pred, dim=-1)
fixed_val_pred_incidies = pad_after_eos(val_pred_max.indices)
val_pred_decoded = tokenizer.batch_decode(fixed_val_pred_incidies, skip_special_tokens=True)
# epoch_history["ValPredictions"].extend(np.array(val_pred_decoded).flatten().tolist())
meteor_score = meteor_model.compute(
predictions=val_pred_decoded,
references=y_batch_decoded,
)
bert_score = bertscore_model.compute(
predictions=val_pred_decoded,
references=y_batch_decoded,
lang="pl"
)
ter_score = ter_model.compute(
predictions=val_pred_decoded,
references=y_batch_decoded,
)
bleu_score = sacrebleu.compute(
predictions=val_pred_decoded,
references=y_batch_decoded,
)
epoch_history["ValLoss"] += val_loss / n_val_batches
epoch_history["ValMETEOR"] += meteor_score["meteor"] / n_val_batches
epoch_history["ValBERTScoreAvgF1"] += np.mean(bert_score["f1"]) / n_val_batches
epoch_history["ValBERTScoreAvgPrecision"] += np.mean(bert_score["precision"]) / n_val_batches
epoch_history["ValBERTScoreAvgRecall"] += np.mean(bert_score["recall"]) / n_val_batches
epoch_history["ValTER"] += ter_score["score"] / n_val_batches
epoch_history["ValBLEU"] += bleu_score["score"] / n_val_batches
checkpoint_data = {
"epoch": epoch,
"net_state_dict": transformer.state_dict(),
"optimizer_state_dict": optim.state_dict(),
}
with tempfile.TemporaryDirectory() as checkpoint_dir:
data_path = Path(checkpoint_dir) / "data.pkl"
with open(data_path, "wb") as fp:
cloudpickle.dump(checkpoint_data, fp)
checkpoint = Checkpoint.from_directory(checkpoint_dir)
ray.train.report(
epoch_history,
checkpoint=checkpoint,
)
def tune_model(num_samples):
config = {
"lr": ray.tune.loguniform(1e-5, 1e-1),
"batch_size": ray.tune.choice([64]),
"d_model": ray.tune.choice([256, 512, 1024]),
"ffn_hidden": ray.tune.choice([512, 1024, 2048]),
"num_heads": ray.tune.choice([8]),
"num_layers": ray.tune.choice([10]),
}
meteor_model = evaluate.load("meteor")
bertscore_model = evaluate.load("bertscore")
ter_model = evaluate.load("ter")
sacrebleu = evaluate.load("sacrebleu")
partial_train = ray.tune.with_parameters(train_model,
device=DEVICE,
train_ds=train_ds,
val_ds=val_ds,
meteor_model=meteor_model,
bertscore_model=bertscore_model,
ter_model=ter_model,
sacrebleu=sacrebleu,
max_sequence_length=MAX_SEQUENCE_LENGTH,
vocab_size=VOCAB_SIZE,
epochs=EPOCHS,
)
scheduler = ASHAScheduler(
metric="ValLoss",
mode="min",
max_t=EPOCHS,
grace_period=1,
reduction_factor=2,
)
result = ray.tune.run(
partial_train,
resources_per_trial={"cpu": 8, "gpu": 1},
config=config,
num_samples=num_samples,
scheduler=scheduler,
storage_path=EXP_TUNE_RESULTS_DIR,
)
return result
Issue Severity
High: It blocks me from completing my task.
@jamnicki - Ray Train team will try and take a look but can you reach out reaching to the Torch committers for that example as well? Example wasn't originally reviewed by us.
@jamnicki , if I remember correctly, there would be potential sterilization issues if you initialize the evaluate modules outside of the training function.
Instead of passing them using with_parameters
, can you try to move these lines into the train_model
function?
meteor_model = evaluate.load("meteor")
bertscore_model = evaluate.load("bertscore")
ter_model = evaluate.load("ter")
sacrebleu = evaluate.load("sacrebleu")
Thank you for the workaround @woshiyyya !