import os import numpy as np import coremltools as ct import time from transformers import AutoTokenizer import shutil from argparse import ArgumentParser import asyncio def copy_compiled_model(mlmodel: ct.models.MLModel, dest: str): compiled_model_path = mlmodel.get_compiled_model_path() shutil.copytree(compiled_model_path, dest, dirs_exist_ok=True) def load_mlmodel(path, function_name, copy_compiled): extension = os.path.splitext(path)[1] if extension == ".mlmodelc": return ct.models.CompiledMLModel( path, function_name=function_name, compute_units=ct.ComputeUnit.CPU_AND_NE, ) else: mlmodel = ct.models.MLModel( path, function_name=function_name, compute_units=ct.ComputeUnit.CPU_AND_NE, ) if copy_compiled: copy_compiled_model(mlmodel, path.replace(".mlpackage", ".mlmodelc")) return mlmodel def load_embeddings(path): return np.load(path) async def generate_single_step( input_id, embed_fn, model, state, position, attention_mask_ref, lm_head, ): embd = embed_fn(input_id).transpose(0, 3, 1, 2) hidden_states = model.predict( { "hidden_states": embd, "kv_write_idx": np.array([position], dtype=np.int32), "positions": np.array([[position]], dtype=np.int32), "attention_mask": attention_mask_ref[:, :, [position]], }, state, )["output_hidden_states"] if lm_head is not None: input_id = lm_head(hidden_states) return input_id class ModelContainer: def __init__( self, embeddings_path, mlmodel_path, lm_head_path, cache_length, hf_model, temp=0.7, min_p=0.1, ): self.mlmodel_path = mlmodel_path self.embeddings_path = embeddings_path self.lm_head_path = lm_head_path self.cache_length = cache_length self.temp = temp self.min_p = min_p print("Loading embeddings...") self.embeddings = load_embeddings(embeddings_path) print("Loading generation model...") self.generation_model = load_mlmodel( mlmodel_path, f"model_input_1_cache_{cache_length}", copy_compiled=True ) # self.prompt_model = None print("Loading prompt model...") self.prompt_model = load_mlmodel( mlmodel_path.replace(".mlpackage", ".mlmodelc"), f"model_input_64_cache_{cache_length}", copy_compiled=False, ) print("Loading lm head model...") self.lm_head_model = load_mlmodel( lm_head_path, "min_p_length_1" if temp > 0 else "lm_head_length_1", copy_compiled=True, ) self.tokenizer = AutoTokenizer.from_pretrained(hf_model) self.end_of_response_token_id = self.tokenizer("<|im_end|>").input_ids[0] self.end_of_text_token_id = self.tokenizer("<|end_of_text|>").input_ids[0] self.break_tokens = [self.end_of_response_token_id, self.end_of_text_token_id] self.state = None self.position = None attention_mask = np.arange(self.cache_length, dtype=np.int32) attention_mask = attention_mask[:, None] >= attention_mask[None, :] attention_mask = attention_mask[None, None, :, :] self.attention_mask = np.where( attention_mask, np.array(0.0, dtype=np.float16), np.array(-np.inf, dtype=np.float16), ) def initialize_generation(self): self.state = self.generation_model.make_state() self.position = 0 def load_prompt_model(self): if self.prompt_model is None: self.prompt_model = load_mlmodel( self.mlmodel_path, f"model_input_64_cache_{self.cache_length}", copy_compiled=False, ) def unload_prompt_model(self): del self.prompt_model self.prompt_model = None def embed(self, ids): return self.embeddings[ids] # .transpose(0, 2, 1) # [..., None, :] def process_prompt(self, prompt): if self.prompt_model is None: self.load_prompt_model() messages = [{"role": "user", "content": prompt}] tokens = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True ) if self.position + len(tokens) >= self.cache_length: return np.array([-1]) stop_processing = False start_time = time.perf_counter() processed_chunks = 0 for i in range(0, len(tokens), 64): chunk = tokens[i : min(i + 64, len(tokens))] if self.position + len(chunk) > self.cache_length: stop_processing = True break processed_chunks += 1 embds = self.embed([chunk]).transpose(0, 2, 1)[ ..., None, : ] # [..., None, :] if len(chunk) < 64: embds = np.concat( ( embds, np.zeros( (1, embds.shape[1], 1, 64 - len(chunk)), dtype=np.float16 ), ), axis=-1, ) kv_write_idx = np.array([self.position], dtype=np.int32) positions = np.arange(self.position, self.position + 64, dtype=np.int32)[ None, : ] attention_mask = self.attention_mask[ :, :, self.position : self.position + 64 ] pred = self.prompt_model.predict( { "hidden_states": embds, "kv_write_idx": kv_write_idx, "positions": positions, "attention_mask": attention_mask, }, self.state, ) self.position += len(chunk) self.unload_prompt_model() end_time = time.perf_counter() print( f"==== Processed {len(tokens)} tokens + {64 - len(chunk)} pad tokens in {end_time - start_time:.2f} seconds, {processed_chunks * 64 / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}", ) if stop_processing: return np.array([-1], dtype=np.int32) output_hidden_states = pred["output_hidden_states"][..., [len(chunk) - 1]] return self.lm_head(output_hidden_states) def lm_head(self, hidden_states): if self.temp > 0: input_id = self.lm_head_model.predict( { "hidden_states": hidden_states, "temp": np.array([self.temp], dtype=np.float16), "p": np.array([self.min_p], dtype=np.float16), "random_number": np.random.uniform(0.0, 1.0, (1,)), } )["sampled_index"][:, 0] else: input_id = self.lm_head_model.predict( { "hidden_states": hidden_states, } )[ "argmax" ][:, 0] return input_id async def generate(self, input_id: np.array): continue_generating = True # for i in range(max_new_tokens): generated_tokens = 0 start_time = time.perf_counter() # task = asyncio.create_task(generate_single_step( # input_id, # self.embed, # self.generation_model, # self.state, # self.position, # self.attention_mask, # self.lm_head, # )) while (self.position < self.cache_length) and continue_generating: generated_tokens += 1 input_id_item = input_id.item() if input_id_item in self.break_tokens: continue_generating = False task = asyncio.create_task( generate_single_step( input_id, self.embed, self.generation_model, self.state, self.position, self.attention_mask, self.lm_head if continue_generating else None, ) ) self.position += 1 print(self.tokenizer.decode(input_id_item), end="", flush=True) input_id = await task print() end_time = time.perf_counter() print( f"==== Generated {generated_tokens} tokens in {end_time - start_time:.2f} seconds, {generated_tokens / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}", ) # if stop_generation: # self.load_prompt_model() def loop(self): print("--- Begin conversation ---") while True: self.initialize_generation() while True: print(">>> ", end="", flush=True) self.load_prompt_model() prompt = input() prompt_result = self.process_prompt(prompt) if prompt_result.item() == -1: print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n") print("--- Beginning new conversation ---") break # print(self.tokenizer.decode(prompt_result.item()), end="", flush=True) asyncio.run(self.generate(prompt_result)) if self.position >= (self.cache_length): print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n") print("--- Beginning new conversation ---") break def parse_args(): parser = ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--lm_head", type=str, required=True) parser.add_argument("--embeddings", type=str, required=True) parser.add_argument( "--cache_length", type=int, choices=[512, 1024, 2048, 2048 + 1024, 4096, 4096 + 2048, 8192], default=1024, ) parser.add_argument("--min_p", type=float, default=0.1) parser.add_argument("--temp", type=float, default=0.7) # parser.add_argument("--hf_model", type=str, default="") return parser.parse_args() def main(): args = parse_args() ModelContainer( args.embeddings, args.model, args.lm_head, args.cache_length, "tiiuae/Falcon-E-1B-Instruct", args.temp, args.min_p, ).loop() if __name__ == "__main__": main()