davidkim205 / komt

Korean Multi-task Instruction Tuning

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

13b 모델의 inference 파라미터 공유가 가능하신가요?

bodybreaker opened this issue · comments

commented

허깅페이스 모델 상세페이지에 예시가 있더라구요
해당 parameter로 13b 모델을 돌리니 성능이 잘 안나오는거 같아서 문의드립니다.

리소스 문제로 load-in-4bit 와 load-in-8bit 로 테스트 해봤습니다.

### instruction: 자동차 종합(정기)검사 의무기간은 얼마인가요?

### Response:3년,3년이상
자동차 종합(정기)검사는 매년 1회 또는 2회에 한해서만 면제됩니다. 그리고 항공기나 선박과 비교하여 보험료가 월등히 낮아 경우에 따라 유용하지 않을 수도 있으니 주의해야 합니다.

자동차 종합(정기)검사를 받지 않으면 위반일수에 따라 50만원부터 60만원까지 과태료가 부과될 수 있습니다. 단, 신규 차량의 경우 총 4회 중 1회만 검사를 받으면 모두 면제됩니다.

자동차 종합(정기)검사를 받으려면 전국 78개 시군구와 대학교 내 자동차 관련 학과 등에서 본인이 소유한 차량을 대상으로 받을 수 있습니다.

자동차 종합(정기)검사를 받으면 차량의 안전성을 평가하고 운영 상태를 파악하여 사고를 방지하는 역할을 합니다. 또한 차량 소유자들은 자동차 종합(정기)검사를 ��

13b 모델을 load-in-4bit 이나 load-in-8bit 로 하는것은 별로 추천해드리고 싶지 않습니다. 성능이 많이 저하 됩니다. 추론시 52G 이상이 필요합니다. 차라리 ggml이나 lora모델을 사용하는것이 좋을것 같습니다. https://huggingface.co/davidkim205/komt-llama2-13b-v1-lora, https://huggingface.co/davidkim205/komt-llama2-13b-v1-ggml 를 참고하셔서 테스트 해보면 좋을것 같습니다.
위의 log에서 글자가 짤리는것 같은데 max_new_tokens를 1024로 변경하는것이 좋을것 같습니다.
아래는 코드 예입니다.

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import TextStreamer, GenerationConfig

model_name='davidkim205/komt-llama2-13b-v1'
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer)

def gen(x):
    generation_config = GenerationConfig(
        temperature=0.8,
        top_p=0.8,
        top_k=100,
        max_new_tokens=1024,
        early_stopping=True,
        do_sample=True,
    )
    q = f"### instruction: {x}\n\n### Response: "
    gened = model.generate(
        **tokenizer(
            q,
            return_tensors='pt',
            return_token_type_ids=False
        ).to('cuda'),
        generation_config=generation_config,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )
    result_str = tokenizer.decode(gened[0])

    start_tag = f"\n\n### Response: "
    start_index = result_str.find(start_tag)

    if start_index != -1:
        result_str = result_str[start_index + len(start_tag):].strip()
    return result_str

result = gen('제주도를 1박2일로 혼자 여행하려고 하는데 여행 코스를 만들어줘')
print('########################')
print(result)