ggerganov commited on
Commit
7e268a7
·
1 Parent(s): ddc04a3

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama.h CHANGED
@@ -81,9 +81,12 @@ extern "C" {
81
  LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
  LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
  LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
- LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10,
85
- LLAMA_VOCAB_PRE_TYPE_OLMO = 11,
86
- LLAMA_VOCAB_PRE_TYPE_DBRX = 12,
 
 
 
87
  };
88
 
89
  // note: these values should be synchronized with ggml_rope
@@ -95,7 +98,7 @@ extern "C" {
95
  LLAMA_ROPE_TYPE_GLM = 4,
96
  };
97
 
98
- enum llama_token_type {
99
  LLAMA_TOKEN_TYPE_UNDEFINED = 0,
100
  LLAMA_TOKEN_TYPE_NORMAL = 1,
101
  LLAMA_TOKEN_TYPE_UNKNOWN = 2,
@@ -105,6 +108,20 @@ extern "C" {
105
  LLAMA_TOKEN_TYPE_BYTE = 6,
106
  };
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  // model file types
109
  enum llama_ftype {
110
  LLAMA_FTYPE_ALL_F32 = 0,
@@ -242,6 +259,9 @@ extern "C" {
242
  // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
243
  const float * tensor_split;
244
 
 
 
 
245
  // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
246
  // If the provided progress_callback returns true, model loading continues.
247
  // If it returns false, model loading is immediately aborted.
@@ -260,6 +280,8 @@ extern "C" {
260
  bool check_tensors; // validate model tensor data
261
  };
262
 
 
 
263
  struct llama_context_params {
264
  uint32_t seed; // RNG seed, -1 for random
265
  uint32_t n_ctx; // text context, 0 = from model
@@ -286,14 +308,14 @@ extern "C" {
286
  ggml_backend_sched_eval_callback cb_eval;
287
  void * cb_eval_user_data;
288
 
289
- enum ggml_type type_k; // data type for K cache
290
- enum ggml_type type_v; // data type for V cache
291
 
292
  // Keep the booleans together to avoid misalignment during copy-by-value.
293
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
294
  bool embeddings; // if true, extract embeddings (together with logits)
295
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
296
- bool flash_attn; // whether to use flash attention
297
 
298
  // Abort callback
299
  // if it returns true, execution of llama_decode() will be aborted
@@ -344,6 +366,9 @@ extern "C" {
344
  // modifies a preceding LLAMA_GRETYPE_CHAR or
345
  // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
346
  LLAMA_GRETYPE_CHAR_ALT = 6,
 
 
 
347
  };
348
 
349
  typedef struct llama_grammar_element {
@@ -417,8 +442,8 @@ extern "C" {
417
 
418
  LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
419
 
420
- LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
421
- LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
422
 
423
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
424
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -755,6 +780,12 @@ extern "C" {
755
  // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
756
  LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
757
 
 
 
 
 
 
 
758
  // Set whether to use causal attention or not
759
  // If set to true, the model will only attend to the past tokens
760
  LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@@ -808,11 +839,14 @@ extern "C" {
808
 
809
  LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
810
 
811
- LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
812
 
813
  // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
814
  LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
815
 
 
 
 
816
  // Special tokens
817
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
818
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
@@ -1026,49 +1060,9 @@ extern "C" {
1026
  llama_token token);
1027
 
1028
  //
1029
- // Beam search
1030
  //
1031
 
1032
- struct llama_beam_view {
1033
- const llama_token * tokens;
1034
-
1035
- size_t n_tokens;
1036
- float p; // Cumulative beam probability (renormalized relative to all beams)
1037
- bool eob; // Callback should set this to true when a beam is at end-of-beam.
1038
- };
1039
-
1040
- // Passed to beam_search_callback function.
1041
- // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
1042
- // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
1043
- // These pointers are valid only during the synchronous callback, so should not be saved.
1044
- struct llama_beams_state {
1045
- struct llama_beam_view * beam_views;
1046
-
1047
- size_t n_beams; // Number of elements in beam_views[].
1048
- size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
1049
- bool last_call; // True iff this is the last callback invocation.
1050
- };
1051
-
1052
- // Type of pointer to the beam_search_callback function.
1053
- // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
1054
- // passed back to beam_search_callback. This avoids having to use global variables in the callback.
1055
- typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
1056
-
1057
- /// @details Deterministically returns entire sentence constructed by a beam search.
1058
- /// @param ctx Pointer to the llama_context.
1059
- /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
1060
- /// @param callback_data A pointer that is simply passed back to callback.
1061
- /// @param n_beams Number of beams to use.
1062
- /// @param n_past Number of tokens already evaluated.
1063
- /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
1064
- LLAMA_API void llama_beam_search(
1065
- struct llama_context * ctx,
1066
- llama_beam_search_callback_fn_t callback,
1067
- void * callback_data,
1068
- size_t n_beams,
1069
- int32_t n_past,
1070
- int32_t n_predict);
1071
-
1072
  /// @details Build a split GGUF final path for this chunk.
1073
  /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
1074
  // Returns the split_path length.
 
81
  LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
82
  LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
83
  LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
84
+ LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
85
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
86
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
87
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
88
+ LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
89
+ LLAMA_VOCAB_PRE_TYPE_PORO = 15,
90
  };
91
 
92
  // note: these values should be synchronized with ggml_rope
 
98
  LLAMA_ROPE_TYPE_GLM = 4,
99
  };
100
 
101
+ enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
102
  LLAMA_TOKEN_TYPE_UNDEFINED = 0,
103
  LLAMA_TOKEN_TYPE_NORMAL = 1,
104
  LLAMA_TOKEN_TYPE_UNKNOWN = 2,
 
108
  LLAMA_TOKEN_TYPE_BYTE = 6,
109
  };
110
 
111
+ enum llama_token_attr {
112
+ LLAMA_TOKEN_ATTR_UNDEFINED = 0,
113
+ LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 0,
114
+ LLAMA_TOKEN_ATTR_UNUSED = 1 << 1,
115
+ LLAMA_TOKEN_ATTR_NORMAL = 1 << 2,
116
+ LLAMA_TOKEN_ATTR_CONTROL = 1 << 3, // SPECIAL?
117
+ LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4,
118
+ LLAMA_TOKEN_ATTR_BYTE = 1 << 5,
119
+ LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 6,
120
+ LLAMA_TOKEN_ATTR_LSTRIP = 1 << 7,
121
+ LLAMA_TOKEN_ATTR_RSTRIP = 1 << 8,
122
+ LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 9,
123
+ };
124
+
125
  // model file types
126
  enum llama_ftype {
127
  LLAMA_FTYPE_ALL_F32 = 0,
 
259
  // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
260
  const float * tensor_split;
261
 
262
+ // comma separated list of RPC servers to use for offloading
263
+ const char * rpc_servers;
264
+
265
  // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
266
  // If the provided progress_callback returns true, model loading continues.
267
  // If it returns false, model loading is immediately aborted.
 
280
  bool check_tensors; // validate model tensor data
281
  };
282
 
283
+ // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
284
+ // https://github.com/ggerganov/llama.cpp/pull/7544
285
  struct llama_context_params {
286
  uint32_t seed; // RNG seed, -1 for random
287
  uint32_t n_ctx; // text context, 0 = from model
 
308
  ggml_backend_sched_eval_callback cb_eval;
309
  void * cb_eval_user_data;
310
 
311
+ enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
312
+ enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
313
 
314
  // Keep the booleans together to avoid misalignment during copy-by-value.
315
  bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
316
  bool embeddings; // if true, extract embeddings (together with logits)
317
  bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
318
+ bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
319
 
320
  // Abort callback
321
  // if it returns true, execution of llama_decode() will be aborted
 
366
  // modifies a preceding LLAMA_GRETYPE_CHAR or
367
  // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
368
  LLAMA_GRETYPE_CHAR_ALT = 6,
369
+
370
+ // any character (.)
371
+ LLAMA_GRETYPE_CHAR_ANY = 7,
372
  };
373
 
374
  typedef struct llama_grammar_element {
 
442
 
443
  LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
444
 
445
+ LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
446
+ LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
447
 
448
  LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
449
  LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
 
780
  // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
781
  LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
782
 
783
+ // Get the number of threads used for generation of a single token.
784
+ LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx);
785
+
786
+ // Get the number of threads used for prompt and batch processing (multiple token).
787
+ LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
788
+
789
  // Set whether to use causal attention or not
790
  // If set to true, the model will only attend to the past tokens
791
  LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
 
839
 
840
  LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
841
 
842
+ LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
843
 
844
  // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
845
  LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
846
 
847
+ // Identify if Token Id is a control token or a render-able token
848
+ LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
849
+
850
  // Special tokens
851
  LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
852
  LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
 
1060
  llama_token token);
1061
 
1062
  //
1063
+ // Model split
1064
  //
1065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1066
  /// @details Build a split GGUF final path for this chunk.
1067
  /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf"
1068
  // Returns the split_path length.
examples/talk-llama/unicode-data.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/unicode-data.h CHANGED
@@ -1,17 +1,20 @@
1
  #pragma once
2
 
3
  #include <cstdint>
4
- #include <map>
5
- #include <utility>
6
  #include <vector>
 
 
7
 
8
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
9
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
10
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
11
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
12
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
13
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
14
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
15
- extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
16
- extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
17
- extern const std::map<char32_t, char32_t> unicode_map_lowercase;
 
 
 
 
1
  #pragma once
2
 
3
  #include <cstdint>
 
 
4
  #include <vector>
5
+ #include <unordered_map>
6
+ #include <unordered_set>
7
 
8
+ struct range_nfd {
9
+ uint32_t first;
10
+ uint32_t last;
11
+ uint32_t nfd;
12
+ };
13
+
14
+ static const uint32_t MAX_CODEPOINTS = 0x110000;
15
+
16
+ extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
17
+ extern const std::unordered_set<uint32_t> unicode_set_whitespace;
18
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
19
+ extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
20
+ extern const std::vector<range_nfd> unicode_ranges_nfd;
examples/talk-llama/unicode.cpp CHANGED
@@ -1,4 +1,4 @@
1
- #include "unicode.h"
2
  #include "unicode-data.h"
3
 
4
  #include <cassert>
@@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
109
  // return result;
110
  //}
111
 
112
- static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
113
- std::unordered_map<uint32_t, int> cpt_types;
114
- for (auto p : unicode_ranges_number) {
115
- for (auto i = p.first; i <= p.second; ++i) {
116
- cpt_types[i] = CODEPOINT_TYPE_NUMBER;
117
- }
118
- }
119
- for (auto p : unicode_ranges_letter) {
120
- for (auto i = p.first; i <= p.second; ++i) {
121
- cpt_types[i] = CODEPOINT_TYPE_LETTER;
122
- }
123
- }
124
- for (auto p : unicode_ranges_separator) {
125
- for (auto i = p.first; i <= p.second; ++i) {
126
- cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
127
  }
128
  }
129
- for (auto p : unicode_ranges_accent_mark) {
130
- for (auto i = p.first; i <= p.second; ++i) {
131
- cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
132
- }
133
  }
134
- for (auto p : unicode_ranges_punctuation) {
135
- for (auto i = p.first; i <= p.second; ++i) {
136
- cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
137
- }
138
  }
139
- for (auto p : unicode_ranges_symbol) {
140
- for (auto i = p.first; i <= p.second; ++i) {
141
- cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
142
- }
143
  }
144
- for (auto p : unicode_ranges_control) {
145
- for (auto i = p.first; i <= p.second; ++i) {
146
- cpt_types[i] = CODEPOINT_TYPE_CONTROL;
147
- }
148
  }
149
- return cpt_types;
 
150
  }
151
 
152
  static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
153
  std::unordered_map<uint8_t, std::string> map;
154
- for (int ch = u'!'; ch <= u'~'; ++ch) {
155
  assert(0 <= ch && ch < 256);
156
  map[ch] = unicode_cpt_to_utf8(ch);
157
  }
158
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
159
  assert(0 <= ch && ch < 256);
160
  map[ch] = unicode_cpt_to_utf8(ch);
161
  }
162
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
163
  assert(0 <= ch && ch < 256);
164
  map[ch] = unicode_cpt_to_utf8(ch);
165
  }
@@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
175
 
176
  static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
177
  std::unordered_map<std::string, uint8_t> map;
178
- for (int ch = u'!'; ch <= u'~'; ++ch) {
179
  assert(0 <= ch && ch < 256);
180
  map[unicode_cpt_to_utf8(ch)] = ch;
181
  }
182
- for (int ch = u'¡'; ch <= u'¬'; ++ch) {
183
  assert(0 <= ch && ch < 256);
184
  map[unicode_cpt_to_utf8(ch)] = ch;
185
  }
186
- for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
187
  assert(0 <= ch && ch < 256);
188
  map[unicode_cpt_to_utf8(ch)] = ch;
189
  }
@@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
238
  return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
239
  };
