wise-water commited on
Commit
13aa528
·
1 Parent(s): a56ad34

init commit

Browse files
Files changed (36) hide show
  1. LICENSE +61 -0
  2. README.md +72 -5
  3. app.py +459 -0
  4. requirements.txt +23 -0
  5. sample_img/sample_danbooru_dragonball.png +3 -0
  6. scripts/__init__.py +0 -0
  7. scripts/__pycache__/__init__.cpython-312.pyc +0 -0
  8. scripts/__pycache__/parse_cut_from_page.cpython-312.pyc +0 -0
  9. scripts/convert_psd_to_png.py +106 -0
  10. scripts/parse_cut_from_page.py +248 -0
  11. scripts/run_tag_filter.py +147 -0
  12. src/__init__.py +0 -0
  13. src/__pycache__/__init__.cpython-312.pyc +0 -0
  14. src/detectors/__init__.py +3 -0
  15. src/detectors/__pycache__/__init__.cpython-312.pyc +0 -0
  16. src/detectors/__pycache__/imgutils_detector.cpython-312.pyc +0 -0
  17. src/detectors/imgutils_detector.py +170 -0
  18. src/oskar_crop/__pycache__/detect_and_crop.cpython-312.pyc +0 -0
  19. src/oskar_crop/detect_and_crop.py +56 -0
  20. src/pipelines/__init__.py +3 -0
  21. src/pipelines/__pycache__/__init__.cpython-312.pyc +0 -0
  22. src/pipelines/__pycache__/pipeline_single_character_filtering.cpython-312.pyc +0 -0
  23. src/pipelines/pipeline_single_character_filtering.py +175 -0
  24. src/taggers/__init__.py +4 -0
  25. src/taggers/__pycache__/__init__.cpython-312.pyc +0 -0
  26. src/taggers/__pycache__/order.cpython-312.pyc +0 -0
  27. src/taggers/__pycache__/tagger.cpython-312.pyc +0 -0
  28. src/taggers/filter.py +113 -0
  29. src/taggers/order.py +85 -0
  30. src/taggers/tagger.py +215 -0
  31. src/utils/__pycache__/device.cpython-312.pyc +0 -0
  32. src/utils/__pycache__/timer.cpython-312.pyc +0 -0
  33. src/utils/device.py +18 -0
  34. src/utils/timer.py +51 -0
  35. src/wise_crop/__pycache__/detect_and_crop.cpython-312.pyc +0 -0
  36. src/wise_crop/detect_and_crop.py +84 -0
LICENSE ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This software project ("the Software") comprises original code and incorporates third-party components as detailed below. Each component is subject to its respective license terms.​
2
+
3
+ 1. Original Code
4
+
5
+ The original code within this repository, excluding third-party components, is licensed under the MIT License:​
6
+
7
+ MIT License
8
+
9
+ Copyright (c) 2025 AI Lab, Kakao Entertainment
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+
29
+ 2. Third-Party Components
30
+ a. Google Gemma 3 - 12B Instruction-Tuned Model (google/gemma-3-12b-it)
31
+
32
+ This project utilizes the "google/gemma-3-12b-it" model, which is governed by Google's Gemma Terms of Use. Access to and use of this model require agreement to these terms, which include specific restrictions on distribution, modification, and usage. For detailed information, please refer to the Gemma Terms of Use.​
33
+
34
+ Note: Ensure compliance with Google's terms when using, distributing, or modifying this model.​
35
+ b. imgutils Python Package
36
+
37
+ The imgutils package is employed for image processing tasks within this project. As of the latest available information, imgutils does not explicitly specify a license. In the absence of a declared license, the usage rights are undefined, and caution is advised. It's recommended to contact the package maintainers or consult the source repository for clarification before using or distributing this package.​
38
+
39
+ 3. Additional Dependencies
40
+
41
+ This project also relies on several other open-source packages, each with its own licensing terms:​
42
+
43
+ Gradio: Licensed under the Apache License 2.0.
44
+
45
+ Pillow (PIL): Licensed under the Historical Permission Notice and Disclaimer (HPND).
46
+
47
+ psd-tools: Licensed under the MIT License.
48
+
49
+ OpenCV: Licensed under the Apache License 2.0.​
50
+
51
+ Please refer to each package's documentation for detailed license information.​
52
+
53
+ 4. Usage Guidelines
54
+
55
+ Users of this software must ensure compliance with all applicable licenses, especially when distributing or modifying the software. This includes adhering to the terms set forth by third-party components. Failure to comply with these terms may result in legal consequences.​
56
+
57
+ 5. Disclaimer
58
+
59
+ This software is provided "as is," without warranty of any kind, express or implied. The authors are not liable for any damages or legal issues arising from the use of this software. Users are responsible for ensuring that their use of the software complies with all applicable laws and regulations.​
60
+
61
+ For any questions or concerns regarding this license, please contact `[email protected]`.​
README.md CHANGED
@@ -1,13 +1,80 @@
1
  ---
2
- title: Webtoon Cropper
3
- emoji:
4
- colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
  app_file: app.py
 
9
  pinned: false
10
- short_description: Webtoon Cropper
11
  ---
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Webtoon Lineart Cropper
3
+ emoji: 🤗
4
+ colorFrom: yellow
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.25.2
8
  app_file: app.py
9
+ python_version: 3.12.10
10
  pinned: false
 
11
  ---
12
+ # Helix-Painting Data Tool
13
 
