ref: fc1b813cf4527bc7f5c3eb237f5162f421b9677b
dir: /llama2.c/
/* Inference for Llama-2 Transformer model in pure C */
/* github.com/karpathy/llama2.c */
/* MIT License
*
* Copyright (c) 2023 Andrej
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#include <u.h>
#include <libc.h>
#include <ctype.h>
unsigned char quantized8 = 0;
int GS = 32;
// ----------------------------------------------------------------------------
// Transformer model
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;
#define SIZEOFCONFIG 24
typedef struct {
char* q; // quantized values
float* s; // scaling factors
} QuantizedTensor;
typedef struct {
// token embedding table
void *q_tokens; // (vocab_size, dim)
float* token_embedding_table; // (vocab_size, dim) // dequantized
// weights for rmsnorms
float* rms_att_weight; // (layer, dim) rmsnorm weights
float* rms_ffn_weight; // (layer, dim)
// weights for matmuls. note dim == n_heads * head_size
void* wq; // (layer, dim, n_heads * head_size)
void* wk; // (layer, dim, n_kv_heads * head_size)
void* wv; // (layer, dim, n_kv_heads * head_size)
void* wo; // (layer, n_heads * head_size, dim)
// weights for ffn
void* w1; // (layer, hidden_dim, dim)
void* w2; // (layer, dim, hidden_dim)
void* w3; // (layer, hidden_dim, dim)
// final rmsnorm
float* rms_final_weight; // (dim,)
// (optional) classifier weights for the logits, on the last layer
void* wcls;
} TransformerWeights;
#define SIZEOFTRANSFORMERWEIGHTS (12*sizeof(void*))
typedef struct {
// current wave of activations
float *x; // activation at current time stamp (dim,)
float *xb; // same, but inside a residual branch (dim,)
float *xb2; // an additional buffer just for convenience (dim,)
float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
QuantizedTensor *xq; // quantized x (dim,)
QuantizedTensor *hq; // quantized hb (hidden_dim,)
float *q; // query (dim,)
float *k; // key (dim,)
float *v; // value (dim,)
float *att; // buffer for scores/attention values (n_heads, seq_len)
float *logits; // output logits
// kv cache
float* key_cache; // (layer, seq_len, dim)
float* value_cache; // (layer, seq_len, dim)
} RunState;
#define SIZEOFRUNSTATE (12*sizeof(float*))
typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
long file_size; // size of the checkpoint file in bytes
} Transformer;
#define SIZEOFTRANSFORMER (SIZEOFCONFIG+SIZEOFTRANSFORMERWEIGHTS+SIZEOFRUNSTATE+4+sizeof(float*)+sizeof(ssize_t))
void malloc_run_state(RunState* s, Config* p) {
QuantizedTensor *xq;
QuantizedTensor *hq;
// we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
s->x = calloc(p->dim, sizeof(float));
s->xb = calloc(p->dim, sizeof(float));
s->xb2 = calloc(p->dim, sizeof(float));
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
xq = calloc(1, sizeof(QuantizedTensor));
hq = calloc(1, sizeof(QuantizedTensor));
xq->q = calloc(p->dim, sizeof(char));
xq->s = calloc(p->dim, sizeof(float));
hq->q = calloc(p->hidden_dim, sizeof(char));
hq->s = calloc(p->hidden_dim, sizeof(float));
s->xq = xq;
s->hq = hq;
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(kv_dim, sizeof(float));
s->v = calloc(kv_dim, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->xq || !s->hq || !s->hq->s || !s->hq->q
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
sysfatal("malloc failed!");
}
}
void free_run_state(RunState* s) {
free(s->x);
free(s->xb);
free(s->xb2);
free(s->hb);
free(s->hb2);
free(s->xq->q);
free(s->xq->s);
free(s->hq->q);
free(s->hq->s);
free(s->xq);
free(s->hq);
free(s->q);
// free(s->k);
// free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
free(s->value_cache);
}
float round(float in) {
float f;
f = fmod(in, 1.0);
if (in > 0) {
if (f < 0.5)
return floor(in);
return ceil(in);
}
if (f > -0.5)
return ceil(in);
return floor(in);
}
// ----------------------------------------------------------------------------
// Quantization functions
void dequantize(QuantizedTensor *qx, float* x, int n, int m) {
for (int j = 0; j < m; j++) {
for (int i = 0; i < n; i++) {
x[j * n + i] = qx->q[j * n + i] * qx->s[(j * n + i) / GS];
}
}
}
void quantize(QuantizedTensor *qx, float* x, int n) {
int num_groups = n / GS;
float Q_MAX = 127.0f;
for (int group = 0; group < num_groups; group++) {
// find the max absolute value in the current group
float wmax = 0.0;
for (int i = 0; i < GS; i++) {
float val = fabs(x[group * GS + i]);
if (val > wmax) {
wmax = val;
}
}
// calculate and write the scaling factor
float scale = wmax / Q_MAX;
qx->s[group] = scale;
// calculate and write the quantized values
for (int i = 0; i < GS; i++) {
float quant_value = x[group * GS + i] / scale; // scale
char quantized = (char) round(quant_value); // round and clamp
qx->q[group * GS + i] = quantized;
}
}
}
/* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */
QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) {
void *p = *ptr;
QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
for(int i=0; i<n; i++) {
/* map quantized int8 values*/
res[i].q = (char*)p;
p = (char*)p + size_each;
/* map scale factors */
res[i].s = (float*)p;
p = (float*)p + size_each / GS;
}
*ptr = p; // advance ptr to current position
return res;
}
void memory_map_weights_q8(TransformerWeights *w, Config* p, void* ptr, uchar shared_classifier) {
int head_size = p->dim / p->n_heads;
// first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
float* fptr = (float*) ptr; // cast our pointer to float*
w->rms_att_weight = fptr;
fptr += p->n_layers * p->dim;
w->rms_ffn_weight = fptr;
fptr += p->n_layers * p->dim;
w->rms_final_weight = fptr;
fptr += p->dim;
// now read all the quantized weights
ptr = (void*)fptr; // now cast the pointer back to void*
w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
// dequantize token embedding table
w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim, 1);
w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
}
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
unsigned long long n_layers = p->n_layers;
w->token_embedding_table = ptr;
ptr = &ptr[p->vocab_size * p->dim];
w->rms_att_weight = ptr;
ptr += n_layers * p->dim;
w->wq = ptr;
ptr += n_layers * p->dim * (p->n_heads * head_size);
w->wk = ptr;
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wv = ptr;
ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
w->wo = ptr;
ptr += n_layers * (p->n_heads * head_size) * p->dim;
w->rms_ffn_weight = ptr;
ptr += n_layers * p->dim;
w->w1 = ptr;
ptr += n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
w->wcls = shared_weights ? w->token_embedding_table : ptr;
}
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
int* fd, float** data, long* file_size) {
uvlong length;
uvlong offset;
int ret;
int fdt;
Dir *dstat;
unsigned int magic;
int header_size = 28;
uchar shared_classifier;
int group_size;
fdt = open(checkpoint, OREAD);
if (fdt < 3) sysfatal("couldn't open file %s", checkpoint);
if (read(fdt, &magic, 4) != 4) sysfatal("read magic");
if (magic == 0x616b3432) {
if (read(fdt, &magic, 4) != 4) sysfatal("read version");
if (magic != 2) sysfatal("version (quantized) is not 2");
quantized8 = 1;
header_size = 256;
if (read(fdt, &magic, 4) != 4) sysfatal("read dim");
}
// read in the config header
config->dim = magic;
if (read(fdt, &config->hidden_dim, 4) != 4) sysfatal("read hidden_dim");
if (read(fdt, &config->n_layers, 4) != 4) sysfatal("read n_layers");
if (read(fdt, &config->n_heads, 4) != 4) sysfatal("read n_heads");
if (read(fdt, &config->n_kv_heads, 4) != 4) sysfatal("read n_kv_heads");
if (read(fdt, &config->vocab_size, 4) != 4) sysfatal("read vocab_size");
if (read(fdt, &config->seq_len, 4) != 4) sysfatal("read seq_len");
if (quantized8 == 1) {
if (read(fdt, &shared_classifier, 1) != 1) sysfatal("read shared_classifier");
if (read(fdt, &group_size, 4) != 4) sysfatal("read group_size");
GS = group_size;
}
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config->vocab_size > 0 ? 1 : 0;
config->vocab_size = abs(config->vocab_size);
close(fdt);
dstat = dirstat(checkpoint);
*file_size = dstat->length;
free(dstat);
*fd = open(checkpoint, OREAD);
// memory map the Transformer weights into the data pointer
if (*fd < 3) sysfatal("open failed!");
*data = malloc(*file_size);
length = *file_size;
offset = 0;
while ((ret = read(*fd, (char*)(*data) + offset, length)) > 0) {
length -= ret;
offset += ret;
}
close(*fd);
*fd = open(checkpoint, OREAD);
float* weights_ptr = (float*)((char*)(*data) + header_size);
if (quantized8 == 0)
memory_map_weights(weights, config, weights_ptr, shared_weights);
else
memory_map_weights_q8(weights, config, weights_ptr, shared_classifier);
}
void build_transformer(Transformer *t, char* checkpoint_path) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
}
void free_transformer(Transformer* t) {
free(t->data);
if (t->fd != -1) { close(t->fd); }
// free the RunState buffers
free_run_state(&t->state);
}
// ----------------------------------------------------------------------------
// neural net blocks; the dynamics of the Transformer
void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares
float ss = 0.0f;
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
}
ss /= size;
ss += 1e-5f;
ss = 1.0f / sqrt(ss);
// normalize and scale
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
}
void softmax(float* x, int size) {
// find max value (for numerical stability)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// exp and sum
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = exp(x[i] - max_val);
sum += x[i];
}
// normalize
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
}
xout[i] = val;
}
}
void matmul_q8(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
// by far the most amount of time is spent inside this little function
// inputs to this function are both quantized
int i;
#pragma omp parallel for private(i)
for (i = 0; i < d; i++) {
float val = 0.0f;
int ival = 0;
int in = i * n;
// do the matmul in groups of GS
int j;
for (j = 0; j <= n - GS; j += GS) {
for (int k = 0; k < GS; k++) {
ival += ((int) x->q[j + k]) * ((int) w->q[in + j + k]);
}
val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
ival = 0;
}
xout[i] = val;
}
}
float* forward(Transformer* transformer, int token, int pos) {
// a few convenience variables
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// copy the token embedding into x
float* content_row = &w->token_embedding_table[token * dim];
memcpy(x, content_row, dim*sizeof(float));
// forward all the layers
for(unsigned long long l = 0; l < p->n_layers; l++) {
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// key and value point to the kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
s->k = s->key_cache + loff + pos * kv_dim;
s->v = s->value_cache + loff + pos * kv_dim;
// qkv matmuls for this position
matmul(s->q, s->xb, (float*)w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, (float*)w->wk + l*dim*kv_dim, dim, kv_dim);
matmul(s->v, s->xb, (float*)w->wv + l*dim*kv_dim, dim, kv_dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cos(val);
float fci = sin(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
for (h = 0; h < p->n_heads; h++) {
// get the query vector for this head
float* q = s->q + h * head_size;
// attention scores for this head
float* att = s->att + h * p->seq_len;
// iterate over all timesteps, including the current one
for (int t = 0; t <= pos; t++) {
// get the key vector for this head and at this timestep
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// calculate the attention score as the dot product of q and k
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
score /= sqrt(head_size);
// save the score to the attention buffer
att[t] = score;
}
// softmax the scores to get attention weights, from 0..pos inclusively
softmax(att, pos + 1);
// weighted sum of the values, store back into xb
float* xb = s->xb + h * head_size;
memset(xb, 0, head_size * 4);
for (int t = 0; t <= pos; t++) {
// get the value vector for this head and at this timestep
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// get the attention weight for this timestep
float a = att[t];
// accumulate the weighted value into xb
for (int i = 0; i < head_size; i++) {
xb[i] += a * v[i];
}
}
}
// final matmul to get the output of the attention
matmul(s->xb2, s->xb, (float*)w->wo + l*dim*dim, dim, dim);
// residual connection back into x
for (int i = 0; i < dim; i++) {
x[i] += s->xb2[i];
}
// ffn rmsnorm
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
matmul(s->hb, s->xb, (float*)w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, (float*)w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// SwiGLU non-linearity
for (int i = 0; i < hidden_dim; i++) {
float val = s->hb[i];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
val *= (1.0f / (1.0f + exp(-val)));
// elementwise multiply with w3(x)
val *= s->hb2[i];
s->hb[i] = val;
}
// final matmul to get the output of the ffn
matmul(s->xb, s->hb, (float*)w->w2 + l*dim*hidden_dim, hidden_dim, dim);
// residual connection
for (int i = 0; i < dim; i++) {
x[i] += s->xb[i];
}
}
// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
return s->logits;
}
float* forward_q8(Transformer* transformer, int token, int pos) {
// a few convenience variables
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// copy the token embedding into x
memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
// forward all the layers
for(int l = 0; l < p->n_layers; l++) {
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// qkv matmuls for this position
quantize(s->xq, s->xb, dim);
matmul_q8(s->q, s->xq, (QuantizedTensor*)w->wq + l, dim, dim);
matmul_q8(s->k, s->xq, (QuantizedTensor*)w->wk + l, dim, kv_dim);
matmul_q8(s->v, s->xq, (QuantizedTensor*)w->wv + l, dim, kv_dim);
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / pow(10000.0f, head_dim / (float)head_size);
float val = pos * freq;
float fcr = cos(val);
float fci = sin(val);
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
for (int v = 0; v < rotn; v++) {
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
float v0 = vec[i];
float v1 = vec[i+1];
vec[i] = v0 * fcr - v1 * fci;
vec[i+1] = v0 * fci + v1 * fcr;
}
}
// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
for (h = 0; h < p->n_heads; h++) {
// get the query vector for this head
float* q = s->q + h * head_size;
// attention scores for this head
float* att = s->att + h * p->seq_len;
// iterate over all timesteps, including the current one
for (int t = 0; t <= pos; t++) {
// get the key vector for this head and at this timestep
float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// calculate the attention score as the dot product of q and k
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
score /= sqrt(head_size);
// save the score to the attention buffer
att[t] = score;
}
// softmax the scores to get attention weights, from 0..pos inclusively
softmax(att, pos + 1);
// weighted sum of the values, store back into xb
float* xb = s->xb + h * head_size;
memset(xb, 0, head_size * sizeof(float));
for (int t = 0; t <= pos; t++) {
// get the value vector for this head and at this timestep
float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
// get the attention weight for this timestep
float a = att[t];
// accumulate the weighted value into xb
for (int i = 0; i < head_size; i++) {
xb[i] += a * v[i];
}
}
}
// final matmul to get the output of the attention
quantize(s->xq, s->xb, dim);
matmul_q8(s->xb2, s->xq, (QuantizedTensor*)w->wo + l, dim, dim);
// residual connection back into x
for (int i = 0; i < dim; i++) {
x[i] += s->xb2[i];
}
// ffn rmsnorm
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
quantize(s->xq, s->xb, dim);
matmul_q8(s->hb, s->xq, (QuantizedTensor*)w->w1 + l, dim, hidden_dim);
matmul_q8(s->hb2, s->xq, (QuantizedTensor*)w->w3 + l, dim, hidden_dim);
// SwiGLU non-linearity
for (int i = 0; i < hidden_dim; i++) {
float val = s->hb[i];
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
val *= (1.0f / (1.0f + exp(-val)));
// elementwise multiply with w3(x)
val *= s->hb2[i];
s->hb[i] = val;
}
// final matmul to get the output of the ffn
quantize(s->hq, s->hb, hidden_dim);
matmul_q8(s->xb, s->hq, (QuantizedTensor*)w->w2 + l, hidden_dim, dim);
// residual connection
for (int i = 0; i < dim; i++) {
x[i] += s->xb[i];
}
}
// final rmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);
// classifier into logits
quantize(s->xq, x, dim);
matmul_q8(s->logits, s->xq, w->wcls, dim, p->vocab_size);
return s->logits;
}
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
typedef struct {
char *str;
int id;
} TokenIndex;
#define SIZEOFTOKENINDEX (sizeof(char*)+4)
typedef struct {
char** vocab;
float* vocab_scores;
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;
int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}
void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->sorted_vocab = nil; // initialized lazily
for (int i = 0; i < 256; i++) {
t->byte_pieces[i * 2] = (unsigned char)i;
t->byte_pieces[i * 2 + 1] = '\0';
}
// read in the file
int fd = open(tokenizer_path, OREAD);
if (fd < 3) sysfatal("couldn't load %s", tokenizer_path);
if (read(fd, &t->max_token_length, sizeof(int)) != sizeof(int)) sysfatal("failed read");
int len;
for (int i = 0; i < vocab_size; i++) {
if (read(fd, t->vocab_scores + i, sizeof(float)) != sizeof(float)) sysfatal("failed read");
if (read(fd, &len, sizeof(int)) != sizeof(int)) sysfatal("failed read");
t->vocab[i] = (char*)malloc(len + 1);
if (read(fd, t->vocab[i], len) != len) sysfatal("failed read");
t->vocab[i][len] = '\0'; // add the string terminating token
}
close(fd);
}
void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
free(t->vocab);
free(t->vocab_scores);
free(t->sorted_vocab);
}
char* decode(Tokenizer* t, int prev_token, int token) {
char *piece = t->vocab[token];
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; }
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
if (strncmp(piece, "<0x", 3) == 0 || piece[strlen(piece)-1] == '>') {
byte_val = strtol(piece + 1, nil, 16);
piece = (char*)t->byte_pieces + byte_val * 2;
}
return piece;
}
void safe_print(char *piece) {
// piece might be a raw byte token, and we only want to print printable chars or whitespace
// because some of the other bytes can be various control codes, backspace, etc.
if (piece == nil) { return; }
if (piece[0] == '\0') { return; }
if (piece[1] == '\0') {
unsigned char byte_val = piece[0];
if (!(isprint(byte_val) || isspace(byte_val))) {
return; // bad byte, don't print it
}
}
print("%s", piece);
}
void *bsearch(void *key, void *base, int num, int size, int (*comp)(void *a, void *b)) {
int A = 0;
int B = num;
int middle = A + ((B - A) / 2);
int result;
if (A == B || A == middle)
return nil;
do {
result = comp(key, (void*)((uchar*)base + (middle * size)));
if (result > 0)
A = middle;
else if (result < 0)
B = middle;
else
return (void*)((uchar*)base + (middle * size));
middle = A + ((B - A) / 2);
} while(A != B && A != middle);
return nil;
}
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, SIZEOFTOKENINDEX, compare_tokens);
return res != nil ? res->id : -1;
}
void encode(Tokenizer* t, char *text, char bos, char eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
if (text == nil) sysfatal("cannot encode NULL text");
if (t->sorted_vocab == nil) {
// lazily malloc and sort the vocabulary
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
t->sorted_vocab[i].str = t->vocab[i];
t->sorted_vocab[i].id = i;
}
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
}
// create a temporary buffer that will store merge candidates of always two consecutive tokens
// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
char* str_buffer = malloc((t->max_token_length*2 +1 +2));
long str_len = 0;
// start at 0 tokens
*n_tokens = 0;
// add optional BOS (=1) token, if desired
if (bos) tokens[(*n_tokens)++] = 1;
// add_dummy_prefix is true by default
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;
}
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
// U+0000 U+007F 0xxxxxxx
// U+0080 U+07FF 110xxxxx 10xxxxxx
// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
// process the raw (UTF-8) byte sequence of the input string
for (char *c = text; *c != '\0'; c++) {
// reset buffer if the current byte is ASCII or a leading byte
// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
// 0x80 is 10000000
// in UTF-8, all continuation bytes start with "10" in first two bits
// so in English this is: "if this byte is not a continuation byte"
if ((*c & 0xC0) != 0x80) {
// this byte must be either a leading byte (11...) or an ASCII char (0x...)
// => reset our location, as we're starting a new UTF-8 codepoint
str_len = 0;
}
// append the current byte to the buffer
str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
str_buffer[str_len] = '\0';
// while the next character is a continuation byte, continue appending
// but if there are too many of them, just stop to avoid overruning str_buffer size.
if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
continue;
}
// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1) {
// we found this codepoint in vocab, add it as a token
tokens[(*n_tokens)++] = id;
} else {
// byte_fallback encoding: just encode each byte as a token
// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
// so the individual bytes only start at index 3
for (int i=0; i < str_len; i++) {
tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
}
}
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
}
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprint(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
tokens[i] = tokens[i+1];
}
(*n_tokens)--; // token length decreased
}
// add optional EOS (=2) token, if desired
if (eos) tokens[(*n_tokens)++] = 2;
free(str_buffer);
}
// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
typedef struct {
float prob;
int index;
} ProbIndex; // struct used when sorting probabilities during top-p sampling
#define SIZEOFPROBINDEX 8
typedef struct {
int vocab_size;
ProbIndex* probindex; // buffer used in top-p sampling
float temperature;
float topp;
unsigned long long rng_state;
} Sampler;
int sample_argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
for (int i = 1; i < n; i++) {
if (probabilities[i] > max_p) {
max_i = i;
max_p = probabilities[i];
}
}
return max_i;
}
int sample_mult(float* probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (coin < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
int compare(const void* a, const void* b) {
ProbIndex* a_ = (ProbIndex*) a;
ProbIndex* b_ = (ProbIndex*) b;
if (a_->prob > b_->prob) return -1;
if (a_->prob < b_->prob) return 1;
return 0;
}
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
// coin is a random number in [0, 1), usually from random_f32()
int n0 = 0;
// quicksort indices in descending order of probabilities
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
if (probabilities[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probabilities[i];
n0++;
}
}
qsort(probindex, n0, 8, compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = n0 - 1; // in case of rounding errors consider all elements
for (int i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;
break; // we've exceeded topp by including last_idx
}
}
// sample from the truncated list
float r = coin * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[last_idx].index; // in case of rounding errors
}
void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
sampler->vocab_size = vocab_size;
sampler->temperature = temperature;
sampler->topp = topp;
sampler->rng_state = rng_seed;
// buffer only used with nucleus sampling; may not need but it's ~small
sampler->probindex = malloc(sampler->vocab_size * SIZEOFPROBINDEX);
}
void free_sampler(Sampler* sampler) {
free(sampler->probindex);
}
unsigned int random_u32(unsigned long long *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long *state) { // random float32 in [0,1)
return (random_u32(state) >> 8) / 16777216.0f;
}
int sample(Sampler* sampler, float* logits) {
// sample the token given the logits and some hyperparameters
int next;
if (sampler->temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, sampler->vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(logits, sampler->vocab_size);
// flip a (float) coin (this is our source of entropy for sampling)
float coin = random_f32(&sampler->rng_state);
// we sample from this distribution to get the next token
if (sampler->topp <= 0 || sampler->topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, sampler->vocab_size, coin);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
}
}
return next;
}
// ----------------------------------------------------------------------------
// utilities: time
/*long time_in_ms() {
// return time in milliseconds, for benchmarking the model speed
struct timespec time;
clock_gettime(CLOCK_REALTIME, &time);
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
} */
#define time_in_ms() (nsec()/1000000)
// ----------------------------------------------------------------------------
// generation loop
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
char *empty_prompt = "";
if (prompt == nil) { prompt = empty_prompt; }
// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * 4); // +3 for '\0', ?BOS, ?EOS
encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
if (num_prompt_tokens < 1) {
sysfatal("something is wrong, expected at least 1 prompt token");
}
// start the main loop
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
float *logits;
// forward the transformer to get logits for the next token
if (quantized8 == 0)
logits = forward(transformer, token, pos);
else
logits = forward_q8(transformer, token, pos);
// advance the state machine
if (pos < (num_prompt_tokens - 1)) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(sampler, logits);
}
pos++;
// data-dependent terminating condition: the BOS (=1) token delimits sequences
if (next == 1) { break; }
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
token = next;
// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
print("\n");
// report achieved tok/s (pos-1 because the timer starts after first iteration)
// if (pos > 1) {
// long end = time_in_ms();
// fprint(2, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
// }
free(prompt_tokens);
}
void read_stdin(const char* guide, char* buffer, long bufsize) {
int r;
// read a line from stdin, up to but not including \n
print("%s", guide);
buffer[0] = '\0';
if ((r = read(0, buffer, bufsize)) > 0) {
buffer[r] = '\0';
buffer[strcspn(buffer, "\r\n")] = '\0'; // strip newline
}
}
// ----------------------------------------------------------------------------
// chat loop
// I manually inspected the tokens for a few chat conversations compared to
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
char *cli_user_prompt, char *cli_system_prompt, int steps) {
// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[1152];
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
int user_idx = 0;
// start the main loop
char user_turn = 1; // user starts
int next = user_turn; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int pos = 0; // position in the sequence
while (pos < steps) {
// when it is the user's turn to contribute tokens to the dialog...
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
// at position 0, the user can also contribute a system prompt
if (cli_system_prompt == nil) {
// system prompt was not passed in, attempt to get it from stdin
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
} else {
// system prompt was passed in, use it
strcpy(system_prompt, cli_system_prompt);
}
}
// get the user prompt
if (pos == 0 && cli_user_prompt != nil) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_user_prompt);
} else {
// otherwise get user prompt from stdin
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
sprint(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
char user_template[] = "[INST] %s [/INST]";
sprint(rendered_prompt, user_template, user_prompt);
}
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
user_turn = 0;
print("Assistant: ");
}
// determine the token to pass into the transformer next
if (user_idx < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
token = prompt_tokens[user_idx++];
} else {
// otherwise use the next token sampled from previous turn
token = next;
}
// EOS (=2) token ends the Assistant turn
if (token == 2) { user_turn = 1; }
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
next = sample(sampler, logits);
pos++;
if (user_idx >= num_prompt_tokens && next != 2) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
safe_print(piece); // same as printf("%s", piece), but skips "unsafe" bytes
}
if (next == 2) { print("\n"); }
}
print("\n");
free(prompt_tokens);
}
// ----------------------------------------------------------------------------
// CLI, include only if not testing
void error_usage(void) {
fprint(2, "Usage: run <checkpoint> [options]\n");
fprint(2, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
fprint(2, "Options:\n");
fprint(2, " -t <float> temperature in [0,inf], default 1.0\n");
fprint(2, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
fprint(2, " -s <int> random seed, default time(nil)\n");
fprint(2, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprint(2, " -i <string> input prompt\n");
fprint(2, " -z <string> optional path to custom tokenizer\n");
fprint(2, " -m <string> mode: generate|chat, default: generate\n");
fprint(2, " -y <string> (optional) system prompt in chat mode\n");
exits("usage");
}
void main(int argc, char *argv[]) {
// default parameters
char *checkpoint_path = nil; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = nil; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
char *mode = "generate"; // generate|chat
char *system_prompt = nil; // the (optional) system prompt to use in chat mode
// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
for (int i = 2; i < argc; i+=2) {
// do some basic validation
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
else { error_usage(); }
}
// parameter validation/overrides
if (rng_seed <= 0) rng_seed = (unsigned int)time(nil);
if (temperature < 0.0) temperature = 0.0;
if (topp < 0.0 || 1.0 < topp) topp = 0.9;
if (steps < 0) steps = 0;
// build the Transformer via the model .bin file
Transformer transformer;
build_transformer(&transformer, checkpoint_path);
if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length
// build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer;
build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
// build the Sampler
Sampler sampler;
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
// run!
if (strcmp(mode, "generate") == 0) {
generate(&transformer, &tokenizer, &sampler, prompt, steps);
} else if (strcmp(mode, "chat") == 0) {
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
} else {
fprint(2, "unknown mode: %s\n", mode);
error_usage();
}
// memory and file handles cleanup
free_sampler(&sampler);
free_tokenizer(&tokenizer);
free_transformer(&transformer);
exits(nil);
}