Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -352,7 +352,8 @@ def train_model(epochs):
|
|
352 |
|
353 |
# Compute image reconstruction loss
|
354 |
img_loss = criterion(output, original)
|
355 |
-
|
|
|
356 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
357 |
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
358 |
|
|
|
352 |
|
353 |
# Compute image reconstruction loss
|
354 |
img_loss = criterion(output, original)
|
355 |
+
print(f"Original Prompt Tokens Type: {original_prompt_tokens.dtype}, Shape: {original_prompt_tokens.shape}")
|
356 |
+
print(f"Enhanced Prompt Tokens Type: {enhanced_prompt_tokens.dtype}, Shape: {enhanced_prompt_tokens.shape}")
|
357 |
# Compute prompt guidance loss (L2 norm between original and enhanced prompt embeddings)
|
358 |
prompt_loss = torch.norm(original_prompt_tokens - enhanced_prompt_tokens, p=2)
|
359 |
|