[BUG] srt throws KeyError when sgl.gen(...) regex parameter contains Chinese characters
m0g1cian opened this issue · comments
It seems that sgl.gen(regex=)
doesn't take Chinese characters.
Error Details
Exception in ModelRpcClient:
Traceback (most recent call last):
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
self.handle_generate_request(recv_req)
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
req.jump_forward_map = self.jump_forward_cache.query(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
val = _init_with_timer(key)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
val = self.init_value(key)
^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
return JumpForwardMap(regex)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
result = cached_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
for symbol, id_ in symbol_to_id.items():
File "<frozen _collections_abc>", line 861, in __iter__
File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
return _getitem(self, key)
^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
raise KeyError()
KeyError
Exception in ModelRpcClient:
Traceback (most recent call last):
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
self.handle_generate_request(recv_req)
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
req.jump_forward_map = self.jump_forward_cache.query(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
val = _init_with_timer(key)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
val = self.init_value(key)
^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
return JumpForwardMap(regex)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
result = cached_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
for symbol, id_ in symbol_to_id.items():
File "<frozen _collections_abc>", line 861, in __iter__
File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
return _getitem(self, key)
^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
raise KeyError()
KeyError
Exception in ModelRpcClient:
Traceback (most recent call last):
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
self.handle_generate_request(recv_req)
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
req.jump_forward_map = self.jump_forward_cache.query(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
val = _init_with_timer(key)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
val = self.init_value(key)
^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
return JumpForwardMap(regex)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
result = cached_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
for symbol, id_ in symbol_to_id.items():
File "<frozen _collections_abc>", line 861, in __iter__
File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
return _getitem(self, key)
^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
raise KeyError()
KeyError
Exception in ModelRpcClient:
Traceback (most recent call last):
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 175, in exposed_step
self.handle_generate_request(recv_req)
File ".../sglang/python/sglang/srt/managers/router/model_rpc.py", line 271, in handle_generate_request
req.jump_forward_map = self.jump_forward_cache.query(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 34, in query
val = _init_with_timer(key)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/base_cache.py", line 18, in _init_with_timer
val = self.init_value(key)
^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 64, in init_value
return JumpForwardMap(regex)
^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 41, in __init__
self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/outlines/caching.py", line 74, in wrapper
result = cached_function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../sglang/python/sglang/srt/constrained/jump_forward.py", line 19, in _init_state_to_jump_forward
for symbol, id_ in symbol_to_id.items():
File "<frozen _collections_abc>", line 861, in __iter__
File ".../sglang/lib/python3.11/site-packages/numba/typed/typeddict.py", line 180, in __getitem__
return _getitem(self, key)
^^^^^^^^^^^^^^^^^^^
File ".../sglang/lib/python3.11/site-packages/numba/typed/dictobject.py", line 778, in impl
raise KeyError()
KeyError
Minimum working demo:
from enum import Enum
from itertools import chain
from typing import Set
import sglang as sgl
from pydantic import BaseModel
from sglang.srt.constrained import build_regex_from_object
class Port(str, Enum):
usba = "USB A"
usbc = "USB C"
lightning = "Lightning"
micro_usb = "Micro USB"
mini_usb = "Mini USB"
apple_watch = "Apple Watch"
wireless = "无线"
dc_2 = "DC 2.0mm"
dc_35 = "DC 3.5mm"
unknown = "未知"
class PortEN(str, Enum):
usba = "USB A"
usbc = "USB C"
lightning = "Lightning"
micro_usb = "Micro USB"
mini_usb = "Mini USB"
apple_watch = "Apple Watch"
wireless = "Wireless"
dc_2 = "DC 2.0mm"
dc_35 = "DC 3.5mm"
unknown = "Unknown"
class FastCharge(str, Enum):
normal = "非快充"
scp = "华为 SCP"
fcp = "华为 FCP"
vooc = "OPPO VOOC"
dart = "一加 Dart"
super_vooc = "OPPO SuperVOOC"
super_dart = "一加 SuperDart"
flash_charge = "Vivo FlashCharge"
qc2 = "高通 QC2.0"
qc3 = "高通 QC3.0"
pd = "USB-PD"
pps = "USB-PD PPS"
pe = "联发科 PE"
mi = "小米"
doubt = "存疑"
unknown = "未知"
class FastChargeEN(str, Enum):
normal = "No Fast Charge"
scp = "Huawei SCP"
fcp = "Huawei FCP"
vooc = "OPPO VOOC"
dart = "OnePlus Dart"
super_vooc = "OPPO SuperVOOC"
super_dart = "OnePlus SuperDart"
flash_charge = "Vivo FlashCharge"
qc2 = "Qualcomm QC2.0"
qc3 = "Qualcomm QC3.0"
pd = "USB-PD"
pps = "USB-PD PPS"
pe = "MediaTek PE"
mi = "XiaoMi"
doubt = "Doubt"
unknown = "Unknown"
class PhoneName(BaseModel):
name: str
class PhonePort(BaseModel):
port: Set[Port]
class PhoneFastCharge(BaseModel):
fc: Set[FastCharge]
@sgl.function
def pydantic_wizard_gen(s):
objs = [PhoneName, PhonePort, PhoneFastCharge]
prop_fields = list(
chain.from_iterable([list(p.model_fields.items()) for p in objs])
)
prop_keys = [p[0] for p in prop_fields]
s += "Give me a description about iPhone 15 in the JSON format.\n"
forks = s.fork(len(objs))
for f, obj in zip(forks, objs):
f += sgl.gen(
"property",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(obj), # Requires pydantic >= 2.0
)
s.set_var("properties", dict(zip(prop_keys, [f["property"] for f in forks])))
if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
outputs = pydantic_wizard_gen.run()
print(outputs["properties"])
If I change the definition of PhonePort
and PhoneFastCharge
to their english alternatives, no error was thrown.
After some code digging, I can confirm there is some bug when outlines
creates its alphabet_symbol_map
in initializing fsm_info
. It seems that numba typed dict somehow didn't handle some chinese charater stored in numpy array.
Will forward this issue to outlines
github.
一样的问题,有什么解决方法吗
一样的问题,有什么解决方法吗
You can try to use this PR to fix outlines temporarily, and this will alleviate KeyError
for now.