ErenalpCet commited on
Commit
c1d70a2
·
verified ·
1 Parent(s): d50ecff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -27,30 +27,48 @@ MAX_GPU_MEMORY = "40GiB" # A100 memory allocation
27
  # --- Model Loading ---
28
  @GPU(memory=40) # ****** THIS DECORATOR IS ESSENTIAL FOR SPACES STARTUP ******
29
  def load_model():
30
- """Load the LLM model optimized for A100 GPU."""
31
- print(f"Attempting to load model: {MODEL_ID}")
32
  try:
33
- # Configure quantization
34
- quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
 
 
 
 
35
 
 
 
36
  pipe = pipeline(
37
  "text-generation",
38
  model=MODEL_ID,
 
 
 
39
  torch_dtype=torch.bfloat16,
40
- device_map="auto", # Relies on accelerate
41
  model_kwargs={
42
  "quantization_config": quantization_config,
43
  "use_cache": True,
44
- # "max_memory": {0: MAX_GPU_MEMORY} # Often handled by device_map="auto"
45
  }
46
  )
47
- print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device}")
48
  return pipe
49
  except Exception as e:
50
- print(f"FATAL Error loading model '{MODEL_ID}': {e}")
51
- # Raise the error to potentially get more detailed logs in Spaces
52
  raise e
53
- # return None # Returning None might hide the root cause in Spaces logs
 
 
 
 
 
 
 
 
54
 
55
  # --- Web Search ---
56
  # (Keep search_person, create_synthetic_profile, extract_text_from_search_results as before)
 
27
  # --- Model Loading ---
28
  @GPU(memory=40) # ****** THIS DECORATOR IS ESSENTIAL FOR SPACES STARTUP ******
29
  def load_model():
30
+ """Load the LLM model optimized for A100 GPU using 4-bit quantization."""
31
+ print(f"Attempting to load model: {MODEL_ID} with 4-bit quantization")
32
  try:
33
+ # Configure quantization for 4-bit
34
+ quantization_config = BitsAndBytesConfig(
35
+ load_in_4bit=True,
36
+ bnb_4bit_quant_type="nf4", # NF4 is often recommended
37
+ bnb_4bit_compute_dtype=torch.bfloat16, # Use bfloat16 for compute
38
+ bnb_4bit_use_double_quant=True, # Use double quantization to save more memory
39
+ )
40
 
41
+ # Device map will handle placing layers, relying on accelerate
42
+ # No need to explicitly set max_memory when using device_map="auto" typically
43
  pipe = pipeline(
44
  "text-generation",
45
  model=MODEL_ID,
46
+ # Note: torch_dtype is sometimes ignored when quantization_config is used,
47
+ # but specifying compute_dtype in BitsAndBytesConfig is key.
48
+ # Keep torch_dtype=torch.bfloat16 here for consistency if needed by other parts.
49
  torch_dtype=torch.bfloat16,
50
+ device_map="auto", # Let accelerate handle layer placement
51
  model_kwargs={
52
  "quantization_config": quantization_config,
53
  "use_cache": True,
54
+ # "trust_remote_code=True" # Add if model requires it (check model card)
55
  }
56
  )
57
+ print(f"Model {MODEL_ID} loaded successfully on device: {pipe.device} (using 4-bit quantization)")
58
  return pipe
59
  except Exception as e:
60
+ print(f"FATAL Error loading model '{MODEL_ID}' (check memory/config): {e}")
61
+ # Raise the error to ensure it's visible in Spaces logs
62
  raise e
63
+
64
+ # --- REST OF THE CODE REMAINS THE SAME ---
65
+ # (search_person, create_synthetic_profile, extract_text_from_search_results,
66
+ # parse_llm_output, generate_enhanced_persona, generate_system_prompt_with_llm,
67
+ # generate_response, PersonaChat class, create_interface function, __main__ block)
68
+ # ... include the rest of the Python code from the previous correct version here ...
69
+
70
+ # Make sure the rest of your app.py file follows this modified load_model function.
71
+ # Keep all other functions and the Gradio interface definition as they were.
72
 
73
  # --- Web Search ---
74
  # (Keep search_person, create_synthetic_profile, extract_text_from_search_results as before)