240
 
241
- auto _get_cpt_type = [&] (const size_t pos) -> int {
242
- return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
 
243
  };
244
 
245
  size_t _prev_end = offset_ini;
@@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
261
 
262
  for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
263
  const char32_t cpt = _get_cpt(pos);
264
- const int cpt_type = _get_cpt_type(pos);
265
 
266
  // regex: 's|'t|'re|'ve|'m|'ll|'d
267
  if (cpt == '\'' && pos+1 < offset_end) {
@@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
281
  }
282
  }
283
 
284
- char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
285
- int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
286
  // regex: <space>?\p{L}+
287
- if (cpt2_type == CODEPOINT_TYPE_LETTER) {
288
  pos += (cpt == ' ');
289
- while (cpt2_type == CODEPOINT_TYPE_LETTER) {
290
- cpt2_type = _get_cpt_type(++pos);
291
  }
292
  _add_token(pos);
293
  continue;
294
  }
295
  // regex: <space>?\p{N}+
296
- if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
297
  pos += (cpt == ' ');
298
- while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
299
- cpt2_type = _get_cpt_type(++pos);
300
  }
301
  _add_token(pos);
302
  continue;
303
  }
304
  // regex: <space>?[^\s\p{L}\p{N}]+
305
- if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
306
  pos += (cpt == ' ');
