# Copyright (c) 2023 Dhruba Ghosh # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: MIT # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. # # Original file was released under MIT, with the full license text # available at https://github.com/djghosh13/geneval/blob/main/LICENSE. # # This modified file is released under the same license. import argparse import json import os import re import sys import time from tqdm import tqdm import warnings warnings.filterwarnings("ignore") import numpy as np import pandas as pd from PIL import Image, ImageOps import torch import torch.distributed as dist import mmdet from mmdet.apis import inference_detector, init_detector import open_clip from clip_benchmark.metrics import zeroshot_classification as zsc zsc.tqdm = lambda it, *args, **kwargs: it def setup_distributed(): """初始化分布式环境""" dist.init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) # Get directory path def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("imagedir", type=str) parser.add_argument("--outfile", type=str, default="results.jsonl") parser.add_argument("--model-config", type=str, default=None) parser.add_argument("--model-path", type=str, default="./") # Other arguments parser.add_argument("--options", nargs="*", type=str, default=[]) args = parser.parse_args() args.options = dict(opt.split("=", 1) for opt in args.options) if args.model_config is None: args.model_config = os.path.join( os.path.dirname(mmdet.__file__), "../configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py" ) return args DEVICE = "cuda" if torch.cuda.is_available() else "cpu" assert DEVICE == "cuda" def timed(fn): def wrapper(*args, **kwargs): startt = time.time() result = fn(*args, **kwargs) endt = time.time() print(f'Function {fn.__name__!r} executed in {endt - startt:.3f}s', file=sys.stderr) return result return wrapper # Load models @timed def load_models(args): CONFIG_PATH = args.model_config OBJECT_DETECTOR = args.options.get('model', "mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco") CKPT_PATH = os.path.join(args.model_path, f"{OBJECT_DETECTOR}.pth") object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=DEVICE) clip_arch = args.options.get('clip_model', "ViT-L-14") clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained="openai", device=DEVICE) tokenizer = open_clip.get_tokenizer(clip_arch) with open(os.path.join(os.path.dirname(__file__), "object_names.txt")) as cls_file: classnames = [line.strip() for line in cls_file] return object_detector, (clip_model, transform, tokenizer), classnames COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] COLOR_CLASSIFIERS = {} # Evaluation parts class ImageCrops(torch.utils.data.Dataset): def __init__(self, image: Image.Image, objects): self._image = image.convert("RGB") bgcolor = args.options.get('bgcolor', "#999") if bgcolor == "original": self._blank = self._image.copy() else: self._blank = Image.new("RGB", image.size, color=bgcolor) self._objects = objects def __len__(self): return len(self._objects) def __getitem__(self, index): box, mask = self._objects[index] if mask is not None: assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape) image = Image.composite(self._image, self._blank, Image.fromarray(mask)) else: image = self._image if args.options.get('crop', '1') == '1': image = image.crop(box[:4]) # if args.save: # base_count = len(os.listdir(args.save)) # image.save(os.path.join(args.save, f"cropped_{base_count:05}.png")) return (transform(image), 0) def color_classification(image, bboxes, classname): if classname not in COLOR_CLASSIFIERS: COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier( clip_model, tokenizer, COLORS, [ f"a photo of a {{c}} {classname}", f"a photo of a {{c}}-colored {classname}", f"a photo of a {{c}} object" ], DEVICE ) clf = COLOR_CLASSIFIERS[classname] dataloader = torch.utils.data.DataLoader( ImageCrops(image, bboxes), batch_size=16, num_workers=4 ) with torch.no_grad(): pred, _ = zsc.run_classification(clip_model, clf, dataloader, DEVICE) return [COLORS[index.item()] for index in pred.argmax(1)] def compute_iou(box_a, box_b): area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0) i_area = area_fn([ max(box_a[0], box_b[0]), max(box_a[1], box_b[1]), min(box_a[2], box_b[2]), min(box_a[3], box_b[3]) ]) u_area = area_fn(box_a) + area_fn(box_b) - i_area return i_area / u_area if u_area else 0 def relative_position(obj_a, obj_b): """Give position of A relative to B, factoring in object dimensions""" boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2) center_a, center_b = boxes.mean(axis=-2) dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :] offset = center_a - center_b # revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset) if np.all(np.abs(revised_offset) < 1e-3): return set() # dx, dy = revised_offset / np.linalg.norm(offset) relations = set() if dx < -0.5: relations.add("left of") if dx > 0.5: relations.add("right of") if dy < -0.5: relations.add("above") if dy > 0.5: relations.add("below") return relations def evaluate(image, objects, metadata): """ Evaluate given image using detected objects on the global metadata specifications. Assumptions: * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR * All clauses are independent, i.e., duplicating a clause has no effect on the correctness * CHANGED: Color and position will only be evaluated on the most confidently predicted objects; therefore, objects are expected to appear in sorted order """ correct = True reason = [] matched_groups = [] # Check for expected objects for req in metadata.get('include', []): classname = req['class'] matched = True found_objects = objects.get(classname, [])[:req['count']] if len(found_objects) < req['count']: correct = matched = False reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}") else: if 'color' in req: # Color check colors = color_classification(image, found_objects, classname) if colors.count(req['color']) < req['count']: correct = matched = False reason.append( f"expected {req['color']} {classname}>={req['count']}, found " + f"{colors.count(req['color'])} {req['color']}; and " + ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) ) if 'position' in req and matched: # Relative position check expected_rel, target_group = req['position'] if matched_groups[target_group] is None: correct = matched = False reason.append(f"no target for {classname} to be {expected_rel}") else: for obj in found_objects: for target_obj in matched_groups[target_group]: true_rels = relative_position(obj, target_obj) if expected_rel not in true_rels: correct = matched = False reason.append( f"expected {classname} {expected_rel} target, found " + f"{' and '.join(true_rels)} target" ) break if not matched: break if matched: matched_groups.append(found_objects) else: matched_groups.append(None) # Check for non-expected objects for req in metadata.get('exclude', []): classname = req['class'] if len(objects.get(classname, [])) >= req['count']: correct = False reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}") return correct, "\n".join(reason) def evaluate_image(filepath, metadata): result = inference_detector(object_detector, filepath) bbox = result[0] if isinstance(result, tuple) else result segm = result[1] if isinstance(result, tuple) and len(result) > 1 else None image = ImageOps.exif_transpose(Image.open(filepath)) detected = {} # Determine bounding boxes to keep confidence_threshold = THRESHOLD if metadata['tag'] != "counting" else COUNTING_THRESHOLD for index, classname in enumerate(classnames): ordering = np.argsort(bbox[index][:, 4])[::-1] ordering = ordering[bbox[index][ordering, 4] > confidence_threshold] # Threshold ordering = ordering[:MAX_OBJECTS].tolist() # Limit number of detected objects per class detected[classname] = [] while ordering: max_obj = ordering.pop(0) detected[classname].append((bbox[index][max_obj], None if segm is None else segm[index][max_obj])) ordering = [ obj for obj in ordering if NMS_THRESHOLD == 1 or compute_iou(bbox[index][max_obj], bbox[index][obj]) < NMS_THRESHOLD ] if not detected[classname]: del detected[classname] # Evaluate is_correct, reason = evaluate(image, detected, metadata) return { 'filename': filepath, 'tag': metadata['tag'], 'prompt': metadata['prompt'], 'correct': is_correct, 'reason': reason, 'metadata': json.dumps(metadata), 'details': json.dumps({ key: [box.tolist() for box, _ in value] for key, value in detected.items() }) } if __name__ == "__main__": args = parse_args() THRESHOLD = float(args.options.get('threshold', 0.3)) COUNTING_THRESHOLD = float(args.options.get('counting_threshold', 0.9)) MAX_OBJECTS = int(args.options.get('max_objects', 16)) NMS_THRESHOLD = float(args.options.get('max_overlap', 1.0)) POSITION_THRESHOLD = float(args.options.get('position_threshold', 0.1)) # Initialize distributed environment setup_distributed() rank = dist.get_rank() world_size = dist.get_world_size() device = f"cuda:{rank}" # Load models if rank == 0: print(f"[Rank 0] Loading model...") object_detector, (clip_model, transform, tokenizer), classnames = load_models(args) full_results = [] subfolders = [f for f in os.listdir(args.imagedir) if os.path.isdir(os.path.join(args.imagedir, f)) and f.isdigit()] total_subfolders = len(subfolders) # Divide subfolders to process by GPU subfolders_per_gpu = (total_subfolders + world_size - 1) // world_size start = rank * subfolders_per_gpu end = min(start + subfolders_per_gpu, total_subfolders) print(f"GPU {rank}: Processing {end - start} subfolders (index {start} to {end - 1})") for subfolder in tqdm(subfolders[start:end]): folderpath = os.path.join(args.imagedir, subfolder) with open(os.path.join(folderpath, "metadata.jsonl")) as fp: metadata = json.load(fp) # Evaluate each image for imagename in os.listdir(os.path.join(folderpath, "samples")): imagepath = os.path.join(folderpath, "samples", imagename) if not os.path.isfile(imagepath) or not re.match(r"\d+\.png", imagename): continue result = evaluate_image(imagepath, metadata) full_results.append(result) # Synchronize results from all GPUs all_results = [None] * world_size dist.all_gather_object(all_results, full_results) if rank == 0: # Merge results from all GPUs final_results = [] for results in all_results: final_results.extend(results) # Save results if os.path.dirname(args.outfile): os.makedirs(os.path.dirname(args.outfile), exist_ok=True) with open(args.outfile, "w") as fp: pd.DataFrame(final_results).to_json(fp, orient="records", lines=True) print("All GPUs have completed their tasks and the final results have been saved.") else: print(f"GPU {rank} has completed all tasks")