14
+ 웹툰 PSD 파일을 이미지 파일(PNG)로 변환하고, 컷 추출 및 인물 검출 등을 통해 데이터를 추출하는 도구입니다.
15
+
16
+ ## Prerequisites
17
+
18
+ [Prerequisites](docs/PREREQUISITES.md)
19
+
20
+ ## Setup a project
21
+
22
+ [Setup a project](docs/SETUP.md)
23
+
24
+ ## Features
25
+
26
+ - **PSD → PNG 변환**
27
+ 웹툰 PSD 파일을 고해상도 PNG 이미지로 변환합니다.
28
+ - **컷 추출 및 필터링**
29
+ 이미지 내 흰색 영역을 기준으로 컷 박스를 추출하고, 인물 검출(face detector)과 태깅(tagger) 기능을 통해 컷 별 데이터를 생성합니다.
30
+ - **병렬 처리**
31
+ 처리 속도를 높이기 위해 Python의 multiprocessing을 사용한 병렬 이미지 처리 기능을 제공합니다.
32
+
33
+ ## Usage
34
+
35
+ ### 1. PSD 파일을 PNG 이미지로 변환
36
+ `convert_psd_to_png.py` 스크립트를 사용하여 PSD 파일들을 PNG 이미지로 변환합니다.
37
+
38
+ ```shell
39
+ python scripts/convert_psd_to_png.py --directory <PSD_directory> --output <output_directory> [--visible_layers layer1 layer2 ...] [--invisible_layers layer3 layer4 ...]
40
+ ```
41
+
42
+ - `--directory` : PSD 파일을 검색할 디렉토리
43
+ - `--output` : 변환된 PNG 파일을 저장할 디렉토리
44
+ - `--visible_layers` / `--invisible_layers` : 보이거나 숨길 레이어들을 지정
45
+
46
+ ### 2. 이미지 컷 추출 및 필터링
47
+ `parse_cut_from_page.py` 스크립트를 사용하여 컷 박스를 추출하고, 각 컷에 대해 필터링된 데이터를 생성합니다.
48
+
49
+ ```shell
50
+ python scripts/parse_cut_from_page.py --lineart <lineart_directory> --flat <flat_directory> --segmentation <segmentation_directory> --color <color_directory> --output <output_directory> [--num_process <number_of_processes>]
51
+ ```
52
+
53
+ - `--lineart` : 라인아트 이미지가 저장된 디렉토리
54
+ - `--flat` : 채색 전 평면 이미지가 저장된 디렉토리
55
+ - `--segmentation` : 세분화 이미지가 저장된 디렉토리
56
+ - `--color` : 컬러 이미지가 저장된 디렉토리
57
+ - `--output` : 잘라낸 컷 이미지를 저장할 디렉토리
58
+ - `--num_process` : 병렬 처리에 사용할 프로세스 수 (선택값)
59
+
60
+ ### 3. 태깅 및 필터링 실행
61
+ `run_tag_filter.py` 스크립트를 사용하여 이미지에 태깅 및 필터링을 수행하고, 결과를 저장합니다.
62
+
63
+ ```shell
64
+ python scripts/run_tag_filter.py --input_dir <input_directory> --output_dir <output_directory> [--ext png jpg jpeg]
65
+ ```
66
+
67
+ - `--input_dir` : 필터링할 이미지들이 저장된 디렉토리
68
+ - `--output_dir` : 필터링 결과 이미지 및 캡션 파일을 저장할 디렉토리
69
+ - `--ext` : 처리할 이미지 확장자 목록 (기본값: png)
70
+
71
+ ---
72
+
73
+ 이와 같이 각 스크립트는 별도의 인자들을 받아서 작업을 수행하며, 스크립트 내부의 로직에 따라 PSD 변환, 컷 추출, 인물 검출/태깅, 캡션 파일 생성 등의 기능을 제공합니다.
74
+
75
+ ### 4. Gradio 데모 페이지
76
+ 아래 명렁어를 실행하여 Gradio 데모 페이지로 원하는 PNG파일을 업로드하여 필터링된 이미지 컷 추출 결과를 ZIP파일 혹은 PSD파일로 확인할 수 있습니다.
77
+
78
+ ```shell
79
+ python app.py
80
+ ```
app.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import zipfile
4
+ import shutil # For make_archive
5
+ import uuid
6
+ from PIL import Image, ImageDraw
7
+ from psd_tools import PSDImage
8
+ from psd_tools.api.layers import PixelLayer
9
+ import gradio as gr
10
+ import traceback # For printing stack traces
11
+ import subprocess
12
+
13
+ def install(package):
14
+ subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
15
+
16
+ install("timm")
17
+ install("pydantic==2.10.6")
18
+ install("dghs-imgutils==0.15.0")
19
+ install("onnxruntime >= 1.17.0")
20
+ install("psd_tools==1.10.7")
21
+ # os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
22
+ # --- Attempt to import the actual function (Detector 1) ---
23
+ # Let's keep the original import name as requested in the previous version
24
+ try:
25
+ # Assuming this is the intended "Detector 1"
26
+ from src.wise_crop.detect_and_crop import crop_and_mask_characters_gradio
27
+ detector_1_available = True
28
+ print("Successfully imported 'crop_and_mask_characters_gradio' as Detector 1.")
29
+ except ImportError:
30
+ detector_1_available = False
31
+ print("Warning: Could not import 'crop_and_mask_characters_gradio'. Using dummy function for Detector 1.")
32
+ # Define a dummy version for Detector 1 if import fails
33
+ def crop_and_mask_characters_gradio(image_pil: Image.Image):
34
+ """Dummy function 1 if import fails."""
35
+ print("Using DUMMY Detector 1.")
36
+ if image_pil is None: return []
37
+ width, height = image_pil.size
38
+ boxes = [
39
+ (0, (int(width * 0.1), int(height * 0.1), int(width * 0.3), int(height * 0.4))),
40
+ (1, (int(width * 0.6), int(height * 0.5), int(width * 0.25), int(height * 0.35))),
41
+ ]
42
+ valid_boxes = []
43
+ for i, (x, y, w, h) in boxes:
44
+ x1, y1, x2, y2 = max(0, x), max(0, y), min(width, x + w), min(height, y + h)
45
+ if x2 - x1 > 0 and y2 - y1 > 0: valid_boxes.append((i, (x1, y1, x2 - x1, y2 - y1)))
46
+ return valid_boxes
47
+
48
+ # from src.oskar_crop.detect_and_crop import process_single_image as detector_2_function
49
+ try:
50
+ # Assuming this is the intended "Detector 2"
51
+ # Note: Renamed the import alias to avoid conflict if both imports succeed.
52
+ # The function call inside process_lineart still uses crop_and_mask_characters_gradio_2
53
+ from src.oskar_crop.detect_and_crop import process_single_image as detector_2_function
54
+ detector_2_available = True
55
+ print("Successfully imported 'process_single_image' as Detector 2.")
56
+ # Define the function name used in process_lineart
57
+ def crop_and_mask_characters_gradio_2(image_pil: Image.Image):
58
+ return detector_2_function(image_pil)
59
+
60
+ except ImportError:
61
+ detector_2_available = False
62
+ print("Warning: Could not import 'process_single_image'. Using dummy function for Detector 2.")
63
+ # Define a dummy version for Detector 2 if import fails
64
+ # --- Define the SECOND Dummy Detection Function ---
65
+ def crop_and_mask_characters_gradio_2(image_pil: Image.Image):
66
+ """
67
+ SECOND Dummy function to simulate detecting objects and returning bounding boxes.
68
+ Returns different results than the first function.
69
+ """
70
+ print("Using DUMMY Detector 2.")
71
+ if image_pil is None:
72
+ return []
73
+
74
+ width, height = image_pil.size
75
+ print(f"Dummy detection 2 running on image size: {width}x{height}")
76
+
77
+ # Define DIFFERENT fixed bounding boxes for demonstration
78
+ boxes = [
79
+ (0, (int(width * 0.05), int(height * 0.6), int(width * 0.4), int(height * 0.3))), # Bottom-leftish, wider
80
+ (1, (int(width * 0.7), int(height * 0.1), int(width * 0.20), int(height * 0.25))), # Top-rightish, smaller
81
+ (2, (int(width * 0.4), int(height * 0.4), int(width * 0.15), int(height * 0.15))), # Center-ish, very small
82
+ ]
83
+
84
+ # Basic validation
85
+ valid_boxes = []
86
+ for i, (x, y, w, h) in boxes:
87
+ x1 = max(0, x)
88
+ y1 = max(0, y)
89
+ x2 = min(width, x + w)
90
+ y2 = min(height, y + h)
91
+ new_w = x2 - x1
92
+ new_h = y2 - y1
93
+ if new_w > 0 and new_h > 0:
94
+ valid_boxes.append((i, (x1, y1, new_w, new_h)))
95
+
96
+ print(f"Dummy detection 2 found {len(valid_boxes)} boxes.")
97
+ return valid_boxes
98
+
99
+
100
+ # --- Helper Function (make_lineart_transparent - unchanged) ---
101
+ def make_lineart_transparent(lineart_path, threshold=200):
102
+ """Converts a lineart image file to a transparent RGBA PIL Image."""
103
+ try:
104
+ # Ensure we handle potential pathlib objects if Gradio passes them
105
+ lineart_gray = Image.open(str(lineart_path)).convert('L')
106
+ w, h = lineart_gray.size
107
+ lineart_rgba = Image.new('RGBA', (w, h), (0, 0, 0, 0))
108
+ gray_pixels = lineart_gray.load()
109
+ rgba_pixels = lineart_rgba.load()
110
+ for y in range(h):
111
+ for x in range(w):
112
+ gray_val = gray_pixels[x, y]
113
+ alpha = 255 - gray_val
114
+ if gray_val < threshold :
115
+ rgba_pixels[x, y] = (0, 0, 0, alpha)
116
+ else:
117
+ rgba_pixels[x, y] = (0, 0, 0, 0)
118
+ return lineart_rgba
119
+ except FileNotFoundError:
120
+ print(f"Helper Error: Image file not found at {lineart_path}")
121
+ # Return a blank transparent image or None? Returning None is clearer.
122
+ return None
123
+ except Exception as e:
124
+ print(f"Helper Error processing image {lineart_path}: {e}")
125
+ return None
126
+
127
+ # --- Main Processing Function (modified for better error handling with PIL) ---
128
+ def process_lineart(input_pil_or_path, detector_choice): # Input can be PIL or path from examples
129
+ """
130
+ Processes the input lineart image using the selected detector.
131
+ Detects objects (e.g., characters based on head/face), crops them,
132
+ provides a gallery of crops, a ZIP file of crops, and a PSD file
133
+ with the original lineart (made transparent) and bounding boxes.
134
+ """
135
+ # --- Initialize variables ---
136
+ input_pil_image = None
137
+ temp_input_path = None
138
+ using_temp_input_path = False
139
+ status_updates = ["Status: Initializing..."]
140
+ psd_output_path = None # Initialize to None
141
+ zip_output_path = None # Initialize to None
142
+ cropped_images_for_gallery = [] # Initialize to empty list
143
+
144
+ try:
145
+ # --- Handle Input ---
146
+ if input_pil_or_path is None:
147
+ gr.Warning("Please upload a PNG image or select an example.")
148
+ return [], None, None, "Status: No image provided."
149
+
150
+ print(f"Input type: {type(input_pil_or_path)}")
151
+ print(f"Input value: {input_pil_or_path}")
152
+ # Check if input is already a PIL image (from upload) or a path (from examples)
153
+ if isinstance(input_pil_or_path, Image.Image):
154
+ input_pil_image = input_pil_or_path
155
+ print("Processing PIL image from upload.")
156
+ # Create a temporary path for make_lineart_transparent if needed later
157
+ temp_input_fd, temp_input_path = tempfile.mkstemp(suffix=".png")
158
+ os.close(temp_input_fd)
159
+ input_pil_image.save(temp_input_path, "PNG")
160
+ using_temp_input_path = True
161
+ elif isinstance(input_pil_or_path, str) and os.path.exists(input_pil_or_path):
162
+ print(f"Processing image from file path: {input_pil_or_path}")
163
+ try:
164
+ input_pil_image = Image.open(input_pil_or_path)
165
+ # Use the example path directly for make_lineart_transparent
166
+ temp_input_path = input_pil_or_path
167
+ using_temp_input_path = False # Don't delete the example file later
168
+ except Exception as e:
169
+ status_updates.append(f"ERROR: Could not open image file from path '{input_pil_or_path}': {e}")
170
+ print(status_updates[-1])
171
+ return [], None, None, "\n".join(status_updates) # Return error status
172
+ else:
173
+ status_updates.append(f"ERROR: Invalid input type received: {type(input_pil_or_path)}. Expected PIL image or file path.")
174
+ print(status_updates[-1])
175
+ return [], None, None, "\n".join(status_updates) # Return error status
176
+
177
+ # --- Ensure RGBA and get dimensions ---
178
+ try:
179
+ input_pil_image = input_pil_image.convert("RGBA")
180
+ width, height = input_pil_image.size
181
+ except Exception as e:
182
+ status_updates.append(f"ERROR: Could not process input image (convert/get size): {e}")
183
+ print(status_updates[-1])
184
+ # Clean up temp file if created before error
185
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
186
+ try: os.remove(temp_input_path)
187
+ except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
188
+ return [], None, None, "\n".join(status_updates) # Return error status
189
+
190
+ status_updates = [f"Status: Processing started using {detector_choice}."] # Reset status
191
+ print("Starting processing...")
192
+
193
+ # --- 1. Detect Objects (Conditional) ---
194
+ print(f"Selected detector: {detector_choice}")
195
+ if detector_choice == "Detector 1":
196
+ if not detector_1_available:
197
+ status_updates.append("Warning: Using DUMMY Detector 1.")
198
+ boxes_info = crop_and_mask_characters_gradio(input_pil_image)
199
+ elif detector_choice == "Detector 2":
200
+ if not detector_2_available:
201
+ status_updates.append("Warning: Using DUMMY Detector 2.")
202
+ boxes_info = crop_and_mask_characters_gradio_2(input_pil_image)
203
+ else:
204
+ # This case should ideally not happen with Radio buttons, but good for safety
205
+ status_updates.append(f"ERROR: Invalid detector choice received: {detector_choice}")
206
+ print(status_updates[-1])
207
+ # Clean up temp file if created before error
208
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
209
+ try: os.remove(temp_input_path)
210
+ except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
211
+ return [], None, None, "\n".join(status_updates) # Return error status
212
+
213
+ if not boxes_info:
214
+ gr.Warning("No objects detected.")
215
+ status_updates.append("No objects detected.")
216
+ # Clean up temp file if created
217
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
218
+ try: os.remove(temp_input_path)
219
+ except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
220
+ return [], None, None, "\n".join(status_updates)
221
+
222
+ status_updates.append(f"Detected {len(boxes_info)} objects.")
223
+ print(f"Detected boxes: {boxes_info}")
224
+
225
+ # --- Temporary file paths (partially adjusted) ---
226
+ temp_dir_for_outputs = tempfile.gettempdir()
227
+ unique_id = uuid.uuid4().hex[:8]
228
+ zip_base_name = os.path.join(temp_dir_for_outputs, f"cropped_images_{unique_id}")
229
+ zip_output_path = f"{zip_base_name}.zip" # Path for the final zip file
230
+ psd_output_path = os.path.join(temp_dir_for_outputs, f"lineart_boxes_{unique_id}.psd")
231
+ # temp_input_path is already handled above based on input source
232
+
233
+ # --- 2. Crop Images and Prepare for ZIP ---
234
+ with tempfile.TemporaryDirectory() as temp_crop_dir:
235
+ print(f"Saving cropped images to temporary directory: {temp_crop_dir}")
236
+ for i, (x, y, w, h) in boxes_info:
237
+ # Ensure box coordinates are within image bounds
238
+ x1, y1 = max(0, x), max(0, y)
239
+ x2, y2 = min(width, x + w), min(height, y + h)
240
+ box = (x1, y1, x2, y2)
241
+ if box[2] > box[0] and box[3] > box[1]: # Check if width and height are positive
242
+ try:
243
+ cropped_img = input_pil_image.crop(box)
244
+ cropped_images_for_gallery.append(cropped_img)
245
+ crop_filename = os.path.join(temp_crop_dir, f"cropped_{i}.png")
246
+ cropped_img.save(crop_filename, "PNG")
247
+ except Exception as e:
248
+ print(f"Error cropping or saving box {i} with coords {box}: {e}")
249
+ status_updates.append(f"Warning: Error processing crop {i}.")
250
+ else:
251
+ print(f"Skipping invalid box {i} with coords {box}")
252
+ status_updates.append(f"Warning: Skipped invalid crop dimensions for box {i}.")
253
+
254
+
255
+ # --- 3. Create ZIP File ---
256
+ # Check if any PNG files were actually created in the temp dir
257
+ if any(f.endswith(".png") for f in os.listdir(temp_crop_dir)):
258
+ print(f"Creating ZIP file: {zip_output_path} from {temp_crop_dir}")
259
+ try:
260
+ shutil.make_archive(zip_base_name, 'zip', temp_crop_dir)
261
+ status_updates.append("Cropped images ZIP created.")
262
+ # zip_output_path is already correctly set
263
+ except Exception as e:
264
+ print(f"Error creating ZIP file: {e}")
265
+ status_updates.append("Error: Failed to create ZIP file.")
266
+ zip_output_path = None # Indicate failure
267
+ else:
268
+ print("No valid cropped images were saved, skipping ZIP creation.")
269
+ status_updates.append("Skipping ZIP creation (no valid crops).")
270
+ zip_output_path = None # No zip file to provide
271
+
272
+ # --- 4. Prepare PSD Layers ---
273
+ # a) Line Layer (Use the temp_input_path which is either the original example path or a temp copy)
274
+ print(f"Using image path for transparent layer: {temp_input_path}")
275
+ line_layer_pil = make_lineart_transparent(temp_input_path)
276
+ if line_layer_pil is None:
277
+ status_updates.append("Error: Failed to create transparent lineart layer.")
278
+ print(status_updates[-1])
279
+ # Don't create PSD if lineart failed, return current results
280
+ # Clean up temp file if created
281
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
282
+ try: os.remove(temp_input_path)
283
+ except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
284
+ return cropped_images_for_gallery, zip_output_path, None, "\n".join(status_updates) # Return None for PSD
285
+
286
+ status_updates.append("Transparent lineart layer created.")
287
+
288
+ # b) Box Layer
289
+ box_layer_pil = Image.new('RGBA', (width, height), (255, 255, 255, 255)) # White background
290
+ draw = ImageDraw.Draw(box_layer_pil)
291
+ for i, (x, y, w, h) in boxes_info:
292
+ # Use validated coords again, ensure they are within bounds
293
+ x1, y1 = max(0, x), max(0, y)
294
+ x2, y2 = min(width, x + w), min(height, y + h)
295
+ if x2 > x1 and y2 > y1: # Check validity again just in case
296
+ rect = [(x1, y1), (x2, y2)]
297
+ # Changed to fill for solid boxes, yellow fill, semi-transparent
298
+ draw.rectangle(rect, fill=(255, 255, 0, 128))
299
+ status_updates.append("Bounding box layer created.")
300
+
301
+ # --- 5. Create PSD File ---
302
+ print(f"Creating PSD file: {psd_output_path}")
303
+ # Double check layer sizes before creating PSD object
304
+ if line_layer_pil.size != (width, height) or box_layer_pil.size != (width, height):
305
+ size_error_msg = (f"Error: Layer size mismatch during PSD creation. "
306
+ f"Line: {line_layer_pil.size}, Box: {box_layer_pil.size}, "
307
+ f"Expected: {(width, height)}")
308
+ status_updates.append(size_error_msg)
309
+ print(size_error_msg)
310
+ # Clean up temp file if created
311
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
312
+ try: os.remove(temp_input_path)
313
+ except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
314
+ return cropped_images_for_gallery, zip_output_path, None, "\n".join(status_updates) # No PSD
315
+
316
+ try:
317
+ psd = PSDImage.new(mode='RGBA', size=(width, height))
318
+ # Add layers (order matters for visibility in PSD viewers)
319
+ # Base layer is transparent by default with RGBA
320
+ psd.append(PixelLayer.frompil(line_layer_pil, layer_name='line', top=0, left=0))
321
+ psd.append(PixelLayer.frompil(box_layer_pil, layer_name='box', top=0, left=0))
322
+ psd.save(psd_output_path)
323
+ status_updates.append("PSD file created.")
324
+ except Exception as e:
325
+ print(f"Error saving PSD file: {e}")
326
+ traceback.print_exc()
327
+ status_updates.append("Error: Failed to save PSD file.")
328
+ psd_output_path = None # Indicate failure
329
+
330
+
331
+ print("Processing finished.")
332
+ status_updates.append("Success!")
333
+ final_status = "\n".join(status_updates)
334
+
335
+ # Return all paths, even if None (Gradio handles None for File output)
336
+ return cropped_images_for_gallery, zip_output_path, psd_output_path, final_status
337
+
338
+ except Exception as e:
339
+ print(f"An unexpected error occurred in process_lineart: {e}")
340
+ traceback.print_exc()
341
+ status_updates.append(f"FATAL ERROR: {e}")
342
+ final_status = "\n".join(status_updates)
343
+ # Return empty/None outputs and the error status
344
+ # Ensure cleanup happens even on fatal error
345
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
346
+ try:
347
+ os.remove(temp_input_path)
348
+ print(f"Cleaned up temporary input file due to error: {temp_input_path}")
349
+ except Exception as e_rem:
350
+ print(f"Warning: Could not remove temp input file {temp_input_path} during error handling: {e_rem}")
351
+ return [], None, None, final_status # Return safe defaults
352
+
353
+ finally:
354
+ # --- Final Cleanup (Only removes temp input if created from upload) ---
355
+ if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
356
+ try:
357
+ os.remove(temp_input_path)
358
+ print(f"Cleaned up temporary input file: {temp_input_path}")
359
+ except Exception as e_rem:
360
+ # This might happen if the file was already removed in an error block
361
+ print(f"Notice: Could not remove temp input file {temp_input_path} in finally block (may already be removed): {e_rem}")
362
+
363
+
364
+ # --- Gradio Interface Definition (modified) ---
365
+ css = '''
366
+ .custom-gallery {
367
+ height: 500px !important;
368
+ width: 100%;
369
+ margin: 10px auto;
370
+ padding: 0px;
371
+ overflow-y: auto !important;
372
+ }
373
+ '''
374
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
375
+ gr.Markdown("# Webtoon Lineart Cropper with Filtering by Head-or-Face Detection")
376
+ gr.Markdown("Upload a PNG lineart image of your webtoon and automatically crop the character's face or head included region. "
377
+ "This demo leverages some detectors to precisely detect and isolate characters. "
378
+ "The app will display cropped objects, provide a ZIP of cropped PNGs, "
379
+ "and a PSD file with transparent lineart and half-transparent yellow-filled box layers. "
380
+ "We provide two detectors to choose from, each with different filtering methods. ")
381
+ gr.Markdown("- **Detector 1**: Uses [`imgutils.detect`](https://github.com/deepghs/imgutils/tree/main/imgutils/detect) and VLM-based filtering with [`google/gemma-3-12b-it`](https://huggingface.co/google/gemma-3-12b-it)")
382
+ gr.Markdown("- **Detector 2**: Uses [`imgutils.detect`](https://github.com/deepghs/imgutils/tree/main/imgutils/detect) and tag-based filtering with [`SmilingWolf/wd-eva02-large-tagger-v3`](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3)")
383
+ gr.Markdown("**Note 1:** The app may take a few seconds to process the image, depending on the size and number of characters detected. The example image below is a lineart PNG file created synthetically from images on [Danbooru](https://danbooru.donmai.us/posts?page=1&tags=dragon_ball_z) after [lineart extraction](https://huggingface.co/spaces/carolineec/informativedrawings).")
384
+ gr.Markdown("**Note 2:** This demo is developed by [Kakao Entertainment](https://kakaoent.com/)'s AI Lab for research purposes, specifically designed to preprocess webtoon image data and is also not intended for production use. It is a research prototype and may not be suitable for all use cases. Please use it at your own risk.")
385
+
386
+ with gr.Row():
387
+ with gr.Column(scale=1):
388
+ # Input type remains 'filepath' to handle examples cleanly.
389
+ image_input = gr.Image(type="filepath", label="Upload Lineart PNG or Select Example", image_mode='RGBA', height=400)
390
+
391
+ detector_choice_radio = gr.Radio(
392
+ choices=["Detector 1", "Detector 2"],
393
+ label="Choose Detection Function",
394
+ value="Detector 1" # Default value
395
+ )
396
+ process_button = gr.Button("Process Uploaded/Modified Image", variant="primary")
397
+ status_output = gr.Textbox(label="Status", interactive=False, lines=8) # Increased lines slightly more
398
+
399
+ with gr.Column(scale=3):
400
+ gr.Markdown("### Cropped Objects")
401
+ # Setting height explicitly can sometimes help layout.
402
+ gallery_output = gr.Gallery(label="Detected Objects (Cropped)", elem_id="gallery_crops", columns=4, height=500, interactive=False, elem_classes="custom-gallery") # object_fit="contain")
403
+ with gr.Row():
404
+ zip_output = gr.File(label="Download Cropped Images (ZIP)")
405
+ psd_output = gr.File(label="Download PSD (Lineart + Boxes)")
406
+
407
+ # --- Add Examples ---
408
+ # IMPORTANT: Make sure 'sample_img.png' exists in the same directory
409
+ # as this script, or provide the correct relative/absolute path.
410
+ # Also ensure the image is a valid PNG.
411
+ example_image_path = "./sample_img/sample_danbooru_dragonball.png"
412
+ if os.path.exists(example_image_path):
413
+ gr.Examples(
414
+ examples=[
415
+ [example_image_path, "Detector 1"],
416
+ [example_image_path, "Detector 2"] # Add example for detector 2 as well
417
+ ],
418
+ # Inputs that the examples populate
419
+ inputs=[image_input, detector_choice_radio],
420
+ # Outputs that are updated when an example is clicked AND run_on_click=True
421
+ outputs=[gallery_output, zip_output, psd_output, status_output],
422
+ # The function to call when an example is clicked
423
+ fn=process_lineart,
424
+ # Make clicking an example automatically run the function
425
+ run_on_click=True,
426
+ label="Click Example to Run Automatically", # Updated label
427
+ cache_examples=True, # Disable caching to ensure fresh processing
428
+ cache_mode="lazy",
429
+ )
430
+ else:
431
+ gr.Markdown(f"**(Note:** Could not find `{example_image_path}` for examples. Please create it or ensure it's in the correct directory.)")
432
+
433
+
434
+ # --- Button Click Handler (for manual uploads/changes) ---
435
+ process_button.click(
436
+ fn=process_lineart,
437
+ inputs=[image_input, detector_choice_radio],
438
+ outputs=[gallery_output, zip_output, psd_output, status_output]
439
+ )
440
+
441
+ # --- Launch the Gradio App ---
442
+ if __name__ == "__main__":
443
+ # Create a dummy sample image if it doesn't exist for testing
444
+ if not os.path.exists("./sample_img/sample_danbooru_dragonball.png"):
445
+ print("Creating a dummy 'sample_danbooru_dragonball.png' for demonstration.")
446
+ try:
447
+ img = Image.new('L', (300, 200), color=255) # White background (grayscale)
448
+ draw = ImageDraw.Draw(img)
449
+ # Draw some black lines/shapes
450
+ draw.line((30, 30, 270, 30), fill=0, width=2)
451
+ draw.rectangle((50, 50, 150, 150), outline=0, width=3)
452
+ draw.ellipse((180, 70, 250, 130), outline=0, width=3)
453
+ img.save("./sample_img/sample_danbooru_dragonball.png", "PNG")
454
+ print("Dummy 'sample_danbooru_dragonball.png' created.")
455
+ except Exception as e:
456
+ print(f"Warning: Failed to create dummy sample image: {e}")
457
+
458
+ demo.launch()
459
+ # ssr_mode=False
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch
3
+ --extra-index-url https://download.pytorch.org/whl/cpu
4
+ torchvision
5
+ --extra-index-url https://download.pytorch.org/whl/cpu
6
+ torchaudio
7
+ pydantic==2.10.6
8
+ timm
9
+ psd_tools==1.10.7
10
+ accelerate
11
+ diffusers
12
+ transformers
13
+ xformers
14
+ opencv-python
15
+ dghs-imgutils==0.15.0
16
+ pillow
17
+ numpy
18
+ scikit-learn
19
+ huggingface_hub
20
+ tqdm
21
+ opencv-contrib-python
22
+ pandas
23
+ scipy
sample_img/sample_danbooru_dragonball.png ADDED