307
- while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
308
- cpt2_type = _get_cpt_type(++pos);
309
- cpt2 = _get_cpt(pos);
310
  }
311
  _add_token(pos);
312
  continue;
313
  }
314
 
315
  size_t num_whitespaces = 0;
316
- while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
317
  num_whitespaces++;
318
  }
319
 
@@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
357
  return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
358
  };
359
 
360
- auto _get_cpt_type = [&] (const size_t pos) -> int {
361
- return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
 
362
  };
363
 
364
  size_t _prev_end = offset_ini;
@@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
380
 
381
  for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
382
  const char32_t cpt = _get_cpt(pos);
383
- const int cpt_type = _get_cpt_type(pos);
384
 
385
  // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
386
  if (cpt == '\'' && pos+1 < offset_end) {
@@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
401
  }
402
 
403
  // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
404
- if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
405
- if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
406
  pos++;
407
- while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
408
  pos++;
409
  }
410
  _add_token(pos);
@@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
413
  }
414
 
415
  // regex: \p{N}{1,3}
416
- if (cpt_type == CODEPOINT_TYPE_NUMBER) {
417
  size_t ini = pos;
418
- while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
419
  if (++pos - ini >= 3 ) {
420
  _add_token(pos);
421
  ini = pos;
@@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
426
  }
427
 
428
  // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
429
- char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
430
- int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
431
- if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
432
  pos += (cpt == ' ');
433
- while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
434
- cpt2_type = _get_cpt_type(++pos);
435
- cpt2 = _get_cpt(pos);
436
  }
 
