Batched executor - basic guidance
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;
namespace LLama.Examples.Examples;
/// <summary>
/// This demonstrates using a batch to generate two sequences and then using one
/// sequence as the negative guidance ("classifier free guidance") for the other.
/// </summary>
public class BatchedExecutorGuidance
{
private const int n_len = 32;
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);
// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Load the two prompts into two conversations
using var guided = executor.Create();
guided.Prompt(positivePrompt);
using var guidance = executor.Create();
guidance.Prompt(negativePrompt);
// Run inference to evaluate prompts
await AnsiConsole
.Status()
.Spinner(Spinner.Known.Line)
.StartAsync("Evaluating Prompts...", _ => executor.Infer());
// Fork the "guided" conversation. We'll run this one without guidance for comparison
using var unguided = guided.Fork();
// Run inference loop
var unguidedSampler = new GuidedSampler(null, weight);
var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
var guidedSampler = new GuidedSampler(guidance, weight);
var guidedDecoder = new StreamingTokenDecoder(executor.Context);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference", maxValue: n_len);
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
// Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
// guidance. This serves as a comparison to show the effect of guidance.
var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
unguidedDecoder.Add(u);
unguided.Prompt(u);
// Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
// to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
guidedDecoder.Add(g);
// Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
guided.Prompt(g);
guidance.Prompt(g);
// Early exit if we reach the natural end of the guided sentence
if (g == model.EndOfSentenceToken)
break;
// Update progress bar
reporter.Increment(1);
}
});
AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
}
private class GuidedSampler(Conversation? guidance, float weight)
: BaseSamplingPipeline
{
public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
{
}
public override ISamplingPipeline Clone()
{
throw new NotSupportedException();
}
protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
{
if (guidance == null)
return;
// Get the logits generated by the guidance sequences
var guidanceLogits = guidance.Sample();
// Use those logits to guide this sequence
NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
}
protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
{
candidates.Temperature(ctx, 0.8f);
candidates.TopK(ctx, 25);
return candidates.SampleToken(ctx);
}
}
}