SmartLi8 / stella

text embedding

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

部分数据集loss为0

wang-ironman opened this issue · comments

I am encountering a similar issue where I suspected that Elastic Weight Consolidation (EWC) was the cause. However, after disabling EWC in the compute_loss function, the issue persists. Specifically, the problem arises during the second step, where the model’s output consists entirely of NaN values

请问大佬解决了吗 @binhna @wang-ironman

@sandan000 The issue arises when loading the pre-trained model in float16. It’s also important to be cautious when setting the ewc_ratio, as a high value can result in a large loss. I adjusted it to 0.01 for my dataset, which resolved the problem.

model = MODEL_NAME_INFO[model_name][0].from_pretrained(
        model_dir,
        trust_remote_code=True,
        # torch_dtype=torch.float16  ### I have to commented this out
    )

I've tried this before, but it didn't work @binhna