anemll commited on
Commit
a335a46
·
verified ·
1 Parent(s): 7b1f81c

Fixed GIL issue

Browse files

race condition between CoreML and causal_mask update

Files changed (1) hide show
  1. chat_full.py +25 -17
chat_full.py CHANGED
@@ -194,7 +194,7 @@ def load_model(path, function_name=None):
194
  raise
195
 
196
  def parse_args():
197
- parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting (c) 2025 Anemll')
198
 
199
  # Add meta.yaml option
200
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
@@ -474,7 +474,7 @@ def make_causal_mask(length, start):
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
- def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state):
478
  """Run prefill on the input sequence."""
479
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
480
 
@@ -499,9 +499,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
499
  # Generate position IDs for this batch
500
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
 
502
- # Create causal mask for this batch
503
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for prefill
504
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
505
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
506
 
507
  # Run embeddings
@@ -525,7 +523,7 @@ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length,
525
 
526
  return torch.tensor([current_pos], dtype=torch.int32)
527
 
528
- def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state=None, temperature=0.0):
529
  """Generate the next token."""
530
  # Get current token
531
  current_token = input_ids[:, pos-1:pos]
@@ -540,9 +538,8 @@ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, c
540
  update_mask[0, 0, pos-1, 0] = 1.0
541
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
542
 
543
- # Create causal mask for current position
544
- causal_mask = make_causal_mask(context_length, 0) # Always start from 0 for generation
545
- single_causal_mask = torch.tensor(causal_mask[:, :, pos-1:pos, :], dtype=torch.float16)
546
 
547
  # Run through FFN chunks
548
  for ffn_model in ffn_models:
@@ -591,6 +588,13 @@ def create_unified_state(ffn_models, context_length):
591
  print("\nCreated unified transformer state")
592
  return state
593
 
 
 
 
 
 
 
 
594
  def get_user_input():
595
  """Get input from user, handling special key combinations."""
596
  global THINKING_MODE
@@ -651,7 +655,7 @@ def get_user_input():
651
  # Fallback for systems without termios
652
  return input("> ")
653
 
654
- def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, auto_prompt=None, warmup=False):
655
  """Interactive chat loop."""
656
  global THINKING_MODE
657
  context_length = metadata.get('context_length')
@@ -743,10 +747,6 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
743
  generation_start_time = time.time()
744
 
745
  try:
746
- # Create initial causal mask
747
- causal_mask = make_causal_mask(context_length, 0)
748
- causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
749
-
750
  # Run prefill on entire context
751
  current_pos = run_prefill(
752
  embed_model,
@@ -755,7 +755,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
755
  context_pos,
756
  context_length,
757
  batch_size,
758
- state
 
759
  )
760
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
761
 
@@ -789,7 +790,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
789
  new_size, # Prefill the entire shifted content
790
  context_length,
791
  batch_size,
792
- state
 
793
  )
794
 
795
  # Start generating from the next position
@@ -808,7 +810,8 @@ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state,
808
  input_ids,
809
  pos,
810
  context_length,
811
- state
 
812
  )
813
 
814
  # Add token
@@ -911,6 +914,9 @@ def main():
911
  # Create unified state once
912
  state = create_unified_state(ffn_models, metadata['context_length'])
913
 
 
 
 
914
  # Warmup runs to prevent Python GIL issues with CoreML !
915
  if not args.nw:
916
  for i in range(2):
@@ -921,6 +927,7 @@ def main():
921
  tokenizer=tokenizer,
922
  metadata=metadata,
923
  state=state, # Pass the state
 
924
  warmup=True,
925
  auto_prompt="who are you?"
926
  )
@@ -933,6 +940,7 @@ def main():
933
  tokenizer=tokenizer,
934
  metadata=metadata,
935
  state=state, # Pass the state
 
936
  warmup=False,
937
  auto_prompt=args.prompt
938
  )
 
194
  raise
195
 
196
  def parse_args():
197
+ parser = argparse.ArgumentParser(description='Full Chat with CoreML LLaMA with context window shifting, gil resolved (c) 2025 Anemll')
198
 
199
  # Add meta.yaml option
