Skip to content

Batched executor - basic guidance

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates using a batch to generate two sequences and then using one
/// sequence as the negative guidance ("classifier free guidance") for the other.
/// </summary>
public class BatchedExecutorGuidance
{
    private const int n_len = 32;

    public static async Task Run()
    {
        string modelPath = UserSettings.GetModelPath();

        var parameters = new ModelParams(modelPath);
        using var model = LLamaWeights.LoadFromFile(parameters);

        var positivePrompt = AnsiConsole.Ask("Positive Prompt (or ENTER for default):", "My favourite colour is").Trim();
        var negativePrompt = AnsiConsole.Ask("Negative Prompt (or ENTER for default):", "I hate the colour red. My favourite colour is").Trim();
        var weight = AnsiConsole.Ask("Guidance Weight (or ENTER for default):", 2.0f);

        // Create an executor that can evaluate a batch of conversations together
        using var executor = new BatchedExecutor(model, parameters);

        // Print some info
        var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
        Console.WriteLine($"Created executor with model: {name}");

        // Load the two prompts into two conversations
        using var guided = executor.Create();
        guided.Prompt(positivePrompt);
        using var guidance = executor.Create();
        guidance.Prompt(negativePrompt);

        // Run inference to evaluate prompts
        await AnsiConsole
             .Status()
             .Spinner(Spinner.Known.Line)
             .StartAsync("Evaluating Prompts...", _ => executor.Infer());

        // Fork the "guided" conversation. We'll run this one without guidance for comparison
        using var unguided = guided.Fork();

        // Run inference loop
        var unguidedSampler = new GuidedSampler(null, weight);
        var unguidedDecoder = new StreamingTokenDecoder(executor.Context);
        var guidedSampler = new GuidedSampler(guidance, weight);
        var guidedDecoder = new StreamingTokenDecoder(executor.Context);
        await AnsiConsole
           .Progress()
           .StartAsync(async progress =>
            {
                var reporter = progress.AddTask("Running Inference", maxValue: n_len);

                for (var i = 0; i < n_len; i++)
                {
                    if (i != 0)
                        await executor.Infer();

                    // Sample from the "unguided" conversation. This is just a conversation using the same prompt, without any
                    // guidance. This serves as a comparison to show the effect of guidance.
                    var u = unguidedSampler.Sample(executor.Context.NativeHandle, unguided.Sample(), Array.Empty<LLamaToken>());
                    unguidedDecoder.Add(u);
                    unguided.Prompt(u);

                    // Sample from the "guided" conversation. This sampler will internally use the "guidance" conversation
                    // to steer the conversation. See how this is done in GuidedSampler.ProcessLogits (bottom of this file).
                    var g = guidedSampler.Sample(executor.Context.NativeHandle, guided.Sample(), Array.Empty<LLamaToken>());
                    guidedDecoder.Add(g);

                    // Use this token to advance both guided _and_ guidance. Keeping them in sync (except for the initial prompt).
                    guided.Prompt(g);
                    guidance.Prompt(g);

                    // Early exit if we reach the natural end of the guided sentence
                    if (g == model.EndOfSentenceToken)
                        break;

                    // Update progress bar
                    reporter.Increment(1);
                }
            });

        AnsiConsole.MarkupLine($"[green]Unguided:[/][white]{unguidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
        AnsiConsole.MarkupLine($"[green]Guided:[/][white]{guidedDecoder.Read().ReplaceLineEndings(" ")}[/]");
    }

    private class GuidedSampler(Conversation? guidance, float weight)
        : BaseSamplingPipeline
    {
        public override void Accept(SafeLLamaContextHandle ctx, LLamaToken token)
        {
        }

        public override ISamplingPipeline Clone()
        {
            throw new NotSupportedException();
        }

        protected override void ProcessLogits(SafeLLamaContextHandle ctx, Span<float> logits, ReadOnlySpan<LLamaToken> lastTokens)
        {
            if (guidance == null)
                return;

            // Get the logits generated by the guidance sequences
            var guidanceLogits = guidance.Sample();

            // Use those logits to guide this sequence
            NativeApi.llama_sample_apply_guidance(ctx, logits, guidanceLogits, weight);
        }

        protected override LLamaToken ProcessTokenDataArray(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, ReadOnlySpan<LLamaToken> lastTokens)
        {
            candidates.Temperature(ctx, 0.8f);
            candidates.TopK(ctx, 25);

            return candidates.SampleToken(ctx);
        }
    }
}