sgl-project / sglang

SGLang is a structured generation language designed for large language models (LLMs). It makes your interaction with models faster and more controllable.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[BUG] srt throws KeyError when sgl.gen(...) regex parameter contains Chinese characters

m0g1cian opened this issue · comments

It seems that sgl.gen(regex=) doesn't take Chinese characters.

Error Details

Exception in ModelRpcClient:
Traceback (most recent call last):
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
    self.handle_generate_request(recv_req)
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
    req.jump_forward_map = self.jump_forward_cache.query(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
    val = _init_with_timer(key)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
    val = self.init_value(key)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
    return JumpForwardMap(regex)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
    self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
    result = cached_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
    for symbol, id_ in symbol_to_id.items():
  File "<frozen _collections_abc>", line 861, in __iter__
  File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
           ^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
    raise KeyError()
KeyError

Exception in ModelRpcClient:
Traceback (most recent call last):
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
    self.handle_generate_request(recv_req)
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
    req.jump_forward_map = self.jump_forward_cache.query(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
    val = _init_with_timer(key)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
    val = self.init_value(key)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
    return JumpForwardMap(regex)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
    self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
    result = cached_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
    for symbol, id_ in symbol_to_id.items():
  File "<frozen _collections_abc>", line 861, in __iter__
  File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
           ^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
    raise KeyError()
KeyError

Exception in ModelRpcClient:
Traceback (most recent call last):
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
    self.handle_generate_request(recv_req)
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
    req.jump_forward_map = self.jump_forward_cache.query(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
    val = _init_with_timer(key)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
    val = self.init_value(key)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
    return JumpForwardMap(regex)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
    self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
    result = cached_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
    for symbol, id_ in symbol_to_id.items():
  File "<frozen _collections_abc>", line 861, in __iter__
  File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
           ^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
    raise KeyError()
KeyError

Exception in ModelRpcClient:
Traceback (most recent call last):
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
    self.handle_generate_request(recv_req)
  File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
    req.jump_forward_map = self.jump_forward_cache.query(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
    val = _init_with_timer(key)
          ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
    val = self.init_value(key)
          ^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
    return JumpForwardMap(regex)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
    self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
    result = cached_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
    for symbol, id_ in symbol_to_id.items():
  File "<frozen _collections_abc>", line 861, in __iter__
  File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
           ^^^^^^^^^^^^^^^^^^^
  File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
    raise KeyError()
KeyError

Minimum working demo:

from enum import Enum
from itertools import chain
from typing import Set
import sglang as sgl
from pydantic import BaseModel
from sglang.srt.constrained import build_regex_from_object

class Port(str, Enum):
    usba = "USB A"
    usbc = "USB C"
    lightning = "Lightning"
    micro_usb = "Micro USB"
    mini_usb = "Mini USB"
    apple_watch = "Apple Watch"
    wireless = "无线"
    dc_2 = "DC 2.0mm"
    dc_35 = "DC 3.5mm"
    unknown = "未知"

class PortEN(str, Enum):
    usba = "USB A"
    usbc = "USB C"
    lightning = "Lightning"
    micro_usb = "Micro USB"
    mini_usb = "Mini USB"
    apple_watch = "Apple Watch"
    wireless = "Wireless"
    dc_2 = "DC 2.0mm"
    dc_35 = "DC 3.5mm"
    unknown = "Unknown"

class FastCharge(str, Enum):
    normal = "非快充"
    scp = "华为 SCP"
    fcp = "华为 FCP"
    vooc = "OPPO VOOC"
    dart = "一加 Dart"
    super_vooc = "OPPO SuperVOOC"
    super_dart = "一加 SuperDart"
    flash_charge = "Vivo FlashCharge"
    qc2 = "高通 QC2.0"
    qc3 = "高通 QC3.0"
    pd = "USB-PD"
    pps = "USB-PD PPS"
    pe = "联发科 PE"
    mi = "小米"
    doubt = "存疑"
    unknown = "未知"

class FastChargeEN(str, Enum):
    normal = "No Fast Charge"
    scp = "Huawei SCP"
    fcp = "Huawei FCP"
    vooc = "OPPO VOOC"
    dart = "OnePlus Dart"
    super_vooc = "OPPO SuperVOOC"
    super_dart = "OnePlus SuperDart"
    flash_charge = "Vivo FlashCharge"
    qc2 = "Qualcomm QC2.0"
    qc3 = "Qualcomm QC3.0"
    pd = "USB-PD"
    pps = "USB-PD PPS"
    pe = "MediaTek PE"
    mi = "XiaoMi"
    doubt = "Doubt"
    unknown = "Unknown"

class PhoneName(BaseModel):
    name: str

class PhonePort(BaseModel):
    port: Set[Port]

class PhoneFastCharge(BaseModel):
    fc: Set[FastCharge]

@sgl.function
def pydantic_wizard_gen(s):
    objs = [PhoneName, PhonePort, PhoneFastCharge]
    prop_fields = list(
        chain.from_iterable([list(p.model_fields.items()) for p in objs])
    )
    prop_keys = [p[0] for p in prop_fields]
    s += "Give me a description about iPhone 15 in the JSON format.\n"
    forks = s.fork(len(objs))
    for f, obj in zip(forks, objs):
        f += sgl.gen(
            "property",
            max_tokens=128,
            temperature=0,
            regex=build_regex_from_object(obj),  # Requires pydantic >= 2.0
        )
    s.set_var("properties", dict(zip(prop_keys, [f["property"] for f in forks])))

if __name__ == "__main__":
    sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
    outputs = pydantic_wizard_gen.run()
    print(outputs["properties"])

If I change the definition of PhonePort and PhoneFastCharge to their english alternatives, no error was thrown.

After some code digging, I can confirm there is some bug when outlines creates its alphabet_symbol_map in initializing fsm_info. It seems that numba typed dict somehow didn't handle some chinese charater stored in numpy array.

Will forward this issue to outlines github.

一样的问题,有什么解决方法吗

一样的问题,有什么解决方法吗

Upstream Issue
Upstream PR

You can try to use this PR to fix outlines temporarily, and this will alleviate KeyError for now.