437
  while (cpt2 == '\r' || cpt2 == '\n') {
438
  cpt2 = _get_cpt(++pos);
439
  }
@@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
443
 
444
  size_t num_whitespaces = 0;
445
  size_t last_end_r_or_n = 0;
446
- while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
447
  char32_t cpt2 = _get_cpt(pos+num_whitespaces);
448
  if (cpt2 == '\r' || cpt2 == '\n') {
449
  last_end_r_or_n = pos + num_whitespaces + 1;
@@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
589
  }
590
 
591
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
592
- std::vector<uint32_t> result;
593
- result.reserve(cpts.size());
 
 
594
  for (size_t i = 0; i < cpts.size(); ++i) {
595
- auto it = unicode_map_nfd.find(cpts[i]);
596
- if (it == unicode_map_nfd.end()) {
597
- result.push_back(cpts[i]);
598
- } else {
599
- result.push_back(it->second);
600
- }
601
  }
602
  return result;
603
  }
@@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
611
  return result;
612
  }
613
 
614
- int unicode_cpt_type(uint32_t cp) {
615
- static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map();
616
- const auto it = cpt_types.find(cp);
617
- return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second;
618
  }
619
 
620
- int unicode_cpt_type(const std::string & utf8) {
621
- if (utf8.length() == 0) {
622
- return CODEPOINT_TYPE_UNIDENTIFIED;
 
623
  }
624
  size_t offset = 0;
625
- return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
626
- }
627
-
628
- bool unicode_cpt_is_whitespace(uint32_t cp) {
629
- static const std::unordered_set<uint32_t> is_whitespace = [] {
630
- std::unordered_set<uint32_t> is_whitespace;
631
- for (auto p : unicode_ranges_whitespace) {
632
- for (auto i = p.first; i <= p.second; ++i) {
633
- is_whitespace.insert(i);
634
- }
635
- }
636
- return is_whitespace;
637
- }();
638
- return (bool)is_whitespace.count(cp);
639
  }
640
 
641
  std::string unicode_byte_to_utf8(uint8_t byte) {
@@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
656
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
657
  // unicode categories
658
  static const std::map<std::string, int> k_ucat_enum = {
659
- { "\\p{N}", CODEPOINT_TYPE_NUMBER },
660
- { "\\p{L}", CODEPOINT_TYPE_LETTER },
661
- { "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
662
  };
663
 
664
  static const std::map<int, int> k_ucat_cpt = {
665
- { CODEPOINT_TYPE_NUMBER, 0xD1 },
666
- { CODEPOINT_TYPE_LETTER, 0xD2 },
667
- { CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
668
  };
669
 
670
  static const std::map<int, std::string> k_ucat_map = {
671
- { CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
672
- { CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
673
- { CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
674
  };
675
 
676
  // compute collapsed codepoints only if needed by at least one regex
@@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
701
  continue;
702
  }
703
 
704
- const int cpt_type = unicode_cpt_type(cpts[i]);
705
 
706
- if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
707
- text_collapsed[i] = k_ucat_cpt.at(cpt_type);
708
  } else {
709
  text_collapsed[i] = (char) 0xD0; // fallback
710
  }
 
1
+ #include "unicode.h"
2
  #include "unicode-data.h"
3
 
4
  #include <cassert>
 
109
  // return result;
110
  //}
111
 
112
+ static std::vector<codepoint_flags> unicode_cpt_flags_array() {
113
+ std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
114
+
115
+ assert (unicode_ranges_flags.front().first == 0);
116
+ assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
117
+ for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
118
+ const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
119
+ const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
120
+ for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
121
+ cpt_flags[cpt] = range_ini.second;
 
 
 
 
 
122
  }
123
  }
124
+
125
+ for (auto cpt : unicode_set_whitespace) {
126
+ cpt_flags[cpt].is_whitespace = true;
 
127
  }
128
+
129
+ for (auto p : unicode_map_lowercase) {
130
+ cpt_flags[p.second].is_lowercase = true;
 
131
  }
132
+
133
+ for (auto p : unicode_map_uppercase) {
134
+ cpt_flags[p.second].is_uppercase = true;
 
135
  }
136
+
137
+ for (auto &range : unicode_ranges_nfd) { // start, last, nfd
138
+ cpt_flags[range.nfd].is_nfd = true;
 
139
  }
140
+
141
+ return cpt_flags;
142
  }
143
 
144
  static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
145
  std::unordered_map<uint8_t, std::string> map;
146
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
147
  assert(0 <= ch && ch < 256);
148
  map[ch] = unicode_cpt_to_utf8(ch);
149
  }
150
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
151
  assert(0 <= ch && ch < 256);
152
  map[ch] = unicode_cpt_to_utf8(ch);
153
  }
154
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
155
  assert(0 <= ch && ch < 256);
156
  map[ch] = unicode_cpt_to_utf8(ch);
157
  }
 
167
 
168
  static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
169
  std::unordered_map<std::string, uint8_t> map;
170
+ for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
171
  assert(0 <= ch && ch < 256);
172
  map[unicode_cpt_to_utf8(ch)] = ch;
173
  }
174
+ for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
175
  assert(0 <= ch && ch < 256);
176
  map[unicode_cpt_to_utf8(ch)] = ch;
177
  }
178
+ for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
179
  assert(0 <= ch && ch < 256);
180
  map[unicode_cpt_to_utf8(ch)] = ch;
181
  }
 
230
  return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
231
  };
232
 
233
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
234
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
235
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
236
  };
237
 
238
  size_t _prev_end = offset_ini;
 
254
 
255
  for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
256
  const char32_t cpt = _get_cpt(pos);
257
+ const auto flags = _get_flags(pos);
258
 