Git LFS Details

  • SHA256: 203064708dc908c07b95c1e8e53302635d260b6c2b238f29f6580e2af597373c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.56 MB
scripts/__init__.py ADDED
File without changes
scripts/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (204 Bytes). View file
 
scripts/__pycache__/parse_cut_from_page.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
scripts/convert_psd_to_png.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import multiprocessing
4
+ import os
5
+ from typing import Iterable
6
+
7
+ from psd_tools import PSDImage
8
+ from tqdm import tqdm
9
+
10
+ logger = logging.getLogger(__name__)
11
+ logging.basicConfig(level=logging.INFO)
12
+
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Convert PSD files to PNG.")
16
+ parser.add_argument(
17
+ "-d",
18
+ "--directory",
19
+ type=str,
20
+ default="./",
21
+ help="Directory to search for PSD files.",
22
+ )
23
+ parser.add_argument("-o", "--output", type=str, default="./", help="Directory to save PNG files.")
24
+ parser.add_argument(
25
+ "--visible_layers",
26
+ default=[],
27
+ nargs="+",
28
+ type=str,
29
+ help="List of layer names to make visible.",
30
+ )
31
+ parser.add_argument(
32
+ "--invisible_layers",
33
+ default=[],
34
+ nargs="+",
35
+ type=str,
36
+ help="List of layer names to make invisible.",
37
+ )
38
+ parser.add_argument("--num_processes", "-n", default=None, type=int, help=" Number of processes to use.")
39
+ return parser.parse_args()
40
+
41
+
42
+ def find_psd_files(directory):
43
+ psd_files = []
44
+ for root, dirs, files in os.walk(directory):
45
+ for file in files:
46
+ if file.endswith(".psd"):
47
+ psd_files.append(os.path.join(root, file))
48
+ return psd_files
49
+
50
+
51
+ def set_layer_visibility(layer, visible_layers, invisible_layers):
52
+ if layer.name in visible_layers:
53
+ layer.visible = True
54
+ if layer.name in invisible_layers:
55
+ layer.visible = False
56
+
57
+ if isinstance(layer, Iterable):
58
+ for child in layer:
59
+ set_layer_visibility(child, visible_layers, invisible_layers)
60
+
61
+
62
+ def process_psd_file(task):
63
+ """
64
+ Worker function that processes a single PSD file.
65
+ Opens the PSD, sets layer visibility, composites the image and saves it as PNG.
66
+ """
67
+ psd_file, output, visible_layers, invisible_layers, force = task
68
+ try:
69
+ psd = PSDImage.open(psd_file)
70
+ if force:
71
+ for layer in psd:
72
+ set_layer_visibility(layer, visible_layers, invisible_layers)
73
+ image = psd.composite(force=force)
74
+ fname = os.path.basename(psd_file).replace(".psd", ".png")
75
+ output_file = os.path.join(output, fname)
76
+ image.save(output_file)
77
+ except Exception as e:
78
+ logger.error("Error processing file %s: %s", psd_file, e)
79
+
80
+
81
+ def main(args):
82
+ # Create output directory if it doesn't exist
83
+ if not os.path.exists(args.output):
84
+ os.makedirs(args.output)
85
+
86
+ psd_files = find_psd_files(args.directory)
87
+ # force=True when any layer visibility is provided
88
+ force = True if len(args.visible_layers) + len(args.invisible_layers) else False
89
+
90
+ tasks = [(psd_file, args.output, args.visible_layers, args.invisible_layers, force) for psd_file in psd_files]
91
+
92
+ num_processes = args.num_processes if args.num_processes else multiprocessing.cpu_count() // 2
93
+ # Use multiprocessing to process PSD files in parallel
94
+ with multiprocessing.Pool(processes=num_processes) as pool:
95
+ list(
96
+ tqdm(
97
+ pool.imap_unordered(process_psd_file, tasks),
98
+ total=len(tasks),
99
+ desc="Convert PSD to PNG files",
100
+ )
101
+ )
102
+
103
+
104
+ if __name__ == "__main__":
105
+ args = parse_args()
106
+ main(args)
scripts/parse_cut_from_page.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import multiprocessing
4
+ import re
5
+ import sys
6
+ from pathlib import Path, PurePath
7
+ from typing import List, Tuple
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from tqdm import tqdm
12
+
13
+ current_file_path = Path(__file__).resolve()
14
+ sys.path.insert(0, str(current_file_path.parent.parent))
15
+
16
+ logger = logging.getLogger(__name__)
17
+ logging.basicConfig(level=logging.INFO)
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description="Parse cut from page script.")
22
+
23
+ parser.add_argument("--lineart", "-l", type=str, required=True, help="Directory of lineart images.")
24
+ parser.add_argument("--flat", "-f", type=str, required=True, help="Directory of flat images.")
25
+ parser.add_argument(
26
+ "--segmentation",
27
+ "-s",
28
+ type=str,
29
+ required=True,
30
+ help="Directory of segmentatio.",
31
+ )
32
+ parser.add_argument("--color", "-c", type=str, required=True, help="Directory of color images.")
33
+ parser.add_argument("--output", "-o", type=str, required=True, help="Output directory for parsed images.")
34
+ parser.add_argument("--num_process", "-n", type=int, default=None, help="Number of processes to use.")
35
+
36
+ return parser.parse_args()
37
+
38
+
39
+ def get_image_list(input_dir: str, ext: List[str]):
40
+ """
41
+ Get a list of images from the input directory with the specified extensions.
42
+ Args:
43
+ input_dir (str): Directory containing images to filter.
44
+ ext (list): List of file extensions to filter by.
45
+ Returns:
46
+ list: List of image file paths.
47
+ """
48
+ image_list = []
49
+ for ext in ext:
50
+ image_list.extend(Path(input_dir).glob(f"*.{ext}"))
51
+ return image_list
52
+
53
+
54
+ def check_image_pair_validity(
55
+ lineart_list: List[PurePath],
56
+ flat_list: List[PurePath],
57
+ segmentation_list: List[PurePath],
58
+ color_list: List[PurePath],
59
+ pattern: str = r"\d+_\d+",
60
+ ) -> Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
61
+ """
62
+ Validates and filters lists of image file paths to ensure they correspond to the same IDs
63
+ based on a given naming pattern. If the lengths of the input lists are mismatched, the
64
+ function filters the lists to include only matching IDs.
65
+
66
+ Args:
67
+ lineart_list (List[PurePath]): List of file paths for lineart images.
68
+ flat_path (List[PurePath]): List of file paths for flat images.
69
+ segmentation_path (List[PurePath]): List of file paths for segmentation images.
70
+ color_path (List[PurePath]): List of file paths for color images.
71
+ pattern (str, optional): Regular expression pattern to extract IDs from file names.
72
+ Defaults to r"\d+_\d+".
73
+
74
+ Returns:
75
+ Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
76
+ A tuple containing four lists of file paths (lineart, flat, segmentation, color)
77
+ that have been filtered to ensure matching IDs.
78
+ """
79
+ pattern = re.compile(pattern)
80
+
81
+ # Sort the lists based on the pattern
82
+ lineart_list = sorted(lineart_list, key=lambda x: pattern.match(x.name).group(0))
83
+ flat_list = sorted(flat_list, key=lambda x: pattern.match(x.name).group(0))
84
+ segmentation_list = sorted(segmentation_list, key=lambda x: pattern.match(x.name).group(0))
85
+ color_list = sorted(color_list, key=lambda x: pattern.match(x.name).group(0))
86
+
87
+ # Check if the lengths of the lists are equal
88
+ if (
89
+ len(lineart_list) != len(flat_list)
90
+ or len(lineart_list) != len(segmentation_list)
91
+ or len(lineart_list) != len(color_list)
92
+ ):
93
+ # If the lengths are not equal, we need to filter the lists based on the pattern
94
+ logger.warning(
95
+ f"Length mismatch: lineart({len(lineart_list)}), flat({len(flat_list)}), segmentation({len(segmentation_list)}), color({len(color_list)})"
96
+ )
97
+ new_lineart_list = []
98
+ new_flat_list = []
99
+ new_segmentation_list = []
100
+ new_color_list = []
101
+ for lineart_path in lineart_list:
102
+ lineart_name = lineart_path.name
103
+ lineart_match = pattern.match(lineart_name)
104
+
105
+ if lineart_match:
106
+ file_id = lineart_match.group(0)
107
+ corresponding_flat_files = [p for p in flat_list if file_id in p.name]
108
+ corresponding_segmentation_files = [p for p in segmentation_list if file_id in p.name]
109
+ corresponding_color_paths = [p for p in color_list if file_id in p.name]
110
+
111
+ if corresponding_flat_files and corresponding_segmentation_files and corresponding_color_paths:
112
+ new_lineart_list.append(lineart_path)
113
+ new_flat_list.append(corresponding_flat_files[0])
114
+ new_segmentation_list.append(corresponding_segmentation_files[0])
115
+ new_color_list.append(corresponding_color_paths[0])
116
+
117
+ return new_lineart_list, new_flat_list, new_segmentation_list, new_color_list
118
+ else:
119
+ return lineart_list, flat_list, segmentation_list, color_list
120
+
121
+
122
+ def extract_cutbox_coordinates(image: Image.Image) -> List[Tuple[int, int, int, int]]:
123
+ """
124
+ Extracts bounding box coordinates for non-white regions in an image.
125
+
126
+ This function identifies regions in the given image that contain non-white pixels
127
+ and calculates the bounding box coordinates for each region. The bounding boxes
128
+ are represented as tuples of (left, top, right, bottom).
129
+
130
+ Args:
131
+ image (Image.Image): The input image as a PIL Image object.
132
+
133
+ Returns:
134
+ List[Tuple[int, int, int, int]]: A list of bounding box coordinates for non-white regions.
135
+ Each tuple contains four integers representing the left, top, right, and bottom
136
+ coordinates of a bounding box.
137
+ """
138
+
139
+ # We'll now detect the bounding box for non-white pixels instead of relying on the alpha channel.
140
+ # Convert the image to RGB and get the numpy array
141
+ image_rgb = image.convert("RGB")
142
+ image_np_rgb = np.array(image_rgb)
143
+
144
+ # Define white color threshold (treat pixels close to white as white)
145
+ threshold = 255 # Allow some margin for near-white
146
+ non_white_mask = np.any(image_np_rgb < threshold, axis=-1) # Any channel below threshold is considered non-white
147
+ non_white_mask = non_white_mask.astype(np.uint8) * 255
148
+ # Image.fromarray(non_white_mask).save("non_white_mask.png")
149
+
150
+ # Find rows containing non-white pixels
151
+ non_white_rows = np.where(non_white_mask.any(axis=1))[0]
152
+ if len(non_white_rows) == 0:
153
+ return []
154
+
155
+ # Group continuous non-white rows
156
+ horizontal_line = np.where(np.diff(non_white_rows) != 1)[0] + 1
157
+ non_white_rows = np.split(non_white_rows, horizontal_line)
158
+ top_bottom_pairs = [(group[0], group[-1]) for group in non_white_rows]
159
+
160
+ # Iterate through each cut and find the left and right bounds
161
+ bounding_boxes = []
162
+ for top, bottom in top_bottom_pairs:
163
+ cut = image_np_rgb[top : bottom + 1]
164
+
165
+ non_white_mask = np.any(cut < threshold, axis=-1)
166
+ non_white_cols = np.where(non_white_mask.any(axis=0))[0]
167
+ left = non_white_cols[0]
168
+ right = non_white_cols[-1]
169
+
170
+ bounding_boxes.append((left, top, right + 1, bottom + 1))
171
+
172
+ return bounding_boxes
173
+
174
+
175
+ def process_single_image(task: Tuple[Image.Image, Image.Image, Image.Image, Image.Image, str]):
176
+ """
177
+ Worker function to process a single set of images.
178
+ Opens images, extracts bounding boxes, crops and saves the cut images.
179
+ """
180
+ line_path, flat_path, seg_path, color_path, output_str = task
181
+ output_dir = Path(output_str)
182
+ try:
183
+ line_img = Image.open(line_path).convert("RGB")
184
+ flat_img = Image.open(flat_path).convert("RGB")
185
+ seg_img = Image.open(seg_path).convert("RGB")
186
+ color_img = Image.open(color_path).convert("RGB")
187
+ except Exception as e:
188
+ logger.error(f"Error opening images for {line_path}: {e}")
189
+ return
190
+
191
+ bounding_boxes = extract_cutbox_coordinates(line_img)
192
+ fname = line_path.stem
193
+ match = re.compile(r"\d+_\d+").match(fname)
194
+ if not match:
195
+ logger.warning(f"Filename pattern not matched for {line_path.name}")
196
+ return
197
+ ep_page_str = match.group(0)
198
+
199
+ for i, (left, top, right, bottom) in enumerate(bounding_boxes):
200
+ try:
201
+ cut_line = line_img.crop((left, top, right, bottom))
202
+ cut_flat = flat_img.crop((left, top, right, bottom))
203
+ cut_seg = seg_img.crop((left, top, right, bottom))
204
+ cut_color = color_img.crop((left, top, right, bottom))
205
+
206
+ cut_line.save(output_dir / "line" / f"{ep_page_str}_{i}_line.png")
207
+ cut_flat.save(output_dir / "flat" / f"{ep_page_str}_{i}_flat.png")
208
+ cut_seg.save(output_dir / "segmentation" / f"{ep_page_str}_{i}_segmentation.png")
209
+ cut_color.save(output_dir / "fullcolor" / f"{ep_page_str}_{i}_fullcolor.png")
210
+ except Exception as e:
211
+ logger.error(f"Error processing crop for {line_path.name} at box {i}: {e}")
212
+
213
+
214
+ def main(args):
215
+ # Prepare output directory
216
+ output_dir = Path(args.output)
217
+ if not output_dir.exists():
218
+ output_dir.mkdir(parents=True, exist_ok=True)
219
+ (output_dir / "line").mkdir(parents=True, exist_ok=True)
220
+ (output_dir / "flat").mkdir(parents=True, exist_ok=True)
221
+ (output_dir / "segmentation").mkdir(parents=True, exist_ok=True)
222
+ (output_dir / "fullcolor").mkdir(parents=True, exist_ok=True)
223
+
224
+ # Prepare input images
225
+ lineart_list = get_image_list(args.lineart, ["png", "jpg", "jpeg"])
226
+ flat_list = get_image_list(args.flat, ["png", "jpg", "jpeg"])
227
+ segmentation_list = get_image_list(args.segmentation, ["png", "jpg", "jpeg"])
228
+ color_list = get_image_list(args.color, ["png", "jpg", "jpeg"])
229
+
230
+ # Check image pair validity
231
+ lineart_list, flat_list, segmentation_list, color_list = check_image_pair_validity(
232
+ lineart_list, flat_list, segmentation_list, color_list
233
+ )
234
+
235
+ # Prepare tasks for multiprocessing
236
+ tasks = []
237
+ for l, f, s, c in zip(lineart_list, flat_list, segmentation_list, color_list):
238
+ tasks.append((l, f, s, c, str(output_dir)))
239
+
240
+ # Use multiprocessing to process images in parallel
241
+ num_processes = args.num_process if args.num_process else multiprocessing.cpu_count() // 2
242
+ with multiprocessing.Pool(processes=num_processes) as pool:
243
+ list(tqdm(pool.imap_unordered(process_single_image, tasks), total=len(tasks), desc="Processing images"))
244
+
245
+
246
+ if __name__ == "__main__":
247
+ args = parse_args()
248
+ main(args)
scripts/run_tag_filter.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path, PurePath
7
+ from typing import List
8
+
9
+ from PIL import Image
10
+
11
+ current_file_path = Path(__file__).resolve()
12
+ sys.path.insert(0, str(current_file_path.parent.parent))
13
+
14
+ from src.detectors import AnimeDetector
15
+ from src.pipelines import TagAndFilteringPipeline
16
+ from src.taggers import WaifuDiffusionTagger
17
+ from src.utils.device import determine_accelerator
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logging.basicConfig(level=logging.INFO)
21
+
22
+
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Filtering script")
25
+
26
+ parser.add_argument(
27
+ "--input_dir",
28
+ "-i",
29
+ type=str,
30
+ required=True,
31
+ help="Directory containing images to filter",
32
+ )
33
+ parser.add_argument(
34
+ "--output_dir",
35
+ "-o",
36
+ type=str,
37
+ required=True,
38
+ help="Directory to save filtered images",
39
+ )
40
+ parser.add_argument(
41
+ "--ext",
42
+ "-e",
43
+ type=str,
44
+ nargs="+",
45
+ default=["png"],
46
+ help="File extension of images to filter (default: png)",
47
+ )
48
+
49
+ return parser.parse_args()
50
+
51
+
52
+ def get_image_list(input_dir: str, ext: List[str]):
53
+ """
54
+ Get a list of images from the input directory with the specified extensions.
55
+ Args:
56
+ input_dir (str): Directory containing images to filter.
57
+ ext (list): List of file extensions to filter by.
58
+ Returns:
59
+ list: List of image file paths.
60
+ """
61
+ image_list = []
62
+ for ext in ext:
63
+ image_list.extend(Path(input_dir).glob(f"*.{ext}"))
64
+ return image_list
65
+
66
+
67
+ def write_image_caption_file(image_list: List[PurePath], captions: List[str], output_dir: str):
68
+ """
69
+ Writes a caption file
70
+
71
+ This function generates a text file named "captions.txt" in the specified output directory.
72
+ Each line in the file contains the image name (without extension) followed by its caption,
73
+ separated by a colon.
74
+
75
+ Args:
76
+ image_list (List[PurePath]): A list of image file paths. Each path should be a PurePath object.
77
+ captions (List[str]): A list of captions corresponding to the images in `image_list`.
78
+ output_dir (str): The directory where the "captions.txt" file will be created.
79
+
80
+ Example:
81
+ image_list = [PurePath("image1.jpg"), PurePath("image2.jpg")]
82
+ captions = ["A beautiful sunset.", "A serene mountain view."]
83
+ output_dir = "/path/to/output"
84
+ write_image_caption_file(image_list, captions, output_dir)
85
+ """
86
+ caption_file = Path(output_dir) / "captions.txt"
87
+ lines = []
88
+ for img_path, caption in zip(image_list, captions):
89
+ img_name = img_path.stem
90
+ line = f"{img_name}: {caption}\n"
91
+ lines.append(line)
92
+
93
+ with open(caption_file, "w") as f:
94
+ f.writelines(lines)
95
+
96
+
97
+ def main(args):
98
+ os.makedirs(args.output_dir, exist_ok=True)
99
+
100
+ # 1. Initialize the filtering pipeline
101
+ device = determine_accelerator()
102
+ logger.info(f"Using device: {device}")
103
+
104
+ logger.info("Initializing filtering pipeline...")
105
+ detector = AnimeDetector(
106
+ repo_id="deepghs/anime_face_detection",
107
+ model_name="face_detect_v1.4_s",
108
+ hf_token=None,
109
+ )
110
+
111
+ tagger = WaifuDiffusionTagger(device=device)
112
+
113
+ filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)
114
+
115
+ # 2. Load images from the input directory
116
+ logger.info(f"Loading images from {args.input_dir}...")
117
+ image_list = get_image_list(args.input_dir, args.ext)
118
+ images = [Image.open(img_path).convert("RGB") for img_path in image_list]
119
+ logger.info(f"Found {len(images)} images.")
120
+
121
+ # 3. Filter images using the filtering pipeline
122
+ logger.info(f"Filtering images...")
123
+ filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
124
+
125
+ filter_flags = filter_output.filter_flags
126
+ tags = filter_output.tags
127
+ captions = [",".join(tag) for tag in tags]
128
+ logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
129
+
130
+ # 4. Save filtered images and captions
131
+ write_image_caption_file(image_list, captions, args.input_dir) # Write captions to input_dir
132
+
133
+ logger.info(f"Copying filtered images to {args.output_dir}...")
134
+ filtered_images = [img for img, flag in zip(image_list, filter_flags) if flag]
135
+ filtered_captions = [caption for caption, flag in zip(captions, filter_flags) if flag]
136
+
137
+ for img_path in filtered_images:
138
+ img_name = img_path.stem
139
+ output_path = Path(args.output_dir) / f"{img_name}.png"
140
+ shutil.copy(img_path, output_path)
141
+
142
+ write_image_caption_file(filtered_images, filtered_captions, args.output_dir)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ args = parse_args()
147
+ main(args)
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (200 Bytes). View file
 