200
  parser.add_argument('--meta', type=str, help='Path to meta.yaml to load all parameters')
 
474
  mask[:, :, col_indices <= (row_indices + start)] = 0
475
  return mask
476
 
477
+ def run_prefill(embed_model, ffn_models, input_ids, current_pos, context_length, batch_size, state, causal_mask):
478
  """Run prefill on the input sequence."""
479
  #print(f"[DEBUG] Running prefill from 0 to {current_pos}")
480
 
 
499
  # Generate position IDs for this batch
500
  position_ids = torch.arange(batch_pos, batch_pos + batch_size, dtype=torch.int32)
501
 
502
+ # Use the pre-initialized causal mask and extract the batch portion
 
 
503
  batch_causal_mask = causal_mask[:, :, batch_pos:batch_pos + batch_size, :]
504
 
505
  # Run embeddings
 
523
 
524
  return torch.tensor([current_pos], dtype=torch.int32)
525
 
526
+ def generate_next_token(embed_model, ffn_models, lmhead_model, input_ids, pos, context_length, state, causal_mask, temperature=0.0):
527
  """Generate the next token."""
528
  # Get current token
529
  current_token = input_ids[:, pos-1:pos]
 
538
  update_mask[0, 0, pos-1, 0] = 1.0
539
  position_ids = torch.tensor([pos-1], dtype=torch.int32)
540
 
541
+ # Use the pre-initialized causal mask and extract the single position portion
542
+ single_causal_mask = causal_mask[:, :, pos-1:pos, :]
 
543
 
544
  # Run through FFN chunks
545
  for ffn_model in ffn_models:
 
588
  print("\nCreated unified transformer state")
589
  return state
590
 
591
+ def initialize_causal_mask(context_length):
592
+ """Initialize causal mask for transformer attention."""
593
+ causal_mask = make_causal_mask(context_length, 0)
594
+ causal_mask = torch.tensor(causal_mask, dtype=torch.float16)
595
+ print(f"\nInitialized causal mask for context length {context_length}")
596
+ return causal_mask
597
+
598
  def get_user_input():
599
  """Get input from user, handling special key combinations."""
600
  global THINKING_MODE
 
655
  # Fallback for systems without termios
656
  return input("> ")
657
 
658
+ def chat_loop(embed_model, ffn_models, lmhead_model, tokenizer, metadata, state, causal_mask, auto_prompt=None, warmup=False):
659
  """Interactive chat loop."""
660
  global THINKING_MODE
661
  context_length = metadata.get('context_length')
 
747
  generation_start_time = time.time()
748
 
749
  try:
 
 
 
 
750
  # Run prefill on entire context
751
  current_pos = run_prefill(
752
  embed_model,
 
755
  context_pos,
756
  context_length,
757
  batch_size,
758
+ state,
759
+ causal_mask
760
  )
761
  #print(f"\n[DEBUG] After initial prefill - current_pos: {current_pos}")
762
 
 
790
  new_size, # Prefill the entire shifted content
791
  context_length,
792
  batch_size,
793
+ state,
794
+ causal_mask
795
  )
796
 
797
  # Start generating from the next position
 
810
  input_ids,
811
  pos,
812
  context_length,
813
+ state,
814
+ causal_mask
815
  )
816
 
817
  # Add token
 
914
  # Create unified state once
915
  state = create_unified_state(ffn_models, metadata['context_length'])
916
 
917
+ # Initialize causal mask once
918
+ causal_mask = initialize_causal_mask(metadata['context_length'])
919
+
920
  # Warmup runs to prevent Python GIL issues with CoreML !
921
  if not args.nw:
922
  for i in range(2):
 
927
  tokenizer=tokenizer,
928
  metadata=metadata,
929
  state=state, # Pass the state
930
+ causal_mask=causal_mask, # Pass the causal mask
931
  warmup=True,
932
  auto_prompt="who are you?"
933
  )
 
940
  tokenizer=tokenizer,
941
  metadata=metadata,
942
  state=state, # Pass the state
943
+ causal_mask=causal_mask, # Pass the causal mask
944
  warmup=False,
945
  auto_prompt=args.prompt
946
  )