K00B404 commited on
Commit
447706a
·
verified ·
1 Parent(s): 84e482a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -355,10 +355,10 @@ def train_model(epochs):
355
  print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
356
  print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
357
  # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
358
- prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
359
 
360
  # Combine losses
361
- total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
362
  total_loss.backward()
363
 
364
  # Optimizer step
 
355
  print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
356
  print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
357
  # Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
358
+ #prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
359
 
360
  # Combine losses
361
+ total_loss = img_loss #+ 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
362
  total_loss.backward()
363
 
364
  # Optimizer step