Fixed GIL issue
Browse filesRace condition between CoreML and casual_mask update
chat.py
CHANGED
|
@@ -26,8 +26,10 @@ DARK_BLUE = "\033[34m"
|
|
| 26 |
LIGHT_GREEN = "\033[92m"
|
| 27 |
RESET_COLOR = "\033[0m"
|
| 28 |
|
| 29 |
-
# Add at top with other constants
|
| 30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
|
|
|
|
|
|
| 31 |
|
| 32 |
class TokenPrinter:
|
| 33 |
"""Handles background printing of generated tokens."""
|
|
@@ -40,9 +42,12 @@ class TokenPrinter:
|
|
| 40 |
self.lock = threading.Lock()
|
| 41 |
self.thinking = True # Track if we're still in thinking mode
|
| 42 |
self.decoding_buffer = [] # Buffer for token IDs
|
| 43 |
-
#
|
| 44 |
self.start_time = time.time()
|
| 45 |
self.token_count = 0
|
|
|
|
|
|
|
|
|
|
| 46 |
self.start()
|
| 47 |
|
| 48 |
def start(self):
|
|
@@ -103,15 +108,15 @@ class TokenPrinter:
|
|
| 103 |
self.thread.join(timeout=1.0)
|
| 104 |
except Exception:
|
| 105 |
pass
|
| 106 |
-
#
|
| 107 |
-
elapsed = time.time() - self.start_time
|
| 108 |
-
if elapsed > 0 and self.token_count > 0:
|
| 109 |
-
tokens_per_sec = self.token_count / elapsed
|
| 110 |
-
print(f"\n{DARK_BLUE}{tokens_per_sec:.1f} t/s{RESET_COLOR}")
|
| 111 |
-
else:
|
| 112 |
-
print(RESET_COLOR) # Reset color at the end
|
| 113 |
return self.buffer
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def parse_model_path(path):
|
| 116 |
"""Parse model path and return full path with .mlmodelc or .mlpackage extension."""
|
| 117 |
path = Path(path)
|
|
@@ -188,6 +193,89 @@ def load_model(path, function_name=None):
|
|
| 188 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
| 189 |
raise
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def load_metadata(model,args):
|
| 192 |
# Extract metadata and config parameters
|
| 193 |
metadata = {}
|
|
@@ -386,102 +474,99 @@ def make_causal_mask(length, start):
|
|
| 386 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 387 |
return mask
|
| 388 |
|
| 389 |
-
def run_prefill(embed_model, ffn_models, input_ids,
|
| 390 |
"""Run prefill on the input sequence."""
|
| 391 |
-
#
|
| 392 |
-
causal_mask = make_causal_mask(context_length, 0)
|
| 393 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 394 |
|
| 395 |
# Process in batches
|
| 396 |
batch_pos = 0
|
| 397 |
-
while batch_pos <
|
| 398 |
-
batch_end = min(batch_pos + batch_size,
|
| 399 |
current_batch_size = batch_end - batch_pos
|
| 400 |
|
|
|
|
|
|
|
| 401 |
# Get current batch
|
| 402 |
batch_input = input_ids[:, batch_pos:batch_end]
|
| 403 |
|
| 404 |
-
#
|
| 405 |
batch_input = F.pad(
|
| 406 |
batch_input,
|
| 407 |
(0, batch_size - current_batch_size),
|
| 408 |
value=0
|
| 409 |
)
|
| 410 |
|
| 411 |
-
# Generate position IDs for
|
| 412 |
-
position_ids = torch.arange(batch_size, dtype=torch.int32)
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
|
| 415 |
# Run embeddings
|
| 416 |
hidden_states = torch.from_numpy(
|
| 417 |
embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
|
| 418 |
)
|
| 419 |
|
| 420 |
-
# Run through FFN chunks
|
| 421 |
for ffn_model in ffn_models:
|
| 422 |
if isinstance(ffn_model, dict):
|
| 423 |
inputs = {
|
| 424 |
-
'hidden_states': hidden_states.numpy(),
|
| 425 |
-
'position_ids': position_ids.numpy(),
|
| 426 |
-
'causal_mask': batch_causal_mask.numpy(),
|
| 427 |
-
'current_pos': np.array([batch_pos], dtype=np.int32)
|
| 428 |
}
|
| 429 |
output = ffn_model['prefill'].predict(inputs, state)
|
| 430 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
| 431 |
|
| 432 |
batch_pos = batch_end
|
| 433 |
|
| 434 |
-
return torch.tensor([
|
| 435 |
|
| 436 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
| 437 |
"""Generate the next token."""
|
| 438 |
# Get current token
|
| 439 |
-
current_token = input_ids[:, pos-1:pos]
|
| 440 |
|
| 441 |
# Run embeddings
|
| 442 |
hidden_states = torch.from_numpy(
|
| 443 |
embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
|
| 444 |
-
)
|
| 445 |
|
| 446 |
# Create masks
|
| 447 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 448 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 449 |
-
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
| 450 |
-
causal_mask = make_causal_mask(context_length, 0)
|
| 451 |
-
causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16) # [1, 1, 1, context_length]
|
| 452 |
|
| 453 |
-
#
|
|
|
|
|
|
|
|
|
|
| 454 |
for ffn_model in ffn_models:
|
| 455 |
if isinstance(ffn_model, dict):
|
| 456 |
inputs = {
|
| 457 |
'hidden_states': hidden_states.numpy(),
|
| 458 |
'update_mask': update_mask.numpy(),
|
| 459 |
'position_ids': position_ids.numpy(),
|
| 460 |
-
'causal_mask':
|
| 461 |
'current_pos': position_ids.numpy()
|
| 462 |
}
|
| 463 |
output = ffn_model['infer'].predict(inputs, state)
|
| 464 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
| 465 |
|
| 466 |
-
# Run LM head
|
| 467 |
lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
|
| 468 |
-
# Debug print
|
| 469 |
-
#print("\nLM Head output keys:", list(lm_output.keys()))
|
| 470 |
|
| 471 |
-
# Combine logits1-8 if they exist
|
| 472 |
if 'logits1' in lm_output:
|
| 473 |
-
# Concatenate all logits parts
|
| 474 |
logits_parts = []
|
| 475 |
for i in range(1, 9):
|
| 476 |
key = f'logits{i}'
|
| 477 |
if key in lm_output:
|
| 478 |
logits_parts.append(torch.from_numpy(lm_output[key]))
|
| 479 |
-
logits = torch.cat(logits_parts, dim=-1)
|
| 480 |
else:
|
| 481 |
-
# Try output_logits as fallback
|
| 482 |
logits = torch.from_numpy(lm_output['output_logits'])
|
| 483 |
|
| 484 |
-
# Apply temperature and sample
|
| 485 |
if temperature > 0:
|
| 486 |
logits = logits / temperature
|
| 487 |
probs = F.softmax(logits[0, -1, :], dim=-1)
|
|
@@ -503,36 +588,93 @@ def create_unified_state(ffn_models, context_length):
|
|
| 503 |
print("\nCreated unified transformer state")
|
| 504 |
return state
|
| 505 |
|
| 506 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
"""Interactive chat loop."""
|
|
|
|
| 508 |
context_length = metadata.get('context_length')
|
| 509 |
batch_size = metadata.get('batch_size', 64)
|
| 510 |
|
| 511 |
if not warmup:
|
| 512 |
print(f"\nUsing context length: {context_length}")
|
| 513 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
| 514 |
-
print("Type your message and press Enter to chat.")
|
| 515 |
-
|
| 516 |
-
# Check if tokenizer has chat template and if it works
|
| 517 |
-
has_chat_template = False
|
| 518 |
-
try:
|
| 519 |
-
# Test if chat template works
|
| 520 |
-
test_messages = [{"role": "user", "content": "test"}]
|
| 521 |
-
tokenizer.apply_chat_template(test_messages, return_tensors="pt")
|
| 522 |
-
has_chat_template = True
|
| 523 |
-
if not warmup:
|
| 524 |
-
print("\nUsing chat template for prompts")
|
| 525 |
-
except:
|
| 526 |
-
if not warmup:
|
| 527 |
-
print("\nUsing manual formatting for prompts")
|
| 528 |
|
|
|
|
| 529 |
conversation = []
|
| 530 |
|
| 531 |
try:
|
| 532 |
while True:
|
| 533 |
try:
|
| 534 |
if not warmup:
|
| 535 |
-
print(f"\n{LIGHT_GREEN}You:{RESET_COLOR}", end=' ', flush=True)
|
| 536 |
if auto_prompt is not None:
|
| 537 |
user_input = auto_prompt
|
| 538 |
if not warmup:
|
|
@@ -543,41 +685,69 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 543 |
if not warmup:
|
| 544 |
print("\nExiting chat...")
|
| 545 |
break
|
| 546 |
-
|
| 547 |
if not user_input:
|
| 548 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
-
#
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 555 |
return_tensors="pt",
|
| 556 |
add_generation_prompt=True
|
| 557 |
).to(torch.int32)
|
| 558 |
else:
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
input_ids = tokenizer(
|
| 562 |
-
formatted_prompt,
|
| 563 |
return_tensors="pt",
|
| 564 |
-
|
| 565 |
-
).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
if not warmup:
|
| 570 |
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
|
| 571 |
|
| 572 |
-
# Initialize token printer
|
| 573 |
token_printer = TokenPrinter(tokenizer)
|
| 574 |
-
|
|
|
|
| 575 |
|
| 576 |
try:
|
| 577 |
-
#
|
| 578 |
-
prefill_start = time.time()
|
| 579 |
-
|
| 580 |
-
# Run prefill with state
|
| 581 |
current_pos = run_prefill(
|
| 582 |
embed_model,
|
| 583 |
ffn_models,
|
|
@@ -585,21 +755,53 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 585 |
context_pos,
|
| 586 |
context_length,
|
| 587 |
batch_size,
|
| 588 |
-
state
|
|
|
|
| 589 |
)
|
|
|
|
| 590 |
|
| 591 |
-
#
|
| 592 |
-
prefill_time = time.time() - prefill_start
|
| 593 |
-
prefill_tokens = context_pos # Number of tokens in input
|
| 594 |
-
prefill_tokens_per_sec = prefill_tokens / prefill_time if prefill_time > 0 else 0
|
| 595 |
-
|
| 596 |
-
# Generation loop with state
|
| 597 |
-
input_ids = input_ids
|
| 598 |
pos = context_pos
|
| 599 |
-
|
| 600 |
-
|
| 601 |
|
| 602 |
-
while
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
# Generate next token
|
| 604 |
next_token = generate_next_token(
|
| 605 |
embed_model,
|
|
@@ -608,146 +810,58 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
| 608 |
input_ids,
|
| 609 |
pos,
|
| 610 |
context_length,
|
| 611 |
-
state
|
|
|
|
| 612 |
)
|
| 613 |
|
| 614 |
-
# Add token
|
| 615 |
-
|
| 616 |
-
input_ids[0, pos] = next_token
|
| 617 |
-
else:
|
| 618 |
-
input_ids = torch.cat([
|
| 619 |
-
input_ids,
|
| 620 |
-
torch.tensor([[next_token]], dtype=torch.int32)
|
| 621 |
-
], dim=1)
|
| 622 |
-
|
| 623 |
-
# Add to printer only if not in warmup
|
| 624 |
if not warmup:
|
| 625 |
token_printer.add_token(next_token)
|
| 626 |
token_printer.drain_buffer()
|
|
|
|
| 627 |
|
| 628 |
pos += 1
|
| 629 |
tokens_generated += 1
|
| 630 |
-
inference_tokens += 1
|
| 631 |
|
| 632 |
-
#
|
| 633 |
if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
|
| 634 |
break
|
| 635 |
-
|
| 636 |
if next_token == tokenizer.eos_token_id:
|
| 637 |
break
|
| 638 |
|
| 639 |
-
# Calculate inference
|
| 640 |
-
|
| 641 |
-
|
|
|
|
|
|
|
| 642 |
|
| 643 |
-
#
|
| 644 |
if not warmup:
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
print(f"
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
token_printer.stop() # Clean up without printing stats
|
| 654 |
|
| 655 |
-
# Exit after one response in auto_prompt mode
|
| 656 |
if auto_prompt is not None:
|
| 657 |
break
|
| 658 |
|
| 659 |
except KeyboardInterrupt:
|
| 660 |
-
|
|
|
|
| 661 |
token_printer.stop()
|
| 662 |
continue
|
| 663 |
|
| 664 |
except Exception as e:
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
def parse_args():
|
| 670 |
-
parser = argparse.ArgumentParser(description='Chat with CoreML LLaMA (c) 2025 Anemll')
|
| 671 |
-
|
| 672 |
-
# Add meta.yaml option
|
| 673 |
-
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
| 674 |
-
|
| 675 |
-
# Model paths
|
| 676 |
-
parser.add_argument('--d', '--dir', type=str, default='.',
|
| 677 |
-
help='Directory containing model files (default: current directory)')
|
| 678 |
-
parser.add_argument('--embed', type=str, required=False,
|
| 679 |
-
help='Path to embeddings model (relative to --dir)')
|
| 680 |
-
parser.add_argument('--ffn', type=str, required=False,
|
| 681 |
-
help='Path to FFN model (can be chunked, relative to --dir)')
|
| 682 |
-
parser.add_argument('--lmhead', type=str, required=False,
|
| 683 |
-
help='Path to LM head model (relative to --dir)')
|
| 684 |
-
parser.add_argument('--tokenizer', type=str, required=False,
|
| 685 |
-
help='Path to tokenizer')
|
| 686 |
-
|
| 687 |
-
# Add new argument for auto-generation
|
| 688 |
-
parser.add_argument('--prompt', type=str,
|
| 689 |
-
help='If specified, run once with this prompt and exit')
|
| 690 |
-
|
| 691 |
-
# Add no-warmup flag
|
| 692 |
-
parser.add_argument('--nw', action='store_true',
|
| 693 |
-
help='Skip warmup phase')
|
| 694 |
-
|
| 695 |
-
# Model configuration
|
| 696 |
-
parser.add_argument('--context-length', type=int,
|
| 697 |
-
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
| 698 |
-
parser.add_argument('--batch-size', type=int,
|
| 699 |
-
help='Batch size for prefill (default: 64)')
|
| 700 |
-
|
| 701 |
-
args = parser.parse_args()
|
| 702 |
-
|
| 703 |
-
# If meta.yaml is provided, load parameters from it
|
| 704 |
-
if args.meta:
|
| 705 |
-
try:
|
| 706 |
-
with open(args.meta, 'r') as f:
|
| 707 |
-
meta = yaml.safe_load(f)
|
| 708 |
-
params = meta['model_info']['parameters']
|
| 709 |
-
|
| 710 |
-
# Set model directory to meta.yaml directory if not specified
|
| 711 |
-
if not args.d or args.d == '.':
|
| 712 |
-
args.d = str(Path(args.meta).parent)
|
| 713 |
-
|
| 714 |
-
# Build model paths based on parameters
|
| 715 |
-
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
| 716 |
-
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
| 717 |
-
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
| 718 |
-
num_chunks = int(params['num_chunks'])
|
| 719 |
-
|
| 720 |
-
# Set model paths if not specified
|
| 721 |
-
if not args.embed:
|
| 722 |
-
args.embed = f'{prefix}_embeddings'
|
| 723 |
-
if not args.lmhead:
|
| 724 |
-
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
| 725 |
-
if not args.ffn:
|
| 726 |
-
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
| 727 |
-
if not args.tokenizer:
|
| 728 |
-
args.tokenizer = args.d
|
| 729 |
-
|
| 730 |
-
# Set other parameters if not overridden by command line
|
| 731 |
-
if args.context_length is None:
|
| 732 |
-
args.context_length = int(params['context_length'])
|
| 733 |
-
if args.batch_size is None:
|
| 734 |
-
args.batch_size = int(params['batch_size'])
|
| 735 |
-
args.num_chunks = num_chunks
|
| 736 |
-
|
| 737 |
-
print(f"\nLoaded parameters from {args.meta}:")
|
| 738 |
-
print(f" Context Length: {args.context_length}")
|
| 739 |
-
print(f" Batch Size: {args.batch_size}")
|
| 740 |
-
print(f" Num Chunks: {args.num_chunks}")
|
| 741 |
-
print(f" Models Directory: {args.d}")
|
| 742 |
-
print(f" Embeddings: {args.embed}")
|
| 743 |
-
print(f" LM Head: {args.lmhead}")
|
| 744 |
-
print(f" FFN: {args.ffn}")
|
| 745 |
-
|
| 746 |
-
except Exception as e:
|
| 747 |
-
print(f"\nError loading meta.yaml: {str(e)}")
|
| 748 |
-
sys.exit(1)
|
| 749 |
-
|
| 750 |
-
return args
|
| 751 |
|
| 752 |
def main():
|
| 753 |
args = parse_args()
|
|
@@ -800,6 +914,9 @@ def main():
|
|
| 800 |
# Create unified state once
|
| 801 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 802 |
|
|
|
|
|
|
|
|
|
|
| 803 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 804 |
if not args.nw:
|
| 805 |
for i in range(2):
|
|
@@ -809,7 +926,8 @@ def main():
|
|
| 809 |
lmhead_model=lmhead_model,
|
| 810 |
tokenizer=tokenizer,
|
| 811 |
metadata=metadata,
|
| 812 |
-
state=state,
|
|
|
|
| 813 |
warmup=True,
|
| 814 |
auto_prompt="who are you?"
|
| 815 |
)
|
|
@@ -821,7 +939,8 @@ def main():
|
|
| 821 |
lmhead_model=lmhead_model,
|
| 822 |
tokenizer=tokenizer,
|
| 823 |
metadata=metadata,
|
| 824 |
-
state=state,
|
|
|
|
| 825 |
warmup=False,
|
| 826 |
auto_prompt=args.prompt
|
| 827 |
)
|
|
|
|
| 26 |
LIGHT_GREEN = "\033[92m"
|
| 27 |
RESET_COLOR = "\033[0m"
|
| 28 |
|
| 29 |
+
# Add at the top with other constants
|
| 30 |
WARMUP_TOKEN_LIMIT = 10 # Maximum tokens to generate during warmup
|
| 31 |
+
THINKING_MODE = False
|
| 32 |
+
THINKING_PROMPT = """You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem."""
|
| 33 |
|
| 34 |
class TokenPrinter:
|
| 35 |
"""Handles background printing of generated tokens."""
|
|
|
|
| 42 |
self.lock = threading.Lock()
|
| 43 |
self.thinking = True # Track if we're still in thinking mode
|
| 44 |
self.decoding_buffer = [] # Buffer for token IDs
|
| 45 |
+
# Timing and stats tracking
|
| 46 |
self.start_time = time.time()
|
| 47 |
self.token_count = 0
|
| 48 |
+
self.prefill_time = 0
|
| 49 |
+
self.inference_time = 0
|
| 50 |
+
self.context_pos = 0
|
| 51 |
self.start()
|
| 52 |
|
| 53 |
def start(self):
|
|
|
|
| 108 |
self.thread.join(timeout=1.0)
|
| 109 |
except Exception:
|
| 110 |
pass
|
| 111 |
+
print(RESET_COLOR) # Reset color at the end
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
return self.buffer
|
| 113 |
|
| 114 |
+
def set_timing(self, prefill_time, inference_time, context_pos):
|
| 115 |
+
"""Set timing information."""
|
| 116 |
+
self.prefill_time = prefill_time
|
| 117 |
+
self.inference_time = inference_time
|
| 118 |
+
self.context_pos = context_pos
|
| 119 |
+
|
| 120 |
def parse_model_path(path):
|
| 121 |
"""Parse model path and return full path with .mlmodelc or .mlpackage extension."""
|
| 122 |
path = Path(path)
|
|
|
|
| 193 |
print("\nTry using the .mlpackage version instead, or recompile the model.")
|
| 194 |
raise
|
| 195 |
|
| 196 |
+
def parse_args():
|
| 197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
| 198 |
+
|
| 199 |
+
# Add meta.yaml option
|
| 200 |
+
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
| 201 |
+
|
| 202 |
+
# Add existing arguments
|
| 203 |
+
parser.add_argument('--d', '--dir', type=str, default='.',
|
| 204 |
+
help='Directory containing model files (default: current directory)')
|
| 205 |
+
parser.add_argument('--embed', type=str, required=False,
|
| 206 |
+
help='Path to embeddings model (relative to --dir)')
|
| 207 |
+
parser.add_argument('--ffn', type=str, required=False,
|
| 208 |
+
help='Path to FFN model (can be chunked, relative to --dir)')
|
| 209 |
+
parser.add_argument('--lmhead', type=str, required=False,
|
| 210 |
+
help='Path to LM head model (relative to --dir)')
|
| 211 |
+
parser.add_argument('--tokenizer', type=str, required=False,
|
| 212 |
+
help='Path to tokenizer')
|
| 213 |
+
|
| 214 |
+
# Add new argument for auto-generation
|
| 215 |
+
parser.add_argument('--prompt', type=str,
|
| 216 |
+
help='If specified, run once with this prompt and exit')
|
| 217 |
+
|
| 218 |
+
# Add no-warmup flag
|
| 219 |
+
parser.add_argument('--nw', action='store_true',
|
| 220 |
+
help='Skip warmup phase')
|
| 221 |
+
|
| 222 |
+
# Model configuration
|
| 223 |
+
parser.add_argument('--context-length', type=int,
|
| 224 |
+
help='Context length for the model (default: 512), if not provided, it will be detected from the model directory name ctxNUMBER')
|
| 225 |
+
parser.add_argument('--batch-size', type=int,
|
| 226 |
+
help='Batch size for prefill (default: 64)')
|
| 227 |
+
|
| 228 |
+
args = parser.parse_args()
|
| 229 |
+
|
| 230 |
+
# If meta.yaml is provided, load parameters from it
|
| 231 |
+
if args.meta:
|
| 232 |
+
try:
|
| 233 |
+
with open(args.meta, 'r') as f:
|
| 234 |
+
meta = yaml.safe_load(f)
|
| 235 |
+
params = meta['model_info']['parameters']
|
| 236 |
+
|
| 237 |
+
# Set model directory to meta.yaml directory if not specified
|
| 238 |
+
if not args.d or args.d == '.':
|
| 239 |
+
args.d = str(Path(args.meta).parent)
|
| 240 |
+
|
| 241 |
+
# Build model paths based on parameters
|
| 242 |
+
prefix = params.get('model_prefix', 'llama') # Default to 'llama' if not specified
|
| 243 |
+
lut_ffn = f"_lut{params['lut_ffn']}" if params['lut_ffn'] != 'none' else ''
|
| 244 |
+
lut_lmhead = f"_lut{params['lut_lmhead']}" if params['lut_lmhead'] != 'none' else ''
|
| 245 |
+
num_chunks = int(params['num_chunks'])
|
| 246 |
+
|
| 247 |
+
# Set model paths if not specified
|
| 248 |
+
if not args.embed:
|
| 249 |
+
args.embed = f'{prefix}_embeddings'
|
| 250 |
+
if not args.lmhead:
|
| 251 |
+
args.lmhead = f'{prefix}_lm_head{lut_lmhead}'
|
| 252 |
+
if not args.ffn:
|
| 253 |
+
args.ffn = f'{prefix}_FFN_PF{lut_ffn}_chunk_01of{num_chunks:02d}'
|
| 254 |
+
if not args.tokenizer:
|
| 255 |
+
args.tokenizer = args.d
|
| 256 |
+
|
| 257 |
+
# Set other parameters if not overridden by command line
|
| 258 |
+
if args.context_length is None:
|
| 259 |
+
args.context_length = int(params['context_length'])
|
| 260 |
+
if args.batch_size is None:
|
| 261 |
+
args.batch_size = int(params['batch_size'])
|
| 262 |
+
args.num_chunks = num_chunks
|
| 263 |
+
|
| 264 |
+
print(f"\nLoaded parameters from {args.meta}:")
|
| 265 |
+
print(f" Context Length: {args.context_length}")
|
| 266 |
+
print(f" Batch Size: {args.batch_size}")
|
| 267 |
+
print(f" Num Chunks: {args.num_chunks}")
|
| 268 |
+
print(f" Models Directory: {args.d}")
|
| 269 |
+
print(f" Embeddings: {args.embed}")
|
| 270 |
+
print(f" LM Head: {args.lmhead}")
|
| 271 |
+
print(f" FFN: {args.ffn}")
|
| 272 |
+
|
| 273 |
+
except Exception as e:
|
| 274 |
+
print(f"\nError loading meta.yaml: {str(e)}")
|
| 275 |
+
sys.exit(1)
|
| 276 |
+
|
| 277 |
+
return args
|
| 278 |
+
|
| 279 |
def load_metadata(model,args):
|
| 280 |
# Extract metadata and config parameters
|
| 281 |
metadata = {}
|
|
|
|
| 474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
| 475 |
return mask
|
| 476 |
|
| 477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
| 478 |
"""Run prefill on the input sequence."""
|
| 479 |
+
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
|
|
|
|
|
|
| 480 |
|
| 481 |
# Process in batches
|
| 482 |
batch_pos = 0
|
| 483 |
+
while batch_pos < current_pos:
|
| 484 |
+
batch_end = min(batch_pos + batch_size, current_pos)
|
| 485 |
current_batch_size = batch_end - batch_pos
|
| 486 |
|
| 487 |
+
#print(f"[DEBUG] Prefill batch {batch_pos}-{batch_end} (size={current_batch_size})")
|
| 488 |
+
|
| 489 |
# Get current batch
|
| 490 |
batch_input = input_ids[:, batch_pos:batch_end]
|
| 491 |
|
| 492 |
+
# Pad to full batch size
|
| 493 |
batch_input = F.pad(
|
| 494 |
batch_input,
|
| 495 |
(0, batch_size - current_batch_size),
|
| 496 |
value=0
|
| 497 |
)
|
| 498 |
|
| 499 |
+
# Generate position IDs for this batch
|
| 500 |
+
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
| 501 |
+
|
| 502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
| 503 |
+
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
| 504 |
|
| 505 |
# Run embeddings
|
| 506 |
hidden_states = torch.from_numpy(
|
| 507 |
embed_model.predict({'input_ids': batch_input.numpy()})['hidden_states']
|
| 508 |
)
|
| 509 |
|
| 510 |
+
# Run through FFN chunks
|
| 511 |
for ffn_model in ffn_models:
|
| 512 |
if isinstance(ffn_model, dict):
|
| 513 |
inputs = {
|
| 514 |
+
'hidden_states': hidden_states.numpy(),
|
| 515 |
+
'position_ids': position_ids.numpy(),
|
| 516 |
+
'causal_mask': batch_causal_mask.numpy(),
|
| 517 |
+
'current_pos': np.array([batch_pos], dtype=np.int32)
|
| 518 |
}
|
| 519 |
output = ffn_model['prefill'].predict(inputs, state)
|
| 520 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
| 521 |
|
| 522 |
batch_pos = batch_end
|
| 523 |
|
| 524 |
+
return torch.tensor([current_pos], dtype=torch.int32)
|
| 525 |
|
| 526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
| 527 |
"""Generate the next token."""
|
| 528 |
# Get current token
|
| 529 |
+
current_token = input_ids[:, pos-1:pos]
|
| 530 |
|
| 531 |
# Run embeddings
|
| 532 |
hidden_states = torch.from_numpy(
|
| 533 |
embed_model.predict({'input_ids': current_token.numpy()})['hidden_states']
|
| 534 |
+
)
|
| 535 |
|
| 536 |
# Create masks
|
| 537 |
update_mask = torch.zeros((1, 1, context_length, 1), dtype=torch.float16)
|
| 538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
| 539 |
+
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
|
|
|
|
|
|
| 540 |
|
| 541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
| 542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
| 543 |
+
|
| 544 |
+
# Run through FFN chunks
|
| 545 |
for ffn_model in ffn_models:
|
| 546 |
if isinstance(ffn_model, dict):
|
| 547 |
inputs = {
|
| 548 |
'hidden_states': hidden_states.numpy(),
|
| 549 |
'update_mask': update_mask.numpy(),
|
| 550 |
'position_ids': position_ids.numpy(),
|
| 551 |
+
'causal_mask': single_causal_mask.numpy(),
|
| 552 |
'current_pos': position_ids.numpy()
|
| 553 |
}
|
| 554 |
output = ffn_model['infer'].predict(inputs, state)
|
| 555 |
hidden_states = torch.from_numpy(output['output_hidden_states'])
|
| 556 |
|
| 557 |
+
# Run LM head and get next token
|
| 558 |
lm_output = lmhead_model.predict({'hidden_states': hidden_states.numpy()})
|
|
|
|
|
|
|
| 559 |
|
|
|
|
| 560 |
if 'logits1' in lm_output:
|
|
|
|
| 561 |
logits_parts = []
|
| 562 |
for i in range(1, 9):
|
| 563 |
key = f'logits{i}'
|
| 564 |
if key in lm_output:
|
| 565 |
logits_parts.append(torch.from_numpy(lm_output[key]))
|
| 566 |
+
logits = torch.cat(logits_parts, dim=-1)
|
| 567 |
else:
|
|
|
|
| 568 |
logits = torch.from_numpy(lm_output['output_logits'])
|
| 569 |
|
|
|
|
| 570 |
if temperature > 0:
|
| 571 |
logits = logits / temperature
|
| 572 |
probs = F.softmax(logits[0, -1, :], dim=-1)
|
|
|
|
| 588 |
print("\nCreated unified transformer state")
|
| 589 |
return state
|
| 590 |
|
| 591 |
+
def initialize_causal_mask(context_length):
|
| 592 |
+
"""Initialize causal mask for transformer attention."""
|
| 593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
| 594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
| 595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
| 596 |
+
return causal_mask
|
| 597 |
+
|
| 598 |
+
def get_user_input():
|
| 599 |
+
"""Get input from user, handling special key combinations."""
|
| 600 |
+
global THINKING_MODE
|
| 601 |
+
try:
|
| 602 |
+
import termios
|
| 603 |
+
import tty
|
| 604 |
+
import sys
|
| 605 |
+
|
| 606 |
+
def _getch():
|
| 607 |
+
fd = sys.stdin.fileno()
|
| 608 |
+
old_settings = termios.tcgetattr(fd)
|
| 609 |
+
try:
|
| 610 |
+
tty.setraw(sys.stdin.fileno())
|
| 611 |
+
ch = sys.stdin.read(1)
|
| 612 |
+
finally:
|
| 613 |
+
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
| 614 |
+
return ch
|
| 615 |
+
|
| 616 |
+
buffer = []
|
| 617 |
+
while True:
|
| 618 |
+
char = _getch()
|
| 619 |
+
|
| 620 |
+
# Debug: print the character code
|
| 621 |
+
print(f"\nKey pressed: {repr(char)} (hex: {hex(ord(char))})")
|
| 622 |
+
|
| 623 |
+
# Check for Enter key
|
| 624 |
+
if char == '\r' or char == '\n':
|
| 625 |
+
print() # Move to next line
|
| 626 |
+
input_text = ''.join(buffer)
|
| 627 |
+
# Check if the command is /t
|
| 628 |
+
if input_text == '/t':
|
| 629 |
+
THINKING_MODE = not THINKING_MODE
|
| 630 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
| 631 |
+
buffer = [] # Clear buffer
|
| 632 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
| 633 |
+
continue
|
| 634 |
+
return input_text
|
| 635 |
+
|
| 636 |
+
# Handle backspace
|
| 637 |
+
if char == '\x7f': # backspace
|
| 638 |
+
if buffer:
|
| 639 |
+
buffer.pop()
|
| 640 |
+
sys.stdout.write('\b \b') # Erase character
|
| 641 |
+
sys.stdout.flush()
|
| 642 |
+
continue
|
| 643 |
+
|
| 644 |
+
# Handle Ctrl-C
|
| 645 |
+
if char == '\x03': # Ctrl-C
|
| 646 |
+
print("^C")
|
| 647 |
+
raise KeyboardInterrupt
|
| 648 |
+
|
| 649 |
+
# Print character and add to buffer
|
| 650 |
+
sys.stdout.write(char)
|
| 651 |
+
sys.stdout.flush()
|
| 652 |
+
buffer.append(char)
|
| 653 |
+
|
| 654 |
+
except ImportError:
|
| 655 |
+
# Fallback for systems without termios
|
| 656 |
+
return input("> ")
|
| 657 |
+
|
| 658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
| 659 |
"""Interactive chat loop."""
|
| 660 |
+
global THINKING_MODE
|
| 661 |
context_length = metadata.get('context_length')
|
| 662 |
batch_size = metadata.get('batch_size', 64)
|
| 663 |
|
| 664 |
if not warmup:
|
| 665 |
print(f"\nUsing context length: {context_length}")
|
| 666 |
print("\nStarting chat session. Press Ctrl+D to exit.")
|
| 667 |
+
print("Type your message and press Enter to chat. Use /t to toggle thinking mode.")
|
| 668 |
+
print(f"Thinking mode is {'ON' if THINKING_MODE else 'OFF'}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
|
| 670 |
+
# Keep track of conversation history
|
| 671 |
conversation = []
|
| 672 |
|
| 673 |
try:
|
| 674 |
while True:
|
| 675 |
try:
|
| 676 |
if not warmup:
|
| 677 |
+
print(f"\n{LIGHT_GREEN}You{' (thinking)' if THINKING_MODE else ''}:{RESET_COLOR}", end=' ', flush=True)
|
| 678 |
if auto_prompt is not None:
|
| 679 |
user_input = auto_prompt
|
| 680 |
if not warmup:
|
|
|
|
| 685 |
if not warmup:
|
| 686 |
print("\nExiting chat...")
|
| 687 |
break
|
| 688 |
+
|
| 689 |
if not user_input:
|
| 690 |
continue
|
| 691 |
+
|
| 692 |
+
# Handle /t command
|
| 693 |
+
if user_input == "/t":
|
| 694 |
+
THINKING_MODE = not THINKING_MODE
|
| 695 |
+
print(f"Thinking mode {'ON' if THINKING_MODE else 'OFF'}")
|
| 696 |
+
continue
|
| 697 |
|
| 698 |
+
# Add user message to conversation
|
| 699 |
+
conversation.append({"role": "user", "content": user_input})
|
| 700 |
+
|
| 701 |
+
# Format using chat template with full history
|
| 702 |
+
if THINKING_MODE:
|
| 703 |
+
# Add thinking prompt to system message
|
| 704 |
+
conversation_with_thinking = [{"role": "system", "content": THINKING_PROMPT}] + conversation
|
| 705 |
+
base_input_ids = tokenizer.apply_chat_template(
|
| 706 |
+
conversation_with_thinking,
|
| 707 |
return_tensors="pt",
|
| 708 |
add_generation_prompt=True
|
| 709 |
).to(torch.int32)
|
| 710 |
else:
|
| 711 |
+
base_input_ids = tokenizer.apply_chat_template(
|
| 712 |
+
conversation,
|
|
|
|
|
|
|
| 713 |
return_tensors="pt",
|
| 714 |
+
add_generation_prompt=True
|
| 715 |
+
).to(torch.int32)
|
| 716 |
+
|
| 717 |
+
# Check if we need to trim history
|
| 718 |
+
while base_input_ids.size(1) > context_length - 100: # Leave room for response
|
| 719 |
+
# Remove oldest message pair (user + assistant)
|
| 720 |
+
if len(conversation) > 2:
|
| 721 |
+
conversation = conversation[2:] # Remove oldest pair
|
| 722 |
+
base_input_ids = tokenizer.apply_chat_template(
|
| 723 |
+
conversation,
|
| 724 |
+
return_tensors="pt",
|
| 725 |
+
add_generation_prompt=True
|
| 726 |
+
).to(torch.int32)
|
| 727 |
+
else:
|
| 728 |
+
# If only current message remains and still too long, truncate
|
| 729 |
+
base_input_ids = base_input_ids[:, -context_length//2:]
|
| 730 |
+
break
|
| 731 |
+
|
| 732 |
+
context_pos = base_input_ids.size(1)
|
| 733 |
|
| 734 |
+
# Pad sequence to context_size
|
| 735 |
+
input_ids = F.pad(
|
| 736 |
+
base_input_ids,
|
| 737 |
+
(0, context_length - context_pos),
|
| 738 |
+
value=0
|
| 739 |
+
)
|
| 740 |
|
| 741 |
if not warmup:
|
| 742 |
print(f"\n{LIGHT_BLUE}Assistant:{RESET_COLOR}", end=' ', flush=True)
|
| 743 |
|
| 744 |
+
# Initialize token printer and collect response
|
| 745 |
token_printer = TokenPrinter(tokenizer)
|
| 746 |
+
response_tokens = []
|
| 747 |
+
generation_start_time = time.time()
|
| 748 |
|
| 749 |
try:
|
| 750 |
+
# Run prefill on entire context
|
|
|
|
|
|
|
|
|
|
| 751 |
current_pos = run_prefill(
|
| 752 |
embed_model,
|
| 753 |
ffn_models,
|
|
|
|
| 755 |
context_pos,
|
| 756 |
context_length,
|
| 757 |
batch_size,
|
| 758 |
+
state,
|
| 759 |
+
causal_mask
|
| 760 |
)
|
| 761 |
+
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
| 762 |
|
| 763 |
+
# Generation loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
pos = context_pos
|
| 765 |
+
tokens_generated = 0
|
| 766 |
+
inference_start = time.time() # Start inference timing
|
| 767 |
|
| 768 |
+
while True:
|
| 769 |
+
# Check if we need to shift window
|
| 770 |
+
if pos >= context_length - 2:
|
| 771 |
+
# Calculate shift to maintain full batches
|
| 772 |
+
batch_size = metadata.get('batch_size', 64)
|
| 773 |
+
# Calculate max batches that fit in context
|
| 774 |
+
max_batches = context_length // batch_size
|
| 775 |
+
desired_batches = max(1, max_batches - 2) # Leave room for new tokens
|
| 776 |
+
new_size = min(desired_batches * batch_size, context_length - batch_size)
|
| 777 |
+
|
| 778 |
+
# Create shifted input_ids
|
| 779 |
+
tmp = torch.zeros((1, context_length), dtype=torch.int32)
|
| 780 |
+
tmp[:,0:new_size] = input_ids[:,pos-new_size:pos]
|
| 781 |
+
input_ids = tmp
|
| 782 |
+
|
| 783 |
+
# Reset state and run prefill
|
| 784 |
+
# keep the same state
|
| 785 |
+
#state = create_unified_state(ffn_models, context_length)
|
| 786 |
+
current_pos = run_prefill(
|
| 787 |
+
embed_model,
|
| 788 |
+
ffn_models,
|
| 789 |
+
input_ids,
|
| 790 |
+
new_size, # Prefill the entire shifted content
|
| 791 |
+
context_length,
|
| 792 |
+
batch_size,
|
| 793 |
+
state,
|
| 794 |
+
causal_mask
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# Start generating from the next position
|
| 798 |
+
pos = new_size # Don't back up, continue from where we left off
|
| 799 |
+
|
| 800 |
+
#print(f"\n[DEBUG] After shift - next token will be at pos {pos}")
|
| 801 |
+
#print(f"[DEBUG] Context before next token: {tokenizer.decode(input_ids[0, pos-40:pos])}")
|
| 802 |
+
|
| 803 |
+
window_shifted = True
|
| 804 |
+
|
| 805 |
# Generate next token
|
| 806 |
next_token = generate_next_token(
|
| 807 |
embed_model,
|
|
|
|
| 810 |
input_ids,
|
| 811 |
pos,
|
| 812 |
context_length,
|
| 813 |
+
state,
|
| 814 |
+
causal_mask
|
| 815 |
)
|
| 816 |
|
| 817 |
+
# Add token
|
| 818 |
+
input_ids[0, pos] = next_token
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
if not warmup:
|
| 820 |
token_printer.add_token(next_token)
|
| 821 |
token_printer.drain_buffer()
|
| 822 |
+
response_tokens.append(next_token)
|
| 823 |
|
| 824 |
pos += 1
|
| 825 |
tokens_generated += 1
|
|
|
|
| 826 |
|
| 827 |
+
# In warmup mode, limit tokens
|
| 828 |
if warmup and tokens_generated >= WARMUP_TOKEN_LIMIT:
|
| 829 |
break
|
| 830 |
+
|
| 831 |
if next_token == tokenizer.eos_token_id:
|
| 832 |
break
|
| 833 |
|
| 834 |
+
inference_time = time.time() - inference_start # Calculate inference time
|
| 835 |
+
|
| 836 |
+
# Add assistant response to conversation
|
| 837 |
+
response_text = token_printer.stop()
|
| 838 |
+
conversation.append({"role": "assistant", "content": response_text})
|
| 839 |
|
| 840 |
+
# Print stats only if not in warmup
|
| 841 |
if not warmup:
|
| 842 |
+
total_time = time.time() - generation_start_time
|
| 843 |
+
prefill_time = total_time - inference_time
|
| 844 |
+
inference_tokens_per_sec = len(response_tokens) / inference_time if inference_time > 0 else 0
|
| 845 |
+
prefill_ms = prefill_time * 1000
|
| 846 |
+
prefill_tokens_per_sec = context_pos / prefill_time if prefill_time > 0 else 0
|
| 847 |
+
print(f"{DARK_BLUE}{inference_tokens_per_sec:.1f} t/s, "
|
| 848 |
+
f"TTFT: {prefill_ms:.1f}ms ({prefill_tokens_per_sec:.1f} t/s), "
|
| 849 |
+
f"{len(response_tokens)} tokens{RESET_COLOR}")
|
|
|
|
| 850 |
|
|
|
|
| 851 |
if auto_prompt is not None:
|
| 852 |
break
|
| 853 |
|
| 854 |
except KeyboardInterrupt:
|
| 855 |
+
if not warmup:
|
| 856 |
+
print("\nGeneration interrupted")
|
| 857 |
token_printer.stop()
|
| 858 |
continue
|
| 859 |
|
| 860 |
except Exception as e:
|
| 861 |
+
if not warmup:
|
| 862 |
+
print(f"\nError in chat loop: {str(e)}")
|
| 863 |
+
import traceback
|
| 864 |
+
traceback.print_exc()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
|
| 866 |
def main():
|
| 867 |
args = parse_args()
|
|
|
|
| 914 |
# Create unified state once
|
| 915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
| 916 |
|
| 917 |
+
# Initialize causal mask once
|
| 918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
| 919 |
+
|
| 920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
| 921 |
if not args.nw:
|
| 922 |
for i in range(2):
|
|
|
|
| 926 |
lmhead_model=lmhead_model,
|
| 927 |
tokenizer=tokenizer,
|
| 928 |
metadata=metadata,
|
| 929 |
+
state=state, # Pass the state
|
| 930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 931 |
warmup=True,
|
| 932 |
auto_prompt="who are you?"
|
| 933 |
)
|
|
|
|
| 939 |
lmhead_model=lmhead_model,
|
| 940 |
tokenizer=tokenizer,
|
| 941 |
metadata=metadata,
|
| 942 |
+
state=state, # Pass the state
|
| 943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
| 944 |
warmup=False,
|
| 945 |
auto_prompt=args.prompt
|
| 946 |
)
|