amal90888 commited on
Commit
4496058
·
verified ·
1 Parent(s): 21a1605

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -1
README.md CHANGED
@@ -63,8 +63,35 @@ This UNet-based segmentation model is designed for **automated segmentation of C
63
  ### **1️⃣ Load the Model**
64
  #### **TensorFlow/Keras**
65
  ```python
 
66
  from huggingface_hub import hf_hub_download
67
  from tensorflow.keras.models import load_model
 
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")
70
- unet = load_model(model_path)
 
 
 
 
 
 
 
 
 
63
  ### **1️⃣ Load the Model**
64
  #### **TensorFlow/Keras**
65
  ```python
66
+ import os
67
  from huggingface_hub import hf_hub_download
68
  from tensorflow.keras.models import load_model
69
+ from keras.saving import register_keras_serializable
70
+ import tensorflow.keras.backend as K
71
 
72
+ # ✅ Set Keras backend (optional)
73
+ os.environ["KERAS_BACKEND"] = "jax"
74
+
75
+ # ✅ Register and define missing functions
76
+ @register_keras_serializable()
77
+ def dice_coef(y_true, y_pred, smooth=1e-6):
78
+ y_true_f = K.flatten(y_true)
79
+ y_pred_f = K.flatten(y_pred)
80
+ intersection = K.sum(y_true_f * y_pred_f)
81
+ return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
82
+
83
+ @register_keras_serializable()
84
+ def gl_sl(*args, **kwargs):
85
+ pass # Placeholder function (update if needed)
86
+
87
+ # ✅ Download the model from Hugging Face
88
  model_path = hf_hub_download(repo_id="amal90888/unet-segmentation-model", filename="unet_model.keras")
89
+
90
+ # ✅ Load the model with registered custom objects
91
+ unet = load_model(model_path, custom_objects={"dice_coef": dice_coef, "gl_sl": gl_sl}, compile=False)
92
+
93
+ # ✅ Recompile with fresh optimizer and correct loss function
94
+ from tensorflow.keras.optimizers import Adam
95
+ unet.compile(optimizer=Adam(learning_rate=1e-4), loss="binary_crossentropy", metrics=["accuracy", dice_coef])
96
+
97
+ print("✅ Model loaded and recompiled successfully!")