|
import os |
|
import numpy as np |
|
import coremltools as ct |
|
import time |
|
from transformers import AutoTokenizer |
|
import shutil |
|
from argparse import ArgumentParser |
|
import asyncio |
|
|
|
|
|
def copy_compiled_model(mlmodel: ct.models.MLModel, dest: str): |
|
compiled_model_path = mlmodel.get_compiled_model_path() |
|
shutil.copytree(compiled_model_path, dest, dirs_exist_ok=True) |
|
|
|
|
|
def load_mlmodel(path, function_name, copy_compiled): |
|
extension = os.path.splitext(path)[1] |
|
if extension == ".mlmodelc": |
|
return ct.models.CompiledMLModel( |
|
path, |
|
function_name=function_name, |
|
compute_units=ct.ComputeUnit.CPU_AND_NE, |
|
) |
|
else: |
|
mlmodel = ct.models.MLModel( |
|
path, |
|
function_name=function_name, |
|
compute_units=ct.ComputeUnit.CPU_AND_NE, |
|
) |
|
if copy_compiled: |
|
copy_compiled_model(mlmodel, path.replace(".mlpackage", ".mlmodelc")) |
|
return mlmodel |
|
|
|
|
|
def load_embeddings(path): |
|
return np.load(path) |
|
|
|
|
|
async def generate_single_step( |
|
input_id, |
|
embed_fn, |
|
model, |
|
state, |
|
position, |
|
attention_mask_ref, |
|
lm_head, |
|
): |
|
embd = embed_fn(input_id).transpose(0, 3, 1, 2) |
|
hidden_states = model.predict( |
|
{ |
|
"hidden_states": embd, |
|
"kv_write_idx": np.array([position], dtype=np.int32), |
|
"positions": np.array([[position]], dtype=np.int32), |
|
"attention_mask": attention_mask_ref[:, :, [position]], |
|
}, |
|
state, |
|
)["output_hidden_states"] |
|
if lm_head is not None: |
|
input_id = lm_head(hidden_states) |
|
return input_id |
|
|
|
|
|
class ModelContainer: |
|
def __init__( |
|
self, |
|
embeddings_path, |
|
mlmodel_path, |
|
lm_head_path, |
|
cache_length, |
|
hf_model, |
|
temp=0.7, |
|
min_p=0.1, |
|
): |
|
self.mlmodel_path = mlmodel_path |
|
self.embeddings_path = embeddings_path |
|
self.lm_head_path = lm_head_path |
|
self.cache_length = cache_length |
|
self.temp = temp |
|
self.min_p = min_p |
|
print("Loading embeddings...") |
|
self.embeddings = load_embeddings(embeddings_path) |
|
print("Loading generation model...") |
|
self.generation_model = load_mlmodel( |
|
mlmodel_path, f"model_input_1_cache_{cache_length}", copy_compiled=True |
|
) |
|
|
|
print("Loading prompt model...") |
|
self.prompt_model = load_mlmodel( |
|
mlmodel_path.replace(".mlpackage", ".mlmodelc"), |
|
f"model_input_64_cache_{cache_length}", |
|
copy_compiled=False, |
|
) |
|
print("Loading lm head model...") |
|
self.lm_head_model = load_mlmodel( |
|
lm_head_path, |
|
"min_p_length_1" if temp > 0 else "lm_head_length_1", |
|
copy_compiled=True, |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(hf_model) |
|
self.end_of_response_token_id = self.tokenizer("<|im_end|>").input_ids[0] |
|
self.end_of_text_token_id = self.tokenizer("<|end_of_text|>").input_ids[0] |
|
self.break_tokens = [self.end_of_response_token_id, self.end_of_text_token_id] |
|
|
|
self.state = None |
|
self.position = None |
|
attention_mask = np.arange(self.cache_length, dtype=np.int32) |
|
attention_mask = attention_mask[:, None] >= attention_mask[None, :] |
|
attention_mask = attention_mask[None, None, :, :] |
|
self.attention_mask = np.where( |
|
attention_mask, |
|
np.array(0.0, dtype=np.float16), |
|
np.array(-np.inf, dtype=np.float16), |
|
) |
|
|
|
def initialize_generation(self): |
|
self.state = self.generation_model.make_state() |
|
self.position = 0 |
|
|
|
def load_prompt_model(self): |
|
if self.prompt_model is None: |
|
self.prompt_model = load_mlmodel( |
|
self.mlmodel_path, |
|
f"model_input_64_cache_{self.cache_length}", |
|
copy_compiled=False, |
|
) |
|
|
|
def unload_prompt_model(self): |
|
del self.prompt_model |
|
self.prompt_model = None |
|
|
|
def embed(self, ids): |
|
return self.embeddings[ids] |
|
|
|
def process_prompt(self, prompt): |
|
if self.prompt_model is None: |
|
self.load_prompt_model() |
|
messages = [{"role": "user", "content": prompt}] |
|
tokens = self.tokenizer.apply_chat_template( |
|
messages, tokenize=True, add_generation_prompt=True |
|
) |
|
if self.position + len(tokens) >= self.cache_length: |
|
return np.array([-1]) |
|
stop_processing = False |
|
start_time = time.perf_counter() |
|
processed_chunks = 0 |
|
for i in range(0, len(tokens), 64): |
|
chunk = tokens[i : min(i + 64, len(tokens))] |
|
if self.position + len(chunk) > self.cache_length: |
|
stop_processing = True |
|
break |
|
processed_chunks += 1 |
|
embds = self.embed([chunk]).transpose(0, 2, 1)[ |
|
..., None, : |
|
] |
|
if len(chunk) < 64: |
|
embds = np.concat( |
|
( |
|
embds, |
|
np.zeros( |
|
(1, embds.shape[1], 1, 64 - len(chunk)), dtype=np.float16 |
|
), |
|
), |
|
axis=-1, |
|
) |
|
kv_write_idx = np.array([self.position], dtype=np.int32) |
|
positions = np.arange(self.position, self.position + 64, dtype=np.int32)[ |
|
None, : |
|
] |
|
attention_mask = self.attention_mask[ |
|
:, :, self.position : self.position + 64 |
|
] |
|
pred = self.prompt_model.predict( |
|
{ |
|
"hidden_states": embds, |
|
"kv_write_idx": kv_write_idx, |
|
"positions": positions, |
|
"attention_mask": attention_mask, |
|
}, |
|
self.state, |
|
) |
|
self.position += len(chunk) |
|
self.unload_prompt_model() |
|
end_time = time.perf_counter() |
|
print( |
|
f"==== Processed {len(tokens)} tokens + {64 - len(chunk)} pad tokens in {end_time - start_time:.2f} seconds, {processed_chunks * 64 / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}", |
|
) |
|
if stop_processing: |
|
return np.array([-1], dtype=np.int32) |
|
output_hidden_states = pred["output_hidden_states"][..., [len(chunk) - 1]] |
|
return self.lm_head(output_hidden_states) |
|
|
|
def lm_head(self, hidden_states): |
|
if self.temp > 0: |
|
input_id = self.lm_head_model.predict( |
|
{ |
|
"hidden_states": hidden_states, |
|
"temp": np.array([self.temp], dtype=np.float16), |
|
"p": np.array([self.min_p], dtype=np.float16), |
|
"random_number": np.random.uniform(0.0, 1.0, (1,)), |
|
} |
|
)["sampled_index"][:, 0] |
|
else: |
|
input_id = self.lm_head_model.predict( |
|
{ |
|
"hidden_states": hidden_states, |
|
} |
|
)[ |
|
"argmax" |
|
][:, 0] |
|
return input_id |
|
|
|
async def generate(self, input_id: np.array): |
|
continue_generating = True |
|
|
|
generated_tokens = 0 |
|
start_time = time.perf_counter() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while (self.position < self.cache_length) and continue_generating: |
|
generated_tokens += 1 |
|
input_id_item = input_id.item() |
|
if input_id_item in self.break_tokens: |
|
continue_generating = False |
|
task = asyncio.create_task( |
|
generate_single_step( |
|
input_id, |
|
self.embed, |
|
self.generation_model, |
|
self.state, |
|
self.position, |
|
self.attention_mask, |
|
self.lm_head if continue_generating else None, |
|
) |
|
) |
|
self.position += 1 |
|
print(self.tokenizer.decode(input_id_item), end="", flush=True) |
|
input_id = await task |
|
|
|
print() |
|
|
|
end_time = time.perf_counter() |
|
print( |
|
f"==== Generated {generated_tokens} tokens in {end_time - start_time:.2f} seconds, {generated_tokens / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}", |
|
) |
|
|
|
|
|
|
|
def loop(self): |
|
print("--- Begin conversation ---") |
|
while True: |
|
self.initialize_generation() |
|
while True: |
|
print(">>> ", end="", flush=True) |
|
self.load_prompt_model() |
|
prompt = input() |
|
prompt_result = self.process_prompt(prompt) |
|
if prompt_result.item() == -1: |
|
print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n") |
|
print("--- Beginning new conversation ---") |
|
break |
|
|
|
asyncio.run(self.generate(prompt_result)) |
|
if self.position >= (self.cache_length): |
|
print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n") |
|
print("--- Beginning new conversation ---") |
|
break |
|
|
|
|
|
def parse_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--model", type=str, required=True) |
|
parser.add_argument("--lm_head", type=str, required=True) |
|
parser.add_argument("--embeddings", type=str, required=True) |
|
parser.add_argument( |
|
"--cache_length", |
|
type=int, |
|
choices=[512, 1024, 2048, 2048 + 1024, 4096, 4096 + 2048, 8192], |
|
default=1024, |
|
) |
|
parser.add_argument("--min_p", type=float, default=0.1) |
|
parser.add_argument("--temp", type=float, default=0.7) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
ModelContainer( |
|
args.embeddings, |
|
args.model, |
|
args.lm_head, |
|
args.cache_length, |
|
"tiiuae/Falcon-E-1B-Instruct", |
|
args.temp, |
|
args.min_p, |
|
).loop() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|