avans06 commited on
Commit
a01cc5b
·
1 Parent(s): 20b04f8

Added support for using gallery as an image input.

Browse files

You can now upload multiple images at once. Due to the slow execution in the online environment, the default image limit is set to 5. You can modify this limit by setting input_images_limit. If set to -1, there will be no limit.

ex.
python app.py --input_images_limit -1

Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +1 -1
  3. app.py +551 -385
  4. requirements.txt +1 -1
  5. utils/dataops.py +1 -5
  6. webui.bat +1 -0
.gitignore CHANGED
@@ -5,6 +5,7 @@ results/*
5
  tb_logger/*
6
  wandb/*
7
  tmp/*
 
8
 
9
  version.py
10
 
 
5
  tb_logger/*
6
  wandb/*
7
  tmp/*
8
+ /venv
9
 
10
  version.py
11
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -8,11 +8,14 @@ import torch
8
  import traceback
9
  import math
10
  import time
 
 
11
  from collections import defaultdict
12
  from facexlib.utils.misc import download_from_url
13
  from basicsr.utils.realesrganer import RealESRGANer
 
14
 
15
-
16
  # Define URLs and their corresponding local storage paths
17
  face_models = {
18
  "GFPGANv1.4.pth" : ["https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
@@ -457,399 +460,420 @@ typed_upscale_models = {get_model_type(key): value[0] for key, value in upscale_
457
 
458
 
459
  class Upscale:
460
- def inference(self, img, face_restoration, upscale_model, scale: float, face_detection, face_detection_threshold: any, face_detection_only_center: bool, outputWithModelName: bool):
461
- print(img)
462
- print(face_restoration, upscale_model, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  try:
464
- if not img or (not face_restoration and not upscale_model):
465
  raise ValueError("Invalid parameter setting")
 
 
466
 
467
  timer = Timer() # Create a timer
468
  self.scale = scale
469
- self.img_name = os.path.basename(str(img))
470
- self.basename, self.extension = os.path.splitext(self.img_name)
471
-
472
- img = cv2.imdecode(np.fromfile(img, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
473
-
474
- self.img_mode = "RGBA" if len(img.shape) == 3 and img.shape[2] == 4 else None
475
- if len(img.shape) == 2: # for gray inputs
476
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
477
-
478
- self.h_input, self.w_input = img.shape[0:2]
479
- self.realesrganer = None
480
 
481
- modelInUse = ""
482
- upscale_type = None
483
- is_auto_split_upscale = True
484
  if upscale_model:
485
- upscale_type, upscale_model = upscale_model.split(", ", 1)
486
- download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
487
- modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
488
-
489
- self.netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
490
- model = None
491
- half = True if torch.cuda.is_available() else False
492
- if upscale_type:
493
- # The values of the following hyperparameters are based on the research findings of the Spandrel project.
494
- # https://github.com/chaiNNer-org/spandrel/tree/main/libs/spandrel/spandrel/architectures
495
- from basicsr.archs.rrdbnet_arch import RRDBNet
496
- loadnet = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
497
- if 'params_ema' in loadnet or 'params' in loadnet:
498
- loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
499
-
500
- if upscale_type == "SRVGG":
501
- from basicsr.archs.srvgg_arch import SRVGGNetCompact
502
- body_max_num = self.find_max_numbers(loadnet, "body")
503
- num_feat = loadnet["body.0.weight"].shape[0]
504
- num_in_ch = loadnet["body.0.weight"].shape[1]
505
- num_conv = body_max_num // 2 - 1
506
- model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=self.netscale, act_type='prelu')
507
- elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
508
- if upscale_type == "RRDB":
509
- num_block = self.find_max_numbers(loadnet, "body") + 1
510
- num_feat = loadnet["conv_first.weight"].shape[0]
511
- else:
512
- num_block = self.find_max_numbers(loadnet, "model.1.sub")
513
- num_feat = loadnet["model.0.weight"].shape[0]
514
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=self.netscale, is_real_esrgan=upscale_type == "RRDB")
515
- elif upscale_type == "DAT":
516
- from basicsr.archs.dat_arch import DAT
517
- half = False
518
-
519
- in_chans = loadnet["conv_first.weight"].shape[1]
520
- embed_dim = loadnet["conv_first.weight"].shape[0]
521
- num_layers = self.find_max_numbers(loadnet, "layers") + 1
522
- depth = [6] * num_layers
523
- num_heads = [6] * num_layers
524
- for i in range(num_layers):
525
- depth[i] = self.find_max_numbers(loadnet, f"layers.{i}.blocks") + 1
526
- num_heads[i] = loadnet[f"layers.{i}.blocks.1.attn.temperature"].shape[0] if depth[i] >= 2 else \
527
- loadnet[f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight"].shape[0] * 2
528
-
529
- upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else "pixelshuffledirect"
530
- resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
531
- qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in loadnet
532
- expansion_factor = float(loadnet["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim)
533
-
534
- img_size = 64
535
- if "layers.0.blocks.2.attn.attn_mask_0" in loadnet:
536
- attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = loadnet["layers.0.blocks.2.attn.attn_mask_0"].shape
537
- img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
538
-
539
- split_size = [2, 4]
540
- if "layers.0.blocks.0.attn.attns.0.rpe_biases" in loadnet:
541
- split_sizes = loadnet["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
542
- split_size = [int(x) for x in split_sizes]
543
-
544
- model = DAT(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, split_size=split_size, depth=depth, num_heads=num_heads, expansion_factor=expansion_factor,
545
- qkv_bias=qkv_bias, resi_connection=resi_connection, upsampler=upsampler, upscale=self.netscale)
546
- elif upscale_type == "HAT":
547
- half = False
548
- from basicsr.archs.hat_arch import HAT
549
- in_chans = loadnet["conv_first.weight"].shape[1]
550
- embed_dim = loadnet["conv_first.weight"].shape[0]
551
- window_size = int(math.sqrt(loadnet["relative_position_index_SA"].shape[0]))
552
- num_layers = self.find_max_numbers(loadnet, "layers") + 1
553
- depths = [6] * num_layers
554
- num_heads = [6] * num_layers
555
- for i in range(num_layers):
556
- depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
557
- num_heads[i] = loadnet[f"layers.{i}.residual_group.overlap_attn.relative_position_bias_table"].shape[1]
558
- resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "identity"
559
-
560
- compress_ratio = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.0.weight"].shape[0],)
561
- squeeze_factor = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.3.attention.1.weight"].shape[0],)
562
-
563
- qkv_bias = "layers.0.residual_group.blocks.0.attn.qkv.bias" in loadnet
564
- patch_norm = "patch_embed.norm.weight" in loadnet
565
- ape = "absolute_pos_embed" in loadnet
566
-
567
- mlp_hidden_dim = int(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0])
568
- mlp_ratio = mlp_hidden_dim / embed_dim
569
- upsampler = "pixelshuffle"
570
-
571
- model = HAT(img_size=64, patch_size=1, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
572
- squeeze_factor=squeeze_factor, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm,
573
- upsampler=upsampler, resi_connection=resi_connection, upscale=self.netscale,)
574
- elif "RealPLKSR" in upscale_type:
575
- from basicsr.archs.realplksr_arch import realplksr
576
- half = False if "RealPLSKR" in upscale_model else half
577
- use_ea = "feats.1.attn.f.0.weight" in loadnet
578
- dim = loadnet["feats.0.weight"].shape[0]
579
- num_feats = self.find_max_numbers(loadnet, "feats") + 1
580
- n_blocks = num_feats - 3
581
- kernel_size = loadnet["feats.1.lk.conv.weight"].shape[2]
582
- split_ratio = loadnet["feats.1.lk.conv.weight"].shape[0] / dim
583
- use_dysample = "to_img.init_pos" in loadnet
584
-
585
- model = realplksr(upscaling_factor=self.netscale, dim=dim, n_blocks=n_blocks, kernel_size=kernel_size, split_ratio=split_ratio, use_ea=use_ea, dysample=use_dysample)
586
- elif upscale_type == "DRCT":
587
- half = False
588
- from basicsr.archs.DRCT_arch import DRCT
589
-
590
- in_chans = loadnet["conv_first.weight"].shape[1]
591
- embed_dim = loadnet["conv_first.weight"].shape[0]
592
- num_layers = self.find_max_numbers(loadnet, "layers") + 1
593
- depths = (6,) * num_layers
594
- num_heads = []
595
- for i in range(num_layers):
596
- num_heads.append(loadnet[f"layers.{i}.swin1.attn.relative_position_bias_table"].shape[1])
597
-
598
- mlp_ratio = loadnet["layers.0.swin1.mlp.fc1.weight"].shape[0] / embed_dim
599
- window_square = loadnet["layers.0.swin1.attn.relative_position_bias_table"].shape[0]
600
- window_size = (math.isqrt(window_square) + 1) // 2
601
- upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else ""
602
- resi_connection = "1conv" if "conv_after_body.weight" in loadnet else ""
603
- qkv_bias = "layers.0.swin1.attn.qkv.bias" in loadnet
604
- gc_adjust1 = loadnet["layers.0.adjust1.weight"].shape[0]
605
- patch_norm = "patch_embed.norm.weight" in loadnet
606
- ape = "absolute_pos_embed" in loadnet
607
-
608
- model = DRCT(in_chans=in_chans, img_size= 64, window_size=window_size, compress_ratio=3,squeeze_factor=30,
609
- conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths=depths, embed_dim=embed_dim, num_heads=num_heads,
610
- mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm, use_checkpoint=False,
611
- upscale=self.netscale, upsampler=upsampler, resi_connection=resi_connection, gc =gc_adjust1,)
612
- elif upscale_type == "ATD":
613
- half = False
614
- from basicsr.archs.atd_arch import ATD
615
- in_chans = loadnet["conv_first.weight"].shape[1]
616
- embed_dim = loadnet["conv_first.weight"].shape[0]
617
- window_size = math.isqrt(loadnet["relative_position_index_SA"].shape[0])
618
- num_layers = self.find_max_numbers(loadnet, "layers") + 1
619
- depths = [6] * num_layers
620
- num_heads = [6] * num_layers
621
- for i in range(num_layers):
622
- depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.layers") + 1
623
- num_heads[i] = loadnet[f"layers.{i}.residual_group.layers.0.attn_win.relative_position_bias_table"].shape[1]
624
- num_tokens = loadnet["layers.0.residual_group.layers.0.attn_atd.scale"].shape[0]
625
- reducted_dim = loadnet["layers.0.residual_group.layers.0.attn_atd.wq.weight"].shape[0]
626
- convffn_kernel_size = loadnet["layers.0.residual_group.layers.0.convffn.dwconv.depthwise_conv.0.weight"].shape[2]
627
- mlp_ratio = (loadnet["layers.0.residual_group.layers.0.convffn.fc1.weight"].shape[0] / embed_dim)
628
- qkv_bias = "layers.0.residual_group.layers.0.wqkv.bias" in loadnet
629
- ape = "absolute_pos_embed" in loadnet
630
- patch_norm = "patch_embed.norm.weight" in loadnet
631
- resi_connection = "1conv" if "layers.0.conv.weight" in loadnet else "3conv"
632
-
633
- if "conv_up1.weight" in loadnet:
634
- upsampler = "nearest+conv"
635
- elif "conv_before_upsample.0.weight" in loadnet:
636
- upsampler = "pixelshuffle"
637
- elif "conv_last.weight" in loadnet:
638
- upsampler = ""
639
- else:
640
- upsampler = "pixelshuffledirect"
641
-
642
- is_light = upsampler == "pixelshuffledirect" and embed_dim == 48
643
- category_size = 128 if is_light else 256
644
-
645
- model = ATD(in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, category_size=category_size,
646
- num_tokens=num_tokens, reducted_dim=reducted_dim, convffn_kernel_size=convffn_kernel_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape,
647
- patch_norm=patch_norm, use_checkpoint=False, upscale=self.netscale, upsampler=upsampler, resi_connection='1conv',)
648
- elif upscale_type == "MoSR":
649
- from basicsr.archs.mosr_arch import mosr
650
- n_block = self.find_max_numbers(loadnet, "gblocks") - 5
651
- in_ch = loadnet["gblocks.0.weight"].shape[1]
652
- out_ch = loadnet["upsampler.end_conv.weight"].shape[0] if "upsampler.init_pos" in loadnet else in_ch
653
- dim = loadnet["gblocks.0.weight"].shape[0]
654
- expansion_ratio = (loadnet["gblocks.1.fc1.weight"].shape[0] / loadnet["gblocks.1.fc1.weight"].shape[1]) / 2
655
- conv_ratio = loadnet["gblocks.1.conv.weight"].shape[0] / dim
656
- kernel_size = loadnet["gblocks.1.conv.weight"].shape[2]
657
- upsampler = "dys" if "upsampler.init_pos" in loadnet else ("gps" if "upsampler.in_to_k.weight" in loadnet else "ps")
658
-
659
- model = mosr(in_ch = in_ch, out_ch = out_ch, upscale = self.netscale, n_block = n_block, dim = dim,
660
- upsampler = upsampler, kernel_size = kernel_size, expansion_ratio = expansion_ratio, conv_ratio = conv_ratio,)
661
- elif upscale_type == "SRFormer":
662
- half = False
663
- from basicsr.archs.srformer_arch import SRFormer
664
- in_chans = loadnet["conv_first.weight"].shape[1]
665
- embed_dim = loadnet["conv_first.weight"].shape[0]
666
- ape = "absolute_pos_embed" in loadnet
667
- patch_norm = "patch_embed.norm.weight" in loadnet
668
- qkv_bias = "layers.0.residual_group.blocks.0.attn.q.bias" in loadnet
669
- mlp_ratio = float(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0] / embed_dim)
670
-
671
- num_layers = self.find_max_numbers(loadnet, "layers") + 1
672
- depths = [6] * num_layers
673
- num_heads = [6] * num_layers
674
- for i in range(num_layers):
675
- depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
676
- num_heads[i] = loadnet[f"layers.{i}.residual_group.blocks.0.attn.relative_position_bias_table"].shape[1]
677
-
678
- if "conv_hr.weight" in loadnet:
679
- upsampler = "nearest+conv"
680
- elif "conv_before_upsample.0.weight" in loadnet:
681
- upsampler = "pixelshuffle"
682
- elif "upsample.0.weight" in loadnet:
683
- upsampler = "pixelshuffledirect"
684
- resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
685
-
686
- window_size = int(math.sqrt(loadnet["layers.0.residual_group.blocks.0.attn.relative_position_bias_table"].shape[0])) + 1
687
-
688
- if "layers.0.residual_group.blocks.1.attn_mask" in loadnet:
689
- attn_mask_0 = loadnet["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
690
- patches_resolution = int(math.sqrt(attn_mask_0) * window_size)
691
- else:
692
- patches_resolution = window_size
693
- if ape:
694
- pos_embed_value = loadnet.get("absolute_pos_embed", [None, None])[1]
695
- if pos_embed_value:
696
- patches_resolution = int(math.sqrt(pos_embed_value))
697
-
698
- img_size = patches_resolution
699
- if img_size % window_size != 0:
700
- for nice_number in [512, 256, 128, 96, 64, 48, 32, 24, 16]:
701
- if nice_number % window_size != 0:
702
- nice_number += window_size - (nice_number % window_size)
703
- if nice_number == patches_resolution:
704
- img_size = nice_number
705
- break
706
-
707
- model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
708
- qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=self.netscale, upsampler=upsampler, resi_connection=resi_connection,)
709
-
710
- if model:
711
- self.realesrganer = RealESRGANer(scale=self.netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
712
- elif upscale_model:
713
- import PIL
714
- from image_gen_aux import UpscaleWithModel
715
- class UpscaleWithModel_Gfpgan(UpscaleWithModel):
716
- def cv2pil(self, image):
717
- ''' OpenCV type -> PIL type
718
- https://qiita.com/derodero24/items/f22c22b22451609908ee
719
- '''
720
- new_image = image.copy()
721
- if new_image.ndim == 2: # Grayscale
722
- pass
723
- elif new_image.shape[2] == 3: # Color
724
- new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
725
- elif new_image.shape[2] == 4: # Transparency
726
- new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
727
- new_image = PIL.Image.fromarray(new_image)
728
- return new_image
729
-
730
- def pil2cv(self, image):
731
- ''' PIL type -> OpenCV type
732
- https://qiita.com/derodero24/items/f22c22b22451609908ee
733
- '''
734
- new_image = np.array(image, dtype=np.uint8)
735
- if new_image.ndim == 2: # Grayscale
736
- pass
737
- elif new_image.shape[2] == 3: # Color
738
- new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
739
- elif new_image.shape[2] == 4: # Transparency
740
- new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
741
- return new_image
742
-
743
- def enhance(self_, img, outscale=None):
744
- # img: numpy
745
- h_input, w_input = img.shape[0:2]
746
- pil_img = self_.cv2pil(img)
747
- pil_img = self_.__call__(pil_img)
748
- cv_image = self_.pil2cv(pil_img)
749
- if outscale is not None and outscale != float(self.netscale):
750
- interpolation = cv2.INTER_AREA if outscale < float(self.netscale) else cv2.INTER_LANCZOS4
751
- cv_image = cv2.resize(
752
- cv_image, (
753
- int(w_input * outscale),
754
- int(h_input * outscale),
755
- ), interpolation=interpolation)
756
- return cv_image, None
757
-
758
- device = "cuda" if torch.cuda.is_available() else "cpu"
759
- upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "upscale", upscale_model)).to(device)
760
- upscaler.__class__ = UpscaleWithModel_Gfpgan
761
- self.realesrganer = upscaler
762
  timer.checkpoint("Initialize BG upscale model")
763
 
764
- self.face_enhancer = None
765
-
766
  if face_restoration:
767
- download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
768
-
769
- resolution = 512
770
- modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + modelInUse
771
- from gfpgan.utils import GFPGANer
772
- model_rootpath = os.path.join("weights", "face")
773
- model_path = os.path.join(model_rootpath, face_restoration)
774
- channel_multiplier = None
775
-
776
- if face_restoration and face_restoration.startswith("GFPGANv1."):
777
- arch = "clean"
778
- channel_multiplier = 2
779
- elif face_restoration and face_restoration.startswith("RestoreFormer"):
780
- arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
781
- elif face_restoration == 'CodeFormer.pth':
782
- arch = "CodeFormer"
783
- elif face_restoration.startswith("GPEN-BFR-"):
784
- arch = "GPEN"
785
- channel_multiplier = 2
786
- if "1024" in face_restoration:
787
- arch = "GPEN-1024"
788
- resolution = 1024
789
- elif "2048" in face_restoration:
790
- arch = "GPEN-2048"
791
- resolution = 2048
792
-
793
- self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
794
  timer.checkpoint("Initialize face enhancer model")
795
 
796
- files = []
 
797
  if not outputWithModelName:
798
- modelInUse = ""
799
-
800
- try:
801
- bg_upsample_img = None
802
- if self.realesrganer and hasattr(self.realesrganer, "enhance"):
803
- from utils.dataops import auto_split_upscale
804
- bg_upsample_img, _ = auto_split_upscale(img, self.realesrganer.enhance, self.scale) if is_auto_split_upscale else self.realesrganer.enhance(img, outscale=self.scale)
805
- timer.checkpoint("Background upscale Section")
806
-
807
- if self.face_enhancer:
808
- cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
809
- # save faces
810
- if cropped_faces and restored_aligned:
811
- for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
812
- # save cropped face
813
- save_crop_path = f"output/{self.basename}{idx:02d}_cropped_faces{modelInUse}.png"
814
- self.imwriteUTF8(save_crop_path, cropped_face)
815
- # save restored face
816
- save_restore_path = f"output/{self.basename}{idx:02d}_restored_faces{modelInUse}.png"
817
- self.imwriteUTF8(save_restore_path, restored_face)
818
- # save comparison image
819
- save_cmp_path = f"output/{self.basename}{idx:02d}_cmp{modelInUse}.png"
820
- cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
821
- self.imwriteUTF8(save_cmp_path, cmp_img)
822
-
823
- files.append(save_crop_path)
824
- files.append(save_restore_path)
825
- files.append(save_cmp_path)
826
- timer.checkpoint("Face enhancer Section")
827
-
828
- restored_img = bg_upsample_img
829
- except RuntimeError as error:
830
- print(traceback.format_exc())
831
- print('Error', error)
832
- finally:
833
- if self.face_enhancer:
834
- self.face_enhancer._cleanup()
835
- else:
836
- # Free GPU memory and clean up resources
837
- torch.cuda.empty_cache()
838
- gc.collect()
839
-
840
- if not self.extension:
841
- self.extension = ".png" if self.img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
842
- save_path = f"output/{self.basename}{modelInUse}{self.extension}"
843
- self.imwriteUTF8(save_path, restored_img)
844
-
845
- restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
846
- files.append(save_path)
847
- timer.report() # Print all recorded times
848
- return files, files
 
 
 
 
 
 
 
 
 
 
 
 
 
849
  except Exception as error:
850
  print(traceback.format_exc())
851
  print("global exception: ", error)
852
  return None, None
 
 
 
 
 
 
 
 
 
 
853
 
854
  def find_max_numbers(self, state_dict, findkeys):
855
  if isinstance(findkeys, str):
@@ -914,12 +938,26 @@ class Timer:
914
  now = time.perf_counter()
915
  self.checkpoints.append((label, now))
916
 
917
- def report(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918
  """Print all recorded checkpoints and total execution time with aligned formatting."""
919
  print("\n> Execution Time Report:")
920
 
921
  # Determine the max label width for alignment
922
- max_label_length = max(len(label) for label, _ in self.checkpoints)
923
 
924
  prev_time = self.start_time
925
  for label, curr_time in self.checkpoints[1:]:
@@ -930,6 +968,112 @@ class Timer:
930
  total_time = self.checkpoints[-1][1] - self.start_time
931
  print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
  def main():
934
  if torch.cuda.is_available():
935
  torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
@@ -981,20 +1125,38 @@ def main():
981
  with gr.Blocks(title = title, css = css) as demo:
982
  gr.Markdown(value=f"<h1 style=\"text-align:center;\">{title}</h1><br>{description}")
983
  with gr.Row():
984
- with gr.Column(variant ="panel"):
985
- input_image = gr.Image(type="filepath", label="Input", format="png")
 
 
 
 
 
 
986
  face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
987
  upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
988
  upscale_scale = gr.Number(label="Rescaling factor", value=4)
989
  face_detection = gr.Dropdown(["retinaface_resnet50", "YOLOv5l", "YOLOv5n"], type="value", value="retinaface_resnet50", label="Face Detection type")
990
  face_detection_threshold = gr.Number(label="Face eye dist threshold", value=10, info="A threshold to filter out faces with too small an eye distance (e.g., side faces).")
991
  face_detection_only_center = gr.Checkbox(value=False, label="Face detection only center", info="If set to True, only the face closest to the center of the image will be kept.")
992
- with_model_name = gr.Checkbox(label="Output image files name with model name", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
993
  with gr.Row():
994
- submit = gr.Button(value="Submit", variant="primary", size="lg")
995
  clear = gr.ClearButton(
996
  components=[
997
- input_image,
998
  face_model,
999
  upscale_model,
1000
  upscale_scale,
@@ -1029,7 +1191,7 @@ def main():
1029
  submit.click(
1030
  upscale.inference,
1031
  inputs=[
1032
- input_image,
1033
  face_model,
1034
  upscale_model,
1035
  upscale_scale,
@@ -1046,4 +1208,8 @@ def main():
1046
 
1047
 
1048
  if __name__ == "__main__":
 
 
 
 
1049
  main()
 
8
  import traceback
9
  import math
10
  import time
11
+ import ast
12
+ import argparse
13
  from collections import defaultdict
14
  from facexlib.utils.misc import download_from_url
15
  from basicsr.utils.realesrganer import RealESRGANer
16
+ from utils.dataops import auto_split_upscale
17
 
18
+ input_images_limit = 5
19
  # Define URLs and their corresponding local storage paths
20
  face_models = {
21
  "GFPGANv1.4.pth" : ["https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
 
460
 
461
 
462
  class Upscale:
463
+ def __init__(self,):
464
+ self.scale = 4
465
+ self.modelInUse = ""
466
+ self.realesrganer = None
467
+ self.face_enhancer = None
468
+
469
+ def initBGUpscaleModel(self, upscale_model):
470
+ upscale_type, upscale_model = upscale_model.split(", ", 1)
471
+ download_from_url(upscale_models[upscale_model][0], upscale_model, os.path.join("weights", "upscale"))
472
+ self.modelInUse = f"_{os.path.splitext(upscale_model)[0]}"
473
+ netscale = 1 if any(sub in upscale_model.lower() for sub in ("x1", "1x")) else (2 if any(sub in upscale_model.lower() for sub in ("x2", "2x")) else 4)
474
+ model = None
475
+ half = True if torch.cuda.is_available() else False
476
+ if upscale_type:
477
+ # The values of the following hyperparameters are based on the research findings of the Spandrel project.
478
+ # https://github.com/chaiNNer-org/spandrel/tree/main/libs/spandrel/spandrel/architectures
479
+ from basicsr.archs.rrdbnet_arch import RRDBNet
480
+ loadnet = torch.load(os.path.join("weights", "upscale", upscale_model), map_location=torch.device('cpu'), weights_only=True)
481
+ if 'params_ema' in loadnet or 'params' in loadnet:
482
+ loadnet = loadnet['params_ema'] if 'params_ema' in loadnet else loadnet['params']
483
+
484
+ if upscale_type == "SRVGG":
485
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
486
+ body_max_num = self.find_max_numbers(loadnet, "body")
487
+ num_feat = loadnet["body.0.weight"].shape[0]
488
+ num_in_ch = loadnet["body.0.weight"].shape[1]
489
+ num_conv = body_max_num // 2 - 1
490
+ model = SRVGGNetCompact(num_in_ch=num_in_ch, num_out_ch=3, num_feat=num_feat, num_conv=num_conv, upscale=netscale, act_type='prelu')
491
+ elif upscale_type == "RRDB" or upscale_type == "ESRGAN":
492
+ if upscale_type == "RRDB":
493
+ num_block = self.find_max_numbers(loadnet, "body") + 1
494
+ num_feat = loadnet["conv_first.weight"].shape[0]
495
+ else:
496
+ num_block = self.find_max_numbers(loadnet, "model.1.sub")
497
+ num_feat = loadnet["model.0.weight"].shape[0]
498
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=num_feat, num_block=num_block, num_grow_ch=32, scale=netscale, is_real_esrgan=upscale_type == "RRDB")
499
+ elif upscale_type == "DAT":
500
+ from basicsr.archs.dat_arch import DAT
501
+ half = False
502
+
503
+ in_chans = loadnet["conv_first.weight"].shape[1]
504
+ embed_dim = loadnet["conv_first.weight"].shape[0]
505
+ num_layers = self.find_max_numbers(loadnet, "layers") + 1
506
+ depth = [6] * num_layers
507
+ num_heads = [6] * num_layers
508
+ for i in range(num_layers):
509
+ depth[i] = self.find_max_numbers(loadnet, f"layers.{i}.blocks") + 1
510
+ num_heads[i] = loadnet[f"layers.{i}.blocks.1.attn.temperature"].shape[0] if depth[i] >= 2 else \
511
+ loadnet[f"layers.{i}.blocks.0.attn.attns.0.pos.pos3.2.weight"].shape[0] * 2
512
+
513
+ upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else "pixelshuffledirect"
514
+ resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
515
+ qkv_bias = "layers.0.blocks.0.attn.qkv.bias" in loadnet
516
+ expansion_factor = float(loadnet["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim)
517
+
518
+ img_size = 64
519
+ if "layers.0.blocks.2.attn.attn_mask_0" in loadnet:
520
+ attn_mask_0_x, attn_mask_0_y, _attn_mask_0_z = loadnet["layers.0.blocks.2.attn.attn_mask_0"].shape
521
+ img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
522
+
523
+ split_size = [2, 4]
524
+ if "layers.0.blocks.0.attn.attns.0.rpe_biases" in loadnet:
525
+ split_sizes = loadnet["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
526
+ split_size = [int(x) for x in split_sizes]
527
+
528
+ model = DAT(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, split_size=split_size, depth=depth, num_heads=num_heads, expansion_factor=expansion_factor,
529
+ qkv_bias=qkv_bias, resi_connection=resi_connection, upsampler=upsampler, upscale=netscale)
530
+ elif upscale_type == "HAT":
531
+ half = False
532
+ from basicsr.archs.hat_arch import HAT
533
+ in_chans = loadnet["conv_first.weight"].shape[1]
534
+ embed_dim = loadnet["conv_first.weight"].shape[0]
535
+ window_size = int(math.sqrt(loadnet["relative_position_index_SA"].shape[0]))
536
+ num_layers = self.find_max_numbers(loadnet, "layers") + 1
537
+ depths = [6] * num_layers
538
+ num_heads = [6] * num_layers
539
+ for i in range(num_layers):
540
+ depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
541
+ num_heads[i] = loadnet[f"layers.{i}.residual_group.overlap_attn.relative_position_bias_table"].shape[1]
542
+ resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "identity"
543
+
544
+ compress_ratio = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.0.weight"].shape[0],)
545
+ squeeze_factor = self.find_divisor_for_quotient(embed_dim, loadnet["layers.0.residual_group.blocks.0.conv_block.cab.3.attention.1.weight"].shape[0],)
546
+
547
+ qkv_bias = "layers.0.residual_group.blocks.0.attn.qkv.bias" in loadnet
548
+ patch_norm = "patch_embed.norm.weight" in loadnet
549
+ ape = "absolute_pos_embed" in loadnet
550
+
551
+ mlp_hidden_dim = int(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0])
552
+ mlp_ratio = mlp_hidden_dim / embed_dim
553
+ upsampler = "pixelshuffle"
554
+
555
+ model = HAT(img_size=64, patch_size=1, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, compress_ratio=compress_ratio,
556
+ squeeze_factor=squeeze_factor, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm,
557
+ upsampler=upsampler, resi_connection=resi_connection, upscale=netscale,)
558
+ elif "RealPLKSR" in upscale_type:
559
+ from basicsr.archs.realplksr_arch import realplksr
560
+ half = False if "RealPLSKR" in upscale_model else half
561
+ use_ea = "feats.1.attn.f.0.weight" in loadnet
562
+ dim = loadnet["feats.0.weight"].shape[0]
563
+ num_feats = self.find_max_numbers(loadnet, "feats") + 1
564
+ n_blocks = num_feats - 3
565
+ kernel_size = loadnet["feats.1.lk.conv.weight"].shape[2]
566
+ split_ratio = loadnet["feats.1.lk.conv.weight"].shape[0] / dim
567
+ use_dysample = "to_img.init_pos" in loadnet
568
+
569
+ model = realplksr(upscaling_factor=netscale, dim=dim, n_blocks=n_blocks, kernel_size=kernel_size, split_ratio=split_ratio, use_ea=use_ea, dysample=use_dysample)
570
+ elif upscale_type == "DRCT":
571
+ half = False
572
+ from basicsr.archs.DRCT_arch import DRCT
573
+
574
+ in_chans = loadnet["conv_first.weight"].shape[1]
575
+ embed_dim = loadnet["conv_first.weight"].shape[0]
576
+ num_layers = self.find_max_numbers(loadnet, "layers") + 1
577
+ depths = (6,) * num_layers
578
+ num_heads = []
579
+ for i in range(num_layers):
580
+ num_heads.append(loadnet[f"layers.{i}.swin1.attn.relative_position_bias_table"].shape[1])
581
+
582
+ mlp_ratio = loadnet["layers.0.swin1.mlp.fc1.weight"].shape[0] / embed_dim
583
+ window_square = loadnet["layers.0.swin1.attn.relative_position_bias_table"].shape[0]
584
+ window_size = (math.isqrt(window_square) + 1) // 2
585
+ upsampler = "pixelshuffle" if "conv_last.weight" in loadnet else ""
586
+ resi_connection = "1conv" if "conv_after_body.weight" in loadnet else ""
587
+ qkv_bias = "layers.0.swin1.attn.qkv.bias" in loadnet
588
+ gc_adjust1 = loadnet["layers.0.adjust1.weight"].shape[0]
589
+ patch_norm = "patch_embed.norm.weight" in loadnet
590
+ ape = "absolute_pos_embed" in loadnet
591
+
592
+ model = DRCT(in_chans=in_chans, img_size= 64, window_size=window_size, compress_ratio=3,squeeze_factor=30,
593
+ conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths=depths, embed_dim=embed_dim, num_heads=num_heads,
594
+ mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape, patch_norm=patch_norm, use_checkpoint=False,
595
+ upscale=netscale, upsampler=upsampler, resi_connection=resi_connection, gc =gc_adjust1,)
596
+ elif upscale_type == "ATD":
597
+ half = False
598
+ from basicsr.archs.atd_arch import ATD
599
+ in_chans = loadnet["conv_first.weight"].shape[1]
600
+ embed_dim = loadnet["conv_first.weight"].shape[0]
601
+ window_size = math.isqrt(loadnet["relative_position_index_SA"].shape[0])
602
+ num_layers = self.find_max_numbers(loadnet, "layers") + 1
603
+ depths = [6] * num_layers
604
+ num_heads = [6] * num_layers
605
+ for i in range(num_layers):
606
+ depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.layers") + 1
607
+ num_heads[i] = loadnet[f"layers.{i}.residual_group.layers.0.attn_win.relative_position_bias_table"].shape[1]
608
+ num_tokens = loadnet["layers.0.residual_group.layers.0.attn_atd.scale"].shape[0]
609
+ reducted_dim = loadnet["layers.0.residual_group.layers.0.attn_atd.wq.weight"].shape[0]
610
+ convffn_kernel_size = loadnet["layers.0.residual_group.layers.0.convffn.dwconv.depthwise_conv.0.weight"].shape[2]
611
+ mlp_ratio = (loadnet["layers.0.residual_group.layers.0.convffn.fc1.weight"].shape[0] / embed_dim)
612
+ qkv_bias = "layers.0.residual_group.layers.0.wqkv.bias" in loadnet
613
+ ape = "absolute_pos_embed" in loadnet
614
+ patch_norm = "patch_embed.norm.weight" in loadnet
615
+ resi_connection = "1conv" if "layers.0.conv.weight" in loadnet else "3conv"
616
+
617
+ if "conv_up1.weight" in loadnet:
618
+ upsampler = "nearest+conv"
619
+ elif "conv_before_upsample.0.weight" in loadnet:
620
+ upsampler = "pixelshuffle"
621
+ elif "conv_last.weight" in loadnet:
622
+ upsampler = ""
623
+ else:
624
+ upsampler = "pixelshuffledirect"
625
+
626
+ is_light = upsampler == "pixelshuffledirect" and embed_dim == 48
627
+ category_size = 128 if is_light else 256
628
+
629
+ model = ATD(in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, category_size=category_size,
630
+ num_tokens=num_tokens, reducted_dim=reducted_dim, convffn_kernel_size=convffn_kernel_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, ape=ape,
631
+ patch_norm=patch_norm, use_checkpoint=False, upscale=netscale, upsampler=upsampler, resi_connection='1conv',)
632
+ elif upscale_type == "MoSR":
633
+ from basicsr.archs.mosr_arch import mosr
634
+ n_block = self.find_max_numbers(loadnet, "gblocks") - 5
635
+ in_ch = loadnet["gblocks.0.weight"].shape[1]
636
+ out_ch = loadnet["upsampler.end_conv.weight"].shape[0] if "upsampler.init_pos" in loadnet else in_ch
637
+ dim = loadnet["gblocks.0.weight"].shape[0]
638
+ expansion_ratio = (loadnet["gblocks.1.fc1.weight"].shape[0] / loadnet["gblocks.1.fc1.weight"].shape[1]) / 2
639
+ conv_ratio = loadnet["gblocks.1.conv.weight"].shape[0] / dim
640
+ kernel_size = loadnet["gblocks.1.conv.weight"].shape[2]
641
+ upsampler = "dys" if "upsampler.init_pos" in loadnet else ("gps" if "upsampler.in_to_k.weight" in loadnet else "ps")
642
+
643
+ model = mosr(in_ch = in_ch, out_ch = out_ch, upscale = netscale, n_block = n_block, dim = dim,
644
+ upsampler = upsampler, kernel_size = kernel_size, expansion_ratio = expansion_ratio, conv_ratio = conv_ratio,)
645
+ elif upscale_type == "SRFormer":
646
+ half = False
647
+ from basicsr.archs.srformer_arch import SRFormer
648
+ in_chans = loadnet["conv_first.weight"].shape[1]
649
+ embed_dim = loadnet["conv_first.weight"].shape[0]
650
+ ape = "absolute_pos_embed" in loadnet
651
+ patch_norm = "patch_embed.norm.weight" in loadnet
652
+ qkv_bias = "layers.0.residual_group.blocks.0.attn.q.bias" in loadnet
653
+ mlp_ratio = float(loadnet["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0] / embed_dim)
654
+
655
+ num_layers = self.find_max_numbers(loadnet, "layers") + 1
656
+ depths = [6] * num_layers
657
+ num_heads = [6] * num_layers
658
+ for i in range(num_layers):
659
+ depths[i] = self.find_max_numbers(loadnet, f"layers.{i}.residual_group.blocks") + 1
660
+ num_heads[i] = loadnet[f"layers.{i}.residual_group.blocks.0.attn.relative_position_bias_table"].shape[1]
661
+
662
+ if "conv_hr.weight" in loadnet:
663
+ upsampler = "nearest+conv"
664
+ elif "conv_before_upsample.0.weight" in loadnet:
665
+ upsampler = "pixelshuffle"
666
+ elif "upsample.0.weight" in loadnet:
667
+ upsampler = "pixelshuffledirect"
668
+ resi_connection = "1conv" if "conv_after_body.weight" in loadnet else "3conv"
669
+
670
+ window_size = int(math.sqrt(loadnet["layers.0.residual_group.blocks.0.attn.relative_position_bias_table"].shape[0])) + 1
671
+
672
+ if "layers.0.residual_group.blocks.1.attn_mask" in loadnet:
673
+ attn_mask_0 = loadnet["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
674
+ patches_resolution = int(math.sqrt(attn_mask_0) * window_size)
675
+ else:
676
+ patches_resolution = window_size
677
+ if ape:
678
+ pos_embed_value = loadnet.get("absolute_pos_embed", [None, None])[1]
679
+ if pos_embed_value:
680
+ patches_resolution = int(math.sqrt(pos_embed_value))
681
+
682
+ img_size = patches_resolution
683
+ if img_size % window_size != 0:
684
+ for nice_number in [512, 256, 128, 96, 64, 48, 32, 24, 16]:
685
+ if nice_number % window_size != 0:
686
+ nice_number += window_size - (nice_number % window_size)
687
+ if nice_number == patches_resolution:
688
+ img_size = nice_number
689
+ break
690
+
691
+ model = SRFormer(img_size=img_size, in_chans=in_chans, embed_dim=embed_dim, depths=depths, num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
692
+ qkv_bias=qkv_bias, qk_scale=None, ape=ape, patch_norm=patch_norm, upscale=netscale, upsampler=upsampler, resi_connection=resi_connection,)
693
+
694
+ if model:
695
+ self.realesrganer = RealESRGANer(scale=netscale, model_path=os.path.join("weights", "upscale", upscale_model), model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
696
+ elif upscale_model:
697
+ import PIL
698
+ from image_gen_aux import UpscaleWithModel
699
+ class UpscaleWithModel_Gfpgan(UpscaleWithModel):
700
+ def cv2pil(self, image):
701
+ ''' OpenCV type -> PIL type
702
+ https://qiita.com/derodero24/items/f22c22b22451609908ee
703
+ '''
704
+ new_image = image.copy()
705
+ if new_image.ndim == 2: # Grayscale
706
+ pass
707
+ elif new_image.shape[2] == 3: # Color
708
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_BGR2RGB)
709
+ elif new_image.shape[2] == 4: # Transparency
710
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_BGRA2RGBA)
711
+ new_image = PIL.Image.fromarray(new_image)
712
+ return new_image
713
+
714
+ def pil2cv(self, image):
715
+ ''' PIL type -> OpenCV type
716
+ https://qiita.com/derodero24/items/f22c22b22451609908ee
717
+ '''
718
+ new_image = np.array(image, dtype=np.uint8)
719
+ if new_image.ndim == 2: # Grayscale
720
+ pass
721
+ elif new_image.shape[2] == 3: # Color
722
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
723
+ elif new_image.shape[2] == 4: # Transparency
724
+ new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
725
+ return new_image
726
+
727
+ def enhance(self, img, outscale=None):
728
+ # img: numpy
729
+ h_input, w_input = img.shape[0:2]
730
+ pil_img = self.cv2pil(img)
731
+ pil_img = self.__call__(pil_img)
732
+ cv_image = self.pil2cv(pil_img)
733
+ if outscale is not None and outscale != float(netscale):
734
+ interpolation = cv2.INTER_AREA if outscale < float(netscale) else cv2.INTER_LANCZOS4
735
+ cv_image = cv2.resize(
736
+ cv_image, (
737
+ int(w_input * outscale),
738
+ int(h_input * outscale),
739
+ ), interpolation=interpolation)
740
+ return cv_image, None
741
+
742
+ device = "cuda" if torch.cuda.is_available() else "cpu"
743
+ upscaler = UpscaleWithModel.from_pretrained(os.path.join("weights", "upscale", upscale_model)).to(device)
744
+ upscaler.__class__ = UpscaleWithModel_Gfpgan
745
+ self.realesrganer = upscaler
746
+
747
+
748
+ def initFaceEnhancerModel(self, face_restoration, face_detection, face_detection_threshold: any, face_detection_only_center: bool):
749
+ download_from_url(face_models[face_restoration][0], face_restoration, os.path.join("weights", "face"))
750
+
751
+ resolution = 512
752
+ self.modelInUse = f"_{os.path.splitext(face_restoration)[0]}" + self.modelInUse
753
+ from gfpgan.utils import GFPGANer
754
+ model_rootpath = os.path.join("weights", "face")
755
+ model_path = os.path.join(model_rootpath, face_restoration)
756
+ channel_multiplier = None
757
+
758
+ if face_restoration and face_restoration.startswith("GFPGANv1."):
759
+ arch = "clean"
760
+ channel_multiplier = 2
761
+ elif face_restoration and face_restoration.startswith("RestoreFormer"):
762
+ arch = "RestoreFormer++" if face_restoration.startswith("RestoreFormer++") else "RestoreFormer"
763
+ elif face_restoration == 'CodeFormer.pth':
764
+ arch = "CodeFormer"
765
+ elif face_restoration.startswith("GPEN-BFR-"):
766
+ arch = "GPEN"
767
+ channel_multiplier = 2
768
+ if "1024" in face_restoration:
769
+ arch = "GPEN-1024"
770
+ resolution = 1024
771
+ elif "2048" in face_restoration:
772
+ arch = "GPEN-2048"
773
+ resolution = 2048
774
+
775
+ self.face_enhancer = GFPGANer(model_path=model_path, upscale=self.scale, arch=arch, channel_multiplier=channel_multiplier, model_rootpath=model_rootpath, det_model=face_detection, resolution=resolution)
776
+
777
+
778
+ def inference(self, gallery, face_restoration, upscale_model, scale: float, face_detection, face_detection_threshold: any, face_detection_only_center: bool, outputWithModelName: bool):
779
  try:
780
+ if not gallery or (not face_restoration and not upscale_model):
781
  raise ValueError("Invalid parameter setting")
782
+
783
+ print(face_restoration, upscale_model, scale, f"gallery length: {len(gallery)}")
784
 
785
  timer = Timer() # Create a timer
786
  self.scale = scale
 
 
 
 
 
 
 
 
 
 
 
787
 
 
 
 
788
  if upscale_model:
789
+ self.initBGUpscaleModel(upscale_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  timer.checkpoint("Initialize BG upscale model")
791
 
 
 
792
  if face_restoration:
793
+ self.initFaceEnhancerModel(face_restoration, face_detection, face_detection_threshold, face_detection_only_center)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
  timer.checkpoint("Initialize face enhancer model")
795
 
796
+ timer.report()
797
+
798
  if not outputWithModelName:
799
+ self.modelInUse = ""
800
+
801
+ files = []
802
+ is_auto_split_upscale = True
803
+ # Dictionary to track counters for each filename
804
+ name_counters = defaultdict(int)
805
+ for gallery_idx, value in enumerate(gallery):
806
+ try:
807
+ img_path = str(value[0])
808
+ img_name = os.path.basename(img_path)
809
+ # Increment the counter for the current name
810
+ name_counters[img_name] += 1
811
+ if name_counters[img_name] > 1:
812
+ img_name = f"{img_name}_{name_counters[img_name]:02d}"
813
+ basename, extension = os.path.splitext(img_name)
814
+
815
+ img_cv2 = cv2.imdecode(np.fromfile(img_path, np.uint8), cv2.IMREAD_UNCHANGED) # numpy.ndarray
816
+
817
+ img_mode = "RGBA" if len(img_cv2.shape) == 3 and img_cv2.shape[2] == 4 else None
818
+ if len(img_cv2.shape) == 2: # for gray inputs
819
+ img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_GRAY2BGR)
820
+ print(f"> image{gallery_idx:02d}, {img_cv2.shape}:")
821
+
822
+ bg_upsample_img = None
823
+ if self.realesrganer and hasattr(self.realesrganer, "enhance"):
824
+ bg_upsample_img, _ = auto_split_upscale(img_cv2, self.realesrganer.enhance, self.scale) if is_auto_split_upscale else self.realesrganer.enhance(img_cv2, outscale=self.scale)
825
+ timer.checkpoint(f"image{gallery_idx:02d}, Background upscale Section")
826
+
827
+ if self.face_enhancer:
828
+ cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
829
+ # save faces
830
+ if cropped_faces and restored_aligned:
831
+ for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_aligned)):
832
+ # save cropped face
833
+ save_crop_path = f"output/{basename}{idx:02d}_cropped_faces{self.modelInUse}.png"
834
+ self.imwriteUTF8(save_crop_path, cropped_face)
835
+ # save restored face
836
+ save_restore_path = f"output/{basename}{idx:02d}_restored_faces{self.modelInUse}.png"
837
+ self.imwriteUTF8(save_restore_path, restored_face)
838
+ # save comparison image
839
+ save_cmp_path = f"output/{basename}{idx:02d}_cmp{self.modelInUse}.png"
840
+ cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
841
+ self.imwriteUTF8(save_cmp_path, cmp_img)
842
+
843
+ files.append(save_crop_path)
844
+ files.append(save_restore_path)
845
+ files.append(save_cmp_path)
846
+ timer.checkpoint(f"image{gallery_idx:02d}, Face enhancer Section")
847
+
848
+ restored_img = bg_upsample_img
849
+ timer.report()
850
+
851
+ if not extension:
852
+ extension = ".png" if img_mode == "RGBA" else ".jpg" # RGBA images should be saved in png format
853
+ save_path = f"output/{basename}{self.modelInUse}{extension}"
854
+ self.imwriteUTF8(save_path, restored_img)
855
+
856
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
857
+ files.append(save_path)
858
+ except RuntimeError as error:
859
+ print(traceback.format_exc())
860
+ print('Error', error)
861
+
862
+ timer.report_all() # Print all recorded times
863
  except Exception as error:
864
  print(traceback.format_exc())
865
  print("global exception: ", error)
866
  return None, None
867
+ finally:
868
+ if self.face_enhancer:
869
+ self.face_enhancer._cleanup()
870
+ else:
871
+ # Free GPU memory and clean up resources
872
+ torch.cuda.empty_cache()
873
+ gc.collect()
874
+
875
+ return files, files
876
+
877
 
878
  def find_max_numbers(self, state_dict, findkeys):
879
  if isinstance(findkeys, str):
 
938
  now = time.perf_counter()
939
  self.checkpoints.append((label, now))
940
 
941
+ def report(self, is_clear_checkpoints = True):
942
+ # Determine the max label width for alignment
943
+ max_label_length = max(len(label) for label, _ in self.checkpoints)
944
+
945
+ prev_time = self.checkpoints[0][1]
946
+ for label, curr_time in self.checkpoints[1:]:
947
+ elapsed = curr_time - prev_time
948
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
949
+ prev_time = curr_time
950
+
951
+ if is_clear_checkpoints:
952
+ self.checkpoints.clear()
953
+ self.checkpoint() # Store checkpoints
954
+
955
+ def report_all(self):
956
  """Print all recorded checkpoints and total execution time with aligned formatting."""
957
  print("\n> Execution Time Report:")
958
 
959
  # Determine the max label width for alignment
960
+ max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
961
 
962
  prev_time = self.start_time
963
  for label, curr_time in self.checkpoints[1:]:
 
968
  total_time = self.checkpoints[-1][1] - self.start_time
969
  print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
970
 
971
+ self.checkpoints.clear()
972
+
973
+ def restart(self):
974
+ self.start_time = time.perf_counter() # Record the start time
975
+ self.checkpoints = [("Start", self.start_time)] # Store checkpoints
976
+
977
+
978
+ def get_selection_from_gallery(selected_state: gr.SelectData):
979
+ """
980
+ Extracts the selected image path and caption from the gallery selection state.
981
+
982
+ Args:
983
+ selected_state (gr.SelectData): The selection state from a Gradio gallery,
984
+ containing information about the selected image.
985
+
986
+ Returns:
987
+ tuple: A tuple containing:
988
+ - str: The file path of the selected image.
989
+ - str: The caption of the selected image.
990
+ If `selected_state` is None or invalid, it returns `None`.
991
+ """
992
+ if not selected_state:
993
+ return selected_state
994
+
995
+ return (selected_state.value["image"]["path"], selected_state.value["caption"])
996
+
997
+ def limit_gallery(gallery):
998
+ """
999
+ Ensures the gallery does not exceed input_images_limit.
1000
+
1001
+ Args:
1002
+ gallery (list): Current gallery images.
1003
+
1004
+ Returns:
1005
+ list: Trimmed gallery with a maximum of input_images_limit images.
1006
+ """
1007
+ return gallery[:input_images_limit] if input_images_limit > 0 and gallery else gallery
1008
+
1009
+ def append_gallery(gallery: list, image: str):
1010
+ """
1011
+ Append a single image to the gallery while respecting input_images_limit.
1012
+
1013
+ Parameters:
1014
+ - gallery (list): Existing list of images. If None, initializes an empty list.
1015
+ - image (str): The image to be added. If None or empty, no action is taken.
1016
+
1017
+ Returns:
1018
+ - list: Updated gallery.
1019
+ """
1020
+ if gallery is None:
1021
+ gallery = []
1022
+ if not image:
1023
+ return gallery, None
1024
+
1025
+ if input_images_limit == -1 or len(gallery) < input_images_limit:
1026
+ gallery.append(image)
1027
+
1028
+ return gallery, None
1029
+
1030
+
1031
+ def extend_gallery(gallery: list, images):
1032
+ """
1033
+ Extend the gallery with new images while respecting the input_images_limit.
1034
+
1035
+ Parameters:
1036
+ - gallery (list): Existing list of images. If None, initializes an empty list.
1037
+ - images (list): New images to be added. If None, defaults to an empty list.
1038
+
1039
+ Returns:
1040
+ - list: Updated gallery with the new images added.
1041
+ """
1042
+ if gallery is None:
1043
+ gallery = []
1044
+ if not images:
1045
+ return gallery
1046
+
1047
+ # Add new images to the gallery
1048
+ gallery.extend(images)
1049
+
1050
+ # Trim gallery to the specified limit, if applicable
1051
+ if input_images_limit > 0:
1052
+ gallery = gallery[:input_images_limit]
1053
+
1054
+ return gallery
1055
+
1056
+ def remove_image_from_gallery(gallery: list, selected_image: str):
1057
+ """
1058
+ Removes a selected image from the gallery if it exists.
1059
+
1060
+ Args:
1061
+ gallery (list): The current list of images in the gallery.
1062
+ selected_image (str): The image to be removed, represented as a string
1063
+ that needs to be parsed into a tuple.
1064
+
1065
+ Returns:
1066
+ list: The updated gallery after removing the selected image.
1067
+ """
1068
+ if not gallery or not selected_image:
1069
+ return gallery
1070
+
1071
+ selected_image = ast.literal_eval(selected_image) # Use ast.literal_eval to parse text into a tuple in remove_image_from_gallery.
1072
+ # Remove the selected image from the gallery
1073
+ if selected_image in gallery:
1074
+ gallery.remove(selected_image)
1075
+ return gallery
1076
+
1077
  def main():
1078
  if torch.cuda.is_available():
1079
  torch.cuda.set_per_process_memory_fraction(0.975, device='cuda:0')
 
1125
  with gr.Blocks(title = title, css = css) as demo:
1126
  gr.Markdown(value=f"<h1 style=\"text-align:center;\">{title}</h1><br>{description}")
1127
  with gr.Row():
1128
+ with gr.Column(variant="panel"):
1129
+ submit = gr.Button(value="Submit", variant="primary", size="lg")
1130
+ # Create an Image component for uploading images
1131
+ input_image = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", format="png", height=150)
1132
+ with gr.Row():
1133
+ upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
1134
+ remove_button = gr.Button("Remove Selected Image", size="sm")
1135
+ input_gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images" + (f"(The online environment image limit is {input_images_limit})" if input_images_limit > 0 else ""))
1136
  face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
1137
  upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
1138
  upscale_scale = gr.Number(label="Rescaling factor", value=4)
1139
  face_detection = gr.Dropdown(["retinaface_resnet50", "YOLOv5l", "YOLOv5n"], type="value", value="retinaface_resnet50", label="Face Detection type")
1140
  face_detection_threshold = gr.Number(label="Face eye dist threshold", value=10, info="A threshold to filter out faces with too small an eye distance (e.g., side faces).")
1141
  face_detection_only_center = gr.Checkbox(value=False, label="Face detection only center", info="If set to True, only the face closest to the center of the image will be kept.")
1142
+ with_model_name = gr.Checkbox(label="Output image files name with model name", value=True)
1143
+
1144
+ # Define the event listener to add the uploaded image to the gallery
1145
+ input_image.change(append_gallery, inputs=[input_gallery, input_image], outputs=[input_gallery, input_image])
1146
+ # When the upload button is clicked, add the new images to the gallery
1147
+ upload_button.upload(extend_gallery, inputs=[input_gallery, upload_button], outputs=input_gallery)
1148
+ # Event to update the selected image when an image is clicked in the gallery
1149
+ selected_image = gr.Textbox(label="Selected Image", visible=False)
1150
+ input_gallery.select(get_selection_from_gallery, inputs=None, outputs=selected_image)
1151
+ # Trigger update when gallery changes
1152
+ input_gallery.change(limit_gallery, input_gallery, input_gallery)
1153
+ # Event to remove a selected image from the gallery
1154
+ remove_button.click(remove_image_from_gallery, inputs=[input_gallery, selected_image], outputs=input_gallery)
1155
+
1156
  with gr.Row():
 
1157
  clear = gr.ClearButton(
1158
  components=[
1159
+ input_gallery,
1160
  face_model,
1161
  upscale_model,
1162
  upscale_scale,
 
1191
  submit.click(
1192
  upscale.inference,
1193
  inputs=[
1194
+ input_gallery,
1195
  face_model,
1196
  upscale_model,
1197
  upscale_scale,
 
1208
 
1209
 
1210
  if __name__ == "__main__":
1211
+ parser = argparse.ArgumentParser()
1212
+ parser.add_argument("--input_images_limit", type=int, default=5)
1213
+ args = parser.parse_args()
1214
+ input_images_limit = args.input_images_limit
1215
  main()
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu124
2
 
3
- gradio==5.15.0
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
 
1
  --extra-index-url https://download.pytorch.org/whl/cu124
2
 
3
+ gradio==5.16.0
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
utils/dataops.py CHANGED
@@ -53,14 +53,10 @@ def auto_split_upscale(
53
  # Check to see if its actually the CUDA out of memory error
54
  if "CUDA" in str(e):
55
  print("RuntimeError: CUDA out of memory...")
56
- # Collect garbage (clear VRAM)
57
- torch.cuda.empty_cache()
58
- gc.collect()
59
  # Re-raise the exception if not an OOM error
60
  else:
61
  raise RuntimeError(e)
62
- finally:
63
- # Free GPU memory and clean up resources
64
  torch.cuda.empty_cache()
65
  gc.collect()
66
 
 
53
  # Check to see if its actually the CUDA out of memory error
54
  if "CUDA" in str(e):
55
  print("RuntimeError: CUDA out of memory...")
 
 
 
56
  # Re-raise the exception if not an OOM error
57
  else:
58
  raise RuntimeError(e)
59
+ # Collect garbage (clear VRAM)
 
60
  torch.cuda.empty_cache()
61
  gc.collect()
62
 
webui.bat CHANGED
@@ -1,6 +1,7 @@
1
  @echo off
2
 
3
  :: The source of the webui.bat file is stable-diffusion-webui
 
4
 
5
  if not defined PYTHON (set PYTHON=python)
6
  if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")
 
1
  @echo off
2
 
3
  :: The source of the webui.bat file is stable-diffusion-webui
4
+ set COMMANDLINE_ARGS=--input_images_limit -1
5
 
6
  if not defined PYTHON (set PYTHON=python)
7
  if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv")