K00B404 commited on
Commit
833a3af
·
verified ·
1 Parent(s): 7976967

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -14,6 +14,7 @@ import numpy as np
14
  from small_256_model import UNet as small_UNet
15
  from big_1024_model import UNet as big_UNet
16
  from CLIP import load as load_clip
 
17
 
18
  # Device configuration
19
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -44,10 +45,10 @@ def load_model():
44
  model.to(device)
45
  model.eval()
46
  global_model = model
47
- print("Model loaded successfully!")
48
  return model
49
  except Exception as e:
50
- print(f"Error loading model: {e}")
51
  model = big_UNet().to(device) if big else small_UNet().to(device)
52
  global_model = model
53
  return model
@@ -253,10 +254,12 @@ checkpoint = torch.load(name)
253
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
254
  model.load_state_dict(checkpoint['model_state_dict'])
255
  model.eval()
 
256
 
257
- Model Architecture
258
 
259
  {str(self.model)} """
 
260
  # Save and upload README
261
  with open("README.md", "w") as f:
262
  f.write(model_card)
@@ -352,8 +355,8 @@ def train_model(epochs):
352
 
353
  # Compute image reconstruction loss
354
  img_loss = criterion(output, original)
355
- print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
356
- print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
357
  # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
358
  #prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
359
 
@@ -366,7 +369,7 @@ def train_model(epochs):
366
 
367
  if i % 10 == 0:
368
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
369
- print(status)
370
  output_text.append(status)
371
 
372
  # Push model to Hugging Face Hub at the end of each epoch
@@ -422,7 +425,7 @@ def train_model_old(epochs):
422
 
423
  if i % 10 == 0:
424
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
425
- print(status)
426
  output_text.append(status)
427
 
428
  # Push model to Hugging Face Hub at the end of each epoch
 
14
  from small_256_model import UNet as small_UNet
15
  from big_1024_model import UNet as big_UNet
16
  from CLIP import load as load_clip
17
+ from rich import print as rp
18
 
19
  # Device configuration
20
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
45
  model.to(device)
46
  model.eval()
47
  global_model = model
48
+ rp("Model loaded successfully!")
49
  return model
50
  except Exception as e:
51
+ rp(f"Error loading model: {e}")
52
  model = big_UNet().to(device) if big else small_UNet().to(device)
53
  global_model = model
54
  return model
 
254
  model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
255
  model.load_state_dict(checkpoint['model_state_dict'])
256
  model.eval()
257
+ ```
258
 
259
+ ## Model Architecture
260
 
261
  {str(self.model)} """
262
+ rp(model_card)
263
  # Save and upload README
264
  with open("README.md", "w") as f:
265
  f.write(model_card)
 
355
 
356
  # Compute image reconstruction loss
357
  img_loss = criterion(output, original)
358
+ rp(f"Image {i} Loss:{imag_loss}")
359
+ #print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
360
  # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
361
  #prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
362
 
 
369
 
370
  if i % 10 == 0:
371
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
372
+ rp(status)
373
  output_text.append(status)
374
 
375
  # Push model to Hugging Face Hub at the end of each epoch
 
425
 
426
  if i % 10 == 0:
427
  status = f"Epoch [{epoch}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {total_loss.item():.8f}"
428
+ rp(status)
429
  output_text.append(status)
430
 
431
  # Push model to Hugging Face Hub at the end of each epoch