fffiloni commited on
Commit
c576d4c
·
verified ·
1 Parent(s): 717456e

Make sure to always load the highest trained safetensors file for all cases

Browse files

### **Summary of Changes in `get_huggingface_safetensors` Function**

This update improves how we select the correct `.safetensors` file from a Hugging Face repository, ensuring the best possible file is chosen in different scenarios.

#### **Changes & Fixes:**
1. **Prioritizing full models:**
- If a `.safetensors` file **without `_000...` (step numbers)** exists, we return it **immediately** (previously, we might have returned a different file).

2. **Selecting the most trained file when only step-based models exist:**
- If **all available `.safetensors` files contain step numbers**, we now return the one with the **highest step count** (previously, this case was not handled explicitly).

3. **Handling repositories with multiple `.safetensors` files but no recognizable step count:**
- If multiple `.safetensors` files exist **but none follow the `_000...` pattern**, we now **return the last one** in the list (previously, selection was arbitrary).

4. **Maintained behavior for images and error handling:**
- Image selection still prioritizes URLs from the model card, falling back to available image files in the repo.
- Improved exception handling and error messages.

#### **Example Scenarios & Expected Behavior:**

| Scenario | Selected File |
|----------|--------------|
| **Repo contains a full model (`carbo-800.safetensors`)** | `carbo-800.safetensors` (immediate return) |
| **Repo contains only step-based files (`_000...`)** | The one with the **highest step count** |
| **Repo contains multiple `.safetensors` files but none follow `_000...`** | The **last one** in the list |
| **Mixed case (some with `_000...`, some without)** | The full model (if available), otherwise the highest step count |

This ensures we always return the **most appropriate model file** based on structure and training progress.

Files changed (1) hide show
  1. app.py +68 -24
app.py CHANGED
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
12
  import copy
13
  import random
14
  import time
 
15
 
16
  # Load LoRAs from JSON file
17
  with open('loras.json', 'r') as f:
@@ -172,30 +173,73 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
172
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
173
 
174
  def get_huggingface_safetensors(link):
175
- split_link = link.split("/")
176
- if(len(split_link) == 2):
177
- model_card = ModelCard.load(link)
178
- base_model = model_card.data.get("base_model")
179
- print(base_model)
180
- if((base_model != "black-forest-labs/FLUX.1-dev") and (base_model != "black-forest-labs/FLUX.1-schnell")):
181
- raise Exception("Not a FLUX LoRA!")
182
- image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
183
- trigger_word = model_card.data.get("instance_prompt", "")
184
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
185
- fs = HfFileSystem()
186
- try:
187
- list_of_files = fs.ls(link, detail=False)
188
- for file in list_of_files:
189
- if(file.endswith(".safetensors")):
190
- safetensors_name = file.split("/")[-1]
191
- if (not image_url and file.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))):
192
- image_elements = file.split("/")
193
- image_url = f"https://huggingface.co/{link}/resolve/main/{image_elements[-1]}"
194
- except Exception as e:
195
- print(e)
196
- gr.Warning(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
197
- raise Exception(f"You didn't include a link neither a valid Hugging Face repository with a *.safetensors LoRA")
198
- return split_link[1], link, safetensors_name, trigger_word, image_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def check_custom_model(link):
201
  if(link.startswith("https://")):
 
12
  import copy
13
  import random
14
  import time
15
+ import re
16
 
17
  # Load LoRAs from JSON file
18
  with open('loras.json', 'r') as f:
 
173
  yield final_image, seed, gr.update(value=progress_bar, visible=False)
174
 
175
  def get_huggingface_safetensors(link):
176
+ split_link = link.split("/")
177
+ if len(split_link) != 2:
178
+ raise Exception("Invalid Hugging Face repository link format.")
179
+
180
+ # Load model card
181
+ model_card = ModelCard.load(link)
182
+ base_model = model_card.data.get("base_model")
183
+ print(base_model)
184
+
185
+ # Validate model type
186
+ if base_model not in {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}:
187
+ raise Exception("Not a FLUX LoRA!")
188
+
189
+ # Extract image and trigger word
190
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
191
+ trigger_word = model_card.data.get("instance_prompt", "")
192
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
193
+
194
+ # Initialize Hugging Face file system
195
+ fs = HfFileSystem()
196
+ try:
197
+ list_of_files = fs.ls(link, detail=False)
198
+
199
+ # Initialize variables for safetensors selection
200
+ safetensors_name = None
201
+ highest_trained_file = None
202
+ highest_steps = -1
203
+ last_safetensors_file = None
204
+ step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
205
+
206
+ for file in list_of_files:
207
+ filename = file.split("/")[-1]
208
+
209
+ # Select safetensors file
210
+ if filename.endswith(".safetensors"):
211
+ last_safetensors_file = filename # Track last encountered file
212
+
213
+ match = step_pattern.search(filename)
214
+ if not match:
215
+ # Found a full model without step numbers, return immediately
216
+ safetensors_name = filename
217
+ break
218
+ else:
219
+ # Extract step count and track highest
220
+ steps = int(match.group().lstrip("_"))
221
+ if steps > highest_steps:
222
+ highest_trained_file = filename
223
+ highest_steps = steps
224
+
225
+ # Select an image file if not found in model card
226
+ if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
227
+ image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
228
+
229
+ # If no full model found, fall back to the most trained safetensors file
230
+ if not safetensors_name:
231
+ safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
232
+
233
+ # If still no safetensors file found, raise an exception
234
+ if not safetensors_name:
235
+ raise Exception("No valid *.safetensors file found in the repository.")
236
+
237
+ except Exception as e:
238
+ print(e)
239
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
240
+
241
+ return split_link[1], link, safetensors_name, trigger_word, image_url
242
+
243
 
244
  def check_custom_model(link):
245
  if(link.startswith("https://")):