Update README.md
Browse files
README.md
CHANGED
@@ -63,8 +63,35 @@ This UNet-based segmentation model is designed for **automated segmentation of C
|
|
63 |
### **1️⃣ Load the Model**
|
64 |
#### **TensorFlow/Keras**
|
65 |
```python
|
|
|
66 |
from huggingface_hub import hf_hub_download
|
67 |
from tensorflow.keras.models import load_model
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
### **1️⃣ Load the Model**
|
64 |
#### **TensorFlow/Keras**
|
65 |
```python
|
66 |
+
import os
|
67 |
from huggingface_hub import hf_hub_download
|
68 |
from tensorflow.keras.models import load_model
|
69 |
+
from keras.saving import register_keras_serializable
|
70 |
+
import tensorflow.keras.backend as K
|
71 |
|
72 |
+
# ✅ Set Keras backend (optional)
|
73 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
74 |
+
|
75 |
+
# ✅ Register and define missing functions
|
76 |
+
@register_keras_serializable()
|
77 |
+
def dice_coef(y_true, y_pred, smooth=1e-6):
|
78 |
+
y_true_f = K.flatten(y_true)
|
79 |
+
y_pred_f = K.flatten(y_pred)
|
80 |
+
intersection = K.sum(y_true_f * y_pred_f)
|
81 |
+
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
|
82 |
+
|
83 |
+
@register_keras_serializable()
|
84 |
+
def gl_sl(*args, **kwargs):
|
85 |
+
pass # Placeholder function (update if needed)
|
86 |
+
|
87 |
+
# ✅ Download the model from Hugging Face
|
88 |
model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")
|
89 |
+
|
90 |
+
# ✅ Load the model with registered custom objects
|
91 |
+
unet = load_model(model_path, custom_objects={"dice_coef": dice_coef, "gl_sl": gl_sl}, compile=False)
|
92 |
+
|
93 |
+
# ✅ Recompile with fresh optimizer and correct loss function
|
94 |
+
from tensorflow.keras.optimizers import Adam
|
95 |
+
unet.compile(optimizer=Adam(learning_rate=1e-4), loss="binary_crossentropy", metrics=["accuracy", dice_coef])
|
96 |
+
|
97 |
+
print("✅ Model loaded and recompiled successfully!")
|