Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -342,14 +342,14 @@ def train_model(epochs):
|
|
342 |
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
343 |
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
|
344 |
original, target = original.to(device), target.to(device)
|
345 |
-
original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
|
346 |
-
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
|
347 |
|
348 |
optimizer.zero_grad()
|
349 |
-
|
350 |
# Forward pass through the model
|
351 |
output = model(target)
|
352 |
-
|
353 |
# Compute image reconstruction loss
|
354 |
img_loss = criterion(output, original)
|
355 |
|
@@ -359,7 +359,7 @@ def train_model(epochs):
|
|
359 |
# Combine losses
|
360 |
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
361 |
total_loss.backward()
|
362 |
-
|
363 |
# Optimizer step
|
364 |
optimizer.step()
|
365 |
|
|
|
342 |
for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
|
343 |
# Move images and prompt embeddings to the appropriate device (CPU or GPU)
|
344 |
original, target = original.to(device), target.to(device)
|
345 |
+
original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float
|
346 |
+
enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float
|
347 |
|
348 |
optimizer.zero_grad()
|
349 |
+
|
350 |
# Forward pass through the model
|
351 |
output = model(target)
|
352 |
+
|
353 |
# Compute image reconstruction loss
|
354 |
img_loss = criterion(output, original)
|
355 |
|
|
|
359 |
# Combine losses
|
360 |
total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
|
361 |
total_loss.backward()
|
362 |
+
|
363 |
# Optimizer step
|
364 |
optimizer.step()
|
365 |
|