Skip to content

Bacthed executor - multi-output to one input

using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
    private const int n_split = 16;
    private const int n_len = 72;

    public static async Task Run()
    {
        string modelPath = UserSettings.GetModelPath();

        var parameters = new ModelParams(modelPath);
        using var model = LLamaWeights.LoadFromFile(parameters);

        var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

        // Create an executor that can evaluate a batch of conversations together
        using var executor = new BatchedExecutor(model, parameters);

        // Print some info
        var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
        Console.WriteLine($"Created executor with model: {name}");

        // Evaluate the initial prompt to create one conversation
        using var start = executor.Create();
        start.Prompt(prompt);
        await executor.Infer();

        // Create the root node of the tree
        var root = new Node(start);

        await AnsiConsole
            .Progress()
            .StartAsync(async progress =>
            {
                var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);

                // Run inference loop
                for (var i = 0; i < n_len; i++)
                {
                    if (i != 0)
                        await executor.Infer();

                    // Occasionally fork all the active conversations
                    if (i != 0 && i % n_split == 0)
                        root.Split();

                    // Sample all active conversations
                    root.Sample();

                    // Update progress bar
                    reporter.Increment(1);
                    reporter.Description($"Running Inference ({root.ActiveConversationCount})");
                }

                // Display results
                var display = new Tree(prompt);
                root.Display(display);
                AnsiConsole.Write(display);
            });
    }

    private class Node
    {
        private readonly StreamingTokenDecoder _decoder;

        private readonly DefaultSamplingPipeline _sampler;
        private Conversation? _conversation;

        private Node? _left;
        private Node? _right;

        public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;

        public Node(Conversation conversation)
        {
            _sampler = new DefaultSamplingPipeline();
            _conversation = conversation;
            _decoder = new StreamingTokenDecoder(conversation.Executor.Context);
        }

        public void Sample()
        {
            if (_conversation == null)
            {
                _left?.Sample();
                _right?.Sample();
                return;
            }

            if (_conversation.RequiresInference)
                return;

            // Sample one token
            var ctx = _conversation.Executor.Context.NativeHandle;
            var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
            _sampler.Accept(ctx, token);
            _decoder.Add(token);

            // Prompt the conversation with this token, to continue generating from there
            _conversation.Prompt(token);
        }

        public void Split()
        {
            if (_conversation != null)
            {
                _left = new Node(_conversation.Fork());
                _right = new Node(_conversation.Fork());

                _conversation.Dispose();
                _conversation = null;
            }
            else
            {
                _left?.Split();
                _right?.Split();
            }
        }

        public void Display<T>(T tree, int depth = 0)
            where T : IHasTreeNodes
        {
            var colors = new[] { "red", "green", "blue", "yellow", "white" };
            var color = colors[depth % colors.Length];

            var message = Markup.Escape(_decoder.Read().ReplaceLineEndings(""));

            var n = tree.AddNode($"[{color}]{message}[/]");

            _left?.Display(n, depth + 1);
            _right?.Display(n, depth + 1);
        }
    }
}