src/detectors/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .imgutils_detector import AnimeDetector
2
+
3
+ __all__ = ["AnimeDetector"]
src/detectors/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (296 Bytes). View file
 
src/detectors/__pycache__/imgutils_detector.cpython-312.pyc ADDED
Binary file (7.3 kB). View file
 
src/detectors/imgutils_detector.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Generic, Optional, TypeVar
3
+
4
+ import cv2
5
+ import imgutils
6
+ import numpy as np
7
+ from imgutils.generic.yolo import (
8
+ _image_preprocess,
9
+ _rtdetr_postprocess,
10
+ _yolo_postprocess,
11
+ rgb_encode,
12
+ )
13
+ from PIL import Image, ImageDraw
14
+
15
+ T = TypeVar("T", int, float)
16
+
17
+ REPO_IDS = {
18
+ "head": "deepghs/anime_head_detection",
19
+ "face": "deepghs/anime_face_detection",
20
+ "eye": "deepghs/anime_eye_detection",
21
+ }
22
+
23
+
24
+ @dataclass
25
+ class DetectorOutput(Generic[T]):
26
+ bboxes: list[list[T]] = field(default_factory=list)
27
+ masks: list[Image.Image] = field(default_factory=list)
28
+ confidences: list[float] = field(default_factory=list)
29
+ previews: Optional[Image.Image] = None
30
+
31
+
32
+ class AnimeDetector:
33
+ """
34
+ A class used to perform object detection on anime images.
35
+ Please refer to the `imgutils` documentation for more information on the available models.
36
+ """
37
+
38
+ def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
39
+ model_manager = imgutils.generic.yolo._open_models_for_repo_id(
40
+ repo_id, hf_token=hf_token
41
+ )
42
+ model, max_infer_size, labels = model_manager._open_model(model_name)
43
+
44
+ self.model = model
45
+
46
+ self.max_infer_size = max_infer_size
47
+ self.labels = labels
48
+ self.model_type = model_manager._get_model_type(model_name)
49
+
50
+ def __call__(
51
+ self,
52
+ image: Image.Image,
53
+ conf_threshold: float = 0.3,
54
+ iou_threshold: float = 0.7,
55
+ allow_dynamic: bool = False,
56
+ ) -> DetectorOutput[float]:
57
+ """
58
+ Perform object detection on the given image.
59
+
60
+ Args:
61
+ image (Image.Image): The input image on which to perform detection.
62
+ conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3.
63
+ iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7.
64
+ allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False.
65
+
66
+ Returns:
67
+ DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image.
68
+
69
+ Raises:
70
+ ValueError: If the model type is unknown.
71
+ """
72
+ # Preprocessing
73
+ new_image, old_size, new_size = _image_preprocess(
74
+ image, self.max_infer_size, allow_dynamic=allow_dynamic
75
+ )
76
+ data = rgb_encode(new_image)[None, ...]
77
+
78
+ # Start detection
79
+ (output,) = self.model.run(["output0"], {"images": data})
80
+
81
+ # Postprocessing
82
+ if self.model_type == "yolo":
83
+ output = _yolo_postprocess(
84
+ output=output[0],
85
+ conf_threshold=conf_threshold,
86
+ iou_threshold=iou_threshold,
87
+ old_size=old_size,
88
+ new_size=new_size,
89
+ labels=self.labels,
90
+ )
91
+ elif self.model_type == "rtdetr":
92
+ output = _rtdetr_postprocess(
93
+ output=output[0],
94
+ conf_threshold=conf_threshold,
95
+ iou_threshold=iou_threshold,
96
+ old_size=old_size,
97
+ new_size=new_size,
98
+ labels=self.labels,
99
+ )
100
+ else:
101
+ raise ValueError(
102
+ f"Unknown object detection model type - {self.model_type!r}."
103
+ ) # pragma: no cover
104
+
105
+ if len(output) == 0:
106
+ return DetectorOutput()
107
+
108
+ bboxes = [x[0] for x in output] # [x0, y0, x1, y1]
109
+ masks = create_mask_from_bbox(bboxes, image.size)
110
+ confidences = [x[2] for x in output]
111
+
112
+ # Create a preview image
113
+ previews = []
114
+ for mask in masks:
115
+ np_image = np.array(image)
116
+ np_mask = np.array(mask)
117
+ preview = cv2.bitwise_and(
118
+ np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR)
119
+ )
120
+ preview = Image.fromarray(preview)
121
+ previews.append(preview)
122
+
123
+ return DetectorOutput(
124
+ bboxes=bboxes, masks=masks, confidences=confidences, previews=previews
125
+ )
126
+
127
+
128
+ def create_mask_from_bbox(
129
+ bboxes: list[list[float]], shape: tuple[int, int]
130
+ ) -> list[Image.Image]:
131
+ """
132
+ Creates a list of binary masks from bounding boxes.
133
+
134
+ Args:
135
+ bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented
136
+ by a list of four float values [x_min, y_min, x_max, y_max].
137
+ shape (tuple[int, int]): The shape of the mask (height, width).
138
+
139
+ Returns:
140
+ list[Image.Image]: A list of PIL Image objects representing the binary masks.
141
+ """
142
+ masks = []
143
+ for bbox in bboxes:
144
+ mask = Image.new("L", shape, 0)
145
+ mask_draw = ImageDraw.Draw(mask)
146
+ mask_draw.rectangle(bbox, fill=255)
147
+ masks.append(mask)
148
+ return masks
149
+
150
+
151
+ def create_bbox_from_mask(
152
+ masks: list[Image.Image], shape: tuple[int, int]
153
+ ) -> list[list[int]]:
154
+ """
155
+ Create bounding boxes from a list of mask images.
156
+
157
+ Args:
158
+ masks (list[Image.Image]): A list of PIL Image objects representing the masks.
159
+ shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks.
160
+
161
+ Returns:
162
+ list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower].
163
+ """
164
+ bboxes = []
165
+ for mask in masks:
166
+ mask = mask.resize(shape)
167
+ bbox = mask.getbbox()
168
+ if bbox is not None:
169
+ bboxes.append(list(bbox))
170
+ return bboxes
src/oskar_crop/__pycache__/detect_and_crop.cpython-312.pyc ADDED
Binary file (3.37 kB). View file
 
src/oskar_crop/detect_and_crop.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+ import logging
5
+ from PIL import Image
6
+ from torchvision.transforms.v2 import ToPILImage
7
+ from scripts.parse_cut_from_page import extract_cutbox_coordinates
8
+
9
+ from src.detectors import AnimeDetector
10
+ from src.pipelines import TagAndFilteringPipeline
11
+ from src.taggers import WaifuDiffusionTagger
12
+ from src.utils.device import determine_accelerator
13
+
14
+ topil = ToPILImage()
15
+ logger = logging.getLogger(__name__)
16
+ logging.basicConfig(level=logging.INFO)
17
+
18
+ # 1. Initialize the filtering pipeline
19
+ device = determine_accelerator()
20
+ logger.info(f"Using device: {device}")
21
+
22
+ logger.info("Initializing filtering pipeline...")
23
+ detector = AnimeDetector(
24
+ repo_id="deepghs/anime_face_detection",
25
+ model_name="face_detect_v1.4_s",
26
+ hf_token=None,
27
+ )
28
+
29
+ tagger = WaifuDiffusionTagger(device=device)
30
+
31
+ filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)
32
+
33
+
34
+ def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]:
35
+ """
36
+ Worker function to process a single set of images.
37
+ Opens images, extracts bounding boxes, crops and filters the cut images.
38
+ """
39
+ try:
40
+ line_img = lineart_pil_img.convert("RGB")
41
+ except Exception as e:
42
+ logger.error(f"Error loading images for {lineart_pil_img}: {e}")
43
+ return
44
+
45
+ bounding_boxes = extract_cutbox_coordinates(line_img)
46
+
47
+ images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes]
48
+ # 3. Filter images using the filtering pipeline
49
+ logger.info(f"Filtering images...")
50
+ filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
51
+ filter_flags = filter_output.filter_flags
52
+ logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
53
+ filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag]
54
+ index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)]
55
+
56
+ return index_added_filtered_bboxes
src/pipelines/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline_single_character_filtering import TagAndFilteringPipeline
2
+
3
+ __all__ = ["TagAndFilteringPipeline"]
src/pipelines/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (324 Bytes). View file
 
