Added support for using gallery as an image input.
Browse filesYou can now upload multiple images at once. Due to the slow execution in the online environment, the default image limit is set to 5. You can modify this limit by setting input_images_limit. If set to -1, there will be no limit.
ex.
python app.py --input_images_limit -1
- .gitignore +1 -0
- README.md +1 -1
- app.py +551 -385
- requirements.txt +1 -1
- utils/dataops.py +1 -5
- webui.bat +1 -0
.gitignore
CHANGED
@@ -5,6 +5,7 @@ results/*
|
|
5 |
tb_logger/*
|
6 |
wandb/*
|
7 |
tmp/*
|
|
|
8 |
|
9 |
version.py
|
10 |
|
|
|
5 |
tb_logger/*
|
6 |
wandb/*
|
7 |
tmp/*
|
8 |
+
/venv
|
9 |
|
10 |
version.py
|
11 |
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
|
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: blue
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.16.0
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
app.py
CHANGED
@@ -8,11 +8,14 @@ import torch
|
|
8 |
import traceback
|
9 |
import math
|
10 |
import time
|
|
|
|
|
11 |
from collections import defaultdict
|
12 |
from facexlib.utils.misc import download_from_url
|
13 |
from basicsr.utils.realesrganer import RealESRGANer
|
|
|
14 |
|
15 |
-
|
16 |
# Define URLs and their corresponding local storage paths
|
17 |
face_models = {
|
18 |
"GFPGANv1.4.pth" : ["https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
@@ -457,399 +460,420 @@ typed_upscale_models = {get_model_type(key): value[0] for key, value in upscale_
|
|
457 |
|
458 |
|
459 |
class Upscale:
|
460 |
-
def
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
try:
|
464 |
-
if not
|
465 |
raise ValueError("Invalid parameter setting")
|
|
|
|
|
466 |
|
467 |
timer = Timer() # Create a timer
|
468 |
self.scale = scale
|
469 |
-
self.img_name = os.path.basename(str(img))
|
470 |
-
self.basename, self.extension = os.path.splitext(self.img_name)
|
471 |
-
|
472 |
-
img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
|
473 |
-
|
474 |
-
self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
|
475 |
-
if len(img.shape) == 2: # for gray inputs
|
476 |
-
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
477 |
-
|
478 |
-
self.h_input, self.w_input = img.shape[0:2]
|
479 |
-
self.realesrganer = None
|
480 |
|
481 |
-
modelInUse = ""
|
482 |
-
upscale_type = None
|
483 |
-
is_auto_split_upscale = True
|
484 |
if upscale_model:
|
485 |
-
|
486 |
-
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
|
487 |
-
modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
488 |
-
|
489 |
-
self.netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
|
490 |
-
model = None
|
491 |
-
half = True if torch.cuda.is_available() else False
|
492 |
-
if upscale_type:
|
493 |
-
# The values of the following hyperparameters are based on the research findings of the Spandrel project.
|
494 |
-
# https://github.com/chaiNNer-org/spandrel/tree/main/libs/spandrel/spandrel/architectures
|
495 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
496 |
-
loadnet = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
497 |
-
if 'params_ema' in loadnet or 'params' in loadnet:
|
498 |
-
loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
|
499 |
-
|
500 |
-
if upscale_type == "SRVGG":
|
501 |
-
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
502 |
-
body_max_num = self.find_max_numbers(loadnet, "body")
|
503 |
-
num_feat = loadnet["body.0.weight"].shape[0]
|
504 |
-
num_in_ch = loadnet["body.0.weight"].shape[1]
|
505 |
-
num_conv = body_max_num // 2 - 1
|
506 |
-
model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=self.netscale, act_type='prelu')
|
507 |
-
elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
|
508 |
-
if upscale_type == "RRDB":
|
509 |
-
num_block = self.find_max_numbers(loadnet, "body") + 1
|
510 |
-
num_feat = loadnet["conv_first.weight"].shape[0]
|
511 |
-
else:
|
512 |
-
num_block = self.find_max_numbers(loadnet, "model.1.sub")
|
513 |
-
num_feat = loadnet["model.0.weight"].shape[0]
|
514 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=self.netscale, is_real_esrgan=upscale_type == "RRDB")
|
515 |
-
elif upscale_type == "DAT":
|
516 |
-
from basicsr.archs.dat_arch import DAT
|
517 |
-
half = False
|
518 |
-
|
519 |
-
in_chans = loadnet["conv_first.weight"].shape[1]
|
520 |
-
embed_dim = loadnet["conv_first.weight"].shape[0]
|
521 |
-
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
522 |
-
depth = [6] * num_layers
|
523 |
-
num_heads = [6] * num_layers
|
524 |
-
for i in range(num_layers):
|
525 |
-
depth[i] = self.find_max_numbers(loadnet, f"layers.{i}.blocks") + 1
|
526 |
-
num_heads[i] = loadnet[f"layers.{i}.blocks.1.attn.temperature"].shape[0] if depth[i] >= 2 else \
|
527 |
-
loadnet[f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight"].shape[0] * 2
|
528 |
-
|
529 |
-
upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else "pixelshuffledirect"
|
530 |
-
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
|
531 |
-
qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in loadnet
|
532 |
-
expansion_factor = float(loadnet["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim)
|
533 |
-
|
534 |
-
img_size = 64
|
535 |
-
if "layers.0.blocks.2.attn.attn_mask_0" in loadnet:
|
536 |
-
attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = loadnet["layers.0.blocks.2.attn.attn_mask_0"].shape
|
537 |
-
img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
|
538 |
-
|
539 |
-
split_size = [2, 4]
|
540 |
-
if "layers.0.blocks.0.attn.attns.0.rpe_biases" in loadnet:
|
541 |
-
split_sizes = loadnet["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
|
542 |
-
split_size = [int(x) for x in split_sizes]
|
543 |
-
|
544 |
-
model = DAT(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, split_size=split_size, depth=depth, num_heads=num_heads, expansion_factor=expansion_factor,
|
545 |
-
qkv_bias=qkv_bias, resi_connection=resi_connection, upsampler=upsampler, upscale=self.netscale)
|
546 |
-
elif upscale_type == "HAT":
|
547 |
-
half = False
|
548 |
-
from basicsr.archs.hat_arch import HAT
|
549 |
-
in_chans = loadnet["conv_first.weight"].shape[1]
|
550 |
-
embed_dim = loadnet["conv_first.weight"].shape[0]
|
551 |
-
window_size = int(math.sqrt(loadnet["relative_position_index_SA"].shape[0]))
|
552 |
-
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
553 |
-
depths = [6] * num_layers
|
554 |
-
num_heads = [6] * num_layers
|
555 |
-
for i in range(num_layers):
|
556 |
-
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
|
557 |
-
num_heads[i] = loadnet[f"layers.{i}.residual_group.overlap_attn.relative_position_bias_table"].shape[1]
|
558 |
-
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "identity"
|
559 |
-
|
560 |
-
compress_ratio = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.0.weight"].shape[0],)
|
561 |
-
squeeze_factor = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.3.attention.1.weight"].shape[0],)
|
562 |
-
|
563 |
-
qkv_bias = "layers.0.residual_group.blocks.0.attn.qkv.bias" in loadnet
|
564 |
-
patch_norm = "patch_embed.norm.weight" in loadnet
|
565 |
-
ape = "absolute_pos_embed" in loadnet
|
566 |
-
|
567 |
-
mlp_hidden_dim = int(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0])
|
568 |
-
mlp_ratio = mlp_hidden_dim / embed_dim
|
569 |
-
upsampler = "pixelshuffle"
|
570 |
-
|
571 |
-
model = HAT(img_size=64, patch_size=1, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
572 |
-
squeeze_factor=squeeze_factor, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm,
|
573 |
-
upsampler=upsampler, resi_connection=resi_connection, upscale=self.netscale,)
|
574 |
-
elif "RealPLKSR" in upscale_type:
|
575 |
-
from basicsr.archs.realplksr_arch import realplksr
|
576 |
-
half = False if "RealPLSKR" in upscale_model else half
|
577 |
-
use_ea = "feats.1.attn.f.0.weight" in loadnet
|
578 |
-
dim = loadnet["feats.0.weight"].shape[0]
|
579 |
-
num_feats = self.find_max_numbers(loadnet, "feats") + 1
|
580 |
-
n_blocks = num_feats - 3
|
581 |
-
kernel_size = loadnet["feats.1.lk.conv.weight"].shape[2]
|
582 |
-
split_ratio = loadnet["feats.1.lk.conv.weight"].shape[0] / dim
|
583 |
-
use_dysample = "to_img.init_pos" in loadnet
|
584 |
-
|
585 |
-
model = realplksr(upscaling_factor=self.netscale, dim=dim, n_blocks=n_blocks, kernel_size=kernel_size, split_ratio=split_ratio, use_ea=use_ea, dysample=use_dysample)
|
586 |
-
elif upscale_type == "DRCT":
|
587 |
-
half = False
|
588 |
-
from basicsr.archs.DRCT_arch import DRCT
|
589 |
-
|
590 |
-
in_chans = loadnet["conv_first.weight"].shape[1]
|
591 |
-
embed_dim = loadnet["conv_first.weight"].shape[0]
|
592 |
-
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
593 |
-
depths = (6,) * num_layers
|
594 |
-
num_heads = []
|
595 |
-
for i in range(num_layers):
|
596 |
-
num_heads.append(loadnet[f"layers.{i}.swin1.attn.relative_position_bias_table"].shape[1])
|
597 |
-
|
598 |
-
mlp_ratio = loadnet["layers.0.swin1.mlp.fc1.weight"].shape[0] / embed_dim
|
599 |
-
window_square = loadnet["layers.0.swin1.attn.relative_position_bias_table"].shape[0]
|
600 |
-
window_size = (math.isqrt(window_square) + 1) // 2
|
601 |
-
upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else ""
|
602 |
-
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else ""
|
603 |
-
qkv_bias = "layers.0.swin1.attn.qkv.bias" in loadnet
|
604 |
-
gc_adjust1 = loadnet["layers.0.adjust1.weight"].shape[0]
|
605 |
-
patch_norm = "patch_embed.norm.weight" in loadnet
|
606 |
-
ape = "absolute_pos_embed" in loadnet
|
607 |
-
|
608 |
-
model = DRCT(in_chans=in_chans, img_size= 64, window_size=window_size, compress_ratio=3,squeeze_factor=30,
|
609 |
-
conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths=depths, embed_dim=embed_dim, num_heads=num_heads,
|
610 |
-
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm, use_checkpoint=False,
|
611 |
-
upscale=self.netscale, upsampler=upsampler, resi_connection=resi_connection, gc =gc_adjust1,)
|
612 |
-
elif upscale_type == "ATD":
|
613 |
-
half = False
|
614 |
-
from basicsr.archs.atd_arch import ATD
|
615 |
-
in_chans = loadnet["conv_first.weight"].shape[1]
|
616 |
-
embed_dim = loadnet["conv_first.weight"].shape[0]
|
617 |
-
window_size = math.isqrt(loadnet["relative_position_index_SA"].shape[0])
|
618 |
-
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
619 |
-
depths = [6] * num_layers
|
620 |
-
num_heads = [6] * num_layers
|
621 |
-
for i in range(num_layers):
|
622 |
-
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.layers") + 1
|
623 |
-
num_heads[i] = loadnet[f"layers.{i}.residual_group.layers.0.attn_win.relative_position_bias_table"].shape[1]
|
624 |
-
num_tokens = loadnet["layers.0.residual_group.layers.0.attn_atd.scale"].shape[0]
|
625 |
-
reducted_dim = loadnet["layers.0.residual_group.layers.0.attn_atd.wq.weight"].shape[0]
|
626 |
-
convffn_kernel_size = loadnet["layers.0.residual_group.layers.0.convffn.dwconv.depthwise_conv.0.weight"].shape[2]
|
627 |
-
mlp_ratio = (loadnet["layers.0.residual_group.layers.0.convffn.fc1.weight"].shape[0] / embed_dim)
|
628 |
-
qkv_bias = "layers.0.residual_group.layers.0.wqkv.bias" in loadnet
|
629 |
-
ape = "absolute_pos_embed" in loadnet
|
630 |
-
patch_norm = "patch_embed.norm.weight" in loadnet
|
631 |
-
resi_connection = "1conv" if "layers.0.conv.weight" in loadnet else "3conv"
|
632 |
-
|
633 |
-
if "conv_up1.weight" in loadnet:
|
634 |
-
upsampler = "nearest+conv"
|
635 |
-
elif "conv_before_upsample.0.weight" in loadnet:
|
636 |
-
upsampler = "pixelshuffle"
|
637 |
-
elif "conv_last.weight" in loadnet:
|
638 |
-
upsampler = ""
|
639 |
-
else:
|
640 |
-
upsampler = "pixelshuffledirect"
|
641 |
-
|
642 |
-
is_light = upsampler == "pixelshuffledirect" and embed_dim == 48
|
643 |
-
category_size = 128 if is_light else 256
|
644 |
-
|
645 |
-
model = ATD(in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, category_size=category_size,
|
646 |
-
num_tokens=num_tokens, reducted_dim=reducted_dim, convffn_kernel_size=convffn_kernel_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape,
|
647 |
-
patch_norm=patch_norm, use_checkpoint=False, upscale=self.netscale, upsampler=upsampler, resi_connection='1conv',)
|
648 |
-
elif upscale_type == "MoSR":
|
649 |
-
from basicsr.archs.mosr_arch import mosr
|
650 |
-
n_block = self.find_max_numbers(loadnet, "gblocks") - 5
|
651 |
-
in_ch = loadnet["gblocks.0.weight"].shape[1]
|
652 |
-
out_ch = loadnet["upsampler.end_conv.weight"].shape[0] if "upsampler.init_pos" in loadnet else in_ch
|
653 |
-
dim = loadnet["gblocks.0.weight"].shape[0]
|
654 |
-
expansion_ratio = (loadnet["gblocks.1.fc1.weight"].shape[0] / loadnet["gblocks.1.fc1.weight"].shape[1]) / 2
|
655 |
-
conv_ratio = loadnet["gblocks.1.conv.weight"].shape[0] / dim
|
656 |
-
kernel_size = loadnet["gblocks.1.conv.weight"].shape[2]
|
657 |
-
upsampler = "dys" if "upsampler.init_pos" in loadnet else ("gps" if "upsampler.in_to_k.weight" in loadnet else "ps")
|
658 |
-
|
659 |
-
model = mosr(in_ch = in_ch, out_ch = out_ch, upscale = self.netscale, n_block = n_block, dim = dim,
|
660 |
-
upsampler = upsampler, kernel_size = kernel_size, expansion_ratio = expansion_ratio, conv_ratio = conv_ratio,)
|
661 |
-
elif upscale_type == "SRFormer":
|
662 |
-
half = False
|
663 |
-
from basicsr.archs.srformer_arch import SRFormer
|
664 |
-
in_chans = loadnet["conv_first.weight"].shape[1]
|
665 |
-
embed_dim = loadnet["conv_first.weight"].shape[0]
|
666 |
-
ape = "absolute_pos_embed" in loadnet
|
667 |
-
patch_norm = "patch_embed.norm.weight" in loadnet
|
668 |
-
qkv_bias = "layers.0.residual_group.blocks.0.attn.q.bias" in loadnet
|
669 |
-
mlp_ratio = float(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0] / embed_dim)
|
670 |
-
|
671 |
-
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
672 |
-
depths = [6] * num_layers
|
673 |
-
num_heads = [6] * num_layers
|
674 |
-
for i in range(num_layers):
|
675 |
-
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
|
676 |
-
num_heads[i] = loadnet[f"layers.{i}.residual_group.blocks.0.attn.relative_position_bias_table"].shape[1]
|
677 |
-
|
678 |
-
if "conv_hr.weight" in loadnet:
|
679 |
-
upsampler = "nearest+conv"
|
680 |
-
elif "conv_before_upsample.0.weight" in loadnet:
|
681 |
-
upsampler = "pixelshuffle"
|
682 |
-
elif "upsample.0.weight" in loadnet:
|
683 |
-
upsampler = "pixelshuffledirect"
|
684 |
-
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
|
685 |
-
|
686 |
-
window_size = int(math.sqrt(loadnet["layers.0.residual_group.blocks.0.attn.relative_position_bias_table"].shape[0])) + 1
|
687 |
-
|
688 |
-
if "layers.0.residual_group.blocks.1.attn_mask" in loadnet:
|
689 |
-
attn_mask_0 = loadnet["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
|
690 |
-
patches_resolution = int(math.sqrt(attn_mask_0) * window_size)
|
691 |
-
else:
|
692 |
-
patches_resolution = window_size
|
693 |
-
if ape:
|
694 |
-
pos_embed_value = loadnet.get("absolute_pos_embed", [None, None])[1]
|
695 |
-
if pos_embed_value:
|
696 |
-
patches_resolution = int(math.sqrt(pos_embed_value))
|
697 |
-
|
698 |
-
img_size = patches_resolution
|
699 |
-
if img_size % window_size != 0:
|
700 |
-
for nice_number in [512, 256, 128, 96, 64, 48, 32, 24, 16]:
|
701 |
-
if nice_number % window_size != 0:
|
702 |
-
nice_number += window_size - (nice_number % window_size)
|
703 |
-
if nice_number == patches_resolution:
|
704 |
-
img_size = nice_number
|
705 |
-
break
|
706 |
-
|
707 |
-
model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
|
708 |
-
qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=self.netscale, upsampler=upsampler, resi_connection=resi_connection,)
|
709 |
-
|
710 |
-
if model:
|
711 |
-
self.realesrganer = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
712 |
-
elif upscale_model:
|
713 |
-
import PIL
|
714 |
-
from image_gen_aux import UpscaleWithModel
|
715 |
-
class UpscaleWithModel_Gfpgan(UpscaleWithModel):
|
716 |
-
def cv2pil(self, image):
|
717 |
-
''' OpenCV type -> PIL type
|
718 |
-
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
719 |
-
'''
|
720 |
-
new_image = image.copy()
|
721 |
-
if new_image.ndim == 2: # Grayscale
|
722 |
-
pass
|
723 |
-
elif new_image.shape[2] == 3: # Color
|
724 |
-
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
|
725 |
-
elif new_image.shape[2] == 4: # Transparency
|
726 |
-
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
|
727 |
-
new_image = PIL.Image.fromarray(new_image)
|
728 |
-
return new_image
|
729 |
-
|
730 |
-
def pil2cv(self, image):
|
731 |
-
''' PIL type -> OpenCV type
|
732 |
-
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
733 |
-
'''
|
734 |
-
new_image = np.array(image, dtype=np.uint8)
|
735 |
-
if new_image.ndim == 2: # Grayscale
|
736 |
-
pass
|
737 |
-
elif new_image.shape[2] == 3: # Color
|
738 |
-
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
739 |
-
elif new_image.shape[2] == 4: # Transparency
|
740 |
-
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
741 |
-
return new_image
|
742 |
-
|
743 |
-
def enhance(self_, img, outscale=None):
|
744 |
-
# img: numpy
|
745 |
-
h_input, w_input = img.shape[0:2]
|
746 |
-
pil_img = self_.cv2pil(img)
|
747 |
-
pil_img = self_.__call__(pil_img)
|
748 |
-
cv_image = self_.pil2cv(pil_img)
|
749 |
-
if outscale is not None and outscale != float(self.netscale):
|
750 |
-
interpolation = cv2.INTER_AREA if outscale < float(self.netscale) else cv2.INTER_LANCZOS4
|
751 |
-
cv_image = cv2.resize(
|
752 |
-
cv_image, (
|
753 |
-
int(w_input * outscale),
|
754 |
-
int(h_input * outscale),
|
755 |
-
), interpolation=interpolation)
|
756 |
-
return cv_image, None
|
757 |
-
|
758 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
759 |
-
upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "upscale", upscale_model)).to(device)
|
760 |
-
upscaler.__class__ = UpscaleWithModel_Gfpgan
|
761 |
-
self.realesrganer = upscaler
|
762 |
timer.checkpoint("Initialize BG upscale model")
|
763 |
|
764 |
-
self.face_enhancer = None
|
765 |
-
|
766 |
if face_restoration:
|
767 |
-
|
768 |
-
|
769 |
-
resolution = 512
|
770 |
-
modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + modelInUse
|
771 |
-
from gfpgan.utils import GFPGANer
|
772 |
-
model_rootpath = os.path.join("weights", "face")
|
773 |
-
model_path = os.path.join(model_rootpath, face_restoration)
|
774 |
-
channel_multiplier = None
|
775 |
-
|
776 |
-
if face_restoration and face_restoration.startswith("GFPGANv1."):
|
777 |
-
arch = "clean"
|
778 |
-
channel_multiplier = 2
|
779 |
-
elif face_restoration and face_restoration.startswith("RestoreFormer"):
|
780 |
-
arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
|
781 |
-
elif face_restoration == 'CodeFormer.pth':
|
782 |
-
arch = "CodeFormer"
|
783 |
-
elif face_restoration.startswith("GPEN-BFR-"):
|
784 |
-
arch = "GPEN"
|
785 |
-
channel_multiplier = 2
|
786 |
-
if "1024" in face_restoration:
|
787 |
-
arch = "GPEN-1024"
|
788 |
-
resolution = 1024
|
789 |
-
elif "2048" in face_restoration:
|
790 |
-
arch = "GPEN-2048"
|
791 |
-
resolution = 2048
|
792 |
-
|
793 |
-
self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
|
794 |
timer.checkpoint("Initialize face enhancer model")
|
795 |
|
796 |
-
|
|
|
797 |
if not outputWithModelName:
|
798 |
-
modelInUse = ""
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
if
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
849 |
except Exception as error:
|
850 |
print(traceback.format_exc())
|
851 |
print("global exception: ", error)
|
852 |
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
|
854 |
def find_max_numbers(self, state_dict, findkeys):
|
855 |
if isinstance(findkeys, str):
|
@@ -914,12 +938,26 @@ class Timer:
|
|
914 |
now = time.perf_counter()
|
915 |
self.checkpoints.append((label, now))
|
916 |
|
917 |
-
def report(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
918 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
919 |
print("\n> Execution Time Report:")
|
920 |
|
921 |
# Determine the max label width for alignment
|
922 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
923 |
|
924 |
prev_time = self.start_time
|
925 |
for label, curr_time in self.checkpoints[1:]:
|
@@ -930,6 +968,112 @@ class Timer:
|
|
930 |
total_time = self.checkpoints[-1][1] - self.start_time
|
931 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
932 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
def main():
|
934 |
if torch.cuda.is_available():
|
935 |
torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
|
@@ -981,20 +1125,38 @@ def main():
|
|
981 |
with gr.Blocks(title = title, css = css) as demo:
|
982 |
gr.Markdown(value=f"<h1 style=\"text-align:center;\">{title}</h1><br>{description}")
|
983 |
with gr.Row():
|
984 |
-
with gr.Column(variant
|
985 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
986 |
face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
|
987 |
upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
|
988 |
upscale_scale = gr.Number(label="Rescaling factor", value=4)
|
989 |
face_detection = gr.Dropdown(["retinaface_resnet50", "YOLOv5l", "YOLOv5n"], type="value", value="retinaface_resnet50", label="Face Detection type")
|
990 |
face_detection_threshold = gr.Number(label="Face eye dist threshold", value=10, info="A threshold to filter out faces with too small an eye distance (e.g., side faces).")
|
991 |
face_detection_only_center = gr.Checkbox(value=False, label="Face detection only center", info="If set to True, only the face closest to the center of the image will be kept.")
|
992 |
-
with_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
993 |
with gr.Row():
|
994 |
-
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
995 |
clear = gr.ClearButton(
|
996 |
components=[
|
997 |
-
|
998 |
face_model,
|
999 |
upscale_model,
|
1000 |
upscale_scale,
|
@@ -1029,7 +1191,7 @@ def main():
|
|
1029 |
submit.click(
|
1030 |
upscale.inference,
|
1031 |
inputs=[
|
1032 |
-
|
1033 |
face_model,
|
1034 |
upscale_model,
|
1035 |
upscale_scale,
|
@@ -1046,4 +1208,8 @@ def main():
|
|
1046 |
|
1047 |
|
1048 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
1049 |
main()
|
|
|
8 |
import traceback
|
9 |
import math
|
10 |
import time
|
11 |
+
import ast
|
12 |
+
import argparse
|
13 |
from collections import defaultdict
|
14 |
from facexlib.utils.misc import download_from_url
|
15 |
from basicsr.utils.realesrganer import RealESRGANer
|
16 |
+
from utils.dataops import auto_split_upscale
|
17 |
|
18 |
+
input_images_limit = 5
|
19 |
# Define URLs and their corresponding local storage paths
|
20 |
face_models = {
|
21 |
"GFPGANv1.4.pth" : ["https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
|
|
|
460 |
|
461 |
|
462 |
class Upscale:
|
463 |
+
def __init__(self,):
|
464 |
+
self.scale = 4
|
465 |
+
self.modelInUse = ""
|
466 |
+
self.realesrganer = None
|
467 |
+
self.face_enhancer = None
|
468 |
+
|
469 |
+
def initBGUpscaleModel(self, upscale_model):
|
470 |
+
upscale_type, upscale_model = upscale_model.split(", ", 1)
|
471 |
+
download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
|
472 |
+
self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
|
473 |
+
netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
|
474 |
+
model = None
|
475 |
+
half = True if torch.cuda.is_available() else False
|
476 |
+
if upscale_type:
|
477 |
+
# The values of the following hyperparameters are based on the research findings of the Spandrel project.
|
478 |
+
# https://github.com/chaiNNer-org/spandrel/tree/main/libs/spandrel/spandrel/architectures
|
479 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
480 |
+
loadnet = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
|
481 |
+
if 'params_ema' in loadnet or 'params' in loadnet:
|
482 |
+
loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
|
483 |
+
|
484 |
+
if upscale_type == "SRVGG":
|
485 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
486 |
+
body_max_num = self.find_max_numbers(loadnet, "body")
|
487 |
+
num_feat = loadnet["body.0.weight"].shape[0]
|
488 |
+
num_in_ch = loadnet["body.0.weight"].shape[1]
|
489 |
+
num_conv = body_max_num // 2 - 1
|
490 |
+
model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=netscale, act_type='prelu')
|
491 |
+
elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
|
492 |
+
if upscale_type == "RRDB":
|
493 |
+
num_block = self.find_max_numbers(loadnet, "body") + 1
|
494 |
+
num_feat = loadnet["conv_first.weight"].shape[0]
|
495 |
+
else:
|
496 |
+
num_block = self.find_max_numbers(loadnet, "model.1.sub")
|
497 |
+
num_feat = loadnet["model.0.weight"].shape[0]
|
498 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=netscale, is_real_esrgan=upscale_type == "RRDB")
|
499 |
+
elif upscale_type == "DAT":
|
500 |
+
from basicsr.archs.dat_arch import DAT
|
501 |
+
half = False
|
502 |
+
|
503 |
+
in_chans = loadnet["conv_first.weight"].shape[1]
|
504 |
+
embed_dim = loadnet["conv_first.weight"].shape[0]
|
505 |
+
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
506 |
+
depth = [6] * num_layers
|
507 |
+
num_heads = [6] * num_layers
|
508 |
+
for i in range(num_layers):
|
509 |
+
depth[i] = self.find_max_numbers(loadnet, f"layers.{i}.blocks") + 1
|
510 |
+
num_heads[i] = loadnet[f"layers.{i}.blocks.1.attn.temperature"].shape[0] if depth[i] >= 2 else \
|
511 |
+
loadnet[f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight"].shape[0] * 2
|
512 |
+
|
513 |
+
upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else "pixelshuffledirect"
|
514 |
+
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
|
515 |
+
qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in loadnet
|
516 |
+
expansion_factor = float(loadnet["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim)
|
517 |
+
|
518 |
+
img_size = 64
|
519 |
+
if "layers.0.blocks.2.attn.attn_mask_0" in loadnet:
|
520 |
+
attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = loadnet["layers.0.blocks.2.attn.attn_mask_0"].shape
|
521 |
+
img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
|
522 |
+
|
523 |
+
split_size = [2, 4]
|
524 |
+
if "layers.0.blocks.0.attn.attns.0.rpe_biases" in loadnet:
|
525 |
+
split_sizes = loadnet["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
|
526 |
+
split_size = [int(x) for x in split_sizes]
|
527 |
+
|
528 |
+
model = DAT(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, split_size=split_size, depth=depth, num_heads=num_heads, expansion_factor=expansion_factor,
|
529 |
+
qkv_bias=qkv_bias, resi_connection=resi_connection, upsampler=upsampler, upscale=netscale)
|
530 |
+
elif upscale_type == "HAT":
|
531 |
+
half = False
|
532 |
+
from basicsr.archs.hat_arch import HAT
|
533 |
+
in_chans = loadnet["conv_first.weight"].shape[1]
|
534 |
+
embed_dim = loadnet["conv_first.weight"].shape[0]
|
535 |
+
window_size = int(math.sqrt(loadnet["relative_position_index_SA"].shape[0]))
|
536 |
+
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
537 |
+
depths = [6] * num_layers
|
538 |
+
num_heads = [6] * num_layers
|
539 |
+
for i in range(num_layers):
|
540 |
+
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
|
541 |
+
num_heads[i] = loadnet[f"layers.{i}.residual_group.overlap_attn.relative_position_bias_table"].shape[1]
|
542 |
+
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "identity"
|
543 |
+
|
544 |
+
compress_ratio = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.0.weight"].shape[0],)
|
545 |
+
squeeze_factor = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.3.attention.1.weight"].shape[0],)
|
546 |
+
|
547 |
+
qkv_bias = "layers.0.residual_group.blocks.0.attn.qkv.bias" in loadnet
|
548 |
+
patch_norm = "patch_embed.norm.weight" in loadnet
|
549 |
+
ape = "absolute_pos_embed" in loadnet
|
550 |
+
|
551 |
+
mlp_hidden_dim = int(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0])
|
552 |
+
mlp_ratio = mlp_hidden_dim / embed_dim
|
553 |
+
upsampler = "pixelshuffle"
|
554 |
+
|
555 |
+
model = HAT(img_size=64, patch_size=1, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
|
556 |
+
squeeze_factor=squeeze_factor, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm,
|
557 |
+
upsampler=upsampler, resi_connection=resi_connection, upscale=netscale,)
|
558 |
+
elif "RealPLKSR" in upscale_type:
|
559 |
+
from basicsr.archs.realplksr_arch import realplksr
|
560 |
+
half = False if "RealPLSKR" in upscale_model else half
|
561 |
+
use_ea = "feats.1.attn.f.0.weight" in loadnet
|
562 |
+
dim = loadnet["feats.0.weight"].shape[0]
|
563 |
+
num_feats = self.find_max_numbers(loadnet, "feats") + 1
|
564 |
+
n_blocks = num_feats - 3
|
565 |
+
kernel_size = loadnet["feats.1.lk.conv.weight"].shape[2]
|
566 |
+
split_ratio = loadnet["feats.1.lk.conv.weight"].shape[0] / dim
|
567 |
+
use_dysample = "to_img.init_pos" in loadnet
|
568 |
+
|
569 |
+
model = realplksr(upscaling_factor=netscale, dim=dim, n_blocks=n_blocks, kernel_size=kernel_size, split_ratio=split_ratio, use_ea=use_ea, dysample=use_dysample)
|
570 |
+
elif upscale_type == "DRCT":
|
571 |
+
half = False
|
572 |
+
from basicsr.archs.DRCT_arch import DRCT
|
573 |
+
|
574 |
+
in_chans = loadnet["conv_first.weight"].shape[1]
|
575 |
+
embed_dim = loadnet["conv_first.weight"].shape[0]
|
576 |
+
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
577 |
+
depths = (6,) * num_layers
|
578 |
+
num_heads = []
|
579 |
+
for i in range(num_layers):
|
580 |
+
num_heads.append(loadnet[f"layers.{i}.swin1.attn.relative_position_bias_table"].shape[1])
|
581 |
+
|
582 |
+
mlp_ratio = loadnet["layers.0.swin1.mlp.fc1.weight"].shape[0] / embed_dim
|
583 |
+
window_square = loadnet["layers.0.swin1.attn.relative_position_bias_table"].shape[0]
|
584 |
+
window_size = (math.isqrt(window_square) + 1) // 2
|
585 |
+
upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else ""
|
586 |
+
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else ""
|
587 |
+
qkv_bias = "layers.0.swin1.attn.qkv.bias" in loadnet
|
588 |
+
gc_adjust1 = loadnet["layers.0.adjust1.weight"].shape[0]
|
589 |
+
patch_norm = "patch_embed.norm.weight" in loadnet
|
590 |
+
ape = "absolute_pos_embed" in loadnet
|
591 |
+
|
592 |
+
model = DRCT(in_chans=in_chans, img_size= 64, window_size=window_size, compress_ratio=3,squeeze_factor=30,
|
593 |
+
conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths=depths, embed_dim=embed_dim, num_heads=num_heads,
|
594 |
+
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm, use_checkpoint=False,
|
595 |
+
upscale=netscale, upsampler=upsampler, resi_connection=resi_connection, gc =gc_adjust1,)
|
596 |
+
elif upscale_type == "ATD":
|
597 |
+
half = False
|
598 |
+
from basicsr.archs.atd_arch import ATD
|
599 |
+
in_chans = loadnet["conv_first.weight"].shape[1]
|
600 |
+
embed_dim = loadnet["conv_first.weight"].shape[0]
|
601 |
+
window_size = math.isqrt(loadnet["relative_position_index_SA"].shape[0])
|
602 |
+
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
603 |
+
depths = [6] * num_layers
|
604 |
+
num_heads = [6] * num_layers
|
605 |
+
for i in range(num_layers):
|
606 |
+
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.layers") + 1
|
607 |
+
num_heads[i] = loadnet[f"layers.{i}.residual_group.layers.0.attn_win.relative_position_bias_table"].shape[1]
|
608 |
+
num_tokens = loadnet["layers.0.residual_group.layers.0.attn_atd.scale"].shape[0]
|
609 |
+
reducted_dim = loadnet["layers.0.residual_group.layers.0.attn_atd.wq.weight"].shape[0]
|
610 |
+
convffn_kernel_size = loadnet["layers.0.residual_group.layers.0.convffn.dwconv.depthwise_conv.0.weight"].shape[2]
|
611 |
+
mlp_ratio = (loadnet["layers.0.residual_group.layers.0.convffn.fc1.weight"].shape[0] / embed_dim)
|
612 |
+
qkv_bias = "layers.0.residual_group.layers.0.wqkv.bias" in loadnet
|
613 |
+
ape = "absolute_pos_embed" in loadnet
|
614 |
+
patch_norm = "patch_embed.norm.weight" in loadnet
|
615 |
+
resi_connection = "1conv" if "layers.0.conv.weight" in loadnet else "3conv"
|
616 |
+
|
617 |
+
if "conv_up1.weight" in loadnet:
|
618 |
+
upsampler = "nearest+conv"
|
619 |
+
elif "conv_before_upsample.0.weight" in loadnet:
|
620 |
+
upsampler = "pixelshuffle"
|
621 |
+
elif "conv_last.weight" in loadnet:
|
622 |
+
upsampler = ""
|
623 |
+
else:
|
624 |
+
upsampler = "pixelshuffledirect"
|
625 |
+
|
626 |
+
is_light = upsampler == "pixelshuffledirect" and embed_dim == 48
|
627 |
+
category_size = 128 if is_light else 256
|
628 |
+
|
629 |
+
model = ATD(in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, category_size=category_size,
|
630 |
+
num_tokens=num_tokens, reducted_dim=reducted_dim, convffn_kernel_size=convffn_kernel_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape,
|
631 |
+
patch_norm=patch_norm, use_checkpoint=False, upscale=netscale, upsampler=upsampler, resi_connection='1conv',)
|
632 |
+
elif upscale_type == "MoSR":
|
633 |
+
from basicsr.archs.mosr_arch import mosr
|
634 |
+
n_block = self.find_max_numbers(loadnet, "gblocks") - 5
|
635 |
+
in_ch = loadnet["gblocks.0.weight"].shape[1]
|
636 |
+
out_ch = loadnet["upsampler.end_conv.weight"].shape[0] if "upsampler.init_pos" in loadnet else in_ch
|
637 |
+
dim = loadnet["gblocks.0.weight"].shape[0]
|
638 |
+
expansion_ratio = (loadnet["gblocks.1.fc1.weight"].shape[0] / loadnet["gblocks.1.fc1.weight"].shape[1]) / 2
|
639 |
+
conv_ratio = loadnet["gblocks.1.conv.weight"].shape[0] / dim
|
640 |
+
kernel_size = loadnet["gblocks.1.conv.weight"].shape[2]
|
641 |
+
upsampler = "dys" if "upsampler.init_pos" in loadnet else ("gps" if "upsampler.in_to_k.weight" in loadnet else "ps")
|
642 |
+
|
643 |
+
model = mosr(in_ch = in_ch, out_ch = out_ch, upscale = netscale, n_block = n_block, dim = dim,
|
644 |
+
upsampler = upsampler, kernel_size = kernel_size, expansion_ratio = expansion_ratio, conv_ratio = conv_ratio,)
|
645 |
+
elif upscale_type == "SRFormer":
|
646 |
+
half = False
|
647 |
+
from basicsr.archs.srformer_arch import SRFormer
|
648 |
+
in_chans = loadnet["conv_first.weight"].shape[1]
|
649 |
+
embed_dim = loadnet["conv_first.weight"].shape[0]
|
650 |
+
ape = "absolute_pos_embed" in loadnet
|
651 |
+
patch_norm = "patch_embed.norm.weight" in loadnet
|
652 |
+
qkv_bias = "layers.0.residual_group.blocks.0.attn.q.bias" in loadnet
|
653 |
+
mlp_ratio = float(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0] / embed_dim)
|
654 |
+
|
655 |
+
num_layers = self.find_max_numbers(loadnet, "layers") + 1
|
656 |
+
depths = [6] * num_layers
|
657 |
+
num_heads = [6] * num_layers
|
658 |
+
for i in range(num_layers):
|
659 |
+
depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
|
660 |
+
num_heads[i] = loadnet[f"layers.{i}.residual_group.blocks.0.attn.relative_position_bias_table"].shape[1]
|
661 |
+
|
662 |
+
if "conv_hr.weight" in loadnet:
|
663 |
+
upsampler = "nearest+conv"
|
664 |
+
elif "conv_before_upsample.0.weight" in loadnet:
|
665 |
+
upsampler = "pixelshuffle"
|
666 |
+
elif "upsample.0.weight" in loadnet:
|
667 |
+
upsampler = "pixelshuffledirect"
|
668 |
+
resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
|
669 |
+
|
670 |
+
window_size = int(math.sqrt(loadnet["layers.0.residual_group.blocks.0.attn.relative_position_bias_table"].shape[0])) + 1
|
671 |
+
|
672 |
+
if "layers.0.residual_group.blocks.1.attn_mask" in loadnet:
|
673 |
+
attn_mask_0 = loadnet["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
|
674 |
+
patches_resolution = int(math.sqrt(attn_mask_0) * window_size)
|
675 |
+
else:
|
676 |
+
patches_resolution = window_size
|
677 |
+
if ape:
|
678 |
+
pos_embed_value = loadnet.get("absolute_pos_embed", [None, None])[1]
|
679 |
+
if pos_embed_value:
|
680 |
+
patches_resolution = int(math.sqrt(pos_embed_value))
|
681 |
+
|
682 |
+
img_size = patches_resolution
|
683 |
+
if img_size % window_size != 0:
|
684 |
+
for nice_number in [512, 256, 128, 96, 64, 48, 32, 24, 16]:
|
685 |
+
if nice_number % window_size != 0:
|
686 |
+
nice_number += window_size - (nice_number % window_size)
|
687 |
+
if nice_number == patches_resolution:
|
688 |
+
img_size = nice_number
|
689 |
+
break
|
690 |
+
|
691 |
+
model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
|
692 |
+
qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection,)
|
693 |
+
|
694 |
+
if model:
|
695 |
+
self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
696 |
+
elif upscale_model:
|
697 |
+
import PIL
|
698 |
+
from image_gen_aux import UpscaleWithModel
|
699 |
+
class UpscaleWithModel_Gfpgan(UpscaleWithModel):
|
700 |
+
def cv2pil(self, image):
|
701 |
+
''' OpenCV type -> PIL type
|
702 |
+
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
703 |
+
'''
|
704 |
+
new_image = image.copy()
|
705 |
+
if new_image.ndim == 2: # Grayscale
|
706 |
+
pass
|
707 |
+
elif new_image.shape[2] == 3: # Color
|
708 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
|
709 |
+
elif new_image.shape[2] == 4: # Transparency
|
710 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
|
711 |
+
new_image = PIL.Image.fromarray(new_image)
|
712 |
+
return new_image
|
713 |
+
|
714 |
+
def pil2cv(self, image):
|
715 |
+
''' PIL type -> OpenCV type
|
716 |
+
https://qiita.com/derodero24/items/f22c22b22451609908ee
|
717 |
+
'''
|
718 |
+
new_image = np.array(image, dtype=np.uint8)
|
719 |
+
if new_image.ndim == 2: # Grayscale
|
720 |
+
pass
|
721 |
+
elif new_image.shape[2] == 3: # Color
|
722 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
723 |
+
elif new_image.shape[2] == 4: # Transparency
|
724 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
|
725 |
+
return new_image
|
726 |
+
|
727 |
+
def enhance(self, img, outscale=None):
|
728 |
+
# img: numpy
|
729 |
+
h_input, w_input = img.shape[0:2]
|
730 |
+
pil_img = self.cv2pil(img)
|
731 |
+
pil_img = self.__call__(pil_img)
|
732 |
+
cv_image = self.pil2cv(pil_img)
|
733 |
+
if outscale is not None and outscale != float(netscale):
|
734 |
+
interpolation = cv2.INTER_AREA if outscale < float(netscale) else cv2.INTER_LANCZOS4
|
735 |
+
cv_image = cv2.resize(
|
736 |
+
cv_image, (
|
737 |
+
int(w_input * outscale),
|
738 |
+
int(h_input * outscale),
|
739 |
+
), interpolation=interpolation)
|
740 |
+
return cv_image, None
|
741 |
+
|
742 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
743 |
+
upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "upscale", upscale_model)).to(device)
|
744 |
+
upscaler.__class__ = UpscaleWithModel_Gfpgan
|
745 |
+
self.realesrganer = upscaler
|
746 |
+
|
747 |
+
|
748 |
+
def initFaceEnhancerModel(self, face_restoration, face_detection, face_detection_threshold: any, face_detection_only_center: bool):
|
749 |
+
download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
|
750 |
+
|
751 |
+
resolution = 512
|
752 |
+
self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse
|
753 |
+
from gfpgan.utils import GFPGANer
|
754 |
+
model_rootpath = os.path.join("weights", "face")
|
755 |
+
model_path = os.path.join(model_rootpath, face_restoration)
|
756 |
+
channel_multiplier = None
|
757 |
+
|
758 |
+
if face_restoration and face_restoration.startswith("GFPGANv1."):
|
759 |
+
arch = "clean"
|
760 |
+
channel_multiplier = 2
|
761 |
+
elif face_restoration and face_restoration.startswith("RestoreFormer"):
|
762 |
+
arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
|
763 |
+
elif face_restoration == 'CodeFormer.pth':
|
764 |
+
arch = "CodeFormer"
|
765 |
+
elif face_restoration.startswith("GPEN-BFR-"):
|
766 |
+
arch = "GPEN"
|
767 |
+
channel_multiplier = 2
|
768 |
+
if "1024" in face_restoration:
|
769 |
+
arch = "GPEN-1024"
|
770 |
+
resolution = 1024
|
771 |
+
elif "2048" in face_restoration:
|
772 |
+
arch = "GPEN-2048"
|
773 |
+
resolution = 2048
|
774 |
+
|
775 |
+
self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
|
776 |
+
|
777 |
+
|
778 |
+
def inference(self, gallery, face_restoration, upscale_model, scale: float, face_detection, face_detection_threshold: any, face_detection_only_center: bool, outputWithModelName: bool):
|
779 |
try:
|
780 |
+
if not gallery or (not face_restoration and not upscale_model):
|
781 |
raise ValueError("Invalid parameter setting")
|
782 |
+
|
783 |
+
print(face_restoration, upscale_model, scale, f"gallery length: {len(gallery)}")
|
784 |
|
785 |
timer = Timer() # Create a timer
|
786 |
self.scale = scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
787 |
|
|
|
|
|
|
|
788 |
if upscale_model:
|
789 |
+
self.initBGUpscaleModel(upscale_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
timer.checkpoint("Initialize BG upscale model")
|
791 |
|
|
|
|
|
792 |
if face_restoration:
|
793 |
+
self.initFaceEnhancerModel(face_restoration, face_detection, face_detection_threshold, face_detection_only_center)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
timer.checkpoint("Initialize face enhancer model")
|
795 |
|
796 |
+
timer.report()
|
797 |
+
|
798 |
if not outputWithModelName:
|
799 |
+
self.modelInUse = ""
|
800 |
+
|
801 |
+
files = []
|
802 |
+
is_auto_split_upscale = True
|
803 |
+
# Dictionary to track counters for each filename
|
804 |
+
name_counters = defaultdict(int)
|
805 |
+
for gallery_idx, value in enumerate(gallery):
|
806 |
+
try:
|
807 |
+
img_path = str(value[0])
|
808 |
+
img_name = os.path.basename(img_path)
|
809 |
+
# Increment the counter for the current name
|
810 |
+
name_counters[img_name] += 1
|
811 |
+
if name_counters[img_name] > 1:
|
812 |
+
img_name = f"{img_name}_{name_counters[img_name]:02d}"
|
813 |
+
basename, extension = os.path.splitext(img_name)
|
814 |
+
|
815 |
+
img_cv2 = cv2.imdecode(np.fromfile(img_path, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
|
816 |
+
|
817 |
+
img_mode = "RGBA" if len(img_cv2.shape) == 3 and img_cv2.shape[2] == 4 else None
|
818 |
+
if len(img_cv2.shape) == 2: # for gray inputs
|
819 |
+
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_GRAY2BGR)
|
820 |
+
print(f"> image{gallery_idx:02d}, {img_cv2.shape}:")
|
821 |
+
|
822 |
+
bg_upsample_img = None
|
823 |
+
if self.realesrganer and hasattr(self.realesrganer, "enhance"):
|
824 |
+
bg_upsample_img, _ = auto_split_upscale(img_cv2, self.realesrganer.enhance, self.scale) if is_auto_split_upscale else self.realesrganer.enhance(img_cv2, outscale=self.scale)
|
825 |
+
timer.checkpoint(f"image{gallery_idx:02d}, Background upscale Section")
|
826 |
+
|
827 |
+
if self.face_enhancer:
|
828 |
+
cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
|
829 |
+
# save faces
|
830 |
+
if cropped_faces and restored_aligned:
|
831 |
+
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
|
832 |
+
# save cropped face
|
833 |
+
save_crop_path = f"output/{basename}{idx:02d}_cropped_faces{self.modelInUse}.png"
|
834 |
+
self.imwriteUTF8(save_crop_path, cropped_face)
|
835 |
+
# save restored face
|
836 |
+
save_restore_path = f"output/{basename}{idx:02d}_restored_faces{self.modelInUse}.png"
|
837 |
+
self.imwriteUTF8(save_restore_path, restored_face)
|
838 |
+
# save comparison image
|
839 |
+
save_cmp_path = f"output/{basename}{idx:02d}_cmp{self.modelInUse}.png"
|
840 |
+
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
|
841 |
+
self.imwriteUTF8(save_cmp_path, cmp_img)
|
842 |
+
|
843 |
+
files.append(save_crop_path)
|
844 |
+
files.append(save_restore_path)
|
845 |
+
files.append(save_cmp_path)
|
846 |
+
timer.checkpoint(f"image{gallery_idx:02d}, Face enhancer Section")
|
847 |
+
|
848 |
+
restored_img = bg_upsample_img
|
849 |
+
timer.report()
|
850 |
+
|
851 |
+
if not extension:
|
852 |
+
extension = ".png" if img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
|
853 |
+
save_path = f"output/{basename}{self.modelInUse}{extension}"
|
854 |
+
self.imwriteUTF8(save_path, restored_img)
|
855 |
+
|
856 |
+
restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
|
857 |
+
files.append(save_path)
|
858 |
+
except RuntimeError as error:
|
859 |
+
print(traceback.format_exc())
|
860 |
+
print('Error', error)
|
861 |
+
|
862 |
+
timer.report_all() # Print all recorded times
|
863 |
except Exception as error:
|
864 |
print(traceback.format_exc())
|
865 |
print("global exception: ", error)
|
866 |
return None, None
|
867 |
+
finally:
|
868 |
+
if self.face_enhancer:
|
869 |
+
self.face_enhancer._cleanup()
|
870 |
+
else:
|
871 |
+
# Free GPU memory and clean up resources
|
872 |
+
torch.cuda.empty_cache()
|
873 |
+
gc.collect()
|
874 |
+
|
875 |
+
return files, files
|
876 |
+
|
877 |
|
878 |
def find_max_numbers(self, state_dict, findkeys):
|
879 |
if isinstance(findkeys, str):
|
|
|
938 |
now = time.perf_counter()
|
939 |
self.checkpoints.append((label, now))
|
940 |
|
941 |
+
def report(self, is_clear_checkpoints = True):
|
942 |
+
# Determine the max label width for alignment
|
943 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
944 |
+
|
945 |
+
prev_time = self.checkpoints[0][1]
|
946 |
+
for label, curr_time in self.checkpoints[1:]:
|
947 |
+
elapsed = curr_time - prev_time
|
948 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
949 |
+
prev_time = curr_time
|
950 |
+
|
951 |
+
if is_clear_checkpoints:
|
952 |
+
self.checkpoints.clear()
|
953 |
+
self.checkpoint() # Store checkpoints
|
954 |
+
|
955 |
+
def report_all(self):
|
956 |
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
957 |
print("\n> Execution Time Report:")
|
958 |
|
959 |
# Determine the max label width for alignment
|
960 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
|
961 |
|
962 |
prev_time = self.start_time
|
963 |
for label, curr_time in self.checkpoints[1:]:
|
|
|
968 |
total_time = self.checkpoints[-1][1] - self.start_time
|
969 |
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
970 |
|
971 |
+
self.checkpoints.clear()
|
972 |
+
|
973 |
+
def restart(self):
|
974 |
+
self.start_time = time.perf_counter() # Record the start time
|
975 |
+
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
976 |
+
|
977 |
+
|
978 |
+
def get_selection_from_gallery(selected_state: gr.SelectData):
|
979 |
+
"""
|
980 |
+
Extracts the selected image path and caption from the gallery selection state.
|
981 |
+
|
982 |
+
Args:
|
983 |
+
selected_state (gr.SelectData): The selection state from a Gradio gallery,
|
984 |
+
containing information about the selected image.
|
985 |
+
|
986 |
+
Returns:
|
987 |
+
tuple: A tuple containing:
|
988 |
+
- str: The file path of the selected image.
|
989 |
+
- str: The caption of the selected image.
|
990 |
+
If `selected_state` is None or invalid, it returns `None`.
|
991 |
+
"""
|
992 |
+
if not selected_state:
|
993 |
+
return selected_state
|
994 |
+
|
995 |
+
return (selected_state.value["image"]["path"], selected_state.value["caption"])
|
996 |
+
|
997 |
+
def limit_gallery(gallery):
|
998 |
+
"""
|
999 |
+
Ensures the gallery does not exceed input_images_limit.
|
1000 |
+
|
1001 |
+
Args:
|
1002 |
+
gallery (list): Current gallery images.
|
1003 |
+
|
1004 |
+
Returns:
|
1005 |
+
list: Trimmed gallery with a maximum of input_images_limit images.
|
1006 |
+
"""
|
1007 |
+
return gallery[:input_images_limit] if input_images_limit > 0 and gallery else gallery
|
1008 |
+
|
1009 |
+
def append_gallery(gallery: list, image: str):
|
1010 |
+
"""
|
1011 |
+
Append a single image to the gallery while respecting input_images_limit.
|
1012 |
+
|
1013 |
+
Parameters:
|
1014 |
+
- gallery (list): Existing list of images. If None, initializes an empty list.
|
1015 |
+
- image (str): The image to be added. If None or empty, no action is taken.
|
1016 |
+
|
1017 |
+
Returns:
|
1018 |
+
- list: Updated gallery.
|
1019 |
+
"""
|
1020 |
+
if gallery is None:
|
1021 |
+
gallery = []
|
1022 |
+
if not image:
|
1023 |
+
return gallery, None
|
1024 |
+
|
1025 |
+
if input_images_limit == -1 or len(gallery) < input_images_limit:
|
1026 |
+
gallery.append(image)
|
1027 |
+
|
1028 |
+
return gallery, None
|
1029 |
+
|
1030 |
+
|
1031 |
+
def extend_gallery(gallery: list, images):
|
1032 |
+
"""
|
1033 |
+
Extend the gallery with new images while respecting the input_images_limit.
|
1034 |
+
|
1035 |
+
Parameters:
|
1036 |
+
- gallery (list): Existing list of images. If None, initializes an empty list.
|
1037 |
+
- images (list): New images to be added. If None, defaults to an empty list.
|
1038 |
+
|
1039 |
+
Returns:
|
1040 |
+
- list: Updated gallery with the new images added.
|
1041 |
+
"""
|
1042 |
+
if gallery is None:
|
1043 |
+
gallery = []
|
1044 |
+
if not images:
|
1045 |
+
return gallery
|
1046 |
+
|
1047 |
+
# Add new images to the gallery
|
1048 |
+
gallery.extend(images)
|
1049 |
+
|
1050 |
+
# Trim gallery to the specified limit, if applicable
|
1051 |
+
if input_images_limit > 0:
|
1052 |
+
gallery = gallery[:input_images_limit]
|
1053 |
+
|
1054 |
+
return gallery
|
1055 |
+
|
1056 |
+
def remove_image_from_gallery(gallery: list, selected_image: str):
|
1057 |
+
"""
|
1058 |
+
Removes a selected image from the gallery if it exists.
|
1059 |
+
|
1060 |
+
Args:
|
1061 |
+
gallery (list): The current list of images in the gallery.
|
1062 |
+
selected_image (str): The image to be removed, represented as a string
|
1063 |
+
that needs to be parsed into a tuple.
|
1064 |
+
|
1065 |
+
Returns:
|
1066 |
+
list: The updated gallery after removing the selected image.
|
1067 |
+
"""
|
1068 |
+
if not gallery or not selected_image:
|
1069 |
+
return gallery
|
1070 |
+
|
1071 |
+
selected_image = ast.literal_eval(selected_image) # Use ast.literal_eval to parse text into a tuple in remove_image_from_gallery.
|
1072 |
+
# Remove the selected image from the gallery
|
1073 |
+
if selected_image in gallery:
|
1074 |
+
gallery.remove(selected_image)
|
1075 |
+
return gallery
|
1076 |
+
|
1077 |
def main():
|
1078 |
if torch.cuda.is_available():
|
1079 |
torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
|
|
|
1125 |
with gr.Blocks(title = title, css = css) as demo:
|
1126 |
gr.Markdown(value=f"<h1 style=\"text-align:center;\">{title}</h1><br>{description}")
|
1127 |
with gr.Row():
|
1128 |
+
with gr.Column(variant="panel"):
|
1129 |
+
submit = gr.Button(value="Submit", variant="primary", size="lg")
|
1130 |
+
# Create an Image component for uploading images
|
1131 |
+
input_image = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", format="png", height=150)
|
1132 |
+
with gr.Row():
|
1133 |
+
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
1134 |
+
remove_button = gr.Button("Remove Selected Image", size="sm")
|
1135 |
+
input_gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images" + (f"(The online environment image limit is {input_images_limit})" if input_images_limit > 0 else ""))
|
1136 |
face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
|
1137 |
upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
|
1138 |
upscale_scale = gr.Number(label="Rescaling factor", value=4)
|
1139 |
face_detection = gr.Dropdown(["retinaface_resnet50", "YOLOv5l", "YOLOv5n"], type="value", value="retinaface_resnet50", label="Face Detection type")
|
1140 |
face_detection_threshold = gr.Number(label="Face eye dist threshold", value=10, info="A threshold to filter out faces with too small an eye distance (e.g., side faces).")
|
1141 |
face_detection_only_center = gr.Checkbox(value=False, label="Face detection only center", info="If set to True, only the face closest to the center of the image will be kept.")
|
1142 |
+
with_model_name = gr.Checkbox(label="Output image files name with model name", value=True)
|
1143 |
+
|
1144 |
+
# Define the event listener to add the uploaded image to the gallery
|
1145 |
+
input_image.change(append_gallery, inputs=[input_gallery, input_image], outputs=[input_gallery, input_image])
|
1146 |
+
# When the upload button is clicked, add the new images to the gallery
|
1147 |
+
upload_button.upload(extend_gallery, inputs=[input_gallery, upload_button], outputs=input_gallery)
|
1148 |
+
# Event to update the selected image when an image is clicked in the gallery
|
1149 |
+
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
1150 |
+
input_gallery.select(get_selection_from_gallery, inputs=None, outputs=selected_image)
|
1151 |
+
# Trigger update when gallery changes
|
1152 |
+
input_gallery.change(limit_gallery, input_gallery, input_gallery)
|
1153 |
+
# Event to remove a selected image from the gallery
|
1154 |
+
remove_button.click(remove_image_from_gallery, inputs=[input_gallery, selected_image], outputs=input_gallery)
|
1155 |
+
|
1156 |
with gr.Row():
|
|
|
1157 |
clear = gr.ClearButton(
|
1158 |
components=[
|
1159 |
+
input_gallery,
|
1160 |
face_model,
|
1161 |
upscale_model,
|
1162 |
upscale_scale,
|
|
|
1191 |
submit.click(
|
1192 |
upscale.inference,
|
1193 |
inputs=[
|
1194 |
+
input_gallery,
|
1195 |
face_model,
|
1196 |
upscale_model,
|
1197 |
upscale_scale,
|
|
|
1208 |
|
1209 |
|
1210 |
if __name__ == "__main__":
|
1211 |
+
parser = argparse.ArgumentParser()
|
1212 |
+
parser.add_argument("--input_images_limit", type=int, default=5)
|
1213 |
+
args = parser.parse_args()
|
1214 |
+
input_images_limit = args.input_images_limit
|
1215 |
main()
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
|
3 |
-
gradio==5.
|
4 |
|
5 |
basicsr @ git+https://github.com/avan06/BasicSR
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
|
3 |
+
gradio==5.16.0
|
4 |
|
5 |
basicsr @ git+https://github.com/avan06/BasicSR
|
6 |
facexlib @ git+https://github.com/avan06/facexlib
|
utils/dataops.py
CHANGED
@@ -53,14 +53,10 @@ def auto_split_upscale(
|
|
53 |
# Check to see if its actually the CUDA out of memory error
|
54 |
if "CUDA" in str(e):
|
55 |
print("RuntimeError: CUDA out of memory...")
|
56 |
-
# Collect garbage (clear VRAM)
|
57 |
-
torch.cuda.empty_cache()
|
58 |
-
gc.collect()
|
59 |
# Re-raise the exception if not an OOM error
|
60 |
else:
|
61 |
raise RuntimeError(e)
|
62 |
-
|
63 |
-
# Free GPU memory and clean up resources
|
64 |
torch.cuda.empty_cache()
|
65 |
gc.collect()
|
66 |
|
|
|
53 |
# Check to see if its actually the CUDA out of memory error
|
54 |
if "CUDA" in str(e):
|
55 |
print("RuntimeError: CUDA out of memory...")
|
|
|
|
|
|
|
56 |
# Re-raise the exception if not an OOM error
|
57 |
else:
|
58 |
raise RuntimeError(e)
|
59 |
+
# Collect garbage (clear VRAM)
|
|
|
60 |
torch.cuda.empty_cache()
|
61 |
gc.collect()
|
62 |
|
webui.bat
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
@echo off
|
2 |
|
3 |
:: The source of the webui.bat file is stable-diffusion-webui
|
|
|
4 |
|
5 |
if not defined PYTHON (set PYTHON=python)
|
6 |
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
|
|
|
1 |
@echo off
|
2 |
|
3 |
:: The source of the webui.bat file is stable-diffusion-webui
|
4 |
+
set COMMANDLINE_ARGS=--input_images_limit -1
|
5 |
|
6 |
if not defined PYTHON (set PYTHON=python)
|
7 |
if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
|