259
  // regex: 's|'t|'re|'ve|'m|'ll|'d
260
  if (cpt == '\'' && pos+1 < offset_end) {
 
274
  }
275
  }
276
 
277
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
 
278
  // regex: <space>?\p{L}+
279
+ if (flags2.is_letter) {
280
  pos += (cpt == ' ');
281
+ while (flags2.is_letter) {
282
+ flags2 = _get_flags(++pos);
283
  }
284
  _add_token(pos);
285
  continue;
286
  }
287
  // regex: <space>?\p{N}+
288
+ if (flags2.is_number) {
289
  pos += (cpt == ' ');
290
+ while (flags2.is_number) {
291
+ flags2 = _get_flags(++pos);
292
  }
293
  _add_token(pos);
294
  continue;
295
  }
296
  // regex: <space>?[^\s\p{L}\p{N}]+
297
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
298
  pos += (cpt == ' ');
299
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
300
+ flags2 = _get_flags(++pos);
 
301
  }
302
  _add_token(pos);
303
  continue;
304
  }
305
 
306
  size_t num_whitespaces = 0;
307
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
308
  num_whitespaces++;
309
  }
310
 
 
348
  return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
349
  };
350
 
351
+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
352
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
353
+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
354
  };
355
 
356
  size_t _prev_end = offset_ini;
 
372
 
373
  for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
374
  const char32_t cpt = _get_cpt(pos);
375
+ const auto flags = _get_flags(pos);
376
 
377
  // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
378
  if (cpt == '\'' && pos+1 < offset_end) {
 
393
  }
394
 
395
  // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
396
+ if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
397
+ if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
398
  pos++;
399
+ while (_get_flags(pos).is_letter) {
400
  pos++;
401
  }
402
  _add_token(pos);
 
405
  }
406
 
407
  // regex: \p{N}{1,3}
408
+ if (flags.is_number) {
409
  size_t ini = pos;
410
+ while (_get_flags(pos).is_number) {
411
  if (++pos - ini >= 3 ) {
412
  _add_token(pos);
413
  ini = pos;
 
418
  }
419
 
420
  // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
421
+ auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
422
+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
 
423
  pos += (cpt == ' ');
424
+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
425
+ flags2 = _get_flags(++pos);
 
426
  }
427
+ char32_t cpt2 = _get_cpt(pos);
428
  while (cpt2 == '\r' || cpt2 == '\n') {
429
  cpt2 = _get_cpt(++pos);
430
  }
 
434
 
435
  size_t num_whitespaces = 0;
436
  size_t last_end_r_or_n = 0;
437
+ while (_get_flags(pos+num_whitespaces).is_whitespace) {
438
  char32_t cpt2 = _get_cpt(pos+num_whitespaces);
439
  if (cpt2 == '\r' || cpt2 == '\n') {
440
  last_end_r_or_n = pos + num_whitespaces + 1;
 
580
  }
581
 
582
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
583
+ auto comp = [] (const uint32_t cpt, const range_nfd & range) {
584
+ return cpt < range.first;
585
+ };
586
+ std::vector<uint32_t> result(cpts.size());
587
  for (size_t i = 0; i < cpts.size(); ++i) {
588
+ const uint32_t cpt = cpts[i];
589
+ auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
590
+ result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
 
 
 
591
  }
592
  return result;
593
  }
 
601
  return result;
602
  }
603
 
604
+ codepoint_flags unicode_cpt_flags(const uint32_t cp) {
605
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
606
+ static const auto cpt_flags = unicode_cpt_flags_array();
607
+ return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
608
  }
609
 
610
+ codepoint_flags unicode_cpt_flags(const std::string & utf8) {
611
+ static const codepoint_flags undef(codepoint_flags::UNDEFINED);
612
+ if (utf8.empty()) {
613
+ return undef; // undefined
614
  }
615
  size_t offset = 0;
616
+ return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  }
618
 