src/pipelines/__pycache__/pipeline_single_character_filtering.cpython-312.pyc ADDED
Binary file (9.88 kB). View file
 
src/pipelines/pipeline_single_character_filtering.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+
8
+ from ..detectors.imgutils_detector import AnimeDetector
9
+ from ..taggers import WaifuDiffusionTagger, sort_tags
10
+ from ..utils.timer import ElapsedTimer
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class TagAndFilteringOutput:
17
+ filter_flags: List[bool]
18
+ tags: List[str]
19
+
20
+
21
+ class TagAndFilteringPipeline:
22
+ """
23
+ TagAndFilteringPipeline is a pipeline for processing images by tagging them and filtering based on tags and face detection.
24
+
25
+ Attributes:
26
+ tagger (WaifuDiffusionTagger): An instance of the WaifuDiffusionTagger used for tagging images.
27
+ detector (AnimeDetector): An instance of the AnimeDetector used for detecting faces in images.
28
+ """
29
+
30
+ def __init__(self, tagger: WaifuDiffusionTagger, detector: AnimeDetector):
31
+ self.tagger = tagger
32
+ self.detector = detector
33
+
34
+ def __call__(self, images: List[Image.Image], *args, **kwargs) -> TagAndFilteringOutput:
35
+ """
36
+ Processes a list of images by tagging and filtering them based on tags and face detection.
37
+ Args:
38
+ images (List[Image.Image]): A list of images to process.
39
+ batch_size (int, optional): The batch size for processing images. Default is 32.
40
+ tag_threshold (float, optional): The threshold for tag confidence. Default is 0.3.
41
+ sort_mode (str, optional): The mode for sorting tags. Default is "score".
42
+ include_tags (List[str], optional): Tags to include during filtering. Default is ["solo"].
43
+ exclude_tags (List[str], optional): Tags to exclude during filtering. Default is ["head_out_of_frame", "out_of_frame"].
44
+ conf_threshold (float, optional): Confidence threshold for face detection. Default is 0.3.
45
+ iou_threshold (float, optional): IOU threshold for face detection. Default is 0.7.
46
+ filter_by_tags (bool, optional): Whether to filter images based on tags. Default is True.
47
+ filter_by_faces (bool, optional): Whether to filter images based on face detection. Default is True.
48
+ Returns:
49
+ FilterOutput: An object containing filter flags and captions for the processed images.
50
+ """
51
+ if not isinstance(images, list):
52
+ images = [images]
53
+
54
+ # Tagging parameters
55
+ batch_size = kwargs.pop("batch_size", 32)
56
+ tag_threshold = kwargs.pop("tag_threshold", 0.3)
57
+ tag_sort_mode = kwargs.pop("sort_mode", "score")
58
+ include_tags = kwargs.pop("include_tags", ["solo"])
59
+ exclude_tags = kwargs.pop("exclude_tags", ["head_out_of_frame", "out_of_frame", "chibi", "negative_space"])
60
+ # Face detection parameters
61
+ conf_threshold = kwargs.pop("conf_threshold", 0.3)
62
+ iou_threshold = kwargs.pop("iou_threshold", 0.7)
63
+ # Etc.
64
+ minimum_resolution = kwargs.pop("minimum_resolution", 512)
65
+
66
+ filter_flags = [True] * len(images)
67
+
68
+ if kwargs.pop("fiter_by_resolution", True):
69
+ with ElapsedTimer("Resolution-based Filtering", logger=logger):
70
+ for idx, image in enumerate(images):
71
+ if min(image.size) < minimum_resolution:
72
+ filter_flags[idx] = False
73
+
74
+ logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)} based on resolution. ({minimum_resolution}px)")
75
+
76
+ with ElapsedTimer("Tagging", logger=logger):
77
+ tags = self.tagging(images, threshold=tag_threshold, sort_mode=tag_sort_mode, batch_size=batch_size)
78
+
79
+ if kwargs.pop("filter_by_tags", True):
80
+ with ElapsedTimer("Tag-based Filtering", logger=logger):
81
+ filter_flags = self.tag_based_filtering(
82
+ tags,
83
+ filter_flags,
84
+ include_tags=include_tags,
85
+ exclude_tags=exclude_tags,
86
+ )
87
+
88
+ if kwargs.pop("filter_by_faces", True):
89
+ with ElapsedTimer("Face-based Filtering", logger=logger):
90
+ filter_flags = self.face_based_filtering(
91
+ images,
92
+ filter_flags=filter_flags,
93
+ conf_threshold=conf_threshold,
94
+ iou_threshold=iou_threshold,
95
+ )
96
+
97
+ return TagAndFilteringOutput(filter_flags=filter_flags, tags=tags)
98
+
99
+ def tagging(
100
+ self,
101
+ images,
102
+ threshold: float = 0.3,
103
+ sort_mode: str = "score",
104
+ batch_size: int = 32,
105
+ ) -> List[List[str]]:
106
+ """
107
+ Tags a list of images and returns their captions.
108
+ Parameters:
109
+ images (List[Image.Image]): A list of images to tag.
110
+ threshold (float, optional): The threshold for tag confidence. Default is 0.3.
111
+ sort_mode (str, optional): The mode for sorting tags. Default is "score".
112
+ batch_size (int, optional): The batch size for tagging images. Default is 32.
113
+ Returns:
114
+ List[str]: A list of captions for the tagged images.
115
+ """
116
+ tags = []
117
+ for i in tqdm(range(0, len(images), batch_size), desc="Tagging"):
118
+ batch = images[i : i + batch_size]
119
+ tagger_output = self.tagger(batch, threshold=threshold)
120
+
121
+ tags.extend([sort_tags(tags, mode=sort_mode) for tags in tagger_output])
122
+
123
+ return tags
124
+
125
+ def tag_based_filtering(
126
+ self,
127
+ tags: List[List[str]],
128
+ filter_flags: List[bool],
129
+ include_tags=["solo"],
130
+ exclude_tags=["head_out_of_frame", "out_of_frame"],
131
+ ) -> List[bool]:
132
+ """
133
+ Filters images based on their tags.
134
+ Parameters:
135
+ tags (List[List[str]]): A list of tags for the images.
136
+ filter_flags (List[bool]): A list of boolean flags indicating whether each image passes filtering.
137
+ include_tags (List[str], optional): Tags to include during filtering. Default is ["solo"].
138
+ exclude_tags (List[str], optional): Tags to exclude during filtering. Default is ["head_out_of_frame", "out_of_frame"].
139
+ Returns:
140
+ Tuple[List[bool], List[str]]: Updated filter flags and captions after tag-based filtering.
141
+ """
142
+ for idx, tag in tqdm(enumerate(tags), desc="Tag-based Filtering", total=len(tags)):
143
+ if any(include_tag in tag for include_tag in include_tags) and all(
144
+ exclude_tag not in tag for exclude_tag in exclude_tags
145
+ ):
146
+ filter_flags[idx] = True
147
+ else:
148
+ filter_flags[idx] = False
149
+
150
+ return filter_flags
151
+
152
+ def face_based_filtering(
153
+ self, images: List[Image.Image], filter_flags: List[bool], conf_threshold: float = 0.3, iou_threshold=0.7
154
+ ) -> List[bool]:
155
+ """
156
+ Filters images based on face detection.
157
+ Parameters:
158
+ images (List[Image.Image]): A list of images to filter.
159
+ filter_flags (List[bool]): A list of boolean flags indicating whether each image passes filtering.
160
+ conf_threshold (float, optional): Confidence threshold for face detection. Default is 0.3.
161
+ iou_threshold (float, optional): IOU threshold for face detection. Default is 0.7.
162
+ Returns:
163
+ List[bool]: Updated filter flags after face-based filtering.
164
+ """
165
+ for idx, image in tqdm(enumerate(images), desc="Face-based Filtering", total=len(images)):
166
+ if not filter_flags[idx]:
167
+ continue
168
+
169
+ detector_output = self.detector(image, conf_threshold=conf_threshold, iou_threshold=iou_threshold)
170
+ if len(detector_output.bboxes) != 1:
171
+ filter_flags[idx] = False
172
+ else:
173
+ filter_flags[idx] = True
174
+
175
+ return filter_flags
src/taggers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .order import sort_tags
2
+ from .tagger import WaifuDiffusionTagger
3
+
4
+ __all__ = ["WaifuDiffusionTagger", "sort_tags"]
src/taggers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (339 Bytes). View file
 
