Spaces:
Running
Running
Commit
·
13aa528
1
Parent(s):
a56ad34
init commit
Browse files- LICENSE +61 -0
- README.md +72 -5
- app.py +459 -0
- requirements.txt +23 -0
- sample_img/sample_danbooru_dragonball.png +3 -0
- scripts/__init__.py +0 -0
- scripts/__pycache__/__init__.cpython-312.pyc +0 -0
- scripts/__pycache__/parse_cut_from_page.cpython-312.pyc +0 -0
- scripts/convert_psd_to_png.py +106 -0
- scripts/parse_cut_from_page.py +248 -0
- scripts/run_tag_filter.py +147 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/detectors/__init__.py +3 -0
- src/detectors/__pycache__/__init__.cpython-312.pyc +0 -0
- src/detectors/__pycache__/imgutils_detector.cpython-312.pyc +0 -0
- src/detectors/imgutils_detector.py +170 -0
- src/oskar_crop/__pycache__/detect_and_crop.cpython-312.pyc +0 -0
- src/oskar_crop/detect_and_crop.py +56 -0
- src/pipelines/__init__.py +3 -0
- src/pipelines/__pycache__/__init__.cpython-312.pyc +0 -0
- src/pipelines/__pycache__/pipeline_single_character_filtering.cpython-312.pyc +0 -0
- src/pipelines/pipeline_single_character_filtering.py +175 -0
- src/taggers/__init__.py +4 -0
- src/taggers/__pycache__/__init__.cpython-312.pyc +0 -0
- src/taggers/__pycache__/order.cpython-312.pyc +0 -0
- src/taggers/__pycache__/tagger.cpython-312.pyc +0 -0
- src/taggers/filter.py +113 -0
- src/taggers/order.py +85 -0
- src/taggers/tagger.py +215 -0
- src/utils/__pycache__/device.cpython-312.pyc +0 -0
- src/utils/__pycache__/timer.cpython-312.pyc +0 -0
- src/utils/device.py +18 -0
- src/utils/timer.py +51 -0
- src/wise_crop/__pycache__/detect_and_crop.cpython-312.pyc +0 -0
- src/wise_crop/detect_and_crop.py +84 -0
LICENSE
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This software project ("the Software") comprises original code and incorporates third-party components as detailed below. Each component is subject to its respective license terms.
|
2 |
+
|
3 |
+
1. Original Code
|
4 |
+
|
5 |
+
The original code within this repository, excluding third-party components, is licensed under the MIT License:
|
6 |
+
|
7 |
+
MIT License
|
8 |
+
|
9 |
+
Copyright (c) 2025 AI Lab, Kakao Entertainment
|
10 |
+
|
11 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
12 |
+
of this software and associated documentation files (the "Software"), to deal
|
13 |
+
in the Software without restriction, including without limitation the rights
|
14 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
15 |
+
copies of the Software, and to permit persons to whom the Software is
|
16 |
+
furnished to do so, subject to the following conditions:
|
17 |
+
|
18 |
+
The above copyright notice and this permission notice shall be included in all
|
19 |
+
copies or substantial portions of the Software.
|
20 |
+
|
21 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
22 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
23 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
24 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
25 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
26 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
27 |
+
SOFTWARE.
|
28 |
+
|
29 |
+
2. Third-Party Components
|
30 |
+
a. Google Gemma 3 - 12B Instruction-Tuned Model (google/gemma-3-12b-it)
|
31 |
+
|
32 |
+
This project utilizes the "google/gemma-3-12b-it" model, which is governed by Google's Gemma Terms of Use. Access to and use of this model require agreement to these terms, which include specific restrictions on distribution, modification, and usage. For detailed information, please refer to the Gemma Terms of Use.
|
33 |
+
|
34 |
+
Note: Ensure compliance with Google's terms when using, distributing, or modifying this model.
|
35 |
+
b. imgutils Python Package
|
36 |
+
|
37 |
+
The imgutils package is employed for image processing tasks within this project. As of the latest available information, imgutils does not explicitly specify a license. In the absence of a declared license, the usage rights are undefined, and caution is advised. It's recommended to contact the package maintainers or consult the source repository for clarification before using or distributing this package.
|
38 |
+
|
39 |
+
3. Additional Dependencies
|
40 |
+
|
41 |
+
This project also relies on several other open-source packages, each with its own licensing terms:
|
42 |
+
|
43 |
+
Gradio: Licensed under the Apache License 2.0.
|
44 |
+
|
45 |
+
Pillow (PIL): Licensed under the Historical Permission Notice and Disclaimer (HPND).
|
46 |
+
|
47 |
+
psd-tools: Licensed under the MIT License.
|
48 |
+
|
49 |
+
OpenCV: Licensed under the Apache License 2.0.
|
50 |
+
|
51 |
+
Please refer to each package's documentation for detailed license information.
|
52 |
+
|
53 |
+
4. Usage Guidelines
|
54 |
+
|
55 |
+
Users of this software must ensure compliance with all applicable licenses, especially when distributing or modifying the software. This includes adhering to the terms set forth by third-party components. Failure to comply with these terms may result in legal consequences.
|
56 |
+
|
57 |
+
5. Disclaimer
|
58 |
+
|
59 |
+
This software is provided "as is," without warranty of any kind, express or implied. The authors are not liable for any damages or legal issues arising from the use of this software. Users are responsible for ensuring that their use of the software complies with all applicable laws and regulations.
|
60 |
+
|
61 |
+
For any questions or concerns regarding this license, please contact `[email protected]`.
|
README.md
CHANGED
@@ -1,13 +1,80 @@
|
|
1 |
---
|
2 |
-
title: Webtoon Cropper
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
8 |
app_file: app.py
|
|
|
9 |
pinned: false
|
10 |
-
short_description: Webtoon Cropper
|
11 |
---
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Webtoon Lineart Cropper
|
3 |
+
emoji: 🤗
|
4 |
+
colorFrom: yellow
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.25.2
|
8 |
app_file: app.py
|
9 |
+
python_version: 3.12.10
|
10 |
pinned: false
|
|
|
11 |
---
|
12 |
+
# Helix-Painting Data Tool
|
13 |
|
14 |
+
웹툰 PSD 파일을 이미지 파일(PNG)로 변환하고, 컷 추출 및 인물 검출 등을 통해 데이터를 추출하는 도구입니다.
|
15 |
+
|
16 |
+
## Prerequisites
|
17 |
+
|
18 |
+
[Prerequisites](docs/PREREQUISITES.md)
|
19 |
+
|
20 |
+
## Setup a project
|
21 |
+
|
22 |
+
[Setup a project](docs/SETUP.md)
|
23 |
+
|
24 |
+
## Features
|
25 |
+
|
26 |
+
- **PSD → PNG 변환**
|
27 |
+
웹툰 PSD 파일을 고해상도 PNG 이미지로 변환합니다.
|
28 |
+
- **컷 추출 및 필터링**
|
29 |
+
이미지 내 흰색 영역을 기준으로 컷 박스를 추출하고, 인물 검출(face detector)과 태깅(tagger) 기능을 통해 컷 별 데이터를 생성합니다.
|
30 |
+
- **병렬 처리**
|
31 |
+
처리 속도를 높이기 위해 Python의 multiprocessing을 사용한 병렬 이미지 처리 기능을 제공합니다.
|
32 |
+
|
33 |
+
## Usage
|
34 |
+
|
35 |
+
### 1. PSD 파일을 PNG 이미지로 변환
|
36 |
+
`convert_psd_to_png.py` 스크립트를 사용하여 PSD 파일들을 PNG 이미지로 변환합니다.
|
37 |
+
|
38 |
+
```shell
|
39 |
+
python scripts/convert_psd_to_png.py --directory <PSD_directory> --output <output_directory> [--visible_layers layer1 layer2 ...] [--invisible_layers layer3 layer4 ...]
|
40 |
+
```
|
41 |
+
|
42 |
+
- `--directory` : PSD 파일을 검색할 디렉토리
|
43 |
+
- `--output` : 변환된 PNG 파일을 저장할 디렉토리
|
44 |
+
- `--visible_layers` / `--invisible_layers` : 보이거나 숨길 레이어들을 지정
|
45 |
+
|
46 |
+
### 2. 이미지 컷 추출 및 필터링
|
47 |
+
`parse_cut_from_page.py` 스크립트를 사용하여 컷 박스를 추출하고, 각 컷에 대해 필터링된 데이터를 생성합니다.
|
48 |
+
|
49 |
+
```shell
|
50 |
+
python scripts/parse_cut_from_page.py --lineart <lineart_directory> --flat <flat_directory> --segmentation <segmentation_directory> --color <color_directory> --output <output_directory> [--num_process <number_of_processes>]
|
51 |
+
```
|
52 |
+
|
53 |
+
- `--lineart` : 라인아트 이미지가 저장된 디렉토리
|
54 |
+
- `--flat` : 채색 전 평면 이미지가 저장된 디렉토리
|
55 |
+
- `--segmentation` : 세분화 이미지가 저장된 디렉토리
|
56 |
+
- `--color` : 컬러 이미지가 저장된 디렉토리
|
57 |
+
- `--output` : 잘라낸 컷 이미지를 저장할 디렉토리
|
58 |
+
- `--num_process` : 병렬 처리에 사용할 프로세스 수 (선택값)
|
59 |
+
|
60 |
+
### 3. 태깅 및 필터링 실행
|
61 |
+
`run_tag_filter.py` 스크립트를 사용하여 이미지에 태깅 및 필터링을 수행하고, 결과를 저장합니다.
|
62 |
+
|
63 |
+
```shell
|
64 |
+
python scripts/run_tag_filter.py --input_dir <input_directory> --output_dir <output_directory> [--ext png jpg jpeg]
|
65 |
+
```
|
66 |
+
|
67 |
+
- `--input_dir` : 필터링할 이미지들이 저장된 디렉토리
|
68 |
+
- `--output_dir` : 필터링 결과 이미지 및 캡션 파일을 저장할 디렉토리
|
69 |
+
- `--ext` : 처리할 이미지 확장자 목록 (기본값: png)
|
70 |
+
|
71 |
+
---
|
72 |
+
|
73 |
+
이와 같이 각 스크립트는 별도의 인자들을 받아서 작업을 수행하며, 스크립트 내부의 로직에 따라 PSD 변환, 컷 추출, 인물 검출/태깅, 캡션 파일 생성 등의 기능을 제공합니다.
|
74 |
+
|
75 |
+
### 4. Gradio 데모 페이지
|
76 |
+
아래 명렁어를 실행하여 Gradio 데모 페이지로 원하는 PNG파일을 업로드하여 필터링된 이미지 컷 추출 결과를 ZIP파일 혹은 PSD파일로 확인할 수 있습니다.
|
77 |
+
|
78 |
+
```shell
|
79 |
+
python app.py
|
80 |
+
```
|
app.py
ADDED
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import zipfile
|
4 |
+
import shutil # For make_archive
|
5 |
+
import uuid
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
from psd_tools import PSDImage
|
8 |
+
from psd_tools.api.layers import PixelLayer
|
9 |
+
import gradio as gr
|
10 |
+
import traceback # For printing stack traces
|
11 |
+
import subprocess
|
12 |
+
|
13 |
+
def install(package):
|
14 |
+
subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
|
15 |
+
|
16 |
+
install("timm")
|
17 |
+
install("pydantic==2.10.6")
|
18 |
+
install("dghs-imgutils==0.15.0")
|
19 |
+
install("onnxruntime >= 1.17.0")
|
20 |
+
install("psd_tools==1.10.7")
|
21 |
+
# os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
|
22 |
+
# --- Attempt to import the actual function (Detector 1) ---
|
23 |
+
# Let's keep the original import name as requested in the previous version
|
24 |
+
try:
|
25 |
+
# Assuming this is the intended "Detector 1"
|
26 |
+
from src.wise_crop.detect_and_crop import crop_and_mask_characters_gradio
|
27 |
+
detector_1_available = True
|
28 |
+
print("Successfully imported 'crop_and_mask_characters_gradio' as Detector 1.")
|
29 |
+
except ImportError:
|
30 |
+
detector_1_available = False
|
31 |
+
print("Warning: Could not import 'crop_and_mask_characters_gradio'. Using dummy function for Detector 1.")
|
32 |
+
# Define a dummy version for Detector 1 if import fails
|
33 |
+
def crop_and_mask_characters_gradio(image_pil: Image.Image):
|
34 |
+
"""Dummy function 1 if import fails."""
|
35 |
+
print("Using DUMMY Detector 1.")
|
36 |
+
if image_pil is None: return []
|
37 |
+
width, height = image_pil.size
|
38 |
+
boxes = [
|
39 |
+
(0, (int(width * 0.1), int(height * 0.1), int(width * 0.3), int(height * 0.4))),
|
40 |
+
(1, (int(width * 0.6), int(height * 0.5), int(width * 0.25), int(height * 0.35))),
|
41 |
+
]
|
42 |
+
valid_boxes = []
|
43 |
+
for i, (x, y, w, h) in boxes:
|
44 |
+
x1, y1, x2, y2 = max(0, x), max(0, y), min(width, x + w), min(height, y + h)
|
45 |
+
if x2 - x1 > 0 and y2 - y1 > 0: valid_boxes.append((i, (x1, y1, x2 - x1, y2 - y1)))
|
46 |
+
return valid_boxes
|
47 |
+
|
48 |
+
# from src.oskar_crop.detect_and_crop import process_single_image as detector_2_function
|
49 |
+
try:
|
50 |
+
# Assuming this is the intended "Detector 2"
|
51 |
+
# Note: Renamed the import alias to avoid conflict if both imports succeed.
|
52 |
+
# The function call inside process_lineart still uses crop_and_mask_characters_gradio_2
|
53 |
+
from src.oskar_crop.detect_and_crop import process_single_image as detector_2_function
|
54 |
+
detector_2_available = True
|
55 |
+
print("Successfully imported 'process_single_image' as Detector 2.")
|
56 |
+
# Define the function name used in process_lineart
|
57 |
+
def crop_and_mask_characters_gradio_2(image_pil: Image.Image):
|
58 |
+
return detector_2_function(image_pil)
|
59 |
+
|
60 |
+
except ImportError:
|
61 |
+
detector_2_available = False
|
62 |
+
print("Warning: Could not import 'process_single_image'. Using dummy function for Detector 2.")
|
63 |
+
# Define a dummy version for Detector 2 if import fails
|
64 |
+
# --- Define the SECOND Dummy Detection Function ---
|
65 |
+
def crop_and_mask_characters_gradio_2(image_pil: Image.Image):
|
66 |
+
"""
|
67 |
+
SECOND Dummy function to simulate detecting objects and returning bounding boxes.
|
68 |
+
Returns different results than the first function.
|
69 |
+
"""
|
70 |
+
print("Using DUMMY Detector 2.")
|
71 |
+
if image_pil is None:
|
72 |
+
return []
|
73 |
+
|
74 |
+
width, height = image_pil.size
|
75 |
+
print(f"Dummy detection 2 running on image size: {width}x{height}")
|
76 |
+
|
77 |
+
# Define DIFFERENT fixed bounding boxes for demonstration
|
78 |
+
boxes = [
|
79 |
+
(0, (int(width * 0.05), int(height * 0.6), int(width * 0.4), int(height * 0.3))), # Bottom-leftish, wider
|
80 |
+
(1, (int(width * 0.7), int(height * 0.1), int(width * 0.20), int(height * 0.25))), # Top-rightish, smaller
|
81 |
+
(2, (int(width * 0.4), int(height * 0.4), int(width * 0.15), int(height * 0.15))), # Center-ish, very small
|
82 |
+
]
|
83 |
+
|
84 |
+
# Basic validation
|
85 |
+
valid_boxes = []
|
86 |
+
for i, (x, y, w, h) in boxes:
|
87 |
+
x1 = max(0, x)
|
88 |
+
y1 = max(0, y)
|
89 |
+
x2 = min(width, x + w)
|
90 |
+
y2 = min(height, y + h)
|
91 |
+
new_w = x2 - x1
|
92 |
+
new_h = y2 - y1
|
93 |
+
if new_w > 0 and new_h > 0:
|
94 |
+
valid_boxes.append((i, (x1, y1, new_w, new_h)))
|
95 |
+
|
96 |
+
print(f"Dummy detection 2 found {len(valid_boxes)} boxes.")
|
97 |
+
return valid_boxes
|
98 |
+
|
99 |
+
|
100 |
+
# --- Helper Function (make_lineart_transparent - unchanged) ---
|
101 |
+
def make_lineart_transparent(lineart_path, threshold=200):
|
102 |
+
"""Converts a lineart image file to a transparent RGBA PIL Image."""
|
103 |
+
try:
|
104 |
+
# Ensure we handle potential pathlib objects if Gradio passes them
|
105 |
+
lineart_gray = Image.open(str(lineart_path)).convert('L')
|
106 |
+
w, h = lineart_gray.size
|
107 |
+
lineart_rgba = Image.new('RGBA', (w, h), (0, 0, 0, 0))
|
108 |
+
gray_pixels = lineart_gray.load()
|
109 |
+
rgba_pixels = lineart_rgba.load()
|
110 |
+
for y in range(h):
|
111 |
+
for x in range(w):
|
112 |
+
gray_val = gray_pixels[x, y]
|
113 |
+
alpha = 255 - gray_val
|
114 |
+
if gray_val < threshold :
|
115 |
+
rgba_pixels[x, y] = (0, 0, 0, alpha)
|
116 |
+
else:
|
117 |
+
rgba_pixels[x, y] = (0, 0, 0, 0)
|
118 |
+
return lineart_rgba
|
119 |
+
except FileNotFoundError:
|
120 |
+
print(f"Helper Error: Image file not found at {lineart_path}")
|
121 |
+
# Return a blank transparent image or None? Returning None is clearer.
|
122 |
+
return None
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Helper Error processing image {lineart_path}: {e}")
|
125 |
+
return None
|
126 |
+
|
127 |
+
# --- Main Processing Function (modified for better error handling with PIL) ---
|
128 |
+
def process_lineart(input_pil_or_path, detector_choice): # Input can be PIL or path from examples
|
129 |
+
"""
|
130 |
+
Processes the input lineart image using the selected detector.
|
131 |
+
Detects objects (e.g., characters based on head/face), crops them,
|
132 |
+
provides a gallery of crops, a ZIP file of crops, and a PSD file
|
133 |
+
with the original lineart (made transparent) and bounding boxes.
|
134 |
+
"""
|
135 |
+
# --- Initialize variables ---
|
136 |
+
input_pil_image = None
|
137 |
+
temp_input_path = None
|
138 |
+
using_temp_input_path = False
|
139 |
+
status_updates = ["Status: Initializing..."]
|
140 |
+
psd_output_path = None # Initialize to None
|
141 |
+
zip_output_path = None # Initialize to None
|
142 |
+
cropped_images_for_gallery = [] # Initialize to empty list
|
143 |
+
|
144 |
+
try:
|
145 |
+
# --- Handle Input ---
|
146 |
+
if input_pil_or_path is None:
|
147 |
+
gr.Warning("Please upload a PNG image or select an example.")
|
148 |
+
return [], None, None, "Status: No image provided."
|
149 |
+
|
150 |
+
print(f"Input type: {type(input_pil_or_path)}")
|
151 |
+
print(f"Input value: {input_pil_or_path}")
|
152 |
+
# Check if input is already a PIL image (from upload) or a path (from examples)
|
153 |
+
if isinstance(input_pil_or_path, Image.Image):
|
154 |
+
input_pil_image = input_pil_or_path
|
155 |
+
print("Processing PIL image from upload.")
|
156 |
+
# Create a temporary path for make_lineart_transparent if needed later
|
157 |
+
temp_input_fd, temp_input_path = tempfile.mkstemp(suffix=".png")
|
158 |
+
os.close(temp_input_fd)
|
159 |
+
input_pil_image.save(temp_input_path, "PNG")
|
160 |
+
using_temp_input_path = True
|
161 |
+
elif isinstance(input_pil_or_path, str) and os.path.exists(input_pil_or_path):
|
162 |
+
print(f"Processing image from file path: {input_pil_or_path}")
|
163 |
+
try:
|
164 |
+
input_pil_image = Image.open(input_pil_or_path)
|
165 |
+
# Use the example path directly for make_lineart_transparent
|
166 |
+
temp_input_path = input_pil_or_path
|
167 |
+
using_temp_input_path = False # Don't delete the example file later
|
168 |
+
except Exception as e:
|
169 |
+
status_updates.append(f"ERROR: Could not open image file from path '{input_pil_or_path}': {e}")
|
170 |
+
print(status_updates[-1])
|
171 |
+
return [], None, None, "\n".join(status_updates) # Return error status
|
172 |
+
else:
|
173 |
+
status_updates.append(f"ERROR: Invalid input type received: {type(input_pil_or_path)}. Expected PIL image or file path.")
|
174 |
+
print(status_updates[-1])
|
175 |
+
return [], None, None, "\n".join(status_updates) # Return error status
|
176 |
+
|
177 |
+
# --- Ensure RGBA and get dimensions ---
|
178 |
+
try:
|
179 |
+
input_pil_image = input_pil_image.convert("RGBA")
|
180 |
+
width, height = input_pil_image.size
|
181 |
+
except Exception as e:
|
182 |
+
status_updates.append(f"ERROR: Could not process input image (convert/get size): {e}")
|
183 |
+
print(status_updates[-1])
|
184 |
+
# Clean up temp file if created before error
|
185 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
186 |
+
try: os.remove(temp_input_path)
|
187 |
+
except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
|
188 |
+
return [], None, None, "\n".join(status_updates) # Return error status
|
189 |
+
|
190 |
+
status_updates = [f"Status: Processing started using {detector_choice}."] # Reset status
|
191 |
+
print("Starting processing...")
|
192 |
+
|
193 |
+
# --- 1. Detect Objects (Conditional) ---
|
194 |
+
print(f"Selected detector: {detector_choice}")
|
195 |
+
if detector_choice == "Detector 1":
|
196 |
+
if not detector_1_available:
|
197 |
+
status_updates.append("Warning: Using DUMMY Detector 1.")
|
198 |
+
boxes_info = crop_and_mask_characters_gradio(input_pil_image)
|
199 |
+
elif detector_choice == "Detector 2":
|
200 |
+
if not detector_2_available:
|
201 |
+
status_updates.append("Warning: Using DUMMY Detector 2.")
|
202 |
+
boxes_info = crop_and_mask_characters_gradio_2(input_pil_image)
|
203 |
+
else:
|
204 |
+
# This case should ideally not happen with Radio buttons, but good for safety
|
205 |
+
status_updates.append(f"ERROR: Invalid detector choice received: {detector_choice}")
|
206 |
+
print(status_updates[-1])
|
207 |
+
# Clean up temp file if created before error
|
208 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
209 |
+
try: os.remove(temp_input_path)
|
210 |
+
except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
|
211 |
+
return [], None, None, "\n".join(status_updates) # Return error status
|
212 |
+
|
213 |
+
if not boxes_info:
|
214 |
+
gr.Warning("No objects detected.")
|
215 |
+
status_updates.append("No objects detected.")
|
216 |
+
# Clean up temp file if created
|
217 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
218 |
+
try: os.remove(temp_input_path)
|
219 |
+
except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
|
220 |
+
return [], None, None, "\n".join(status_updates)
|
221 |
+
|
222 |
+
status_updates.append(f"Detected {len(boxes_info)} objects.")
|
223 |
+
print(f"Detected boxes: {boxes_info}")
|
224 |
+
|
225 |
+
# --- Temporary file paths (partially adjusted) ---
|
226 |
+
temp_dir_for_outputs = tempfile.gettempdir()
|
227 |
+
unique_id = uuid.uuid4().hex[:8]
|
228 |
+
zip_base_name = os.path.join(temp_dir_for_outputs, f"cropped_images_{unique_id}")
|
229 |
+
zip_output_path = f"{zip_base_name}.zip" # Path for the final zip file
|
230 |
+
psd_output_path = os.path.join(temp_dir_for_outputs, f"lineart_boxes_{unique_id}.psd")
|
231 |
+
# temp_input_path is already handled above based on input source
|
232 |
+
|
233 |
+
# --- 2. Crop Images and Prepare for ZIP ---
|
234 |
+
with tempfile.TemporaryDirectory() as temp_crop_dir:
|
235 |
+
print(f"Saving cropped images to temporary directory: {temp_crop_dir}")
|
236 |
+
for i, (x, y, w, h) in boxes_info:
|
237 |
+
# Ensure box coordinates are within image bounds
|
238 |
+
x1, y1 = max(0, x), max(0, y)
|
239 |
+
x2, y2 = min(width, x + w), min(height, y + h)
|
240 |
+
box = (x1, y1, x2, y2)
|
241 |
+
if box[2] > box[0] and box[3] > box[1]: # Check if width and height are positive
|
242 |
+
try:
|
243 |
+
cropped_img = input_pil_image.crop(box)
|
244 |
+
cropped_images_for_gallery.append(cropped_img)
|
245 |
+
crop_filename = os.path.join(temp_crop_dir, f"cropped_{i}.png")
|
246 |
+
cropped_img.save(crop_filename, "PNG")
|
247 |
+
except Exception as e:
|
248 |
+
print(f"Error cropping or saving box {i} with coords {box}: {e}")
|
249 |
+
status_updates.append(f"Warning: Error processing crop {i}.")
|
250 |
+
else:
|
251 |
+
print(f"Skipping invalid box {i} with coords {box}")
|
252 |
+
status_updates.append(f"Warning: Skipped invalid crop dimensions for box {i}.")
|
253 |
+
|
254 |
+
|
255 |
+
# --- 3. Create ZIP File ---
|
256 |
+
# Check if any PNG files were actually created in the temp dir
|
257 |
+
if any(f.endswith(".png") for f in os.listdir(temp_crop_dir)):
|
258 |
+
print(f"Creating ZIP file: {zip_output_path} from {temp_crop_dir}")
|
259 |
+
try:
|
260 |
+
shutil.make_archive(zip_base_name, 'zip', temp_crop_dir)
|
261 |
+
status_updates.append("Cropped images ZIP created.")
|
262 |
+
# zip_output_path is already correctly set
|
263 |
+
except Exception as e:
|
264 |
+
print(f"Error creating ZIP file: {e}")
|
265 |
+
status_updates.append("Error: Failed to create ZIP file.")
|
266 |
+
zip_output_path = None # Indicate failure
|
267 |
+
else:
|
268 |
+
print("No valid cropped images were saved, skipping ZIP creation.")
|
269 |
+
status_updates.append("Skipping ZIP creation (no valid crops).")
|
270 |
+
zip_output_path = None # No zip file to provide
|
271 |
+
|
272 |
+
# --- 4. Prepare PSD Layers ---
|
273 |
+
# a) Line Layer (Use the temp_input_path which is either the original example path or a temp copy)
|
274 |
+
print(f"Using image path for transparent layer: {temp_input_path}")
|
275 |
+
line_layer_pil = make_lineart_transparent(temp_input_path)
|
276 |
+
if line_layer_pil is None:
|
277 |
+
status_updates.append("Error: Failed to create transparent lineart layer.")
|
278 |
+
print(status_updates[-1])
|
279 |
+
# Don't create PSD if lineart failed, return current results
|
280 |
+
# Clean up temp file if created
|
281 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
282 |
+
try: os.remove(temp_input_path)
|
283 |
+
except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
|
284 |
+
return cropped_images_for_gallery, zip_output_path, None, "\n".join(status_updates) # Return None for PSD
|
285 |
+
|
286 |
+
status_updates.append("Transparent lineart layer created.")
|
287 |
+
|
288 |
+
# b) Box Layer
|
289 |
+
box_layer_pil = Image.new('RGBA', (width, height), (255, 255, 255, 255)) # White background
|
290 |
+
draw = ImageDraw.Draw(box_layer_pil)
|
291 |
+
for i, (x, y, w, h) in boxes_info:
|
292 |
+
# Use validated coords again, ensure they are within bounds
|
293 |
+
x1, y1 = max(0, x), max(0, y)
|
294 |
+
x2, y2 = min(width, x + w), min(height, y + h)
|
295 |
+
if x2 > x1 and y2 > y1: # Check validity again just in case
|
296 |
+
rect = [(x1, y1), (x2, y2)]
|
297 |
+
# Changed to fill for solid boxes, yellow fill, semi-transparent
|
298 |
+
draw.rectangle(rect, fill=(255, 255, 0, 128))
|
299 |
+
status_updates.append("Bounding box layer created.")
|
300 |
+
|
301 |
+
# --- 5. Create PSD File ---
|
302 |
+
print(f"Creating PSD file: {psd_output_path}")
|
303 |
+
# Double check layer sizes before creating PSD object
|
304 |
+
if line_layer_pil.size != (width, height) or box_layer_pil.size != (width, height):
|
305 |
+
size_error_msg = (f"Error: Layer size mismatch during PSD creation. "
|
306 |
+
f"Line: {line_layer_pil.size}, Box: {box_layer_pil.size}, "
|
307 |
+
f"Expected: {(width, height)}")
|
308 |
+
status_updates.append(size_error_msg)
|
309 |
+
print(size_error_msg)
|
310 |
+
# Clean up temp file if created
|
311 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
312 |
+
try: os.remove(temp_input_path)
|
313 |
+
except Exception as e_rem: print(f"Warning: Could not remove temp input file {temp_input_path}: {e_rem}")
|
314 |
+
return cropped_images_for_gallery, zip_output_path, None, "\n".join(status_updates) # No PSD
|
315 |
+
|
316 |
+
try:
|
317 |
+
psd = PSDImage.new(mode='RGBA', size=(width, height))
|
318 |
+
# Add layers (order matters for visibility in PSD viewers)
|
319 |
+
# Base layer is transparent by default with RGBA
|
320 |
+
psd.append(PixelLayer.frompil(line_layer_pil, layer_name='line', top=0, left=0))
|
321 |
+
psd.append(PixelLayer.frompil(box_layer_pil, layer_name='box', top=0, left=0))
|
322 |
+
psd.save(psd_output_path)
|
323 |
+
status_updates.append("PSD file created.")
|
324 |
+
except Exception as e:
|
325 |
+
print(f"Error saving PSD file: {e}")
|
326 |
+
traceback.print_exc()
|
327 |
+
status_updates.append("Error: Failed to save PSD file.")
|
328 |
+
psd_output_path = None # Indicate failure
|
329 |
+
|
330 |
+
|
331 |
+
print("Processing finished.")
|
332 |
+
status_updates.append("Success!")
|
333 |
+
final_status = "\n".join(status_updates)
|
334 |
+
|
335 |
+
# Return all paths, even if None (Gradio handles None for File output)
|
336 |
+
return cropped_images_for_gallery, zip_output_path, psd_output_path, final_status
|
337 |
+
|
338 |
+
except Exception as e:
|
339 |
+
print(f"An unexpected error occurred in process_lineart: {e}")
|
340 |
+
traceback.print_exc()
|
341 |
+
status_updates.append(f"FATAL ERROR: {e}")
|
342 |
+
final_status = "\n".join(status_updates)
|
343 |
+
# Return empty/None outputs and the error status
|
344 |
+
# Ensure cleanup happens even on fatal error
|
345 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
346 |
+
try:
|
347 |
+
os.remove(temp_input_path)
|
348 |
+
print(f"Cleaned up temporary input file due to error: {temp_input_path}")
|
349 |
+
except Exception as e_rem:
|
350 |
+
print(f"Warning: Could not remove temp input file {temp_input_path} during error handling: {e_rem}")
|
351 |
+
return [], None, None, final_status # Return safe defaults
|
352 |
+
|
353 |
+
finally:
|
354 |
+
# --- Final Cleanup (Only removes temp input if created from upload) ---
|
355 |
+
if using_temp_input_path and temp_input_path and os.path.exists(temp_input_path):
|
356 |
+
try:
|
357 |
+
os.remove(temp_input_path)
|
358 |
+
print(f"Cleaned up temporary input file: {temp_input_path}")
|
359 |
+
except Exception as e_rem:
|
360 |
+
# This might happen if the file was already removed in an error block
|
361 |
+
print(f"Notice: Could not remove temp input file {temp_input_path} in finally block (may already be removed): {e_rem}")
|
362 |
+
|
363 |
+
|
364 |
+
# --- Gradio Interface Definition (modified) ---
|
365 |
+
css = '''
|
366 |
+
.custom-gallery {
|
367 |
+
height: 500px !important;
|
368 |
+
width: 100%;
|
369 |
+
margin: 10px auto;
|
370 |
+
padding: 0px;
|
371 |
+
overflow-y: auto !important;
|
372 |
+
}
|
373 |
+
'''
|
374 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
375 |
+
gr.Markdown("# Webtoon Lineart Cropper with Filtering by Head-or-Face Detection")
|
376 |
+
gr.Markdown("Upload a PNG lineart image of your webtoon and automatically crop the character's face or head included region. "
|
377 |
+
"This demo leverages some detectors to precisely detect and isolate characters. "
|
378 |
+
"The app will display cropped objects, provide a ZIP of cropped PNGs, "
|
379 |
+
"and a PSD file with transparent lineart and half-transparent yellow-filled box layers. "
|
380 |
+
"We provide two detectors to choose from, each with different filtering methods. ")
|
381 |
+
gr.Markdown("- **Detector 1**: Uses [`imgutils.detect`](https://github.com/deepghs/imgutils/tree/main/imgutils/detect) and VLM-based filtering with [`google/gemma-3-12b-it`](https://huggingface.co/google/gemma-3-12b-it)")
|
382 |
+
gr.Markdown("- **Detector 2**: Uses [`imgutils.detect`](https://github.com/deepghs/imgutils/tree/main/imgutils/detect) and tag-based filtering with [`SmilingWolf/wd-eva02-large-tagger-v3`](https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3)")
|
383 |
+
gr.Markdown("**Note 1:** The app may take a few seconds to process the image, depending on the size and number of characters detected. The example image below is a lineart PNG file created synthetically from images on [Danbooru](https://danbooru.donmai.us/posts?page=1&tags=dragon_ball_z) after [lineart extraction](https://huggingface.co/spaces/carolineec/informativedrawings).")
|
384 |
+
gr.Markdown("**Note 2:** This demo is developed by [Kakao Entertainment](https://kakaoent.com/)'s AI Lab for research purposes, specifically designed to preprocess webtoon image data and is also not intended for production use. It is a research prototype and may not be suitable for all use cases. Please use it at your own risk.")
|
385 |
+
|
386 |
+
with gr.Row():
|
387 |
+
with gr.Column(scale=1):
|
388 |
+
# Input type remains 'filepath' to handle examples cleanly.
|
389 |
+
image_input = gr.Image(type="filepath", label="Upload Lineart PNG or Select Example", image_mode='RGBA', height=400)
|
390 |
+
|
391 |
+
detector_choice_radio = gr.Radio(
|
392 |
+
choices=["Detector 1", "Detector 2"],
|
393 |
+
label="Choose Detection Function",
|
394 |
+
value="Detector 1" # Default value
|
395 |
+
)
|
396 |
+
process_button = gr.Button("Process Uploaded/Modified Image", variant="primary")
|
397 |
+
status_output = gr.Textbox(label="Status", interactive=False, lines=8) # Increased lines slightly more
|
398 |
+
|
399 |
+
with gr.Column(scale=3):
|
400 |
+
gr.Markdown("### Cropped Objects")
|
401 |
+
# Setting height explicitly can sometimes help layout.
|
402 |
+
gallery_output = gr.Gallery(label="Detected Objects (Cropped)", elem_id="gallery_crops", columns=4, height=500, interactive=False, elem_classes="custom-gallery") # object_fit="contain")
|
403 |
+
with gr.Row():
|
404 |
+
zip_output = gr.File(label="Download Cropped Images (ZIP)")
|
405 |
+
psd_output = gr.File(label="Download PSD (Lineart + Boxes)")
|
406 |
+
|
407 |
+
# --- Add Examples ---
|
408 |
+
# IMPORTANT: Make sure 'sample_img.png' exists in the same directory
|
409 |
+
# as this script, or provide the correct relative/absolute path.
|
410 |
+
# Also ensure the image is a valid PNG.
|
411 |
+
example_image_path = "./sample_img/sample_danbooru_dragonball.png"
|
412 |
+
if os.path.exists(example_image_path):
|
413 |
+
gr.Examples(
|
414 |
+
examples=[
|
415 |
+
[example_image_path, "Detector 1"],
|
416 |
+
[example_image_path, "Detector 2"] # Add example for detector 2 as well
|
417 |
+
],
|
418 |
+
# Inputs that the examples populate
|
419 |
+
inputs=[image_input, detector_choice_radio],
|
420 |
+
# Outputs that are updated when an example is clicked AND run_on_click=True
|
421 |
+
outputs=[gallery_output, zip_output, psd_output, status_output],
|
422 |
+
# The function to call when an example is clicked
|
423 |
+
fn=process_lineart,
|
424 |
+
# Make clicking an example automatically run the function
|
425 |
+
run_on_click=True,
|
426 |
+
label="Click Example to Run Automatically", # Updated label
|
427 |
+
cache_examples=True, # Disable caching to ensure fresh processing
|
428 |
+
cache_mode="lazy",
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
gr.Markdown(f"**(Note:** Could not find `{example_image_path}` for examples. Please create it or ensure it's in the correct directory.)")
|
432 |
+
|
433 |
+
|
434 |
+
# --- Button Click Handler (for manual uploads/changes) ---
|
435 |
+
process_button.click(
|
436 |
+
fn=process_lineart,
|
437 |
+
inputs=[image_input, detector_choice_radio],
|
438 |
+
outputs=[gallery_output, zip_output, psd_output, status_output]
|
439 |
+
)
|
440 |
+
|
441 |
+
# --- Launch the Gradio App ---
|
442 |
+
if __name__ == "__main__":
|
443 |
+
# Create a dummy sample image if it doesn't exist for testing
|
444 |
+
if not os.path.exists("./sample_img/sample_danbooru_dragonball.png"):
|
445 |
+
print("Creating a dummy 'sample_danbooru_dragonball.png' for demonstration.")
|
446 |
+
try:
|
447 |
+
img = Image.new('L', (300, 200), color=255) # White background (grayscale)
|
448 |
+
draw = ImageDraw.Draw(img)
|
449 |
+
# Draw some black lines/shapes
|
450 |
+
draw.line((30, 30, 270, 30), fill=0, width=2)
|
451 |
+
draw.rectangle((50, 50, 150, 150), outline=0, width=3)
|
452 |
+
draw.ellipse((180, 70, 250, 130), outline=0, width=3)
|
453 |
+
img.save("./sample_img/sample_danbooru_dragonball.png", "PNG")
|
454 |
+
print("Dummy 'sample_danbooru_dragonball.png' created.")
|
455 |
+
except Exception as e:
|
456 |
+
print(f"Warning: Failed to create dummy sample image: {e}")
|
457 |
+
|
458 |
+
demo.launch()
|
459 |
+
# ssr_mode=False
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
2 |
+
torch
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
4 |
+
torchvision
|
5 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
6 |
+
torchaudio
|
7 |
+
pydantic==2.10.6
|
8 |
+
timm
|
9 |
+
psd_tools==1.10.7
|
10 |
+
accelerate
|
11 |
+
diffusers
|
12 |
+
transformers
|
13 |
+
xformers
|
14 |
+
opencv-python
|
15 |
+
dghs-imgutils==0.15.0
|
16 |
+
pillow
|
17 |
+
numpy
|
18 |
+
scikit-learn
|
19 |
+
huggingface_hub
|
20 |
+
tqdm
|
21 |
+
opencv-contrib-python
|
22 |
+
pandas
|
23 |
+
scipy
|
sample_img/sample_danbooru_dragonball.png
ADDED
![]() |
Git LFS Details
|
scripts/__init__.py
ADDED
File without changes
|
scripts/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (204 Bytes). View file
|
|
scripts/__pycache__/parse_cut_from_page.cpython-312.pyc
ADDED
Binary file (14 kB). View file
|
|
scripts/convert_psd_to_png.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import multiprocessing
|
4 |
+
import os
|
5 |
+
from typing import Iterable
|
6 |
+
|
7 |
+
from psd_tools import PSDImage
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
|
13 |
+
|
14 |
+
def parse_args():
|
15 |
+
parser = argparse.ArgumentParser(description="Convert PSD files to PNG.")
|
16 |
+
parser.add_argument(
|
17 |
+
"-d",
|
18 |
+
"--directory",
|
19 |
+
type=str,
|
20 |
+
default="./",
|
21 |
+
help="Directory to search for PSD files.",
|
22 |
+
)
|
23 |
+
parser.add_argument("-o", "--output", type=str, default="./", help="Directory to save PNG files.")
|
24 |
+
parser.add_argument(
|
25 |
+
"--visible_layers",
|
26 |
+
default=[],
|
27 |
+
nargs="+",
|
28 |
+
type=str,
|
29 |
+
help="List of layer names to make visible.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--invisible_layers",
|
33 |
+
default=[],
|
34 |
+
nargs="+",
|
35 |
+
type=str,
|
36 |
+
help="List of layer names to make invisible.",
|
37 |
+
)
|
38 |
+
parser.add_argument("--num_processes", "-n", default=None, type=int, help=" Number of processes to use.")
|
39 |
+
return parser.parse_args()
|
40 |
+
|
41 |
+
|
42 |
+
def find_psd_files(directory):
|
43 |
+
psd_files = []
|
44 |
+
for root, dirs, files in os.walk(directory):
|
45 |
+
for file in files:
|
46 |
+
if file.endswith(".psd"):
|
47 |
+
psd_files.append(os.path.join(root, file))
|
48 |
+
return psd_files
|
49 |
+
|
50 |
+
|
51 |
+
def set_layer_visibility(layer, visible_layers, invisible_layers):
|
52 |
+
if layer.name in visible_layers:
|
53 |
+
layer.visible = True
|
54 |
+
if layer.name in invisible_layers:
|
55 |
+
layer.visible = False
|
56 |
+
|
57 |
+
if isinstance(layer, Iterable):
|
58 |
+
for child in layer:
|
59 |
+
set_layer_visibility(child, visible_layers, invisible_layers)
|
60 |
+
|
61 |
+
|
62 |
+
def process_psd_file(task):
|
63 |
+
"""
|
64 |
+
Worker function that processes a single PSD file.
|
65 |
+
Opens the PSD, sets layer visibility, composites the image and saves it as PNG.
|
66 |
+
"""
|
67 |
+
psd_file, output, visible_layers, invisible_layers, force = task
|
68 |
+
try:
|
69 |
+
psd = PSDImage.open(psd_file)
|
70 |
+
if force:
|
71 |
+
for layer in psd:
|
72 |
+
set_layer_visibility(layer, visible_layers, invisible_layers)
|
73 |
+
image = psd.composite(force=force)
|
74 |
+
fname = os.path.basename(psd_file).replace(".psd", ".png")
|
75 |
+
output_file = os.path.join(output, fname)
|
76 |
+
image.save(output_file)
|
77 |
+
except Exception as e:
|
78 |
+
logger.error("Error processing file %s: %s", psd_file, e)
|
79 |
+
|
80 |
+
|
81 |
+
def main(args):
|
82 |
+
# Create output directory if it doesn't exist
|
83 |
+
if not os.path.exists(args.output):
|
84 |
+
os.makedirs(args.output)
|
85 |
+
|
86 |
+
psd_files = find_psd_files(args.directory)
|
87 |
+
# force=True when any layer visibility is provided
|
88 |
+
force = True if len(args.visible_layers) + len(args.invisible_layers) else False
|
89 |
+
|
90 |
+
tasks = [(psd_file, args.output, args.visible_layers, args.invisible_layers, force) for psd_file in psd_files]
|
91 |
+
|
92 |
+
num_processes = args.num_processes if args.num_processes else multiprocessing.cpu_count() // 2
|
93 |
+
# Use multiprocessing to process PSD files in parallel
|
94 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
95 |
+
list(
|
96 |
+
tqdm(
|
97 |
+
pool.imap_unordered(process_psd_file, tasks),
|
98 |
+
total=len(tasks),
|
99 |
+
desc="Convert PSD to PNG files",
|
100 |
+
)
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
args = parse_args()
|
106 |
+
main(args)
|
scripts/parse_cut_from_page.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import multiprocessing
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
from pathlib import Path, PurePath
|
7 |
+
from typing import List, Tuple
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
current_file_path = Path(__file__).resolve()
|
14 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
|
19 |
+
|
20 |
+
def parse_args():
|
21 |
+
parser = argparse.ArgumentParser(description="Parse cut from page script.")
|
22 |
+
|
23 |
+
parser.add_argument("--lineart", "-l", type=str, required=True, help="Directory of lineart images.")
|
24 |
+
parser.add_argument("--flat", "-f", type=str, required=True, help="Directory of flat images.")
|
25 |
+
parser.add_argument(
|
26 |
+
"--segmentation",
|
27 |
+
"-s",
|
28 |
+
type=str,
|
29 |
+
required=True,
|
30 |
+
help="Directory of segmentatio.",
|
31 |
+
)
|
32 |
+
parser.add_argument("--color", "-c", type=str, required=True, help="Directory of color images.")
|
33 |
+
parser.add_argument("--output", "-o", type=str, required=True, help="Output directory for parsed images.")
|
34 |
+
parser.add_argument("--num_process", "-n", type=int, default=None, help="Number of processes to use.")
|
35 |
+
|
36 |
+
return parser.parse_args()
|
37 |
+
|
38 |
+
|
39 |
+
def get_image_list(input_dir: str, ext: List[str]):
|
40 |
+
"""
|
41 |
+
Get a list of images from the input directory with the specified extensions.
|
42 |
+
Args:
|
43 |
+
input_dir (str): Directory containing images to filter.
|
44 |
+
ext (list): List of file extensions to filter by.
|
45 |
+
Returns:
|
46 |
+
list: List of image file paths.
|
47 |
+
"""
|
48 |
+
image_list = []
|
49 |
+
for ext in ext:
|
50 |
+
image_list.extend(Path(input_dir).glob(f"*.{ext}"))
|
51 |
+
return image_list
|
52 |
+
|
53 |
+
|
54 |
+
def check_image_pair_validity(
|
55 |
+
lineart_list: List[PurePath],
|
56 |
+
flat_list: List[PurePath],
|
57 |
+
segmentation_list: List[PurePath],
|
58 |
+
color_list: List[PurePath],
|
59 |
+
pattern: str = r"\d+_\d+",
|
60 |
+
) -> Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
|
61 |
+
"""
|
62 |
+
Validates and filters lists of image file paths to ensure they correspond to the same IDs
|
63 |
+
based on a given naming pattern. If the lengths of the input lists are mismatched, the
|
64 |
+
function filters the lists to include only matching IDs.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
lineart_list (List[PurePath]): List of file paths for lineart images.
|
68 |
+
flat_path (List[PurePath]): List of file paths for flat images.
|
69 |
+
segmentation_path (List[PurePath]): List of file paths for segmentation images.
|
70 |
+
color_path (List[PurePath]): List of file paths for color images.
|
71 |
+
pattern (str, optional): Regular expression pattern to extract IDs from file names.
|
72 |
+
Defaults to r"\d+_\d+".
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
|
76 |
+
A tuple containing four lists of file paths (lineart, flat, segmentation, color)
|
77 |
+
that have been filtered to ensure matching IDs.
|
78 |
+
"""
|
79 |
+
pattern = re.compile(pattern)
|
80 |
+
|
81 |
+
# Sort the lists based on the pattern
|
82 |
+
lineart_list = sorted(lineart_list, key=lambda x: pattern.match(x.name).group(0))
|
83 |
+
flat_list = sorted(flat_list, key=lambda x: pattern.match(x.name).group(0))
|
84 |
+
segmentation_list = sorted(segmentation_list, key=lambda x: pattern.match(x.name).group(0))
|
85 |
+
color_list = sorted(color_list, key=lambda x: pattern.match(x.name).group(0))
|
86 |
+
|
87 |
+
# Check if the lengths of the lists are equal
|
88 |
+
if (
|
89 |
+
len(lineart_list) != len(flat_list)
|
90 |
+
or len(lineart_list) != len(segmentation_list)
|
91 |
+
or len(lineart_list) != len(color_list)
|
92 |
+
):
|
93 |
+
# If the lengths are not equal, we need to filter the lists based on the pattern
|
94 |
+
logger.warning(
|
95 |
+
f"Length mismatch: lineart({len(lineart_list)}), flat({len(flat_list)}), segmentation({len(segmentation_list)}), color({len(color_list)})"
|
96 |
+
)
|
97 |
+
new_lineart_list = []
|
98 |
+
new_flat_list = []
|
99 |
+
new_segmentation_list = []
|
100 |
+
new_color_list = []
|
101 |
+
for lineart_path in lineart_list:
|
102 |
+
lineart_name = lineart_path.name
|
103 |
+
lineart_match = pattern.match(lineart_name)
|
104 |
+
|
105 |
+
if lineart_match:
|
106 |
+
file_id = lineart_match.group(0)
|
107 |
+
corresponding_flat_files = [p for p in flat_list if file_id in p.name]
|
108 |
+
corresponding_segmentation_files = [p for p in segmentation_list if file_id in p.name]
|
109 |
+
corresponding_color_paths = [p for p in color_list if file_id in p.name]
|
110 |
+
|
111 |
+
if corresponding_flat_files and corresponding_segmentation_files and corresponding_color_paths:
|
112 |
+
new_lineart_list.append(lineart_path)
|
113 |
+
new_flat_list.append(corresponding_flat_files[0])
|
114 |
+
new_segmentation_list.append(corresponding_segmentation_files[0])
|
115 |
+
new_color_list.append(corresponding_color_paths[0])
|
116 |
+
|
117 |
+
return new_lineart_list, new_flat_list, new_segmentation_list, new_color_list
|
118 |
+
else:
|
119 |
+
return lineart_list, flat_list, segmentation_list, color_list
|
120 |
+
|
121 |
+
|
122 |
+
def extract_cutbox_coordinates(image: Image.Image) -> List[Tuple[int, int, int, int]]:
|
123 |
+
"""
|
124 |
+
Extracts bounding box coordinates for non-white regions in an image.
|
125 |
+
|
126 |
+
This function identifies regions in the given image that contain non-white pixels
|
127 |
+
and calculates the bounding box coordinates for each region. The bounding boxes
|
128 |
+
are represented as tuples of (left, top, right, bottom).
|
129 |
+
|
130 |
+
Args:
|
131 |
+
image (Image.Image): The input image as a PIL Image object.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
List[Tuple[int, int, int, int]]: A list of bounding box coordinates for non-white regions.
|
135 |
+
Each tuple contains four integers representing the left, top, right, and bottom
|
136 |
+
coordinates of a bounding box.
|
137 |
+
"""
|
138 |
+
|
139 |
+
# We'll now detect the bounding box for non-white pixels instead of relying on the alpha channel.
|
140 |
+
# Convert the image to RGB and get the numpy array
|
141 |
+
image_rgb = image.convert("RGB")
|
142 |
+
image_np_rgb = np.array(image_rgb)
|
143 |
+
|
144 |
+
# Define white color threshold (treat pixels close to white as white)
|
145 |
+
threshold = 255 # Allow some margin for near-white
|
146 |
+
non_white_mask = np.any(image_np_rgb < threshold, axis=-1) # Any channel below threshold is considered non-white
|
147 |
+
non_white_mask = non_white_mask.astype(np.uint8) * 255
|
148 |
+
# Image.fromarray(non_white_mask).save("non_white_mask.png")
|
149 |
+
|
150 |
+
# Find rows containing non-white pixels
|
151 |
+
non_white_rows = np.where(non_white_mask.any(axis=1))[0]
|
152 |
+
if len(non_white_rows) == 0:
|
153 |
+
return []
|
154 |
+
|
155 |
+
# Group continuous non-white rows
|
156 |
+
horizontal_line = np.where(np.diff(non_white_rows) != 1)[0] + 1
|
157 |
+
non_white_rows = np.split(non_white_rows, horizontal_line)
|
158 |
+
top_bottom_pairs = [(group[0], group[-1]) for group in non_white_rows]
|
159 |
+
|
160 |
+
# Iterate through each cut and find the left and right bounds
|
161 |
+
bounding_boxes = []
|
162 |
+
for top, bottom in top_bottom_pairs:
|
163 |
+
cut = image_np_rgb[top : bottom + 1]
|
164 |
+
|
165 |
+
non_white_mask = np.any(cut < threshold, axis=-1)
|
166 |
+
non_white_cols = np.where(non_white_mask.any(axis=0))[0]
|
167 |
+
left = non_white_cols[0]
|
168 |
+
right = non_white_cols[-1]
|
169 |
+
|
170 |
+
bounding_boxes.append((left, top, right + 1, bottom + 1))
|
171 |
+
|
172 |
+
return bounding_boxes
|
173 |
+
|
174 |
+
|
175 |
+
def process_single_image(task: Tuple[Image.Image, Image.Image, Image.Image, Image.Image, str]):
|
176 |
+
"""
|
177 |
+
Worker function to process a single set of images.
|
178 |
+
Opens images, extracts bounding boxes, crops and saves the cut images.
|
179 |
+
"""
|
180 |
+
line_path, flat_path, seg_path, color_path, output_str = task
|
181 |
+
output_dir = Path(output_str)
|
182 |
+
try:
|
183 |
+
line_img = Image.open(line_path).convert("RGB")
|
184 |
+
flat_img = Image.open(flat_path).convert("RGB")
|
185 |
+
seg_img = Image.open(seg_path).convert("RGB")
|
186 |
+
color_img = Image.open(color_path).convert("RGB")
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Error opening images for {line_path}: {e}")
|
189 |
+
return
|
190 |
+
|
191 |
+
bounding_boxes = extract_cutbox_coordinates(line_img)
|
192 |
+
fname = line_path.stem
|
193 |
+
match = re.compile(r"\d+_\d+").match(fname)
|
194 |
+
if not match:
|
195 |
+
logger.warning(f"Filename pattern not matched for {line_path.name}")
|
196 |
+
return
|
197 |
+
ep_page_str = match.group(0)
|
198 |
+
|
199 |
+
for i, (left, top, right, bottom) in enumerate(bounding_boxes):
|
200 |
+
try:
|
201 |
+
cut_line = line_img.crop((left, top, right, bottom))
|
202 |
+
cut_flat = flat_img.crop((left, top, right, bottom))
|
203 |
+
cut_seg = seg_img.crop((left, top, right, bottom))
|
204 |
+
cut_color = color_img.crop((left, top, right, bottom))
|
205 |
+
|
206 |
+
cut_line.save(output_dir / "line" / f"{ep_page_str}_{i}_line.png")
|
207 |
+
cut_flat.save(output_dir / "flat" / f"{ep_page_str}_{i}_flat.png")
|
208 |
+
cut_seg.save(output_dir / "segmentation" / f"{ep_page_str}_{i}_segmentation.png")
|
209 |
+
cut_color.save(output_dir / "fullcolor" / f"{ep_page_str}_{i}_fullcolor.png")
|
210 |
+
except Exception as e:
|
211 |
+
logger.error(f"Error processing crop for {line_path.name} at box {i}: {e}")
|
212 |
+
|
213 |
+
|
214 |
+
def main(args):
|
215 |
+
# Prepare output directory
|
216 |
+
output_dir = Path(args.output)
|
217 |
+
if not output_dir.exists():
|
218 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
219 |
+
(output_dir / "line").mkdir(parents=True, exist_ok=True)
|
220 |
+
(output_dir / "flat").mkdir(parents=True, exist_ok=True)
|
221 |
+
(output_dir / "segmentation").mkdir(parents=True, exist_ok=True)
|
222 |
+
(output_dir / "fullcolor").mkdir(parents=True, exist_ok=True)
|
223 |
+
|
224 |
+
# Prepare input images
|
225 |
+
lineart_list = get_image_list(args.lineart, ["png", "jpg", "jpeg"])
|
226 |
+
flat_list = get_image_list(args.flat, ["png", "jpg", "jpeg"])
|
227 |
+
segmentation_list = get_image_list(args.segmentation, ["png", "jpg", "jpeg"])
|
228 |
+
color_list = get_image_list(args.color, ["png", "jpg", "jpeg"])
|
229 |
+
|
230 |
+
# Check image pair validity
|
231 |
+
lineart_list, flat_list, segmentation_list, color_list = check_image_pair_validity(
|
232 |
+
lineart_list, flat_list, segmentation_list, color_list
|
233 |
+
)
|
234 |
+
|
235 |
+
# Prepare tasks for multiprocessing
|
236 |
+
tasks = []
|
237 |
+
for l, f, s, c in zip(lineart_list, flat_list, segmentation_list, color_list):
|
238 |
+
tasks.append((l, f, s, c, str(output_dir)))
|
239 |
+
|
240 |
+
# Use multiprocessing to process images in parallel
|
241 |
+
num_processes = args.num_process if args.num_process else multiprocessing.cpu_count() // 2
|
242 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
243 |
+
list(tqdm(pool.imap_unordered(process_single_image, tasks), total=len(tasks), desc="Processing images"))
|
244 |
+
|
245 |
+
|
246 |
+
if __name__ == "__main__":
|
247 |
+
args = parse_args()
|
248 |
+
main(args)
|
scripts/run_tag_filter.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import sys
|
6 |
+
from pathlib import Path, PurePath
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
current_file_path = Path(__file__).resolve()
|
12 |
+
sys.path.insert(0, str(current_file_path.parent.parent))
|
13 |
+
|
14 |
+
from src.detectors import AnimeDetector
|
15 |
+
from src.pipelines import TagAndFilteringPipeline
|
16 |
+
from src.taggers import WaifuDiffusionTagger
|
17 |
+
from src.utils.device import determine_accelerator
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
+
|
22 |
+
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(description="Filtering script")
|
25 |
+
|
26 |
+
parser.add_argument(
|
27 |
+
"--input_dir",
|
28 |
+
"-i",
|
29 |
+
type=str,
|
30 |
+
required=True,
|
31 |
+
help="Directory containing images to filter",
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--output_dir",
|
35 |
+
"-o",
|
36 |
+
type=str,
|
37 |
+
required=True,
|
38 |
+
help="Directory to save filtered images",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--ext",
|
42 |
+
"-e",
|
43 |
+
type=str,
|
44 |
+
nargs="+",
|
45 |
+
default=["png"],
|
46 |
+
help="File extension of images to filter (default: png)",
|
47 |
+
)
|
48 |
+
|
49 |
+
return parser.parse_args()
|
50 |
+
|
51 |
+
|
52 |
+
def get_image_list(input_dir: str, ext: List[str]):
|
53 |
+
"""
|
54 |
+
Get a list of images from the input directory with the specified extensions.
|
55 |
+
Args:
|
56 |
+
input_dir (str): Directory containing images to filter.
|
57 |
+
ext (list): List of file extensions to filter by.
|
58 |
+
Returns:
|
59 |
+
list: List of image file paths.
|
60 |
+
"""
|
61 |
+
image_list = []
|
62 |
+
for ext in ext:
|
63 |
+
image_list.extend(Path(input_dir).glob(f"*.{ext}"))
|
64 |
+
return image_list
|
65 |
+
|
66 |
+
|
67 |
+
def write_image_caption_file(image_list: List[PurePath], captions: List[str], output_dir: str):
|
68 |
+
"""
|
69 |
+
Writes a caption file
|
70 |
+
|
71 |
+
This function generates a text file named "captions.txt" in the specified output directory.
|
72 |
+
Each line in the file contains the image name (without extension) followed by its caption,
|
73 |
+
separated by a colon.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
image_list (List[PurePath]): A list of image file paths. Each path should be a PurePath object.
|
77 |
+
captions (List[str]): A list of captions corresponding to the images in `image_list`.
|
78 |
+
output_dir (str): The directory where the "captions.txt" file will be created.
|
79 |
+
|
80 |
+
Example:
|
81 |
+
image_list = [PurePath("image1.jpg"), PurePath("image2.jpg")]
|
82 |
+
captions = ["A beautiful sunset.", "A serene mountain view."]
|
83 |
+
output_dir = "/path/to/output"
|
84 |
+
write_image_caption_file(image_list, captions, output_dir)
|
85 |
+
"""
|
86 |
+
caption_file = Path(output_dir) / "captions.txt"
|
87 |
+
lines = []
|
88 |
+
for img_path, caption in zip(image_list, captions):
|
89 |
+
img_name = img_path.stem
|
90 |
+
line = f"{img_name}: {caption}\n"
|
91 |
+
lines.append(line)
|
92 |
+
|
93 |
+
with open(caption_file, "w") as f:
|
94 |
+
f.writelines(lines)
|
95 |
+
|
96 |
+
|
97 |
+
def main(args):
|
98 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
99 |
+
|
100 |
+
# 1. Initialize the filtering pipeline
|
101 |
+
device = determine_accelerator()
|
102 |
+
logger.info(f"Using device: {device}")
|
103 |
+
|
104 |
+
logger.info("Initializing filtering pipeline...")
|
105 |
+
detector = AnimeDetector(
|
106 |
+
repo_id="deepghs/anime_face_detection",
|
107 |
+
model_name="face_detect_v1.4_s",
|
108 |
+
hf_token=None,
|
109 |
+
)
|
110 |
+
|
111 |
+
tagger = WaifuDiffusionTagger(device=device)
|
112 |
+
|
113 |
+
filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)
|
114 |
+
|
115 |
+
# 2. Load images from the input directory
|
116 |
+
logger.info(f"Loading images from {args.input_dir}...")
|
117 |
+
image_list = get_image_list(args.input_dir, args.ext)
|
118 |
+
images = [Image.open(img_path).convert("RGB") for img_path in image_list]
|
119 |
+
logger.info(f"Found {len(images)} images.")
|
120 |
+
|
121 |
+
# 3. Filter images using the filtering pipeline
|
122 |
+
logger.info(f"Filtering images...")
|
123 |
+
filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
|
124 |
+
|
125 |
+
filter_flags = filter_output.filter_flags
|
126 |
+
tags = filter_output.tags
|
127 |
+
captions = [",".join(tag) for tag in tags]
|
128 |
+
logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
|
129 |
+
|
130 |
+
# 4. Save filtered images and captions
|
131 |
+
write_image_caption_file(image_list, captions, args.input_dir) # Write captions to input_dir
|
132 |
+
|
133 |
+
logger.info(f"Copying filtered images to {args.output_dir}...")
|
134 |
+
filtered_images = [img for img, flag in zip(image_list, filter_flags) if flag]
|
135 |
+
filtered_captions = [caption for caption, flag in zip(captions, filter_flags) if flag]
|
136 |
+
|
137 |
+
for img_path in filtered_images:
|
138 |
+
img_name = img_path.stem
|
139 |
+
output_path = Path(args.output_dir) / f"{img_name}.png"
|
140 |
+
shutil.copy(img_path, output_path)
|
141 |
+
|
142 |
+
write_image_caption_file(filtered_images, filtered_captions, args.output_dir)
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
args = parse_args()
|
147 |
+
main(args)
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (200 Bytes). View file
|
|
src/detectors/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .imgutils_detector import AnimeDetector
|
2 |
+
|
3 |
+
__all__ = ["AnimeDetector"]
|
src/detectors/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (296 Bytes). View file
|
|
src/detectors/__pycache__/imgutils_detector.cpython-312.pyc
ADDED
Binary file (7.3 kB). View file
|
|
src/detectors/imgutils_detector.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from typing import Generic, Optional, TypeVar
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import imgutils
|
6 |
+
import numpy as np
|
7 |
+
from imgutils.generic.yolo import (
|
8 |
+
_image_preprocess,
|
9 |
+
_rtdetr_postprocess,
|
10 |
+
_yolo_postprocess,
|
11 |
+
rgb_encode,
|
12 |
+
)
|
13 |
+
from PIL import Image, ImageDraw
|
14 |
+
|
15 |
+
T = TypeVar("T", int, float)
|
16 |
+
|
17 |
+
REPO_IDS = {
|
18 |
+
"head": "deepghs/anime_head_detection",
|
19 |
+
"face": "deepghs/anime_face_detection",
|
20 |
+
"eye": "deepghs/anime_eye_detection",
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class DetectorOutput(Generic[T]):
|
26 |
+
bboxes: list[list[T]] = field(default_factory=list)
|
27 |
+
masks: list[Image.Image] = field(default_factory=list)
|
28 |
+
confidences: list[float] = field(default_factory=list)
|
29 |
+
previews: Optional[Image.Image] = None
|
30 |
+
|
31 |
+
|
32 |
+
class AnimeDetector:
|
33 |
+
"""
|
34 |
+
A class used to perform object detection on anime images.
|
35 |
+
Please refer to the `imgutils` documentation for more information on the available models.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, repo_id: str, model_name: str, hf_token: Optional[str] = None):
|
39 |
+
model_manager = imgutils.generic.yolo._open_models_for_repo_id(
|
40 |
+
repo_id, hf_token=hf_token
|
41 |
+
)
|
42 |
+
model, max_infer_size, labels = model_manager._open_model(model_name)
|
43 |
+
|
44 |
+
self.model = model
|
45 |
+
|
46 |
+
self.max_infer_size = max_infer_size
|
47 |
+
self.labels = labels
|
48 |
+
self.model_type = model_manager._get_model_type(model_name)
|
49 |
+
|
50 |
+
def __call__(
|
51 |
+
self,
|
52 |
+
image: Image.Image,
|
53 |
+
conf_threshold: float = 0.3,
|
54 |
+
iou_threshold: float = 0.7,
|
55 |
+
allow_dynamic: bool = False,
|
56 |
+
) -> DetectorOutput[float]:
|
57 |
+
"""
|
58 |
+
Perform object detection on the given image.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
image (Image.Image): The input image on which to perform detection.
|
62 |
+
conf_threshold (float, optional): Confidence threshold for detection. Defaults to 0.3.
|
63 |
+
iou_threshold (float, optional): Intersection over Union (IoU) threshold for detection. Defaults to 0.7.
|
64 |
+
allow_dynamic (bool, optional): Whether to allow dynamic resizing of the image. Defaults to False.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
DetectorOutput[float]: The detection results, including bounding boxes, masks, confidences, and a preview image.
|
68 |
+
|
69 |
+
Raises:
|
70 |
+
ValueError: If the model type is unknown.
|
71 |
+
"""
|
72 |
+
# Preprocessing
|
73 |
+
new_image, old_size, new_size = _image_preprocess(
|
74 |
+
image, self.max_infer_size, allow_dynamic=allow_dynamic
|
75 |
+
)
|
76 |
+
data = rgb_encode(new_image)[None, ...]
|
77 |
+
|
78 |
+
# Start detection
|
79 |
+
(output,) = self.model.run(["output0"], {"images": data})
|
80 |
+
|
81 |
+
# Postprocessing
|
82 |
+
if self.model_type == "yolo":
|
83 |
+
output = _yolo_postprocess(
|
84 |
+
output=output[0],
|
85 |
+
conf_threshold=conf_threshold,
|
86 |
+
iou_threshold=iou_threshold,
|
87 |
+
old_size=old_size,
|
88 |
+
new_size=new_size,
|
89 |
+
labels=self.labels,
|
90 |
+
)
|
91 |
+
elif self.model_type == "rtdetr":
|
92 |
+
output = _rtdetr_postprocess(
|
93 |
+
output=output[0],
|
94 |
+
conf_threshold=conf_threshold,
|
95 |
+
iou_threshold=iou_threshold,
|
96 |
+
old_size=old_size,
|
97 |
+
new_size=new_size,
|
98 |
+
labels=self.labels,
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
raise ValueError(
|
102 |
+
f"Unknown object detection model type - {self.model_type!r}."
|
103 |
+
) # pragma: no cover
|
104 |
+
|
105 |
+
if len(output) == 0:
|
106 |
+
return DetectorOutput()
|
107 |
+
|
108 |
+
bboxes = [x[0] for x in output] # [x0, y0, x1, y1]
|
109 |
+
masks = create_mask_from_bbox(bboxes, image.size)
|
110 |
+
confidences = [x[2] for x in output]
|
111 |
+
|
112 |
+
# Create a preview image
|
113 |
+
previews = []
|
114 |
+
for mask in masks:
|
115 |
+
np_image = np.array(image)
|
116 |
+
np_mask = np.array(mask)
|
117 |
+
preview = cv2.bitwise_and(
|
118 |
+
np_image, cv2.cvtColor(np_mask, cv2.COLOR_GRAY2BGR)
|
119 |
+
)
|
120 |
+
preview = Image.fromarray(preview)
|
121 |
+
previews.append(preview)
|
122 |
+
|
123 |
+
return DetectorOutput(
|
124 |
+
bboxes=bboxes, masks=masks, confidences=confidences, previews=previews
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def create_mask_from_bbox(
|
129 |
+
bboxes: list[list[float]], shape: tuple[int, int]
|
130 |
+
) -> list[Image.Image]:
|
131 |
+
"""
|
132 |
+
Creates a list of binary masks from bounding boxes.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
bboxes (list[list[float]]): A list of bounding boxes, where each bounding box is represented
|
136 |
+
by a list of four float values [x_min, y_min, x_max, y_max].
|
137 |
+
shape (tuple[int, int]): The shape of the mask (height, width).
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
list[Image.Image]: A list of PIL Image objects representing the binary masks.
|
141 |
+
"""
|
142 |
+
masks = []
|
143 |
+
for bbox in bboxes:
|
144 |
+
mask = Image.new("L", shape, 0)
|
145 |
+
mask_draw = ImageDraw.Draw(mask)
|
146 |
+
mask_draw.rectangle(bbox, fill=255)
|
147 |
+
masks.append(mask)
|
148 |
+
return masks
|
149 |
+
|
150 |
+
|
151 |
+
def create_bbox_from_mask(
|
152 |
+
masks: list[Image.Image], shape: tuple[int, int]
|
153 |
+
) -> list[list[int]]:
|
154 |
+
"""
|
155 |
+
Create bounding boxes from a list of mask images.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
masks (list[Image.Image]): A list of PIL Image objects representing the masks.
|
159 |
+
shape (tuple[int, int]): A tuple representing the desired shape (width, height) to resize the masks.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
list[list[int]]: A list of bounding boxes, where each bounding box is represented as a list of four integers [left, upper, right, lower].
|
163 |
+
"""
|
164 |
+
bboxes = []
|
165 |
+
for mask in masks:
|
166 |
+
mask = mask.resize(shape)
|
167 |
+
bbox = mask.getbbox()
|
168 |
+
if bbox is not None:
|
169 |
+
bboxes.append(list(bbox))
|
170 |
+
return bboxes
|
src/oskar_crop/__pycache__/detect_and_crop.cpython-312.pyc
ADDED
Binary file (3.37 kB). View file
|
|
src/oskar_crop/detect_and_crop.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Tuple
|
4 |
+
import logging
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision.transforms.v2 import ToPILImage
|
7 |
+
from scripts.parse_cut_from_page import extract_cutbox_coordinates
|
8 |
+
|
9 |
+
from src.detectors import AnimeDetector
|
10 |
+
from src.pipelines import TagAndFilteringPipeline
|
11 |
+
from src.taggers import WaifuDiffusionTagger
|
12 |
+
from src.utils.device import determine_accelerator
|
13 |
+
|
14 |
+
topil = ToPILImage()
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
logging.basicConfig(level=logging.INFO)
|
17 |
+
|
18 |
+
# 1. Initialize the filtering pipeline
|
19 |
+
device = determine_accelerator()
|
20 |
+
logger.info(f"Using device: {device}")
|
21 |
+
|
22 |
+
logger.info("Initializing filtering pipeline...")
|
23 |
+
detector = AnimeDetector(
|
24 |
+
repo_id="deepghs/anime_face_detection",
|
25 |
+
model_name="face_detect_v1.4_s",
|
26 |
+
hf_token=None,
|
27 |
+
)
|
28 |
+
|
29 |
+
tagger = WaifuDiffusionTagger(device=device)
|
30 |
+
|
31 |
+
filtering_pipeline = TagAndFilteringPipeline(tagger=tagger, detector=detector)
|
32 |
+
|
33 |
+
|
34 |
+
def process_single_image(lineart_pil_img: Image.Image) -> List[Tuple[int, Tuple[int, int, int, int]]]:
|
35 |
+
"""
|
36 |
+
Worker function to process a single set of images.
|
37 |
+
Opens images, extracts bounding boxes, crops and filters the cut images.
|
38 |
+
"""
|
39 |
+
try:
|
40 |
+
line_img = lineart_pil_img.convert("RGB")
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Error loading images for {lineart_pil_img}: {e}")
|
43 |
+
return
|
44 |
+
|
45 |
+
bounding_boxes = extract_cutbox_coordinates(line_img)
|
46 |
+
|
47 |
+
images = [topil(np.array(line_img)[top:bottom, left:right]) for (left, top, right, bottom) in bounding_boxes]
|
48 |
+
# 3. Filter images using the filtering pipeline
|
49 |
+
logger.info(f"Filtering images...")
|
50 |
+
filter_output = filtering_pipeline(images, batch_size=32, tag_threshold=0.3, conf_threshold=0.3, iou_threshold=0.7)
|
51 |
+
filter_flags = filter_output.filter_flags
|
52 |
+
logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)}.")
|
53 |
+
filtered_bboxes = [bb for bb, flag in zip(bounding_boxes, filter_flags) if flag]
|
54 |
+
index_added_filtered_bboxes = [(i+1, (left, top, right - left, bottom - top)) for i, (left, top, right, bottom) in enumerate(filtered_bboxes)]
|
55 |
+
|
56 |
+
return index_added_filtered_bboxes
|
src/pipelines/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .pipeline_single_character_filtering import TagAndFilteringPipeline
|
2 |
+
|
3 |
+
__all__ = ["TagAndFilteringPipeline"]
|
src/pipelines/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (324 Bytes). View file
|
|
src/pipelines/__pycache__/pipeline_single_character_filtering.cpython-312.pyc
ADDED
Binary file (9.88 kB). View file
|
|
src/pipelines/pipeline_single_character_filtering.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from ..detectors.imgutils_detector import AnimeDetector
|
9 |
+
from ..taggers import WaifuDiffusionTagger, sort_tags
|
10 |
+
from ..utils.timer import ElapsedTimer
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class TagAndFilteringOutput:
|
17 |
+
filter_flags: List[bool]
|
18 |
+
tags: List[str]
|
19 |
+
|
20 |
+
|
21 |
+
class TagAndFilteringPipeline:
|
22 |
+
"""
|
23 |
+
TagAndFilteringPipeline is a pipeline for processing images by tagging them and filtering based on tags and face detection.
|
24 |
+
|
25 |
+
Attributes:
|
26 |
+
tagger (WaifuDiffusionTagger): An instance of the WaifuDiffusionTagger used for tagging images.
|
27 |
+
detector (AnimeDetector): An instance of the AnimeDetector used for detecting faces in images.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, tagger: WaifuDiffusionTagger, detector: AnimeDetector):
|
31 |
+
self.tagger = tagger
|
32 |
+
self.detector = detector
|
33 |
+
|
34 |
+
def __call__(self, images: List[Image.Image], *args, **kwargs) -> TagAndFilteringOutput:
|
35 |
+
"""
|
36 |
+
Processes a list of images by tagging and filtering them based on tags and face detection.
|
37 |
+
Args:
|
38 |
+
images (List[Image.Image]): A list of images to process.
|
39 |
+
batch_size (int, optional): The batch size for processing images. Default is 32.
|
40 |
+
tag_threshold (float, optional): The threshold for tag confidence. Default is 0.3.
|
41 |
+
sort_mode (str, optional): The mode for sorting tags. Default is "score".
|
42 |
+
include_tags (List[str], optional): Tags to include during filtering. Default is ["solo"].
|
43 |
+
exclude_tags (List[str], optional): Tags to exclude during filtering. Default is ["head_out_of_frame", "out_of_frame"].
|
44 |
+
conf_threshold (float, optional): Confidence threshold for face detection. Default is 0.3.
|
45 |
+
iou_threshold (float, optional): IOU threshold for face detection. Default is 0.7.
|
46 |
+
filter_by_tags (bool, optional): Whether to filter images based on tags. Default is True.
|
47 |
+
filter_by_faces (bool, optional): Whether to filter images based on face detection. Default is True.
|
48 |
+
Returns:
|
49 |
+
FilterOutput: An object containing filter flags and captions for the processed images.
|
50 |
+
"""
|
51 |
+
if not isinstance(images, list):
|
52 |
+
images = [images]
|
53 |
+
|
54 |
+
# Tagging parameters
|
55 |
+
batch_size = kwargs.pop("batch_size", 32)
|
56 |
+
tag_threshold = kwargs.pop("tag_threshold", 0.3)
|
57 |
+
tag_sort_mode = kwargs.pop("sort_mode", "score")
|
58 |
+
include_tags = kwargs.pop("include_tags", ["solo"])
|
59 |
+
exclude_tags = kwargs.pop("exclude_tags", ["head_out_of_frame", "out_of_frame", "chibi", "negative_space"])
|
60 |
+
# Face detection parameters
|
61 |
+
conf_threshold = kwargs.pop("conf_threshold", 0.3)
|
62 |
+
iou_threshold = kwargs.pop("iou_threshold", 0.7)
|
63 |
+
# Etc.
|
64 |
+
minimum_resolution = kwargs.pop("minimum_resolution", 512)
|
65 |
+
|
66 |
+
filter_flags = [True] * len(images)
|
67 |
+
|
68 |
+
if kwargs.pop("fiter_by_resolution", True):
|
69 |
+
with ElapsedTimer("Resolution-based Filtering", logger=logger):
|
70 |
+
for idx, image in enumerate(images):
|
71 |
+
if min(image.size) < minimum_resolution:
|
72 |
+
filter_flags[idx] = False
|
73 |
+
|
74 |
+
logger.info(f"Filtered {sum(filter_flags)} images out of {len(images)} based on resolution. ({minimum_resolution}px)")
|
75 |
+
|
76 |
+
with ElapsedTimer("Tagging", logger=logger):
|
77 |
+
tags = self.tagging(images, threshold=tag_threshold, sort_mode=tag_sort_mode, batch_size=batch_size)
|
78 |
+
|
79 |
+
if kwargs.pop("filter_by_tags", True):
|
80 |
+
with ElapsedTimer("Tag-based Filtering", logger=logger):
|
81 |
+
filter_flags = self.tag_based_filtering(
|
82 |
+
tags,
|
83 |
+
filter_flags,
|
84 |
+
include_tags=include_tags,
|
85 |
+
exclude_tags=exclude_tags,
|
86 |
+
)
|
87 |
+
|
88 |
+
if kwargs.pop("filter_by_faces", True):
|
89 |
+
with ElapsedTimer("Face-based Filtering", logger=logger):
|
90 |
+
filter_flags = self.face_based_filtering(
|
91 |
+
images,
|
92 |
+
filter_flags=filter_flags,
|
93 |
+
conf_threshold=conf_threshold,
|
94 |
+
iou_threshold=iou_threshold,
|
95 |
+
)
|
96 |
+
|
97 |
+
return TagAndFilteringOutput(filter_flags=filter_flags, tags=tags)
|
98 |
+
|
99 |
+
def tagging(
|
100 |
+
self,
|
101 |
+
images,
|
102 |
+
threshold: float = 0.3,
|
103 |
+
sort_mode: str = "score",
|
104 |
+
batch_size: int = 32,
|
105 |
+
) -> List[List[str]]:
|
106 |
+
"""
|
107 |
+
Tags a list of images and returns their captions.
|
108 |
+
Parameters:
|
109 |
+
images (List[Image.Image]): A list of images to tag.
|
110 |
+
threshold (float, optional): The threshold for tag confidence. Default is 0.3.
|
111 |
+
sort_mode (str, optional): The mode for sorting tags. Default is "score".
|
112 |
+
batch_size (int, optional): The batch size for tagging images. Default is 32.
|
113 |
+
Returns:
|
114 |
+
List[str]: A list of captions for the tagged images.
|
115 |
+
"""
|
116 |
+
tags = []
|
117 |
+
for i in tqdm(range(0, len(images), batch_size), desc="Tagging"):
|
118 |
+
batch = images[i : i + batch_size]
|
119 |
+
tagger_output = self.tagger(batch, threshold=threshold)
|
120 |
+
|
121 |
+
tags.extend([sort_tags(tags, mode=sort_mode) for tags in tagger_output])
|
122 |
+
|
123 |
+
return tags
|
124 |
+
|
125 |
+
def tag_based_filtering(
|
126 |
+
self,
|
127 |
+
tags: List[List[str]],
|
128 |
+
filter_flags: List[bool],
|
129 |
+
include_tags=["solo"],
|
130 |
+
exclude_tags=["head_out_of_frame", "out_of_frame"],
|
131 |
+
) -> List[bool]:
|
132 |
+
"""
|
133 |
+
Filters images based on their tags.
|
134 |
+
Parameters:
|
135 |
+
tags (List[List[str]]): A list of tags for the images.
|
136 |
+
filter_flags (List[bool]): A list of boolean flags indicating whether each image passes filtering.
|
137 |
+
include_tags (List[str], optional): Tags to include during filtering. Default is ["solo"].
|
138 |
+
exclude_tags (List[str], optional): Tags to exclude during filtering. Default is ["head_out_of_frame", "out_of_frame"].
|
139 |
+
Returns:
|
140 |
+
Tuple[List[bool], List[str]]: Updated filter flags and captions after tag-based filtering.
|
141 |
+
"""
|
142 |
+
for idx, tag in tqdm(enumerate(tags), desc="Tag-based Filtering", total=len(tags)):
|
143 |
+
if any(include_tag in tag for include_tag in include_tags) and all(
|
144 |
+
exclude_tag not in tag for exclude_tag in exclude_tags
|
145 |
+
):
|
146 |
+
filter_flags[idx] = True
|
147 |
+
else:
|
148 |
+
filter_flags[idx] = False
|
149 |
+
|
150 |
+
return filter_flags
|
151 |
+
|
152 |
+
def face_based_filtering(
|
153 |
+
self, images: List[Image.Image], filter_flags: List[bool], conf_threshold: float = 0.3, iou_threshold=0.7
|
154 |
+
) -> List[bool]:
|
155 |
+
"""
|
156 |
+
Filters images based on face detection.
|
157 |
+
Parameters:
|
158 |
+
images (List[Image.Image]): A list of images to filter.
|
159 |
+
filter_flags (List[bool]): A list of boolean flags indicating whether each image passes filtering.
|
160 |
+
conf_threshold (float, optional): Confidence threshold for face detection. Default is 0.3.
|
161 |
+
iou_threshold (float, optional): IOU threshold for face detection. Default is 0.7.
|
162 |
+
Returns:
|
163 |
+
List[bool]: Updated filter flags after face-based filtering.
|
164 |
+
"""
|
165 |
+
for idx, image in tqdm(enumerate(images), desc="Face-based Filtering", total=len(images)):
|
166 |
+
if not filter_flags[idx]:
|
167 |
+
continue
|
168 |
+
|
169 |
+
detector_output = self.detector(image, conf_threshold=conf_threshold, iou_threshold=iou_threshold)
|
170 |
+
if len(detector_output.bboxes) != 1:
|
171 |
+
filter_flags[idx] = False
|
172 |
+
else:
|
173 |
+
filter_flags[idx] = True
|
174 |
+
|
175 |
+
return filter_flags
|
src/taggers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .order import sort_tags
|
2 |
+
from .tagger import WaifuDiffusionTagger
|
3 |
+
|
4 |
+
__all__ = ["WaifuDiffusionTagger", "sort_tags"]
|
src/taggers/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (339 Bytes). View file
|
|
src/taggers/__pycache__/order.cpython-312.pyc
ADDED
Binary file (3.78 kB). View file
|
|
src/taggers/__pycache__/tagger.cpython-312.pyc
ADDED
Binary file (12.2 kB). View file
|
|
src/taggers/filter.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple
|
2 |
+
|
3 |
+
TAG_GROUP = {
|
4 |
+
"frame": [
|
5 |
+
"portrait",
|
6 |
+
"upper_body",
|
7 |
+
"lower_body",
|
8 |
+
"cowboy_shot",
|
9 |
+
"feet_out_of_frame",
|
10 |
+
"full_body",
|
11 |
+
"wide_shot",
|
12 |
+
"very_wide_shot",
|
13 |
+
],
|
14 |
+
"frame_2": [
|
15 |
+
"close-up",
|
16 |
+
"cut-in",
|
17 |
+
"cropped",
|
18 |
+
],
|
19 |
+
"view_angle": [
|
20 |
+
"dutch_angle",
|
21 |
+
"from_above",
|
22 |
+
"from_behind",
|
23 |
+
"from_below",
|
24 |
+
"from_side",
|
25 |
+
# "multiple_views",
|
26 |
+
"sideways",
|
27 |
+
"straight-on",
|
28 |
+
"three_quarter_view",
|
29 |
+
"upside-down",
|
30 |
+
],
|
31 |
+
"focus": ["eye_focus"],
|
32 |
+
"lip_action": [
|
33 |
+
"parted_lips",
|
34 |
+
"biting_own_lip",
|
35 |
+
"pursed_lips",
|
36 |
+
"spread_lips",
|
37 |
+
"open_mouth",
|
38 |
+
"closed_mouth",
|
39 |
+
],
|
40 |
+
"eye": ["closed_eyes", "one_eye_closed"],
|
41 |
+
"gaze": [
|
42 |
+
"eye_contact",
|
43 |
+
"looking_afar",
|
44 |
+
"looking_around",
|
45 |
+
"looking_at_another",
|
46 |
+
"looking_at_hand",
|
47 |
+
"looking_at_hands",
|
48 |
+
"looking_at_mirror",
|
49 |
+
"looking_at_self",
|
50 |
+
"looking_at_viewer",
|
51 |
+
"looking_away",
|
52 |
+
"looking_back",
|
53 |
+
"looking_down",
|
54 |
+
"looking_outside",
|
55 |
+
"looking_over_eyewear",
|
56 |
+
"looking_through_own_legs",
|
57 |
+
"looking_to_the_side",
|
58 |
+
"looking_up",
|
59 |
+
],
|
60 |
+
"emotion": [
|
61 |
+
"smile",
|
62 |
+
"angry",
|
63 |
+
"anger_vein",
|
64 |
+
"annoyed",
|
65 |
+
"clenched_teeth",
|
66 |
+
"scowl",
|
67 |
+
"blush",
|
68 |
+
"embarrassed",
|
69 |
+
"bored",
|
70 |
+
"confused",
|
71 |
+
"crazy",
|
72 |
+
"despair",
|
73 |
+
"disappointed",
|
74 |
+
"disgust",
|
75 |
+
"envy",
|
76 |
+
"excited",
|
77 |
+
"exhausted",
|
78 |
+
"expressioinless",
|
79 |
+
"furrowed_brow",
|
80 |
+
"happy",
|
81 |
+
"sad",
|
82 |
+
"depressed",
|
83 |
+
"frown",
|
84 |
+
"tears",
|
85 |
+
"scared",
|
86 |
+
"serious",
|
87 |
+
"sleepy",
|
88 |
+
"surprised",
|
89 |
+
"thinking",
|
90 |
+
"pain",
|
91 |
+
],
|
92 |
+
}
|
93 |
+
|
94 |
+
|
95 |
+
def parse_valid_tags(
|
96 |
+
input_tags: Dict[str, float], valid_tags=TAG_GROUP
|
97 |
+
) -> Dict[str, Tuple[str, float]]:
|
98 |
+
"""
|
99 |
+
Parses valid tags from the input tags based on predefined tag groups.
|
100 |
+
Args:
|
101 |
+
input_tags (Dict[str, float]): A dictionary of tags with their confidence scores, sorted by confidence in descending order.
|
102 |
+
valid_tags (dict, optional): A dictionary where keys are tag groups and values are lists of valid tags. Defaults to TAG_GROUP.
|
103 |
+
Returns:
|
104 |
+
dict: A dictionary where keys are tag groups and values are the first valid tag found in the input tags for each group.
|
105 |
+
"""
|
106 |
+
output_tags = {}
|
107 |
+
for tag_group, tags in valid_tags.items():
|
108 |
+
for tag in tags:
|
109 |
+
if tag.replace(" ", "_") in input_tags:
|
110 |
+
output_tags[tag_group] = (tag, input_tags[tag])
|
111 |
+
break # parse only one tag from each tag group and return the tag group and tag
|
112 |
+
|
113 |
+
return output_tags
|
src/taggers/order.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adopted by https://github.com/deepghs/imgutils/blob/main/imgutils/tagging/order.py
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
from typing import List, Literal, Mapping, Union
|
5 |
+
|
6 |
+
|
7 |
+
def sort_tags(
|
8 |
+
tags: Union[List[str], Mapping[str, float]], mode: Literal["original", "shuffle", "score"] = "score"
|
9 |
+
) -> List[str]:
|
10 |
+
"""
|
11 |
+
Sort the input list or mapping of tags by specified mode.
|
12 |
+
|
13 |
+
Tags can represent people counts (e.g., '1girl', '2boys'), and 'solo' tags.
|
14 |
+
|
15 |
+
:param tags: List or mapping of tags to be sorted.
|
16 |
+
:type tags: Union[List[str], Mapping[str, float]]
|
17 |
+
:param mode: The mode for sorting the tags. Options: 'original' (original order),
|
18 |
+
'shuffle' (random shuffle), 'score' (sorted by score if available).
|
19 |
+
:type mode: Literal['original', 'shuffle', 'score']
|
20 |
+
:return: Sorted list of tags based on the specified mode.
|
21 |
+
:rtype: List[str]
|
22 |
+
:raises ValueError: If an unknown sort mode is provided.
|
23 |
+
:raises TypeError: If the input tags are of unsupported type or if mode is 'score'
|
24 |
+
and the input is a list (as it does not have scores).
|
25 |
+
|
26 |
+
Examples:
|
27 |
+
Sorting tags in original order:
|
28 |
+
|
29 |
+
>>> from imgutils.tagging import sort_tags
|
30 |
+
>>>
|
31 |
+
>>> tags = ['1girls', 'solo', 'red_hair', 'cat ears']
|
32 |
+
>>> sort_tags(tags, mode='original')
|
33 |
+
['solo', '1girls', 'red_hair', 'cat ears']
|
34 |
+
>>>
|
35 |
+
>>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
|
36 |
+
>>> sort_tags(tags, mode='original')
|
37 |
+
['solo', '1girls', 'red_hair', 'cat_ears']
|
38 |
+
|
39 |
+
Sorting tags by score (for a mapping of tags with scores):
|
40 |
+
|
41 |
+
>>> from imgutils.tagging import sort_tags
|
42 |
+
>>>
|
43 |
+
>>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
|
44 |
+
>>> sort_tags(tags)
|
45 |
+
['solo', '1girls', 'red_hair', 'cat_ears']
|
46 |
+
|
47 |
+
Shuffling tags (output is not unique)
|
48 |
+
|
49 |
+
>>> from imgutils.tagging import sort_tags
|
50 |
+
>>>
|
51 |
+
>>> tags = ['1girls', 'solo', 'red_hair', 'cat ears']
|
52 |
+
>>> sort_tags(tags, mode='shuffle')
|
53 |
+
['solo', '1girls', 'red_hair', 'cat ears']
|
54 |
+
>>>
|
55 |
+
>>> tags = {'1girls': 0.9, 'solo': 0.95, 'red_hair': 1.0, 'cat_ears': 0.92}
|
56 |
+
>>> sort_tags(tags, mode='shuffle')
|
57 |
+
['solo', '1girls', 'cat_ears', 'red_hair']
|
58 |
+
"""
|
59 |
+
if mode not in {"original", "shuffle", "score"}:
|
60 |
+
raise ValueError(f"Unknown sort_mode, 'original', " f"'shuffle' or 'score' expected but {mode!r} found.")
|
61 |
+
npeople_tags = []
|
62 |
+
remaining_tags = []
|
63 |
+
|
64 |
+
if "solo" in tags:
|
65 |
+
npeople_tags.append("solo")
|
66 |
+
|
67 |
+
for tag in tags:
|
68 |
+
if tag == "solo":
|
69 |
+
continue
|
70 |
+
if re.fullmatch(r"^\d+\+?(boy|girl)s?$", tag): # 1girl, 1boy, 2girls, 3boys, 9+girls
|
71 |
+
npeople_tags.append(tag)
|
72 |
+
else:
|
73 |
+
remaining_tags.append(tag)
|
74 |
+
|
75 |
+
if mode == "score":
|
76 |
+
if isinstance(tags, dict):
|
77 |
+
remaining_tags = sorted(remaining_tags, key=lambda x: -tags[x])
|
78 |
+
else:
|
79 |
+
raise TypeError(f"Sort mode {mode!r} not supported for list, " f"for it do not have scores.")
|
80 |
+
elif mode == "shuffle":
|
81 |
+
random.shuffle(remaining_tags)
|
82 |
+
else:
|
83 |
+
pass
|
84 |
+
|
85 |
+
return npeople_tags + remaining_tags
|
src/taggers/tagger.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, Literal, Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import timm
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
from huggingface_hub.utils import HfHubHTTPError
|
13 |
+
from PIL import Image
|
14 |
+
from timm.data import create_transform, resolve_data_config
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
MODEL_REPO_MAP = {
|
19 |
+
"vit": "SmilingWolf/wd-vit-tagger-v3",
|
20 |
+
"swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
|
21 |
+
"convnext": "SmilingWolf/wd-convnext-tagger-v3",
|
22 |
+
"eva-02": "SmilingWolf/wd-eva02-large-tagger-v3",
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class LabelData:
|
28 |
+
names: List[str]
|
29 |
+
rating: List[np.int64]
|
30 |
+
general: List[np.int64]
|
31 |
+
character: List[np.int64]
|
32 |
+
|
33 |
+
|
34 |
+
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
|
35 |
+
# convert to RGB/RGBA if not already (deals with palette images etc.)
|
36 |
+
if image.mode not in ["RGB", "RGBA"]:
|
37 |
+
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
38 |
+
# convert RGBA to RGB with white background
|
39 |
+
if image.mode == "RGBA":
|
40 |
+
canvas = Image.new("RGBA", image.size, (255, 255, 255))
|
41 |
+
canvas.alpha_composite(image)
|
42 |
+
image = canvas.convert("RGB")
|
43 |
+
return image
|
44 |
+
|
45 |
+
|
46 |
+
def pil_pad_square(image: Image.Image) -> Image.Image:
|
47 |
+
w, h = image.size
|
48 |
+
# get the largest dimension so we can pad to a square
|
49 |
+
px = max(image.size)
|
50 |
+
# pad to square with white background
|
51 |
+
canvas = Image.new("RGB", (px, px), (255, 255, 255))
|
52 |
+
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
|
53 |
+
return canvas
|
54 |
+
|
55 |
+
|
56 |
+
class WaifuDiffusionTagger:
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
model_name: Literal["vit", "swinv2", "convnext", "eva-02"] = "eva-02",
|
60 |
+
device: str = "cpu",
|
61 |
+
):
|
62 |
+
if model_name not in MODEL_REPO_MAP.keys():
|
63 |
+
raise ValueError(f"Model {model_name} not found. Available models: {MODEL_REPO_MAP.keys()}")
|
64 |
+
|
65 |
+
repo_id = MODEL_REPO_MAP[model_name]
|
66 |
+
|
67 |
+
self.init_model(repo_id, device)
|
68 |
+
self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
|
69 |
+
|
70 |
+
self.labels = self.load_labels_from_hf(repo_id)
|
71 |
+
|
72 |
+
def init_model(self, repo_id: str, device: str = "cpu"):
|
73 |
+
logger.info(f"Loading taggingmodel from {repo_id}")
|
74 |
+
self.model = timm.create_model("hf-hub:" + repo_id, pretrained=True)
|
75 |
+
|
76 |
+
state_dict = timm.models.load_state_dict_from_hf(repo_id)
|
77 |
+
self.model.load_state_dict(state_dict)
|
78 |
+
self.model.to(device).eval()
|
79 |
+
|
80 |
+
def load_labels_from_hf(self, repo_id: str, revision: Optional[str] = None, token: Optional[str] = None):
|
81 |
+
try:
|
82 |
+
csv_path = hf_hub_download(repo_id, filename="selected_tags.csv", revision=revision, token=token)
|
83 |
+
csv_path = Path(csv_path).resolve()
|
84 |
+
except HfHubHTTPError as e:
|
85 |
+
raise FileNotFoundError(f"Failed to download labels from {repo_id}") from e
|
86 |
+
|
87 |
+
df = pd.read_csv(csv_path, usecols=["name", "category"])
|
88 |
+
tag_data = LabelData(
|
89 |
+
names=df["name"],
|
90 |
+
rating=np.where(df["category"] == 9)[0],
|
91 |
+
general=np.where(df["category"] == 0)[0],
|
92 |
+
character=np.where(df["category"] == 4)[0],
|
93 |
+
)
|
94 |
+
return tag_data
|
95 |
+
|
96 |
+
def prepare_inputs(self, images: List[Image.Image]):
|
97 |
+
inputs = []
|
98 |
+
for image in images:
|
99 |
+
image = pil_ensure_rgb(image)
|
100 |
+
image = pil_pad_square(image)
|
101 |
+
inputs += [self.transform(image)]
|
102 |
+
|
103 |
+
inputs = torch.stack(inputs, dim=0)
|
104 |
+
inputs = inputs[:, [2, 1, 0]] # RGB to BGR
|
105 |
+
|
106 |
+
return inputs.to(self.device, dtype=self.dtype)
|
107 |
+
|
108 |
+
def get_tags(self, probs: torch.Tensor, gen_threshold: float) -> List[Dict[str, float]]:
|
109 |
+
"""
|
110 |
+
Generate tags based on prediction probabilities and a confidence threshold.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
probs (torch.Tensor): A tensor of shape [B, num_labels] containing
|
114 |
+
prediction probabilities for each label, where B is the batch size.
|
115 |
+
gen_threshold (float): The confidence threshold for selecting labels.
|
116 |
+
Only labels with probabilities greater than this threshold will be included.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
List[Dict[str, float]]: A list of dictionaries, where each dictionary
|
120 |
+
corresponds to a batch element and contains label names as keys and
|
121 |
+
their associated probabilities as values. The labels are sorted in
|
122 |
+
descending order of probability.
|
123 |
+
"""
|
124 |
+
# probs: [B, num_labels]
|
125 |
+
gen_labels = []
|
126 |
+
for prob in probs:
|
127 |
+
# Convert indices+probs to labels
|
128 |
+
prob = list(zip(self.labels.names, prob.cpu().numpy()))
|
129 |
+
|
130 |
+
# General labels, pick any where prediction confidence > threshold
|
131 |
+
gen_label = [prob[i] for i in self.labels.general]
|
132 |
+
gen_label = dict([x for x in gen_label if x[1] > gen_threshold])
|
133 |
+
gen_label = dict(sorted(gen_label.items(), key=lambda item: item[1], reverse=True))
|
134 |
+
|
135 |
+
gen_labels += [gen_label]
|
136 |
+
|
137 |
+
return gen_labels
|
138 |
+
|
139 |
+
@torch.inference_mode()
|
140 |
+
def __call__(
|
141 |
+
self,
|
142 |
+
images: Union[Image.Image, List[Image.Image]],
|
143 |
+
threshold: float = 0.3,
|
144 |
+
):
|
145 |
+
"""
|
146 |
+
Processes input images through the model and returns predicted labels based on a threshold.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
images (Union[Image.Image, List[Image.Image]]): A single image or a list of images to be processed.
|
150 |
+
threshold (float, optional): The threshold value for determining labels. Defaults to 0.3.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
List[List[str]]: A list of lists containing predicted labels for each input image.
|
154 |
+
"""
|
155 |
+
if not isinstance(images, list):
|
156 |
+
images = [images]
|
157 |
+
|
158 |
+
inputs = self.prepare_inputs(images)
|
159 |
+
|
160 |
+
outputs = self.model(inputs)
|
161 |
+
outputs = F.sigmoid(outputs)
|
162 |
+
|
163 |
+
labels = self.get_tags(outputs, threshold)
|
164 |
+
return labels
|
165 |
+
|
166 |
+
@torch.inference_mode()
|
167 |
+
def get_image_features(self, images: Union[Image.Image, List[Image.Image]], global_pool: bool = True):
|
168 |
+
"""
|
169 |
+
Extracts features from one or more images using the model.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
images (Union[Image.Image, List[Image.Image]]): A single PIL Image or a list of PIL Images
|
173 |
+
from which features are to be extracted.
|
174 |
+
global_pool (bool, optional): If True, applies global pooling to the extracted features
|
175 |
+
by averaging across all spatial dimensions. If False, returns only the features
|
176 |
+
corresponding to the first token. Defaults to True.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
torch.Tensor: A tensor containing the extracted features. If `global_pool` is True,
|
180 |
+
the features are averaged across spatial dimensions. Otherwise, the features
|
181 |
+
corresponding to the first token are returned.
|
182 |
+
"""
|
183 |
+
if not isinstance(images, list):
|
184 |
+
images = [images]
|
185 |
+
|
186 |
+
inputs = self.prepare_inputs(images)
|
187 |
+
|
188 |
+
features = self.model.forward_features(inputs)
|
189 |
+
|
190 |
+
if global_pool:
|
191 |
+
return features[:, self.model.num_prefix_tokens :].mean(dim=1)
|
192 |
+
else:
|
193 |
+
return features[:, 0]
|
194 |
+
|
195 |
+
@property
|
196 |
+
def device(self):
|
197 |
+
return next(self.model.parameters()).device
|
198 |
+
|
199 |
+
@property
|
200 |
+
def dtype(self):
|
201 |
+
return next(self.model.parameters()).dtype
|
202 |
+
|
203 |
+
|
204 |
+
def show_result_with_confidence(image, tag_result, ax):
|
205 |
+
ax.imshow(image)
|
206 |
+
|
207 |
+
confidence = [[x] for x in tag_result.values()]
|
208 |
+
rowLabels = list(tag_result[0].keys())
|
209 |
+
ax.table(
|
210 |
+
cellText=confidence,
|
211 |
+
loc="bottom",
|
212 |
+
rowLabels=rowLabels,
|
213 |
+
cellLoc="center",
|
214 |
+
colLabels=["Confidence"],
|
215 |
+
)
|
src/utils/__pycache__/device.cpython-312.pyc
ADDED
Binary file (659 Bytes). View file
|
|
src/utils/__pycache__/timer.cpython-312.pyc
ADDED
Binary file (2.46 kB). View file
|
|
src/utils/device.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def determine_accelerator():
|
5 |
+
"""
|
6 |
+
Determine the accelerator to be used based on the environment.
|
7 |
+
"""
|
8 |
+
|
9 |
+
# Check for CUDA availability
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
return "cuda"
|
12 |
+
|
13 |
+
# Check for MPS (Metal Performance Shaders) availability on macOS
|
14 |
+
if torch.backends.mps.is_available():
|
15 |
+
return "mps"
|
16 |
+
|
17 |
+
# Default to CPU if no accelerators are available
|
18 |
+
return "cpu"
|
src/utils/timer.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
|
4 |
+
|
5 |
+
def msec2human(ms) -> str:
|
6 |
+
"""
|
7 |
+
Converts milliseconds to a human-readable string representation.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
ms (int): The input number of milliseconds.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: The formatted string representing the milliseconds in a human-readable format.
|
14 |
+
"""
|
15 |
+
s = ms // 1000 # Calculate the number of seconds
|
16 |
+
m = s // 60 # Calculate the number of minutes
|
17 |
+
h = m // 60 # Calculate the number of hours
|
18 |
+
|
19 |
+
m %= 60 # Get the remaining minutes after calculating hours
|
20 |
+
s %= 60 # Get the remaining seconds after calculating minutes
|
21 |
+
ms %= 1000 # Get the remaining milliseconds after calculating seconds
|
22 |
+
|
23 |
+
if h:
|
24 |
+
return (
|
25 |
+
f"{h} hour {m:2d} min" # Return the formatted string with hours and minutes
|
26 |
+
)
|
27 |
+
if m:
|
28 |
+
return f"{m} min {s:2d} sec" # Return the formatted string with minutes and seconds
|
29 |
+
if s:
|
30 |
+
return f"{s} sec {ms:3d} msec" # Return the formatted string with seconds and milliseconds
|
31 |
+
return f"{ms} msec" # Return the formatted string with milliseconds
|
32 |
+
|
33 |
+
|
34 |
+
class ElapsedTimer:
|
35 |
+
def __init__(self, name, logger=None, unit="ms"):
|
36 |
+
self.name = name
|
37 |
+
self.logger = logger or logging.getLogger(__name__)
|
38 |
+
|
39 |
+
def __enter__(self):
|
40 |
+
self.start_time = time.perf_counter()
|
41 |
+
self.logger.info(f"<{self.name}>: start")
|
42 |
+
return self
|
43 |
+
|
44 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
45 |
+
elapsed_time = time.perf_counter() - self.start_time
|
46 |
+
elapsed_time = msec2human(int(elapsed_time * 1000))
|
47 |
+
|
48 |
+
if exc_type:
|
49 |
+
self.logger.warning(f"<{self.name}> raised {exc_type}, {elapsed_time}")
|
50 |
+
else:
|
51 |
+
self.logger.info(f"<{self.name}>: {elapsed_time}")
|
src/wise_crop/__pycache__/detect_and_crop.cpython-312.pyc
ADDED
Binary file (10.5 kB). View file
|
|
src/wise_crop/detect_and_crop.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from torchvision.transforms.v2 import ToPILImage
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import pipeline
|
6 |
+
import torch
|
7 |
+
from imgutils.detect import detect_heads
|
8 |
+
from src.utils.device import determine_accelerator
|
9 |
+
topil = ToPILImage()
|
10 |
+
|
11 |
+
# 1. Initialize the filtering pipeline
|
12 |
+
device = determine_accelerator()
|
13 |
+
|
14 |
+
print("Loading AI Model...")
|
15 |
+
pipe = pipeline(
|
16 |
+
"image-text-to-text",
|
17 |
+
model="google/gemma-3-12b-it",
|
18 |
+
device=device,
|
19 |
+
torch_dtype=torch.bfloat16,
|
20 |
+
)
|
21 |
+
|
22 |
+
def crop_and_mask_characters_gradio(pil_img):
|
23 |
+
"""
|
24 |
+
Crops character regions from an image, saves them as separate files,
|
25 |
+
and generates binary masks for each cropped region using the Gemini 2.0 Flash Exp model.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
uploaded_file_obj (str): The path to the input image.
|
29 |
+
"""
|
30 |
+
img = np.array(pil_img)
|
31 |
+
|
32 |
+
# Convert the image to grayscale
|
33 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
34 |
+
|
35 |
+
# Apply thresholding to create a binary image
|
36 |
+
_, thresh = cv2.threshold(gray, 253, 255, cv2.THRESH_BINARY_INV)
|
37 |
+
|
38 |
+
# Find contours in the binary image
|
39 |
+
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
40 |
+
|
41 |
+
# Create output directories if they don't exist
|
42 |
+
# cropped_dir = Path(temp_dir) / 'cropped_dir'
|
43 |
+
# masks_dir = cropped_dir
|
44 |
+
|
45 |
+
# os.makedirs(cropped_dir, exist_ok=True)
|
46 |
+
# os.makedirs(masks_dir, exist_ok=True)
|
47 |
+
coord_info_list = []
|
48 |
+
i = 0
|
49 |
+
# Iterate through the contours and crop the regions
|
50 |
+
for contour in contours:
|
51 |
+
# Get the bounding box of the contour
|
52 |
+
x, y, w, h = cv2.boundingRect(contour)
|
53 |
+
if w < 256 or h < 256: # Skip small contours
|
54 |
+
continue
|
55 |
+
|
56 |
+
# Crop the region
|
57 |
+
cropped_img = img[y:y+h, x:x+w]
|
58 |
+
|
59 |
+
messages = [
|
60 |
+
{
|
61 |
+
"role": "system",
|
62 |
+
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"role": "user",
|
66 |
+
"content": [
|
67 |
+
{"type": "image", "image": topil(cropped_img)},
|
68 |
+
{"type": "text", "text": "You are given a black-and-white line drawing as input. Please analyze the image carefully. If the drawing contains the majority of a head or face—meaning most key facial features or the overall shape of the head are visible—respond with 'True'. Otherwise, respond with 'False'. Do not contain any punctuation or extra spaces in your answer. Just respond with 'True' or 'False'"}
|
69 |
+
]
|
70 |
+
}
|
71 |
+
]
|
72 |
+
result = detect_heads(topil(cropped_img))
|
73 |
+
if len(result) == 0:
|
74 |
+
continue
|
75 |
+
|
76 |
+
output = pipe(text=messages, max_new_tokens=200)
|
77 |
+
if output[0]["generated_text"][-1]["content"] == 'False':
|
78 |
+
# print(f"Skipping character {i+1} as it does not contain a head or face.")
|
79 |
+
continue
|
80 |
+
i += 1
|
81 |
+
# Append the coordinates to the list
|
82 |
+
coord_info_list.append((i,(x,y,w,h)))
|
83 |
+
return coord_info_list
|
84 |
+
|