ludwig-ai / ludwig

Low-code framework for building custom LLMs, neural networks, and other AI models

Home Page:http://ludwig.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

repetition_penalty bugged out

rlleshi opened this issue · comments

Describe the bug
When defining a value for repetition_penalty & fine-tuning the model, predictions fail with the following error:

Prediction:   0%|                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/projects/llmodels/karli_iqa/train.py", line 93, in <module>
    main()
  File "/projects/llmodels/karli_iqa/train.py", line 73, in main
    predictions = model.predict(test_df, generation_config={'temperature': 0.1, 'max_new_tokens': 26, 'repetition_penalty': 1.1})[0]
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/api.py", line 1083, in predict
    predictions = predictor.batch_predict(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 143, in batch_predict
    preds = self._predict(batch)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 189, in _predict
    outputs = self._predict_on_inputs(inputs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 345, in _predict_on_inputs
    return self.dist_model.generate(inputs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/llm.py", line 334, in generate
    model_outputs = self.model.generate(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/peft/peft_model.py", line 1130, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/utils.py", line 2874, in sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 97, in __call__
    scores = processor(input_ids, scores)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 332, in __call__
    score = torch.gather(scores, 1, input_ids)
RuntimeError: gather(): Expected dtype int64 for index

To Reproduce
Steps to reproduce the behavior:

Config file:

model_type: llm
base_model: huggyllama/llama-7b
# base_model: meta-llama/Llama-2-7b-hf
# base_model: meta-llama/Llama-2-13b-hf

model_parameters:
  trust_remote_code: true

backend:
  type: local
  cache_dir: ./ludwig_cache

input_features:
  - name: input
    type: text
    preprocessing:
      max_sequence_length: 128

output_features:
  - name: output
    type: text
    preprocessing:
      max_sequence_length: 64

prompt:
  template: >-
    ### User: {input}

    ### Assistant:


generation:
  temperature: 0.1
  max_new_tokens: 32
  repetition_penalty: 1.1
  # remove_invalid_values: true


adapter:
  type: lora
  dropout: 0.05
  r: 8

quantization:
  bits: 4

preprocessing:
  global_max_sequence_length: 256
  split:
    type: fixed

trainer:
  type: finetune
  epochs: 1
  batch_size: 3
  eval_batch_size: 2
  gradient_accumulation_steps: 16
  learning_rate: 0.0004
  learning_rate_scheduler:
    warmup_fraction: 0.03

Lora fine-tuning works fine. But when attempting inference like so: preds = model.predict(test_df, generation_config={'temperature': 0.1, 'max_new_tokens': 26, 'repetition_penalty': 1.1})[0] the above-mentioned error is thrown.

Expected behavior
Predictions should happen normally.

Environment (please complete the following information):

  • OS: ubuntu 22.04
  • Python: 3.10.0
  • Ludwig version: 0.9.1