src/taggers/__pycache__/order.cpython-312.pyc ADDED
Binary file (3.78 kB). View file
 
src/taggers/__pycache__/tagger.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
src/taggers/filter.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple
2
+
3
+ TAG_GROUP = {
4
+ "frame": [
5
+ "portrait",
6
+ "upper_body",
7
+ "lower_body",
8
+ "cowboy_shot",
9
+ "feet_out_of_frame",
10
+ "full_body",
11
+ "wide_shot",
12
+ "very_wide_shot",
13
+ ],
14
+ "frame_2": [
15
+ "close-up",
16
+ "cut-in",
17
+ "cropped",
18
+ ],
19
+ "view_angle": [
20
+ "dutch_angle",
21
+ "from_above",
22
+ "from_behind",
23
+ "from_below",
24
+ "from_side",
25
+ # "multiple_views",
26
+ "sideways",
27
+ "straight-on",
28
+ "three_quarter_view",
29
+ "upside-down",
30
+ ],
31
+ "focus": ["eye_focus"],
32
+ "lip_action": [
33
+ "parted_lips",
34
+ "biting_own_lip",
35
+ "pursed_lips",
36
+ "spread_lips",
37
+ "open_mouth",
38
+ "closed_mouth",
39
+ ],
40
+ "eye": ["closed_eyes", "one_eye_closed"],
41
+ "gaze": [
42
+ "eye_contact",
43
+ "looking_afar",
44
+ "looking_around",
45
+ "looking_at_another",
46
+ "looking_at_hand",
47
+ "looking_at_hands",
48
+ "looking_at_mirror",
49
+ "looking_at_self",
50
+ "looking_at_viewer",
51
+ "looking_away",
52
+ "looking_back",
53
+ "looking_down",
54
+ "looking_outside",
55
+ "looking_over_eyewear",
56
+ "looking_through_own_legs",
57
+ "looking_to_the_side",
58
+ "looking_up",
59
+ ],
60
+ "emotion": [
61
+ "smile",
62
+ "angry",
63
+ "anger_vein",
64
+ "annoyed",
65
+ "clenched_teeth",
66
+ "scowl",
67
+ "blush",
68
+ "embarrassed",
69
+ "bored",
70
+ "confused",
71
+ "crazy",
72
+ "despair",
73
+ "disappointed",
74
+ "disgust",
75
+ "envy",
76
+ "excited",
77
+ "exhausted",
78
+ "expressioinless",
79
+ "furrowed_brow",
80
+ "happy",
81
+ "sad",
82
+ "depressed",
83
+ "frown",
84
+ "tears",
85
+ "scared",
86
+ "serious",
87
+ "sleepy",
88
+ "surprised",
89
+ "thinking",
90
+ "pain",
91
+ ],
92
+ }
93
+
94
+
95
+ def parse_valid_tags(
96
+ input_tags: Dict[str, float], valid_tags=TAG_GROUP
97
+ ) -> Dict[str, Tuple[str, float]]:
98
+ """
99
+ Parses valid tags from the input tags based on predefined tag groups.
100
+ Args:
101
+ input_tags (Dict[str, float]): A dictionary of tags with their confidence scores, sorted by confidence in descending order.
102
+ valid_tags (dict, optional): A dictionary where keys are tag groups and values are lists of valid tags. Defaults to TAG_GROUP.
103
+ Returns:
104
+ dict: A dictionary where keys are tag groups and values are the first valid tag found in the input tags for each group.
105
+ """
106
+ output_tags = {}
107
+ for tag_group, tags in valid_tags.items():
108
+ for tag in tags:
109
+ if tag.replace(" ", "_") in input_tags:
110
+ output_tags[tag_group] = (tag, input_tags[tag])
111
+ break # parse only one tag from each tag group and return the tag group and tag
112
+
113
+ return output_tags
src/taggers/order.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted by https://github.com/deepghs/imgutils/blob/main/imgutils/tagging/order.py
2
+ import random
3
+ import re
4
+ from typing import List, Literal, Mapping, Union
5
+
6
+
7
+ def sort_tags(
8
+ tags: Union[List[str], Mapping[str, float]], mode: Literal["original", "shuffle", "score"] = "score"
9
+ ) -> List[str]:
10
+ """
11
+ Sort the input list or mapping of tags by specified mode.
12
+
13
+ Tags can represent people counts (e.g., '1girl', '2boys'), and 'solo' tags.
14
+
15
+ :param tags: List or mapping of tags to be sorted.
16
+ :type tags: Union[List[str], Mapping[str, float]]
17
+ :param mode: The mode for sorting the tags. Options: 'original' (original order),
18
+ 'shuffle' (random shuffle), 'score' (sorted by score if available).
19
+ :type mode: Literal['original', 'shuffle', 'score']
20
+ :return: Sorted list of tags based on the specified mode.
21
+ :rtype: List[str]
22
+ :raises ValueError: If an unknown sort mode is provided.
23
+ :raises TypeError: If the input tags are of unsupported type or if mode is 'score'
24
+ and the input is a list (as it does not have scores).
25
+
26
+ Examples:
27
+ Sorting tags in original order:
28
+
29
+ >>> from imgutils.tagging import sort_tags
30
+ >>>
31
+ >>> tags = ['1girls', 'solo', 'red_hair', 'cat ears']
32
+ >>> sort_tags(tags, mode='original')
33
+ ['solo', '1girls', 'red_hair', 'cat ears']
34
+ >>>
35
+ >>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
36
+ >>> sort_tags(tags, mode='original')
37
+ ['solo', '1girls', 'red_hair', 'cat_ears']
38
+
39
+ Sorting tags by score (for a mapping of tags with scores):
40
+
41
+ >>> from imgutils.tagging import sort_tags
42
+ >>>
43
+ >>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
44
+ >>> sort_tags(tags)
45
+ ['solo', '1girls', 'red_hair', 'cat_ears']
46
+
47
+ Shuffling tags (output is not unique)
48
+
49
+ >>> from imgutils.tagging import sort_tags
50
+ >>>
51
+ >>> tags = ['1girls', 'solo', 'red_hair', 'cat ears']
52
+ >>> sort_tags(tags, mode='shuffle')
53
+ ['solo', '1girls', 'red_hair', 'cat ears']
54
+ >>>
55
+ >>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
56
+ >>> sort_tags(tags, mode='shuffle')
57
+ ['solo', '1girls', 'cat_ears', 'red_hair']
58
+ """
59
+ if mode not in {"original", "shuffle", "score"}:
60
+ raise ValueError(f"Unknown sort_mode, 'original', " f"'shuffle' or 'score' expected but {mode!r} found.")
61
+ npeople_tags = []
62
+ remaining_tags = []
63
+
64
+ if "solo" in tags:
65
+ npeople_tags.append("solo")
66
+
67
+ for tag in tags:
68
+ if tag == "solo":
69
+ continue
70
+ if re.fullmatch(r"^\d+\+?(boy|girl)s?$", tag): # 1girl, 1boy, 2girls, 3boys, 9+girls
71
+ npeople_tags.append(tag)
72
+ else:
73
+ remaining_tags.append(tag)
74
+
75
+ if mode == "score":
76
+ if isinstance(tags, dict):
77
+ remaining_tags = sorted(remaining_tags, key=lambda x: -tags[x])
78
+ else:
79
+ raise TypeError(f"Sort mode {mode!r} not supported for list, " f"for it do not have scores.")
80
+ elif mode == "shuffle":
81
+ random.shuffle(remaining_tags)
82
+ else:
83
+ pass
84
+
85
+ return npeople_tags + remaining_tags
src/taggers/tagger.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Dict, List, Literal, Optional, Union
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import timm
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from huggingface_hub import hf_hub_download
12
+ from huggingface_hub.utils import HfHubHTTPError
13
+ from PIL import Image
14
+ from timm.data import create_transform, resolve_data_config
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ MODEL_REPO_MAP = {
19
+ "vit": "SmilingWolf/wd-vit-tagger-v3",
20
+ "swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
21
+ "convnext": "SmilingWolf/wd-convnext-tagger-v3",
22
+ "eva-02": "SmilingWolf/wd-eva02-large-tagger-v3",
23
+ }
24
+
25
+
26
+ @dataclass
27
+ class LabelData:
28
+ names: List[str]
29
+ rating: List[np.int64]
30
+ general: List[np.int64]
31
+ character: List[np.int64]
32
+
33
+
34
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
35
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
36
+ if image.mode not in ["RGB", "RGBA"]:
37
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
38
+ # convert RGBA to RGB with white background
39
+ if image.mode == "RGBA":
40
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
41
+ canvas.alpha_composite(image)
42
+ image = canvas.convert("RGB")
43
+ return image
44
+
45
+
46
+ def pil_pad_square(image: Image.Image) -> Image.Image:
47
+ w, h = image.size
48
+ # get the largest dimension so we can pad to a square
49
+ px = max(image.size)
50
+ # pad to square with white background
51
+ canvas = Image.new("RGB", (px, px), (255, 255, 255))
52
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
53
+ return canvas
54
+
55
+
56
+ class WaifuDiffusionTagger:
57
+ def __init__(
58
+ self,
59
+ model_name: Literal["vit", "swinv2", "convnext", "eva-02"] = "eva-02",
60
+ device: str = "cpu",
61
+ ):
62
+ if model_name not in MODEL_REPO_MAP.keys():
63
+ raise ValueError(f"Model {model_name} not found. Available models: {MODEL_REPO_MAP.keys()}")
64
+
65
+ repo_id = MODEL_REPO_MAP[model_name]
66
+
67
+ self.init_model(repo_id, device)
68
+ self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
69
+
70
+ self.labels = self.load_labels_from_hf(repo_id)
71
+
72
+ def init_model(self, repo_id: str, device: str = "cpu"):
73
+ logger.info(f"Loading taggingmodel from {repo_id}")
74
+ self.model = timm.create_model("hf-hub:" + repo_id, pretrained=True)
75
+
76
+ state_dict = timm.models.load_state_dict_from_hf(repo_id)
77
+ self.model.load_state_dict(state_dict)
78
+ self.model.to(device).eval()
79
+
80
+ def load_labels_from_hf(self, repo_id: str, revision: Optional[str] = None, token: Optional[str] = None):
81
+ try:
82
+ csv_path = hf_hub_download(repo_id, filename="selected_tags.csv", revision=revision, token=token)
83
+ csv_path = Path(csv_path).resolve()
84
+ except HfHubHTTPError as e:
85
+ raise FileNotFoundError(f"Failed to download labels from {repo_id}") from e
86
+
87
+ df = pd.read_csv(csv_path, usecols=["name", "category"])
88
+ tag_data = LabelData(
89
+ names=df["name"],
90
+ rating=np.where(df["category"] == 9)[0],
91
+ general=np.where(df["category"] == 0)[0],
92
+ character=np.where(df["category"] == 4)[0],
93
+ )
94
+ return tag_data
95
+
96
+ def prepare_inputs(self, images: List[Image.Image]):
97
+ inputs = []
98
+ for image in images:
99
+ image = pil_ensure_rgb(image)
100
+ image = pil_pad_square(image)
101
+ inputs += [self.transform(image)]
102
+
103
+ inputs = torch.stack(inputs, dim=0)
104
+ inputs = inputs[:, [2, 1, 0]] # RGB to BGR
105
+
106
+ return inputs.to(self.device, dtype=self.dtype)
107
+
108
+ def get_tags(self, probs: torch.Tensor, gen_threshold: float) -> List[Dict[str, float]]:
109
+ """
110
+ Generate tags based on prediction probabilities and a confidence threshold.
111
+
112
+ Args:
113
+ probs (torch.Tensor): A tensor of shape [B, num_labels] containing
114
+ prediction probabilities for each label, where B is the batch size.
115
+ gen_threshold (float): The confidence threshold for selecting labels.
116
+ Only labels with probabilities greater than this threshold will be included.
117
+
118
+ Returns:
119
+ List[Dict[str, float]]: A list of dictionaries, where each dictionary
120
+ corresponds to a batch element and contains label names as keys and
121
+ their associated probabilities as values. The labels are sorted in
122
+ descending order of probability.
123
+ """
124
+ # probs: [B, num_labels]
125
+ gen_labels = []
126
+ for prob in probs:
127
+ # Convert indices+probs to labels
128
+ prob = list(zip(self.labels.names, prob.cpu().numpy()))
129
+
130
+ # General labels, pick any where prediction confidence > threshold
131
+ gen_label = [prob[i] for i in self.labels.general]
132
+ gen_label = dict([x for x in gen_label if x[1] > gen_threshold])
133
+ gen_label = dict(sorted(gen_label.items(), key=lambda item: item[1], reverse=True))
134
+
135
+ gen_labels += [gen_label]
136
+
137
+ return gen_labels
138
+
139
+ @torch.inference_mode()
140
+ def __call__(
141
+ self,
142
+ images: Union[Image.Image, List[Image.Image]],
143
+ threshold: float = 0.3,
144
+ ):
145
+ """
146
+ Processes input images through the model and returns predicted labels based on a threshold.
147
+
148
+ Args:
149
+ images (Union[Image.Image, List[Image.Image]]): A single image or a list of images to be processed.
150
+ threshold (float, optional): The threshold value for determining labels. Defaults to 0.3.
151
+
152
+ Returns:
153
+ List[List[str]]: A list of lists containing predicted labels for each input image.
154
+ """
155
+ if not isinstance(images, list):
156
+ images = [images]
157
+
158
+ inputs = self.prepare_inputs(images)
159
+
160
+ outputs = self.model(inputs)
161
+ outputs = F.sigmoid(outputs)
162
+
163
+ labels = self.get_tags(outputs, threshold)
164
+ return labels
165
+
166
+ @torch.inference_mode()
167
+ def get_image_features(self, images: Union[Image.Image, List[Image.Image]], global_pool: bool = True):
168
+ """
169
+ Extracts features from one or more images using the model.
170
+
171
+ Args:
172
+ images (Union[Image.Image, List[Image.Image]]): A single PIL Image or a list of PIL Images
173
+ from which features are to be extracted.
174
+ global_pool (bool, optional): If True, applies global pooling to the extracted features
175
+ by averaging across all spatial dimensions. If False, returns only the features
176
+ corresponding to the first token. Defaults to True.
177
+
178
+ Returns:
179
+ torch.Tensor: A tensor containing the extracted features. If `global_pool` is True,
180
+ the features are averaged across spatial dimensions. Otherwise, the features
181
+ corresponding to the first token are returned.
182
+ """
183
+ if not isinstance(images, list):
184
+ images = [images]
185
+
186
+ inputs = self.prepare_inputs(images)
187
+
188
+ features = self.model.forward_features(inputs)
189
+
190
+ if global_pool:
191
+ return features[:, self.model.num_prefix_tokens :].mean(dim=1)
192
+ else:
193
+ return features[:, 0]
194
+
195
+ @property
196
+ def device(self):
197
+ return next(self.model.parameters()).device
198
+
199
+ @property
200
+ def dtype(self):
201
+ return next(self.model.parameters()).dtype
202
+
203
+
204
+ def show_result_with_confidence(image, tag_result, ax):
205
+ ax.imshow(image)
206
+
207
+ confidence = [[x] for x in tag_result.values()]
208
+ rowLabels = list(tag_result[0].keys())
209
+ ax.table(
210
+ cellText=confidence,
211
+ loc="bottom",
212
+ rowLabels=rowLabels,
213
+ cellLoc="center",
214
+ colLabels=["Confidence"],
215
+ )
src/utils/__pycache__/device.cpython-312.pyc ADDED
Binary file (659 Bytes). View file
 
