Spaces:
Running
on
Zero
Make sure to always load the highest trained safetensors file for all cases
Browse files### **Summary of Changes in `get_huggingface_safetensors` Function**
This update improves how we select the correct `.safetensors` file from a Hugging Face repository, ensuring the best possible file is chosen in different scenarios.
#### **Changes & Fixes:**
1. **Prioritizing full models:**
- If a `.safetensors` file **without `_000...` (step numbers)** exists, we return it **immediately** (previously, we might have returned a different file).
2. **Selecting the most trained file when only step-based models exist:**
- If **all available `.safetensors` files contain step numbers**, we now return the one with the **highest step count** (previously, this case was not handled explicitly).
3. **Handling repositories with multiple `.safetensors` files but no recognizable step count:**
- If multiple `.safetensors` files exist **but none follow the `_000...` pattern**, we now **return the last one** in the list (previously, selection was arbitrary).
4. **Maintained behavior for images and error handling:**
- Image selection still prioritizes URLs from the model card, falling back to available image files in the repo.
- Improved exception handling and error messages.
#### **Example Scenarios & Expected Behavior:**
| Scenario | Selected File |
|----------|--------------|
| **Repo contains a full model (`carbo-800.safetensors`)** | `carbo-800.safetensors` (immediate return) |
| **Repo contains only step-based files (`_000...`)** | The one with the **highest step count** |
| **Repo contains multiple `.safetensors` files but none follow `_000...`** | The **last one** in the list |
| **Mixed case (some with `_000...`, some without)** | The full model (if available), otherwise the highest step count |
This ensures we always return the **most appropriate model file** based on structure and training progress.
@@ -12,6 +12,7 @@ from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_d
|
|
12 |
import copy
|
13 |
import random
|
14 |
import time
|
|
|
15 |
|
16 |
# Load LoRAs from JSON file
|
17 |
with open('loras.json', 'r') as f:
|
@@ -172,30 +173,73 @@ def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_ind
|
|
172 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
173 |
|
174 |
def get_huggingface_safetensors(link):
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
def check_custom_model(link):
|
201 |
if(link.startswith("https://")):
|
|
|
12 |
import copy
|
13 |
import random
|
14 |
import time
|
15 |
+
import re
|
16 |
|
17 |
# Load LoRAs from JSON file
|
18 |
with open('loras.json', 'r') as f:
|
|
|
173 |
yield final_image, seed, gr.update(value=progress_bar, visible=False)
|
174 |
|
175 |
def get_huggingface_safetensors(link):
|
176 |
+
split_link = link.split("/")
|
177 |
+
if len(split_link) != 2:
|
178 |
+
raise Exception("Invalid Hugging Face repository link format.")
|
179 |
+
|
180 |
+
# Load model card
|
181 |
+
model_card = ModelCard.load(link)
|
182 |
+
base_model = model_card.data.get("base_model")
|
183 |
+
print(base_model)
|
184 |
+
|
185 |
+
# Validate model type
|
186 |
+
if base_model not in {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}:
|
187 |
+
raise Exception("Not a FLUX LoRA!")
|
188 |
+
|
189 |
+
# Extract image and trigger word
|
190 |
+
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
|
191 |
+
trigger_word = model_card.data.get("instance_prompt", "")
|
192 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
|
193 |
+
|
194 |
+
# Initialize Hugging Face file system
|
195 |
+
fs = HfFileSystem()
|
196 |
+
try:
|
197 |
+
list_of_files = fs.ls(link, detail=False)
|
198 |
+
|
199 |
+
# Initialize variables for safetensors selection
|
200 |
+
safetensors_name = None
|
201 |
+
highest_trained_file = None
|
202 |
+
highest_steps = -1
|
203 |
+
last_safetensors_file = None
|
204 |
+
step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
|
205 |
+
|
206 |
+
for file in list_of_files:
|
207 |
+
filename = file.split("/")[-1]
|
208 |
+
|
209 |
+
# Select safetensors file
|
210 |
+
if filename.endswith(".safetensors"):
|
211 |
+
last_safetensors_file = filename # Track last encountered file
|
212 |
+
|
213 |
+
match = step_pattern.search(filename)
|
214 |
+
if not match:
|
215 |
+
# Found a full model without step numbers, return immediately
|
216 |
+
safetensors_name = filename
|
217 |
+
break
|
218 |
+
else:
|
219 |
+
# Extract step count and track highest
|
220 |
+
steps = int(match.group().lstrip("_"))
|
221 |
+
if steps > highest_steps:
|
222 |
+
highest_trained_file = filename
|
223 |
+
highest_steps = steps
|
224 |
+
|
225 |
+
# Select an image file if not found in model card
|
226 |
+
if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
|
227 |
+
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
|
228 |
+
|
229 |
+
# If no full model found, fall back to the most trained safetensors file
|
230 |
+
if not safetensors_name:
|
231 |
+
safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
|
232 |
+
|
233 |
+
# If still no safetensors file found, raise an exception
|
234 |
+
if not safetensors_name:
|
235 |
+
raise Exception("No valid *.safetensors file found in the repository.")
|
236 |
+
|
237 |
+
except Exception as e:
|
238 |
+
print(e)
|
239 |
+
raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
|
240 |
+
|
241 |
+
return split_link[1], link, safetensors_name, trigger_word, image_url
|
242 |
+
|
243 |
|
244 |
def check_custom_model(link):
|
245 |
if(link.startswith("https://")):
|