Fixed GIL issue
Browse filesrace condition between CoreML and causal_mask update
- chat_full.py +25 -17
chat_full.py
CHANGED
@@ -194,7 +194,7 @@ def load_model(path, function_name=None):
|
|
194 |
raise
|
195 |
|
196 |
def parse_args():
|
197 |
-
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting (c) 2025 Anemll')
|
198 |
|
199 |
# Add meta.yaml option
|
200 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
@@ -474,7 +474,7 @@ def make_causal_mask(length, start):
|
|
474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
475 |
return mask
|
476 |
|
477 |
-
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
|
478 |
"""Run prefill on the input sequence."""
|
479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
480 |
|
@@ -499,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
499 |
# Generate position IDs for this batch
|
500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
501 |
|
502 |
-
#
|
503 |
-
causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
|
504 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
505 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
506 |
|
507 |
# Run embeddings
|
@@ -525,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
|
|
525 |
|
526 |
return torch.tensor([current_pos], dtype=torch.int32)
|
527 |
|
528 |
-
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state
|
529 |
"""Generate the next token."""
|
530 |
# Get current token
|
531 |
current_token = input_ids[:, pos-1:pos]
|
@@ -540,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
|
|
540 |
update_mask[0, 0, pos-1, 0] = 1.0
|
541 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
542 |
|
543 |
-
#
|
544 |
-
|
545 |
-
single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
|
546 |
|
547 |
# Run through FFN chunks
|
548 |
for ffn_model in ffn_models:
|
@@ -591,6 +588,13 @@ def create_unified_state(ffn_models, context_length):
|
|
591 |
print("\nCreated unified transformer state")
|
592 |
return state
|
593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
594 |
def get_user_input():
|
595 |
"""Get input from user, handling special key combinations."""
|
596 |
global THINKING_MODE
|
@@ -651,7 +655,7 @@ def get_user_input():
|
|
651 |
# Fallback for systems without termios
|
652 |
return input("> ")
|
653 |
|
654 |
-
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
|
655 |
"""Interactive chat loop."""
|
656 |
global THINKING_MODE
|
657 |
context_length = metadata.get('context_length')
|
@@ -743,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
743 |
generation_start_time = time.time()
|
744 |
|
745 |
try:
|
746 |
-
# Create initial causal mask
|
747 |
-
causal_mask = make_causal_mask(context_length, 0)
|
748 |
-
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
749 |
-
|
750 |
# Run prefill on entire context
|
751 |
current_pos = run_prefill(
|
752 |
embed_model,
|
@@ -755,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
755 |
context_pos,
|
756 |
context_length,
|
757 |
batch_size,
|
758 |
-
state
|
|
|
759 |
)
|
760 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
761 |
|
@@ -789,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
789 |
new_size, # Prefill the entire shifted content
|
790 |
context_length,
|
791 |
batch_size,
|
792 |
-
state
|
|
|
793 |
)
|
794 |
|
795 |
# Start generating from the next position
|
@@ -808,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
|
|
808 |
input_ids,
|
809 |
pos,
|
810 |
context_length,
|
811 |
-
state
|
|
|
812 |
)
|
813 |
|
814 |
# Add token
|
@@ -911,6 +914,9 @@ def main():
|
|
911 |
# Create unified state once
|
912 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
913 |
|
|
|
|
|
|
|
914 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
915 |
if not args.nw:
|
916 |
for i in range(2):
|
@@ -921,6 +927,7 @@ def main():
|
|
921 |
tokenizer=tokenizer,
|
922 |
metadata=metadata,
|
923 |
state=state, # Pass the state
|
|
|
924 |
warmup=True,
|
925 |
auto_prompt="who are you?"
|
926 |
)
|
@@ -933,6 +940,7 @@ def main():
|
|
933 |
tokenizer=tokenizer,
|
934 |
metadata=metadata,
|
935 |
state=state, # Pass the state
|
|
|
936 |
warmup=False,
|
937 |
auto_prompt=args.prompt
|
938 |
)
|
|
|
194 |
raise
|
195 |
|
196 |
def parse_args():
|
197 |
+
parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
|
198 |
|
199 |
# Add meta.yaml option
|
200 |
parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
|
|
|
474 |
mask[:, :, col_indices <= (row_indices + start)] = 0
|
475 |
return mask
|
476 |
|
477 |
+
def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
|
478 |
"""Run prefill on the input sequence."""
|
479 |
#print(f"[DEBUG] Running prefill from 0 to {current_pos}")
|
480 |
|
|
|
499 |
# Generate position IDs for this batch
|
500 |
position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
|
501 |
|
502 |
+
# Use the pre-initialized causal mask and extract the batch portion
|
|
|
|
|
503 |
batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
|
504 |
|
505 |
# Run embeddings
|
|
|
523 |
|
524 |
return torch.tensor([current_pos], dtype=torch.int32)
|
525 |
|
526 |
+
def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
|
527 |
"""Generate the next token."""
|
528 |
# Get current token
|
529 |
current_token = input_ids[:, pos-1:pos]
|
|
|
538 |
update_mask[0, 0, pos-1, 0] = 1.0
|
539 |
position_ids = torch.tensor([pos-1], dtype=torch.int32)
|
540 |
|
541 |
+
# Use the pre-initialized causal mask and extract the single position portion
|
542 |
+
single_causal_mask = causal_mask[:, :, pos-1:pos, :]
|
|
|
543 |
|
544 |
# Run through FFN chunks
|
545 |
for ffn_model in ffn_models:
|
|
|
588 |
print("\nCreated unified transformer state")
|
589 |
return state
|
590 |
|
591 |
+
def initialize_causal_mask(context_length):
|
592 |
+
"""Initialize causal mask for transformer attention."""
|
593 |
+
causal_mask = make_causal_mask(context_length, 0)
|
594 |
+
causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
|
595 |
+
print(f"\nInitialized causal mask for context length {context_length}")
|
596 |
+
return causal_mask
|
597 |
+
|
598 |
def get_user_input():
|
599 |
"""Get input from user, handling special key combinations."""
|
600 |
global THINKING_MODE
|
|
|
655 |
# Fallback for systems without termios
|
656 |
return input("> ")
|
657 |
|
658 |
+
def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
|
659 |
"""Interactive chat loop."""
|
660 |
global THINKING_MODE
|
661 |
context_length = metadata.get('context_length')
|
|
|
747 |
generation_start_time = time.time()
|
748 |
|
749 |
try:
|
|
|
|
|
|
|
|
|
750 |
# Run prefill on entire context
|
751 |
current_pos = run_prefill(
|
752 |
embed_model,
|
|
|
755 |
context_pos,
|
756 |
context_length,
|
757 |
batch_size,
|
758 |
+
state,
|
759 |
+
causal_mask
|
760 |
)
|
761 |
#print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
|
762 |
|
|
|
790 |
new_size, # Prefill the entire shifted content
|
791 |
context_length,
|
792 |
batch_size,
|
793 |
+
state,
|
794 |
+
causal_mask
|
795 |
)
|
796 |
|
797 |
# Start generating from the next position
|
|
|
810 |
input_ids,
|
811 |
pos,
|
812 |
context_length,
|
813 |
+
state,
|
814 |
+
causal_mask
|
815 |
)
|
816 |
|
817 |
# Add token
|
|
|
914 |
# Create unified state once
|
915 |
state = create_unified_state(ffn_models, metadata['context_length'])
|
916 |
|
917 |
+
# Initialize causal mask once
|
918 |
+
causal_mask = initialize_causal_mask(metadata['context_length'])
|
919 |
+
|
920 |
# Warmup runs to prevent Python GIL issues with CoreML !
|
921 |
if not args.nw:
|
922 |
for i in range(2):
|
|
|
927 |
tokenizer=tokenizer,
|
928 |
metadata=metadata,
|
929 |
state=state, # Pass the state
|
930 |
+
causal_mask=causal_mask, # Pass the causal mask
|
931 |
warmup=True,
|
932 |
auto_prompt="who are you?"
|
933 |
)
|
|
|
940 |
tokenizer=tokenizer,
|
941 |
metadata=metadata,
|
942 |
state=state, # Pass the state
|
943 |
+
causal_mask=causal_mask, # Pass the causal mask
|
944 |
warmup=False,
|
945 |
auto_prompt=args.prompt
|
946 |
)
|