src/utils/__pycache__/timer.cpython-312.pyc ADDED
Binary file (2.46 kB). View file
 
src/utils/device.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def determine_accelerator():
5
+ """
6
+ Determine the accelerator to be used based on the environment.
7
+ """
8
+
9
+ # Check for CUDA availability
10
+ if torch.cuda.is_available():
11
+ return "cuda"
12
+
13
+ # Check for MPS (Metal Performance Shaders) availability on macOS
14
+ if torch.backends.mps.is_available():
15
+ return "mps"
16
+
17
+ # Default to CPU if no accelerators are available
18
+ return "cpu"
src/utils/timer.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+
4
+
5
+ def msec2human(ms) -> str:
6
+ """
7
+ Converts milliseconds to a human-readable string representation.
8
+
9
+ Args:
10
+ ms (int): The input number of milliseconds.
11
+
12
+ Returns:
13
+ str: The formatted string representing the milliseconds in a human-readable format.
14
+ """
15
+ s = ms // 1000 # Calculate the number of seconds
16
+ m = s // 60 # Calculate the number of minutes
17
+ h = m // 60 # Calculate the number of hours
18
+
19
+ m %= 60 # Get the remaining minutes after calculating hours
20
+ s %= 60 # Get the remaining seconds after calculating minutes
21
+ ms %= 1000 # Get the remaining milliseconds after calculating seconds
22
+
23
+ if h:
24
+ return (
25
+ f"{h} hour {m:2d} min" # Return the formatted string with hours and minutes
26
+ )
27
+ if m:
28
+ return f"{m} min {s:2d} sec" # Return the formatted string with minutes and seconds
29
+ if s:
30
+ return f"{s} sec {ms:3d} msec" # Return the formatted string with seconds and milliseconds
31
+ return f"{ms} msec" # Return the formatted string with milliseconds
32
+
33
+
34
+ class ElapsedTimer:
35
+ def __init__(self, name, logger=None, unit="ms"):
36
+ self.name = name
37
+ self.logger = logger or logging.getLogger(__name__)
38
+
39
+ def __enter__(self):
40
+ self.start_time = time.perf_counter()
41
+ self.logger.info(f"<{self.name}>: start")
42
+ return self
43
+
44
+ def __exit__(self, exc_type, exc_val, exc_tb):
45
+ elapsed_time = time.perf_counter() - self.start_time
46
+ elapsed_time = msec2human(int(elapsed_time * 1000))
47
+
48
+ if exc_type:
49
+ self.logger.warning(f"<{self.name}> raised {exc_type}, {elapsed_time}")
50
+ else:
51
+ self.logger.info(f"<{self.name}>: {elapsed_time}")
src/wise_crop/__pycache__/detect_and_crop.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
src/wise_crop/detect_and_crop.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from torchvision.transforms.v2 import ToPILImage
4
+ from PIL import Image
5
+ from transformers import pipeline
6
+ import torch
7
+ from imgutils.detect import detect_heads
8
+ from src.utils.device import determine_accelerator
9
+ topil = ToPILImage()
10
+
11
+ # 1. Initialize the filtering pipeline
12
+ device = determine_accelerator()
13
+
14
+ print("Loading AI Model...")
15
+ pipe = pipeline(
16
+ "image-text-to-text",
17
+ model="google/gemma-3-12b-it",
18
+ device=device,
19
+ torch_dtype=torch.bfloat16,
20
+ )
21
+
22
+ def crop_and_mask_characters_gradio(pil_img):
23
+ """
24
+ Crops character regions from an image, saves them as separate files,
25
+ and generates binary masks for each cropped region using the Gemini 2.0 Flash Exp model.
26
+
27
+ Args:
28
+ uploaded_file_obj (str): The path to the input image.
29
+ """
30
+ img = np.array(pil_img)
31
+
32
+ # Convert the image to grayscale
33
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
34
+
35
+ # Apply thresholding to create a binary image
36
+ _, thresh = cv2.threshold(gray, 253, 255, cv2.THRESH_BINARY_INV)
37
+
38
+ # Find contours in the binary image
39
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
40
+
41
+ # Create output directories if they don't exist
42
+ # cropped_dir = Path(temp_dir) / 'cropped_dir'
43
+ # masks_dir = cropped_dir
44
+
45
+ # os.makedirs(cropped_dir, exist_ok=True)
46
+ # os.makedirs(masks_dir, exist_ok=True)
47
+ coord_info_list = []
48
+ i = 0
49
+ # Iterate through the contours and crop the regions
50
+ for contour in contours:
51
+ # Get the bounding box of the contour
52
+ x, y, w, h = cv2.boundingRect(contour)
53
+ if w < 256 or h < 256: # Skip small contours
54
+ continue
55
+
56
+ # Crop the region
57
+ cropped_img = img[y:y+h, x:x+w]
58
+
59
+ messages = [
60
+ {
61
+ "role": "system",
62
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
63
+ },
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {"type": "image", "image": topil(cropped_img)},
68
+ {"type": "text", "text": "You are given a black-and-white line drawing as input. Please analyze the image carefully. If the drawing contains the majority of a head or face—meaning most key facial features or the overall shape of the head are visible—respond with 'True'. Otherwise, respond with 'False'. Do not contain any punctuation or extra spaces in your answer. Just respond with 'True' or 'False'"}
69
+ ]
70
+ }
71
+ ]
72
+ result = detect_heads(topil(cropped_img))
73
+ if len(result) == 0:
74
+ continue
75
+
76
+ output = pipe(text=messages, max_new_tokens=200)
77
+ if output[0]["generated_text"][-1]["content"] == 'False':
78
+ # print(f"Skipping character {i+1} as it does not contain a head or face.")
79
+ continue
80
+ i += 1
81
+ # Append the coordinates to the list
82
+ coord_info_list.append((i,(x,y,w,h)))
83
+ return coord_info_list
84
+