rustformers / llm

[Unmaintained, see README] An ecosystem of Rust libraries for working with large language models

Home Page:https://docs.rs/llm/latest/llm/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

NaN logits on LLaMA 65B when using 2k+ token contexts

hugoabonizio opened this issue · comments

I'm trying to make inferences using more than 2k token contexts, but I'm having some trouble making it work for 65B models. The following code works on 7B scale models, but returns token sampling failed (due to nan logits) when using 65B models.

I'm trying with internal models, but using these 7B and 65B reproduce the issue.

use std::{io::Write, path::PathBuf};
use llm::Model;

fn main() {
    let llama = llm::load::<llm::models::Llama>(
        std::path::Path::new("/data/tmp/llama-65b.ggmlv3.q4_0.bin"),
        // std::path::Path::new("/data/tmp/llama-7b.ggmlv3.q4_0.bin"),
        llm::TokenizerSource::HuggingFaceTokenizerFile(PathBuf::from("/data/tmp/tokenizer.json").to_owned()),
        llm::ModelParameters {
            use_gpu: true,
            gpu_layers: Some(99),
            context_size: 8192,
            rope_overrides: Some(llm::RoPEOverrides {
                frequency_scale: 0.25,
                ..Default::default()
            }),
            ..Default::default()
        },
        llm::load_progress_callback_stdout
    )
    .unwrap_or_else(|err| panic!("Failed to load model: {err}"));

    println!("\n\ncontext_size {}", llama.context_size());

    let prompt = "hello ".repeat(2800); // works until 2k tokens

    let mut session = llama.start_session(llm::InferenceSessionConfig {
        n_batch: 256,
        ..Default::default()
    });

    let res = session.infer::<std::convert::Infallible>(
        &llama,
        &mut rand::thread_rng(),
        &llm::InferenceRequest {
            prompt: (&prompt).into(),
            parameters: &llm::InferenceParameters::default(),
            play_back_previous_tokens: false,
            maximum_token_count: Some(1),
        },
        &mut Default::default(),
        |r| match r {
            _ => Ok(llm::InferenceFeedback::Continue),
        }
    );

    match res {
        Ok(result) => println!("\n\nInference stats:\n{result}"),
        Err(err) => println!("\n{err}"),
    }
}