Spaces:
Paused
Paused
# Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache-2.0 | |
import json | |
import os | |
import argparse | |
from collections import defaultdict | |
def calculate_wiscore(consistency, realism, aesthetic_quality): | |
return 0.7 * consistency + 0.2 * realism + 0.1 * aesthetic_quality | |
def cal_culture(file_path): | |
all_scores = [] | |
total_objects = 0 | |
has_9_9 = False | |
with open(file_path, 'r') as file: | |
for line in file: | |
total_objects += 1 | |
data = json.loads(line) | |
if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]: | |
has_9_9 = True | |
wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality']) | |
all_scores.append(wiscore) | |
if has_9_9 or total_objects < 400: | |
print(f"Skipping file {file_path}: Contains 9.9 or has less than 400 objects.") | |
return None | |
total_score = sum(all_scores) | |
avg_score = total_score / (len(all_scores)*2) if len(all_scores) > 0 else 0 | |
score = { | |
'total': total_score, | |
'average': avg_score | |
} | |
print(f" Cultural - Total: {score['total']:.2f}, Average: {score['average']:.2f}") | |
return avg_score | |
def cal_space_time(file_path): | |
categories = defaultdict(list) | |
total_objects = 0 | |
has_9_9 = False | |
with open(file_path, 'r') as file: | |
for line in file: | |
total_objects += 1 | |
data = json.loads(line) | |
if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]: | |
has_9_9 = True | |
subcategory = data['Subcategory'] | |
wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality']) | |
if subcategory in ['Longitudinal time', 'Horizontal time']: | |
categories['Time'].append(wiscore) | |
else: | |
categories['Space'].append(wiscore) | |
if has_9_9 or total_objects < 300: | |
print(f"Skipping file {file_path}: Contains 9.9 or has less than 400 objects.") | |
return None | |
total_scores = {category: sum(scores) for category, scores in categories.items()} | |
avg_scores = {category: sum(scores) / (len(scores) * 2 )if len(scores) > 0 else 0 for category, scores in categories.items()} | |
scores = { | |
'total': total_scores, | |
'average': avg_scores | |
} | |
print(f" Time - Total: {scores['total'].get('Time', 0):.2f}, Average: {scores['average'].get('Time', 0):.2f}") | |
print(f" Space - Total: {scores['total'].get('Space', 0):.2f}, Average: {scores['average'].get('Space', 0):.2f}") | |
return avg_scores | |
def cal_science(file_path): | |
categories = defaultdict(list) | |
total_objects = 0 | |
has_9_9 = False | |
with open(file_path, 'r') as file: | |
for line in file: | |
total_objects += 1 | |
data = json.loads(line) | |
if 9.9 in [data['consistency'], data['realism'], data['aesthetic_quality']]: | |
has_9_9 = True | |
prompt_id = data.get('prompt_id', 0) | |
if 701 <= prompt_id <= 800: | |
category = 'Biology' | |
elif 801 <= prompt_id <= 900: | |
category = 'Physics' | |
elif 901 <= prompt_id <= 1000: | |
category = 'Chemistry' | |
else: | |
category = "?" | |
wiscore = calculate_wiscore(data['consistency'], data['realism'], data['aesthetic_quality']) | |
categories[category].append(wiscore) | |
if has_9_9 or total_objects < 300: | |
print(f"Skipping file {file_path}: Contains 9.9 or has less than 300 objects.") | |
return None | |
total_scores = {category: sum(scores) for category, scores in categories.items()} | |
avg_scores = {category: sum(scores) / (len(scores)*2) if len(scores) > 0 else 0 for category, scores in categories.items()} | |
scores = { | |
'total': total_scores, | |
'average': avg_scores | |
} | |
for category in ['Biology', 'Physics', 'Chemistry']: | |
print(f" {category} - Total: {scores['total'].get(category, 0):.2f}, Average: {scores['average'].get(category, 0):.2f}") | |
return avg_scores | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Image Quality Assessment Tool') | |
parser.add_argument('--output_dir', required=True, | |
help='Path to the output directory') | |
args = parser.parse_args() | |
avg_score = dict() | |
score = cal_culture( | |
os.path.join(args.output_dir, "cultural_common_sense_scores.jsonl") | |
) | |
avg_score['Cultural'] = score | |
scores = cal_space_time( | |
os.path.join(args.output_dir, "spatio-temporal_reasoning_scores.jsonl") | |
) | |
avg_score.update(scores) | |
scores = cal_science( | |
os.path.join(args.output_dir, "natural_science_scores.jsonl") | |
) | |
avg_score.update(scores) | |
avg_all = sum(avg_score.values()) / len(avg_score) | |
avg_score['Overall'] = avg_all | |
keys = "" | |
values = "" | |
for k, v in avg_score.items(): | |
keys += f"{k} " | |
values += f"{v:.2f} " | |
print(keys) | |
print(values) | |
writer = open(os.path.join(args.output_dir, "results.txt"), 'w') | |
print(f"write results to file {os.path.join(args.output_dir, 'results.txt')}") | |
writer.write(keys + "\n") | |
writer.write(values + "\n") | |
writer.close() |