1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144 | using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;
namespace LLama.Examples.Examples;
/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 72;
public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");
// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);
// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");
// Evaluate the initial prompt to create one conversation
using var start = executor.Create();
start.Prompt(prompt);
await executor.Infer();
// Create the root node of the tree
var root = new Node(start);
await AnsiConsole
.Progress()
.StartAsync(async progress =>
{
var reporter = progress.AddTask("Running Inference (1)", maxValue: n_len);
// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();
// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();
// Sample all active conversations
root.Sample();
// Update progress bar
reporter.Increment(1);
reporter.Description($"Running Inference ({root.ActiveConversationCount})");
}
// Display results
var display = new Tree(prompt);
root.Display(display);
AnsiConsole.Write(display);
});
}
private class Node
{
private readonly StreamingTokenDecoder _decoder;
private readonly DefaultSamplingPipeline _sampler;
private Conversation? _conversation;
private Node? _left;
private Node? _right;
public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;
public Node(Conversation conversation)
{
_sampler = new DefaultSamplingPipeline();
_conversation = conversation;
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
}
public void Sample()
{
if (_conversation == null)
{
_left?.Sample();
_right?.Sample();
return;
}
if (_conversation.RequiresInference)
return;
// Sample one token
var ctx = _conversation.Executor.Context.NativeHandle;
var token = _sampler.Sample(ctx, _conversation.Sample(), Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token);
_decoder.Add(token);
// Prompt the conversation with this token, to continue generating from there
_conversation.Prompt(token);
}
public void Split()
{
if (_conversation != null)
{
_left = new Node(_conversation.Fork());
_right = new Node(_conversation.Fork());
_conversation.Dispose();
_conversation = null;
}
else
{
_left?.Split();
_right?.Split();
}
}
public void Display<T>(T tree, int depth = 0)
where T : IHasTreeNodes
{
var colors = new[] { "red", "green", "blue", "yellow", "white" };
var color = colors[depth % colors.Length];
var message = Markup.Escape(_decoder.Read().ReplaceLineEndings(""));
var n = tree.AddNode($"[{color}]{message}[/]");
_left?.Display(n, depth + 1);
_right?.Display(n, depth + 1);
}
}
}
|