[Bug] Text embedding generator for LLamaSharp fails during token count.
vvdb-architecture opened this issue · comments
Context / Scenario
If you want to use LLamaSharp
, with the latest version of Kernel Memory
and remain on your local hardware, you need to write a text embedding generator.
There aren't a thousand ways to write one. Here's mine:
internal sealed class TextEmbeddingGenerator : ITextEmbeddingGenerator, IDisposable
{
private readonly LLamaWeights _weights;
private readonly LLamaEmbedder _embedder;
private readonly LLamaContext _context;
public TextEmbeddingGenerator(LlamaSharpConfig config, ILogger<TextEmbeddingGenerator>? logger)
{
var parameters = new ModelParams(config.ModelPath)
{
ContextSize = config.MaxTokenTotal
};
MaxTokens = (int)config.MaxTokenTotal;
if (config.GpuLayerCount.HasValue)
parameters.GpuLayerCount = config.GpuLayerCount.Value;
if (config.Seed.HasValue)
parameters.Seed = config.Seed.Value;
_weights = LLamaWeights.LoadFromFile(parameters);
_embedder = new LLamaEmbedder(_weights, parameters, logger);
_context = _weights.CreateContext(parameters, logger);
}
public int CountTokens(string text)
{
return _context.Tokenize(text).Length;
}
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken)
{
var embeddings = _embedder.GetEmbeddings(text);
return Task.FromResult(new Embedding(embeddings));
}
public void Dispose()
{
_context.Dispose();
_embedder.Dispose();
_weights.Dispose();
}
public int MaxTokens { get; }
}
What happened?
This works well except when you're doing some text partitioning. Then the call to _context.Tokenize(text)
throws an exception:
LLama.Exceptions.RuntimeError
HResult=0x80131500
Message=Error happened during tokenization. It's possibly caused by wrong encoding. Please try to specify the encoding.
Source=LLamaSharp
StackTrace:
at LLama.Native.SafeLLamaContextHandle.Tokenize(String text, Boolean add_bos, Boolean special, Encoding encoding)
at LLama.LLamaContext.Tokenize(String text, Boolean addBos, Boolean special)
at ConsoleApp1.TextEmbeddingGenerator.CountTokens(String text) in D:\Source\km\ConsoleApp1\TextEmbeddingGenerator.cs:line 33
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(ReadOnlySpan`1 input, String inputString, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(ReadOnlySpan`1 input, String inputString, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(ReadOnlySpan`1 input, String inputString, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(ReadOnlySpan`1 input, String inputString, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(ReadOnlySpan`1 input, String inputString, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.Split(List`1 input, Int32 maxTokens, ReadOnlySpan`1 separators, Boolean trim, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.InternalSplitLines(String text, Int32 maxTokensPerLine, Boolean trim, String[] splitOptions, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.<>c.<SplitPlainTextParagraphs>b__6_0(String text, Int32 maxTokens, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.ProcessParagraphs(List`1 paragraphs, Int32 adjustedMaxTokensPerParagraph, Int32 overlapTokens, String chunkHeader, Func`4 longLinesSplitter, TokenCounter tokenCounter)
at Microsoft.KernelMemory.DataFormats.Text.TextChunker.InternalSplitTextParagraphs(List`1 lines, Int32 maxTokensPerParagraph, Int32 overlapTokens, String chunkHeader, Func`4 longLinesSplitter, TokenCounter tokenCounter)
at Microsoft.KernelMemory.Handlers.TextPartitioningHandler.<InvokeAsync>d__9.MoveNext()
After investigation, it turns out that (contrary to what the exception mentions), this has nothing to do with encoding.
It has to do with the text
string containing only a newline: "\n"
.
The workaround is to rewrite the method as follows:
public int CountTokens(string text)
{
if (text == "\n")
return 0;
return _context.Tokenize(text).Length;
}
I am therefore wondering if this isn't a bug in LLamaContext.Tokenize. Surely it should handle newline text without throwing an exception.
Importance
a fix would make my life easier
Platform, Language, Versions
Windows,
Microsoft.KernelMemory.AI.LlamaSharp Version="0.26.240104.1"
Microsoft.KernelMemory.Core Version="0.26.240104.1"
Relevant log output
No response
The exception is coming from LLama.Native.SafeLLamaContextHandle.Tokenize
so it looks like a bug in the LlamaSharp repo (or the underlying llama.cpp code). Could you report the bug at https://github.com/SciSharp/LLamaSharp/issues ?
The issue has been reported to the LLamaSharp colleagues and has been closed in the meantime. Thanks!