619
  std::string unicode_byte_to_utf8(uint8_t byte) {
 
634
  std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
635
  // unicode categories
636
  static const std::map<std::string, int> k_ucat_enum = {
637
+ { "\\p{N}", codepoint_flags::NUMBER },
638
+ { "\\p{L}", codepoint_flags::LETTER },
639
+ { "\\p{P}", codepoint_flags::PUNCTUATION },
640
  };
641
 
642
  static const std::map<int, int> k_ucat_cpt = {
643
+ { codepoint_flags::NUMBER, 0xD1 },
644
+ { codepoint_flags::LETTER, 0xD2 },
645
+ { codepoint_flags::PUNCTUATION, 0xD3 },
646
  };
647
 
648
  static const std::map<int, std::string> k_ucat_map = {
649
+ { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
650
+ { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
651
+ { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
652
  };
653
 
654
  // compute collapsed codepoints only if needed by at least one regex
 
679
  continue;
680
  }
681
 
682
+ const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
683
 
684
+ if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
685
+ text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
686
  } else {
687
  text_collapsed[i] = (char) 0xD0; // fallback
688
  }
examples/talk-llama/unicode.h CHANGED
@@ -4,24 +4,56 @@
4
  #include <string>
5
  #include <vector>
6
 
7
- #define CODEPOINT_TYPE_UNIDENTIFIED 0
8
- #define CODEPOINT_TYPE_NUMBER 1
9
- #define CODEPOINT_TYPE_LETTER 2
10
- #define CODEPOINT_TYPE_SEPARATOR 3
11
- #define CODEPOINT_TYPE_ACCENT_MARK 4
12
- #define CODEPOINT_TYPE_PUNCTUATION 5
13
- #define CODEPOINT_TYPE_SYMBOL 6
14
- #define CODEPOINT_TYPE_CONTROL 7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  std::string unicode_cpt_to_utf8(uint32_t cp);
17
  std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
18
 
19
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
20
 
21
- int unicode_cpt_type(uint32_t cp);
22
- int unicode_cpt_type(const std::string & utf8);
23
-
24
- bool unicode_cpt_is_whitespace(uint32_t cp);
25
 
26
  std::string unicode_byte_to_utf8(uint8_t byte);
27
  uint8_t unicode_utf8_to_byte(const std::string & utf8);
 
4
  #include <string>
5
  #include <vector>
6
 
7
+ struct codepoint_flags {
8
+ enum {
9
+ UNDEFINED = 0x0001,
10
+ NUMBER = 0x0002, // regex: \p{N}
11
+ LETTER = 0x0004, // regex: \p{L}
12
+ SEPARATOR = 0x0008, // regex: \p{Z}
13
+ ACCENT_MARK = 0x0010, // regex: \p{M}
14
+ PUNCTUATION = 0x0020, // regex: \p{P}
15
+ SYMBOL = 0x0040, // regex: \p{S}
16
+ CONTROL = 0x0080, // regex: \p{C}
17
+ MASK_CATEGORIES = 0x00FF,
18
+ };
19
+
20
+ // codepoint type
21
+ uint16_t is_undefined : 1;
22
+ uint16_t is_number : 1; // regex: \p{N}
23
+ uint16_t is_letter : 1; // regex: \p{L}
24
+ uint16_t is_separator : 1; // regex: \p{Z}
25
+ uint16_t is_accent_mark : 1; // regex: \p{M}
26
+ uint16_t is_punctuation : 1; // regex: \p{P}
27
+ uint16_t is_symbol : 1; // regex: \p{S}
28
+ uint16_t is_control : 1; // regex: \p{C}
29
+ // helper flags
30
+ uint16_t is_whitespace : 1; // regex: \s
31
+ uint16_t is_lowercase : 1;
32
+ uint16_t is_uppercase : 1;
33
+ uint16_t is_nfd : 1;
34
+
35
+ // decode from uint16
36
+ inline codepoint_flags(const uint16_t flags=0) {
37
+ *reinterpret_cast<uint16_t*>(this) = flags;
38
+ }
39
+
40
+ inline uint16_t as_uint() const {
41
+ return *reinterpret_cast<const uint16_t*>(this);
42
+ }
43
+
44
+ inline uint16_t category_flag() const {
45
+ return this->as_uint() & MASK_CATEGORIES;
46
+ }
47
+ };
48
+
49
 
50
  std::string unicode_cpt_to_utf8(uint32_t cp);
51
  std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
52
 
53
  std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
54
 
55
+ codepoint_flags unicode_cpt_flags(const uint32_t cp);
56
+ codepoint_flags unicode_cpt_flags(const std::string & utf8);
 
 
57
 
58
  std::string unicode_byte_to_utf8(uint8_t byte);
59
  uint8_t unicode_utf8_to_byte(const std::string & utf8);