Skip to content

Bacthed executor - multi-output to one input

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
    private const int n_split = 16;
    private const int n_len = 72;

    public static async Task Run()
    {
        string modelPath = UserSettings.GetModelPath();

        var parameters = new ModelParams(modelPath);
        using var model = LLamaWeights.LoadFromFile(parameters);

        var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

        // Create an executor that can evaluate a batch of conversations together
        using var executor = new BatchedExecutor(model, parameters);

        // Print some info
        var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
        Console.WriteLine($"Created executor with model: {name}");

        // Evaluate the initial prompt to create one conversation
        using var start = executor.Create();
        start.Prompt(prompt);
        await executor.Infer();

        // Create the root node of the tree
        var root = new Node(start);

        await AnsiConsole
            .Progress()
            .StartAsync(async progress =>
            {
                var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);

                // Run inference loop
                for (var i = 0; i < n_len; i++)
                {
                    if (i != 0)
                        await executor.Infer();

                    // Occasionally fork all the active conversations
                    if (i != 0 && i % n_split == 0)
                        root.Split();

                    // Sample all active conversations
                    root.Sample();

                    // Update progress bar
                    reporter.Increment(1);
                    reporter.Description($"Running Inference ({root.ActiveConversationCount})");
                }

                // Display results
                var display = new Tree(prompt);
                root.Display(display);
                AnsiConsole.Write(display);
            });
    }

    private class Node
    {
        private readonly StreamingTokenDecoder _decoder;

        private readonly DefaultSamplingPipeline _sampler;
        private Conversation? _conversation;

        private Node? _left;
        private Node? _right;

        public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;

        public Node(Conversation conversation)
        {
            _sampler = new DefaultSamplingPipeline();
            _conversation = conversation;
            _decoder = new StreamingTokenDecoder(conversation.Executor.Context);
        }

        public void Sample()
        {
            if (_conversation == null)
            {
                _left?.Sample();
                _right?.Sample();
                return;
            }

            if (_conversation.RequiresInference)
                return;

            // Sample one token
            var ctx = _conversation.Executor.Context.NativeHandle;
            var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
            _sampler.Accept(ctx, token);
            _decoder.Add(token);

            // Prompt the conversation with this token, to continue generating from there
            _conversation.Prompt(token);
        }

        public void Split()
        {
            if (_conversation != null)
            {
                _left = new Node(_conversation.Fork());
                _right = new Node(_conversation.Fork());

                _conversation.Dispose();
                _conversation = null;
            }
            else
            {
                _left?.Split();
                _right?.Split();
            }
        }

        public void Display<T>(T tree, int depth = 0)
            where T : IHasTreeNodes
        {
            var colors = new[] { "red", "green", "blue", "yellow", "white" };
            var color = colors[depth % colors.Length];

            var message = Markup.Escape(_decoder.Read().ReplaceLineEndings(""));

            var n = tree.AddNode($"[{color}]{message}[/]");

            _left?.Display(n, depth + 1);
            _right?.Display(n, depth + 1);
        }
    }
}