Skip to content

Batched executor - rewinding to an earlier state

using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating tokens and then rewinding to an earlier state
/// </summary>
public class BatchedExecutorRewind
{
    private const int n_generate = 24;
    private const int n_rewind = 12;
    private const int n_repeats = 6;

    public static async Task Run()
    {
        string modelPath = UserSettings.GetModelPath();

        var parameters = new ModelParams(modelPath);
        using var model = LLamaWeights.LoadFromFile(parameters);

        var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

        // Create an executor that can evaluate a batch of conversations together
        using var executor = new BatchedExecutor(model, parameters);

        // Print some info
        var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
        Console.WriteLine($"Created executor with model: {name}");

        // Evaluate the initial prompt to create one conversation
        using var conversation = executor.Create();
        conversation.Prompt(prompt);

        // Create the start node wrapping the conversation
        var node = new Node(executor.Context);

        // Print the prompt
        Console.ForegroundColor = ConsoleColor.Green;
        Console.WriteLine(prompt);

        for (var i = 0; i < n_repeats; i++)
        {
            for (var j = 0; j < n_generate; j++)
            {
                // Run inference
                await executor.Infer();

                // Sample a token
                var token = node.Sample(conversation);

                // Continue conversation with this token
                if (j != n_generate - 1)
                    conversation.Prompt(token);
            }

            // Write out what we generated
            node.Write(n_rewind, i + 1);

            // Rewind back a few tokens
            conversation.Rewind(n_rewind + 1);

            // Prompt with a token
            conversation.Prompt(node.GetToken(n_generate - n_rewind - 1));

            // Create a new node around the rewound conversation
            node = new Node(executor.Context);
        }

        Console.WriteLine("Press any key to exit demo");
        Console.ReadKey(true);
    }

    private class Node
    {
        private readonly LLamaContext _context;

        private readonly List<LLamaToken> _tokens = new List<LLamaToken>();
        private readonly DefaultSamplingPipeline Sampler;

        public Node(LLamaContext context)
        {
            _context = context;
            Sampler = new DefaultSamplingPipeline();
        }

        public LLamaToken Sample(Conversation conversation)
        {
            var token = Sampler.Sample(_context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
            _tokens.Add(token);
            return token;
        }

        public void Write(int n_rewind, int depth)
        {
            var decoder = new StreamingTokenDecoder(_context);

            for (var i = 0; i < _tokens.Count - n_rewind; i++)
                decoder.Add(_tokens[i]);

            AnsiConsole.MarkupLine($"[green]{new string(' ', depth * 3) + decoder.Read().ReplaceLineEndings(" ")}[/]");

            for (var i = _tokens.Count - n_rewind; i < _tokens.Count; i++)
                decoder.Add(_tokens[i]);

            AnsiConsole.MarkupLine($"[maroon]{decoder.Read().ReplaceLineEndings(" ")}[/]");
        }

        public LLamaToken GetToken(int index)
        {
            return _tokens[index];
        }
    }
}