marella / ctransformers

Python bindings for the Transformer models implemented in C/C++ using GGML library.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Request: `stopping_criteria`

freckletonj opened this issue · comments

Huggingface transformers offers stopping_criteria: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_stopping_criteria.html

I use this with a threading.Event so I can stop generation from a separate thread, and it works great with transformers:

def custom_stopping_criteria(local_llm_stop_event):
    def f(input_ids: torch.LongTensor,
          score: torch.FloatTensor,
          **kwargs) -> bool:
        return local_llm_stop_event.is_set()
    return f

stopping_criteria = StoppingCriteriaList([
    custom_stopping_criteria(local_llm_stop_event)
])

output = model.generate(
    ...
    stopping_criteria=stopping_criteria,
)

I notice that ctransformers.generate accepts this kwarg, but ignores it.

This resulted in a sneaky segfault since this library isn't threadsafe. I thought model generation was stopped, but it actually wasn't, so upon a second generation, it would segfault mysteriously.