using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;
namespace LLama.Examples.Examples;
/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 72;
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Evaluate the initial prompt to create one conversation
using var start = executor.Create();
start.Prompt(prompt);
await executor.Infer();
// Create the root node of the tree
var root = new Node(start);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();
// Sample all active conversations
root.Sample();
// Update progress bar
reporter.Increment(1);
reporter.Description($"Running Inference ({root.ActiveConversationCount})");
}
// Display results
var display = new Tree(prompt);
root.Display(display);
AnsiConsole.Write(display);
});
}
private class Node
{
private readonly StreamingTokenDecoder _decoder;
private readonly DefaultSamplingPipeline _sampler;
private Conversation? _conversation;
private Node? _left;
private Node? _right;
public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
public Node(Conversation conversation)
{
_sampler = new DefaultSamplingPipeline();
_conversation = conversation;
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
}
public void Sample()
{
if (_conversation == null)
{
_left?.Sample();
_right?.Sample();
return;
}
if (_conversation.RequiresInference)
return;
// Sample one token
var ctx = _conversation.Executor.Context.NativeHandle;
var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token);
_decoder.Add(token);
// Prompt the conversation with this token, to continue generating from there
_conversation.Prompt(token);
}
public void Split()
{
if (_conversation != null)
{
_left = new Node(_conversation.Fork());
_right = new Node(_conversation.Fork());
_conversation.Dispose();
_conversation = null;
}
else
{
_left?.Split();
_right?.Split();
}
}
public void Display<T>(T tree, int depth = 0)
where T : IHasTreeNodes
{
var colors = new[] { "red", "green", "blue", "yellow", "white" };
var color = colors[depth % colors.Length];
var message = Markup.Escape(_decoder.Read().ReplaceLineEndings(""));
var n = tree.AddNode($"[{color}]{message}[/]");
_left?.Display(n, depth + 1);
_right?.Display(n, depth + 1);
}
}
}