K00B404 commited on
Commit
54ff88c
·
verified ·
1 Parent(s): 27027f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -342,14 +342,14 @@ def train_model(epochs):
342
  for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
343
  # Move images and prompt embeddings to the appropriate device (CPU or GPU)
344
  original, target = original.to(device), target.to(device)
345
- original_prompt_tokens = original_prompt_tokens.input_ids.to(device)
346
- enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device)
347
 
348
  optimizer.zero_grad()
349
-
350
  # Forward pass through the model
351
  output = model(target)
352
-
353
  # Compute image reconstruction loss
354
  img_loss = criterion(output, original)
355
 
@@ -359,7 +359,7 @@ def train_model(epochs):
359
  # Combine losses
360
  total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
361
  total_loss.backward()
362
-
363
  # Optimizer step
364
  optimizer.step()
365
 
 
342
  for i, (original, target, original_prompt_tokens, enhanced_prompt_tokens) in enumerate(dataloader):
343
  # Move images and prompt embeddings to the appropriate device (CPU or GPU)
344
  original, target = original.to(device), target.to(device)
345
+ original_prompt_tokens = original_prompt_tokens.input_ids.to(device).float() # Convert to float
346
+ enhanced_prompt_tokens = enhanced_prompt_tokens.input_ids.to(device).float() # Convert to float
347
 
348
  optimizer.zero_grad()
349
+
350
  # Forward pass through the model
351
  output = model(target)
352
+
353
  # Compute image reconstruction loss
354
  img_loss = criterion(output, original)
355
 
 
359
  # Combine losses
360
  total_loss = img_loss + 0.1 * prompt_loss # Weight the prompt guidance loss with 0.1 to balance
361
  total_loss.backward()
362
+
363
  # Optimizer step
364
  optimizer.step()
365