webtoon_cropper / scripts /parse_cut_from_page.py
wise-water's picture
init commit
13aa528
raw
history blame contribute delete
10.5 kB
import argparse
import logging
import multiprocessing
import re
import sys
from pathlib import Path, PurePath
from typing import List, Tuple
import numpy as np
from PIL import Image
from tqdm import tqdm
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="Parse cut from page script.")
parser.add_argument("--lineart", "-l", type=str, required=True, help="Directory of lineart images.")
parser.add_argument("--flat", "-f", type=str, required=True, help="Directory of flat images.")
parser.add_argument(
"--segmentation",
"-s",
type=str,
required=True,
help="Directory of segmentatio.",
)
parser.add_argument("--color", "-c", type=str, required=True, help="Directory of color images.")
parser.add_argument("--output", "-o", type=str, required=True, help="Output directory for parsed images.")
parser.add_argument("--num_process", "-n", type=int, default=None, help="Number of processes to use.")
return parser.parse_args()
def get_image_list(input_dir: str, ext: List[str]):
"""
Get a list of images from the input directory with the specified extensions.
Args:
input_dir (str): Directory containing images to filter.
ext (list): List of file extensions to filter by.
Returns:
list: List of image file paths.
"""
image_list = []
for ext in ext:
image_list.extend(Path(input_dir).glob(f"*.{ext}"))
return image_list
def check_image_pair_validity(
lineart_list: List[PurePath],
flat_list: List[PurePath],
segmentation_list: List[PurePath],
color_list: List[PurePath],
pattern: str = r"\d+_\d+",
) -> Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
"""
Validates and filters lists of image file paths to ensure they correspond to the same IDs
based on a given naming pattern. If the lengths of the input lists are mismatched, the
function filters the lists to include only matching IDs.
Args:
lineart_list (List[PurePath]): List of file paths for lineart images.
flat_path (List[PurePath]): List of file paths for flat images.
segmentation_path (List[PurePath]): List of file paths for segmentation images.
color_path (List[PurePath]): List of file paths for color images.
pattern (str, optional): Regular expression pattern to extract IDs from file names.
Defaults to r"\d+_\d+".
Returns:
Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]:
A tuple containing four lists of file paths (lineart, flat, segmentation, color)
that have been filtered to ensure matching IDs.
"""
pattern = re.compile(pattern)
# Sort the lists based on the pattern
lineart_list = sorted(lineart_list, key=lambda x: pattern.match(x.name).group(0))
flat_list = sorted(flat_list, key=lambda x: pattern.match(x.name).group(0))
segmentation_list = sorted(segmentation_list, key=lambda x: pattern.match(x.name).group(0))
color_list = sorted(color_list, key=lambda x: pattern.match(x.name).group(0))
# Check if the lengths of the lists are equal
if (
len(lineart_list) != len(flat_list)
or len(lineart_list) != len(segmentation_list)
or len(lineart_list) != len(color_list)
):
# If the lengths are not equal, we need to filter the lists based on the pattern
logger.warning(
f"Length mismatch: lineart({len(lineart_list)}), flat({len(flat_list)}), segmentation({len(segmentation_list)}), color({len(color_list)})"
)
new_lineart_list = []
new_flat_list = []
new_segmentation_list = []
new_color_list = []
for lineart_path in lineart_list:
lineart_name = lineart_path.name
lineart_match = pattern.match(lineart_name)
if lineart_match:
file_id = lineart_match.group(0)
corresponding_flat_files = [p for p in flat_list if file_id in p.name]
corresponding_segmentation_files = [p for p in segmentation_list if file_id in p.name]
corresponding_color_paths = [p for p in color_list if file_id in p.name]
if corresponding_flat_files and corresponding_segmentation_files and corresponding_color_paths:
new_lineart_list.append(lineart_path)
new_flat_list.append(corresponding_flat_files[0])
new_segmentation_list.append(corresponding_segmentation_files[0])
new_color_list.append(corresponding_color_paths[0])
return new_lineart_list, new_flat_list, new_segmentation_list, new_color_list
else:
return lineart_list, flat_list, segmentation_list, color_list
def extract_cutbox_coordinates(image: Image.Image) -> List[Tuple[int, int, int, int]]:
"""
Extracts bounding box coordinates for non-white regions in an image.
This function identifies regions in the given image that contain non-white pixels
and calculates the bounding box coordinates for each region. The bounding boxes
are represented as tuples of (left, top, right, bottom).
Args:
image (Image.Image): The input image as a PIL Image object.
Returns:
List[Tuple[int, int, int, int]]: A list of bounding box coordinates for non-white regions.
Each tuple contains four integers representing the left, top, right, and bottom
coordinates of a bounding box.
"""
# We'll now detect the bounding box for non-white pixels instead of relying on the alpha channel.
# Convert the image to RGB and get the numpy array
image_rgb = image.convert("RGB")
image_np_rgb = np.array(image_rgb)
# Define white color threshold (treat pixels close to white as white)
threshold = 255 # Allow some margin for near-white
non_white_mask = np.any(image_np_rgb < threshold, axis=-1) # Any channel below threshold is considered non-white
non_white_mask = non_white_mask.astype(np.uint8) * 255
# Image.fromarray(non_white_mask).save("non_white_mask.png")
# Find rows containing non-white pixels
non_white_rows = np.where(non_white_mask.any(axis=1))[0]
if len(non_white_rows) == 0:
return []
# Group continuous non-white rows
horizontal_line = np.where(np.diff(non_white_rows) != 1)[0] + 1
non_white_rows = np.split(non_white_rows, horizontal_line)
top_bottom_pairs = [(group[0], group[-1]) for group in non_white_rows]
# Iterate through each cut and find the left and right bounds
bounding_boxes = []
for top, bottom in top_bottom_pairs:
cut = image_np_rgb[top : bottom + 1]
non_white_mask = np.any(cut < threshold, axis=-1)
non_white_cols = np.where(non_white_mask.any(axis=0))[0]
left = non_white_cols[0]
right = non_white_cols[-1]
bounding_boxes.append((left, top, right + 1, bottom + 1))
return bounding_boxes
def process_single_image(task: Tuple[Image.Image, Image.Image, Image.Image, Image.Image, str]):
"""
Worker function to process a single set of images.
Opens images, extracts bounding boxes, crops and saves the cut images.
"""
line_path, flat_path, seg_path, color_path, output_str = task
output_dir = Path(output_str)
try:
line_img = Image.open(line_path).convert("RGB")
flat_img = Image.open(flat_path).convert("RGB")
seg_img = Image.open(seg_path).convert("RGB")
color_img = Image.open(color_path).convert("RGB")
except Exception as e:
logger.error(f"Error opening images for {line_path}: {e}")
return
bounding_boxes = extract_cutbox_coordinates(line_img)
fname = line_path.stem
match = re.compile(r"\d+_\d+").match(fname)
if not match:
logger.warning(f"Filename pattern not matched for {line_path.name}")
return
ep_page_str = match.group(0)
for i, (left, top, right, bottom) in enumerate(bounding_boxes):
try:
cut_line = line_img.crop((left, top, right, bottom))
cut_flat = flat_img.crop((left, top, right, bottom))
cut_seg = seg_img.crop((left, top, right, bottom))
cut_color = color_img.crop((left, top, right, bottom))
cut_line.save(output_dir / "line" / f"{ep_page_str}_{i}_line.png")
cut_flat.save(output_dir / "flat" / f"{ep_page_str}_{i}_flat.png")
cut_seg.save(output_dir / "segmentation" / f"{ep_page_str}_{i}_segmentation.png")
cut_color.save(output_dir / "fullcolor" / f"{ep_page_str}_{i}_fullcolor.png")
except Exception as e:
logger.error(f"Error processing crop for {line_path.name} at box {i}: {e}")
def main(args):
# Prepare output directory
output_dir = Path(args.output)
if not output_dir.exists():
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "line").mkdir(parents=True, exist_ok=True)
(output_dir / "flat").mkdir(parents=True, exist_ok=True)
(output_dir / "segmentation").mkdir(parents=True, exist_ok=True)
(output_dir / "fullcolor").mkdir(parents=True, exist_ok=True)
# Prepare input images
lineart_list = get_image_list(args.lineart, ["png", "jpg", "jpeg"])
flat_list = get_image_list(args.flat, ["png", "jpg", "jpeg"])
segmentation_list = get_image_list(args.segmentation, ["png", "jpg", "jpeg"])
color_list = get_image_list(args.color, ["png", "jpg", "jpeg"])
# Check image pair validity
lineart_list, flat_list, segmentation_list, color_list = check_image_pair_validity(
lineart_list, flat_list, segmentation_list, color_list
)
# Prepare tasks for multiprocessing
tasks = []
for l, f, s, c in zip(lineart_list, flat_list, segmentation_list, color_list):
tasks.append((l, f, s, c, str(output_dir)))
# Use multiprocessing to process images in parallel
num_processes = args.num_process if args.num_process else multiprocessing.cpu_count() // 2
with multiprocessing.Pool(processes=num_processes) as pool:
list(tqdm(pool.imap_unordered(process_single_image, tasks), total=len(tasks), desc="Processing images"))
if __name__ == "__main__":
args = parse_args()
main(args)