Spaces:
Running
Running
import argparse | |
import logging | |
import multiprocessing | |
import re | |
import sys | |
from pathlib import Path, PurePath | |
from typing import List, Tuple | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
current_file_path = Path(__file__).resolve() | |
sys.path.insert(0, str(current_file_path.parent.parent)) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Parse cut from page script.") | |
parser.add_argument("--lineart", "-l", type=str, required=True, help="Directory of lineart images.") | |
parser.add_argument("--flat", "-f", type=str, required=True, help="Directory of flat images.") | |
parser.add_argument( | |
"--segmentation", | |
"-s", | |
type=str, | |
required=True, | |
help="Directory of segmentatio.", | |
) | |
parser.add_argument("--color", "-c", type=str, required=True, help="Directory of color images.") | |
parser.add_argument("--output", "-o", type=str, required=True, help="Output directory for parsed images.") | |
parser.add_argument("--num_process", "-n", type=int, default=None, help="Number of processes to use.") | |
return parser.parse_args() | |
def get_image_list(input_dir: str, ext: List[str]): | |
""" | |
Get a list of images from the input directory with the specified extensions. | |
Args: | |
input_dir (str): Directory containing images to filter. | |
ext (list): List of file extensions to filter by. | |
Returns: | |
list: List of image file paths. | |
""" | |
image_list = [] | |
for ext in ext: | |
image_list.extend(Path(input_dir).glob(f"*.{ext}")) | |
return image_list | |
def check_image_pair_validity( | |
lineart_list: List[PurePath], | |
flat_list: List[PurePath], | |
segmentation_list: List[PurePath], | |
color_list: List[PurePath], | |
pattern: str = r"\d+_\d+", | |
) -> Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]: | |
""" | |
Validates and filters lists of image file paths to ensure they correspond to the same IDs | |
based on a given naming pattern. If the lengths of the input lists are mismatched, the | |
function filters the lists to include only matching IDs. | |
Args: | |
lineart_list (List[PurePath]): List of file paths for lineart images. | |
flat_path (List[PurePath]): List of file paths for flat images. | |
segmentation_path (List[PurePath]): List of file paths for segmentation images. | |
color_path (List[PurePath]): List of file paths for color images. | |
pattern (str, optional): Regular expression pattern to extract IDs from file names. | |
Defaults to r"\d+_\d+". | |
Returns: | |
Tuple[List[PurePath], List[PurePath], List[PurePath], List[PurePath]]: | |
A tuple containing four lists of file paths (lineart, flat, segmentation, color) | |
that have been filtered to ensure matching IDs. | |
""" | |
pattern = re.compile(pattern) | |
# Sort the lists based on the pattern | |
lineart_list = sorted(lineart_list, key=lambda x: pattern.match(x.name).group(0)) | |
flat_list = sorted(flat_list, key=lambda x: pattern.match(x.name).group(0)) | |
segmentation_list = sorted(segmentation_list, key=lambda x: pattern.match(x.name).group(0)) | |
color_list = sorted(color_list, key=lambda x: pattern.match(x.name).group(0)) | |
# Check if the lengths of the lists are equal | |
if ( | |
len(lineart_list) != len(flat_list) | |
or len(lineart_list) != len(segmentation_list) | |
or len(lineart_list) != len(color_list) | |
): | |
# If the lengths are not equal, we need to filter the lists based on the pattern | |
logger.warning( | |
f"Length mismatch: lineart({len(lineart_list)}), flat({len(flat_list)}), segmentation({len(segmentation_list)}), color({len(color_list)})" | |
) | |
new_lineart_list = [] | |
new_flat_list = [] | |
new_segmentation_list = [] | |
new_color_list = [] | |
for lineart_path in lineart_list: | |
lineart_name = lineart_path.name | |
lineart_match = pattern.match(lineart_name) | |
if lineart_match: | |
file_id = lineart_match.group(0) | |
corresponding_flat_files = [p for p in flat_list if file_id in p.name] | |
corresponding_segmentation_files = [p for p in segmentation_list if file_id in p.name] | |
corresponding_color_paths = [p for p in color_list if file_id in p.name] | |
if corresponding_flat_files and corresponding_segmentation_files and corresponding_color_paths: | |
new_lineart_list.append(lineart_path) | |
new_flat_list.append(corresponding_flat_files[0]) | |
new_segmentation_list.append(corresponding_segmentation_files[0]) | |
new_color_list.append(corresponding_color_paths[0]) | |
return new_lineart_list, new_flat_list, new_segmentation_list, new_color_list | |
else: | |
return lineart_list, flat_list, segmentation_list, color_list | |
def extract_cutbox_coordinates(image: Image.Image) -> List[Tuple[int, int, int, int]]: | |
""" | |
Extracts bounding box coordinates for non-white regions in an image. | |
This function identifies regions in the given image that contain non-white pixels | |
and calculates the bounding box coordinates for each region. The bounding boxes | |
are represented as tuples of (left, top, right, bottom). | |
Args: | |
image (Image.Image): The input image as a PIL Image object. | |
Returns: | |
List[Tuple[int, int, int, int]]: A list of bounding box coordinates for non-white regions. | |
Each tuple contains four integers representing the left, top, right, and bottom | |
coordinates of a bounding box. | |
""" | |
# We'll now detect the bounding box for non-white pixels instead of relying on the alpha channel. | |
# Convert the image to RGB and get the numpy array | |
image_rgb = image.convert("RGB") | |
image_np_rgb = np.array(image_rgb) | |
# Define white color threshold (treat pixels close to white as white) | |
threshold = 255 # Allow some margin for near-white | |
non_white_mask = np.any(image_np_rgb < threshold, axis=-1) # Any channel below threshold is considered non-white | |
non_white_mask = non_white_mask.astype(np.uint8) * 255 | |
# Image.fromarray(non_white_mask).save("non_white_mask.png") | |
# Find rows containing non-white pixels | |
non_white_rows = np.where(non_white_mask.any(axis=1))[0] | |
if len(non_white_rows) == 0: | |
return [] | |
# Group continuous non-white rows | |
horizontal_line = np.where(np.diff(non_white_rows) != 1)[0] + 1 | |
non_white_rows = np.split(non_white_rows, horizontal_line) | |
top_bottom_pairs = [(group[0], group[-1]) for group in non_white_rows] | |
# Iterate through each cut and find the left and right bounds | |
bounding_boxes = [] | |
for top, bottom in top_bottom_pairs: | |
cut = image_np_rgb[top : bottom + 1] | |
non_white_mask = np.any(cut < threshold, axis=-1) | |
non_white_cols = np.where(non_white_mask.any(axis=0))[0] | |
left = non_white_cols[0] | |
right = non_white_cols[-1] | |
bounding_boxes.append((left, top, right + 1, bottom + 1)) | |
return bounding_boxes | |
def process_single_image(task: Tuple[Image.Image, Image.Image, Image.Image, Image.Image, str]): | |
""" | |
Worker function to process a single set of images. | |
Opens images, extracts bounding boxes, crops and saves the cut images. | |
""" | |
line_path, flat_path, seg_path, color_path, output_str = task | |
output_dir = Path(output_str) | |
try: | |
line_img = Image.open(line_path).convert("RGB") | |
flat_img = Image.open(flat_path).convert("RGB") | |
seg_img = Image.open(seg_path).convert("RGB") | |
color_img = Image.open(color_path).convert("RGB") | |
except Exception as e: | |
logger.error(f"Error opening images for {line_path}: {e}") | |
return | |
bounding_boxes = extract_cutbox_coordinates(line_img) | |
fname = line_path.stem | |
match = re.compile(r"\d+_\d+").match(fname) | |
if not match: | |
logger.warning(f"Filename pattern not matched for {line_path.name}") | |
return | |
ep_page_str = match.group(0) | |
for i, (left, top, right, bottom) in enumerate(bounding_boxes): | |
try: | |
cut_line = line_img.crop((left, top, right, bottom)) | |
cut_flat = flat_img.crop((left, top, right, bottom)) | |
cut_seg = seg_img.crop((left, top, right, bottom)) | |
cut_color = color_img.crop((left, top, right, bottom)) | |
cut_line.save(output_dir / "line" / f"{ep_page_str}_{i}_line.png") | |
cut_flat.save(output_dir / "flat" / f"{ep_page_str}_{i}_flat.png") | |
cut_seg.save(output_dir / "segmentation" / f"{ep_page_str}_{i}_segmentation.png") | |
cut_color.save(output_dir / "fullcolor" / f"{ep_page_str}_{i}_fullcolor.png") | |
except Exception as e: | |
logger.error(f"Error processing crop for {line_path.name} at box {i}: {e}") | |
def main(args): | |
# Prepare output directory | |
output_dir = Path(args.output) | |
if not output_dir.exists(): | |
output_dir.mkdir(parents=True, exist_ok=True) | |
(output_dir / "line").mkdir(parents=True, exist_ok=True) | |
(output_dir / "flat").mkdir(parents=True, exist_ok=True) | |
(output_dir / "segmentation").mkdir(parents=True, exist_ok=True) | |
(output_dir / "fullcolor").mkdir(parents=True, exist_ok=True) | |
# Prepare input images | |
lineart_list = get_image_list(args.lineart, ["png", "jpg", "jpeg"]) | |
flat_list = get_image_list(args.flat, ["png", "jpg", "jpeg"]) | |
segmentation_list = get_image_list(args.segmentation, ["png", "jpg", "jpeg"]) | |
color_list = get_image_list(args.color, ["png", "jpg", "jpeg"]) | |
# Check image pair validity | |
lineart_list, flat_list, segmentation_list, color_list = check_image_pair_validity( | |
lineart_list, flat_list, segmentation_list, color_list | |
) | |
# Prepare tasks for multiprocessing | |
tasks = [] | |
for l, f, s, c in zip(lineart_list, flat_list, segmentation_list, color_list): | |
tasks.append((l, f, s, c, str(output_dir))) | |
# Use multiprocessing to process images in parallel | |
num_processes = args.num_process if args.num_process else multiprocessing.cpu_count() // 2 | |
with multiprocessing.Pool(processes=num_processes) as pool: | |
list(tqdm(pool.imap_unordered(process_single_image, tasks), total=len(tasks), desc="Processing images")) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |