Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -355,10 +355,10 @@ def train_model(epochs):
|
|
355 |
print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
|
356 |
print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
357 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
358 |
-
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
359 |
|
360 |
# Combine losses
|
361 |
-
total_loss = img_loss
|
362 |
total_loss.backward()
|
363 |
|
364 |
# Optimizer step
|
|
|
355 |
print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
|
356 |
print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
357 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
358 |
+
#prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
359 |
|
360 |
# Combine losses
|
361 |
+
total_loss = img_loss #+ 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
362 |
total_loss.backward()
|
363 |
|
364 |
# Optimizer step
|