Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ import numpy as np
|
|
14 |
from small_256_model import UNet as small_UNet
|
15 |
from big_1024_model import UNet as big_UNet
|
16 |
from CLIP import load as load_clip
|
|
|
17 |
|
18 |
# Device configuration
|
19 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -44,10 +45,10 @@ def load_model():
|
|
44 |
model.to(device)
|
45 |
model.eval()
|
46 |
global_model = model
|
47 |
-
|
48 |
return model
|
49 |
except Exception as e:
|
50 |
-
|
51 |
model = big_UNet().to(device) if big else small_UNet().to(device)
|
52 |
global_model = model
|
53 |
return model
|
@@ -253,10 +254,12 @@ checkpoint = torch.load(name)
|
|
253 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
254 |
model.load_state_dict(checkpoint['model_state_dict'])
|
255 |
model.eval()
|
|
|
256 |
|
257 |
-
Model Architecture
|
258 |
|
259 |
{str(self.model)} """
|
|
|
260 |
# Save and upload README
|
261 |
with open("README.md", "w") as f:
|
262 |
f.write(model_card)
|
@@ -352,8 +355,8 @@ def train_model(epochs):
|
|
352 |
|
353 |
# Compute image reconstruction loss
|
354 |
img_loss = criterion(output, original)
|
355 |
-
|
356 |
-
print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
357 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
358 |
#prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
359 |
|
@@ -366,7 +369,7 @@ def train_model(epochs):
|
|
366 |
|
367 |
if i % 10 == 0:
|
368 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
369 |
-
|
370 |
output_text.append(status)
|
371 |
|
372 |
# Push model to Hugging Face Hub at the end of each epoch
|
@@ -422,7 +425,7 @@ def train_model_old(epochs):
|
|
422 |
|
423 |
if i % 10 == 0:
|
424 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
425 |
-
|
426 |
output_text.append(status)
|
427 |
|
428 |
# Push model to Hugging Face Hub at the end of each epoch
|
|
|
14 |
from small_256_model import UNet as small_UNet
|
15 |
from big_1024_model import UNet as big_UNet
|
16 |
from CLIP import load as load_clip
|
17 |
+
from rich import print as rp
|
18 |
|
19 |
# Device configuration
|
20 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
45 |
model.to(device)
|
46 |
model.eval()
|
47 |
global_model = model
|
48 |
+
rp("Model loaded successfully!")
|
49 |
return model
|
50 |
except Exception as e:
|
51 |
+
rp(f"Error loading model: {e}")
|
52 |
model = big_UNet().to(device) if big else small_UNet().to(device)
|
53 |
global_model = model
|
54 |
return model
|
|
|
254 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
255 |
model.load_state_dict(checkpoint['model_state_dict'])
|
256 |
model.eval()
|
257 |
+
```
|
258 |
|
259 |
+
## Model Architecture
|
260 |
|
261 |
{str(self.model)} """
|
262 |
+
rp(model_card)
|
263 |
# Save and upload README
|
264 |
with open("README.md", "w") as f:
|
265 |
f.write(model_card)
|
|
|
355 |
|
356 |
# Compute image reconstruction loss
|
357 |
img_loss = criterion(output, original)
|
358 |
+
rp(f"Image {i} Loss:{imag_loss}")
|
359 |
+
#print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
360 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
361 |
#prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
362 |
|
|
|
369 |
|
370 |
if i % 10 == 0:
|
371 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
372 |
+
rp(status)
|
373 |
output_text.append(status)
|
374 |
|
375 |
# Push model to Hugging Face Hub at the end of each epoch
|
|
|
425 |
|
426 |
if i % 10 == 0:
|
427 |
status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
|
428 |
+
rp(status)
|
429 |
output_text.append(status)
|
430 |
|
431 |
# Push model to Hugging Face Hub at the end of each epoch
|