Spaces:
Paused
Paused
import ast | |
import argparse | |
import logging | |
import numpy as np | |
import pandas as pd | |
import json | |
from datetime import datetime | |
import os | |
from PIL import Image | |
from ml import MLModel | |
from dl import DLModel | |
from naive import NaiveModel | |
import cairosvg | |
import io | |
from typing import Dict, Any, List, Tuple | |
from tqdm import tqdm | |
from metric import harmonic_mean, VQAEvaluator, AestheticEvaluator | |
import gc | |
import torch | |
# Setup logging | |
os.makedirs("logs", exist_ok=True) | |
log_file = f"logs/eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler(log_file), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
# Custom JSON encoder to handle NumPy types | |
class NumpyEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.integer): | |
return int(obj) | |
if isinstance(obj, np.floating): | |
return float(obj) | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
return super(NumpyEncoder, self).default(obj) | |
def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image: | |
"""Converts SVG code to a PNG image. | |
Args: | |
svg_code (str): SVG code to convert | |
size (tuple, optional): Output image size. Defaults to (384, 384). | |
Returns: | |
PIL.Image.Image: The converted PNG image | |
""" | |
try: | |
png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'), output_width=size[0], output_height=size[1]) | |
return Image.open(io.BytesIO(png_data)) | |
except Exception as e: | |
logger.error(f"Error converting SVG to PNG: {e}") | |
# Return a default red circle if conversion fails | |
default_svg = """<svg width="384" height="384" viewBox="0 0 256 256"><circle cx="128" cy="128" r="64" fill="red" /></svg>""" | |
png_data = cairosvg.svg2png(bytestring=default_svg.encode('utf-8'), output_width=size[0], output_height=size[1]) | |
return Image.open(io.BytesIO(png_data)) | |
def load_evaluation_data(eval_csv_path: str, descriptions_csv_path: str, index: int = None) -> Tuple[pd.DataFrame, pd.DataFrame]: | |
"""Load evaluation data from CSV files. | |
Args: | |
eval_csv_path (str): Path to the evaluation CSV | |
descriptions_csv_path (str): Path to the descriptions CSV | |
index (int, optional): Specific index to load. Defaults to None (load all). | |
Returns: | |
Tuple[pd.DataFrame, pd.DataFrame]: Loaded evaluation and descriptions dataframes | |
""" | |
logger.info(f"Loading evaluation data from {eval_csv_path} and {descriptions_csv_path}") | |
with tqdm(total=2, desc="Loading data files") as pbar: | |
eval_df = pd.read_csv(eval_csv_path) | |
pbar.update(1) | |
descriptions_df = pd.read_csv(descriptions_csv_path) | |
pbar.update(1) | |
if index is not None: | |
eval_df = eval_df.iloc[[index]] | |
descriptions_df = descriptions_df.iloc[[index]] | |
logger.info(f"Selected description at index {index}: {descriptions_df.iloc[0]['description']}") | |
return eval_df, descriptions_df | |
def generate_svg(model: Any, description: str, eval_data: pd.Series, | |
results_dir: str = "results") -> Dict[str, Any]: | |
"""Generate SVG using the model and save it. | |
Args: | |
model (Any): The model to evaluate (MLModel, DLModel, or NaiveModel) | |
description (str): Text description to generate SVG from | |
eval_data (pd.Series): Evaluation data with questions, choices, and answers | |
results_dir (str): Directory to save results to | |
Returns: | |
Dict[str, Any]: Generation results | |
""" | |
# Create output directories | |
os.makedirs(results_dir, exist_ok=True) | |
os.makedirs(f"{results_dir}/svg", exist_ok=True) | |
os.makedirs(f"{results_dir}/png", exist_ok=True) | |
model_name = model.__class__.__name__ | |
results = { | |
"description": description, | |
"model_type": model_name, | |
"id": eval_data.get('id', '0'), | |
"category": description.split(',')[-1] if ',' in description else "unknown", | |
"timestamp": datetime.now().isoformat(), | |
} | |
# Generate SVG | |
logger.info(f"Generating SVG for description: {description}") | |
start_time = datetime.now() | |
svg = model.predict(description) | |
generation_time = (datetime.now() - start_time).total_seconds() | |
results["svg"] = svg | |
results["generation_time_seconds"] = generation_time | |
# Convert SVG to PNG for visual evaluation | |
image = svg_to_png(svg) | |
results["image_width"] = image.width | |
results["image_height"] = image.height | |
# Save the SVG and PNG for inspection | |
output_filename = f"{results['id']}_{model_name}" | |
with open(f"{results_dir}/svg/{output_filename}.svg", "w") as f: | |
f.write(svg) | |
image.save(f"{results_dir}/png/{output_filename}.png") | |
logger.info(f"Generated SVG for model {model_name} in {generation_time:.2f} seconds") | |
return results | |
def evaluate_results(results_list: List[Dict[str, Any]], | |
vqa_evaluator, aesthetic_evaluator, | |
results_dir: str = "results") -> List[Dict[str, Any]]: | |
"""Evaluate generated SVGs. | |
Args: | |
results_list (List[Dict[str, Any]]): List of generation results | |
vqa_evaluator: VQA evaluation model | |
aesthetic_evaluator: Aesthetic evaluation model | |
results_dir (str): Directory with saved results | |
Returns: | |
List[Dict[str, Any]]: Evaluation results | |
""" | |
evaluated_results = [] | |
for result in tqdm(results_list, desc="Evaluating results"): | |
model_name = result["model_type"] | |
output_filename = f"{result['id']}_{model_name}" | |
# Load the PNG image | |
image = Image.open(f"{results_dir}/png/{output_filename}.png").convert('RGB') | |
try: | |
# Parse evaluation data | |
questions = result.get("questions") | |
choices = result.get("choices") | |
answers = result.get("answers") | |
if not all([questions, choices, answers]): | |
logger.warning(f"Missing evaluation data for {output_filename}") | |
continue | |
# Calculate scores | |
logger.info(f"Calculating VQA score for model: {model_name}") | |
vqa_score = vqa_evaluator.score(questions, choices, answers, image) | |
logger.info(f"Calculating aesthetic score for model: {model_name}") | |
aesthetic_score = aesthetic_evaluator.score(image) | |
# Calculate final fidelity score using harmonic mean | |
instance_score = harmonic_mean(vqa_score, aesthetic_score, beta=0.5) | |
# Add scores to results | |
result["vqa_score"] = vqa_score | |
result["aesthetic_score"] = aesthetic_score | |
result["fidelity_score"] = instance_score | |
logger.info(f"VQA Score: {vqa_score:.4f}") | |
logger.info(f"Aesthetic Score: {aesthetic_score:.4f}") | |
logger.info(f"Final Fidelity Score: {instance_score:.4f}") | |
except Exception as e: | |
logger.error(f"Error during evaluation: {e}") | |
result["error"] = str(e) | |
evaluated_results.append(result) | |
return evaluated_results | |
def create_model(model_type: str, device: str = "cuda") -> Any: | |
"""Create a model instance based on model type. | |
Args: | |
model_type (str): Type of model ('ml', 'dl', or 'naive') | |
device (str, optional): Device to run model on. Defaults to "cuda". | |
Returns: | |
Any: Model instance | |
""" | |
logger.info(f"Creating {model_type.upper()} model on {device}") | |
with tqdm(total=1, desc=f"Loading {model_type.upper()} model") as pbar: | |
if model_type.lower() == 'ml': | |
model = MLModel(device=device) | |
elif model_type.lower() == 'dl': | |
model = DLModel(device=device) | |
elif model_type.lower() == 'naive': | |
model = NaiveModel(device=device) | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
pbar.update(1) | |
return model | |
def main(): | |
parser = argparse.ArgumentParser(description='Evaluate SVG generation models') | |
# dl is not working and takes too long, so we don't evaluate it by default | |
parser.add_argument('--models', nargs='+', choices=['ml', 'dl', 'naive'], default=['ml', 'naive'], | |
help='Models to evaluate (ml, dl, naive)') | |
parser.add_argument('--index', type=int, default=None, | |
help='Index of the description to evaluate (default: None, evaluate all)') | |
parser.add_argument('--device', type=str, default='cuda', | |
help='Device to run models on (default: cuda)') | |
parser.add_argument('--eval-csv', type=str, default='data/eval.csv', | |
help='Path to evaluation CSV (default: data/eval.csv)') | |
parser.add_argument('--descriptions-csv', type=str, default='data/descriptions.csv', | |
help='Path to descriptions CSV (default: data/descriptions.csv)') | |
parser.add_argument('--results-dir', type=str, default='results', | |
help='Directory to save results (default: results)') | |
parser.add_argument('--generate-only', action='store_true', | |
help='Only generate SVGs without evaluation') | |
parser.add_argument('--evaluate-only', action='store_true', | |
help='Only evaluate previously generated SVGs') | |
args = parser.parse_args() | |
# Create results directory | |
os.makedirs(args.results_dir, exist_ok=True) | |
# Load evaluation data | |
eval_df, descriptions_df = load_evaluation_data(args.eval_csv, args.descriptions_csv, args.index) | |
# Load cached results or initialize new results | |
cached_results_file = f"{args.results_dir}/cached_results.json" | |
if os.path.exists(cached_results_file) and args.evaluate_only: | |
with open(cached_results_file, 'r') as f: | |
results = json.load(f) | |
logger.info(f"Loaded {len(results)} cached results from {cached_results_file}") | |
else: | |
results = [] | |
# Step 1: Generate SVGs if not in evaluate-only mode | |
if not args.evaluate_only: | |
# Process one model at a time to avoid loading/unloading models repeatedly | |
for model_type in args.models: | |
logger.info(f"Processing all descriptions with model: {model_type}") | |
model = create_model(model_type, args.device) | |
# Process all descriptions with the current model | |
for idx, (_, desc_row) in enumerate(descriptions_df.iterrows()): | |
description = desc_row['description'] | |
eval_data = eval_df.iloc[idx] | |
logger.info(f"Processing description {idx}: {description}") | |
# Generate SVG and save | |
result = generate_svg(model, description, eval_data, args.results_dir) | |
# Add questions, choices and answers to the result | |
try: | |
result["questions"] = ast.literal_eval(eval_data['question']) | |
result["choices"] = ast.literal_eval(eval_data['choices']) | |
result["answers"] = ast.literal_eval(eval_data['answer']) | |
except Exception as e: | |
logger.error(f"Error parsing evaluation data: {e}") | |
results.append(result) | |
logger.info(f"Completed SVG generation for description {idx}") | |
# Free up memory after processing all descriptions with this model | |
logger.info(f"Completed all SVG generations for model: {model_type}") | |
del model | |
if args.device == 'cuda': | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Save the results for later evaluation | |
with open(cached_results_file, 'w') as f: | |
# Remove image data from results to avoid large JSON files | |
clean_results = [] | |
for result in results: | |
clean_result = {k: v for k, v in result.items() if k not in ['image', 'svg']} | |
clean_results.append(clean_result) | |
json.dump(clean_results, f, indent=2, cls=NumpyEncoder) | |
logger.info(f"Saved {len(results)} results to {cached_results_file}") | |
# Exit if only generating | |
if args.generate_only: | |
logger.info("Generation completed. Skipping evaluation as requested.") | |
return | |
# Step 2: Evaluate the generated SVGs | |
logger.info("Starting evaluation phase") | |
# Initialize evaluators | |
logger.info("Initializing VQA evaluator...") | |
vqa_evaluator = VQAEvaluator() | |
logger.info("Initializing Aesthetic evaluator...") | |
aesthetic_evaluator = AestheticEvaluator() | |
# Evaluate all results | |
evaluated_results = evaluate_results(results, vqa_evaluator, aesthetic_evaluator, args.results_dir) | |
# Save final results | |
results_file = f"{args.results_dir}/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
with open(results_file, 'w') as f: | |
# Remove image data from results to avoid large JSON files | |
clean_results = [] | |
for result in evaluated_results: | |
clean_result = {k: v for k, v in result.items() if k not in ['image', 'svg']} | |
clean_results.append(clean_result) | |
json.dump(clean_results, f, indent=2, cls=NumpyEncoder) | |
# Create a summary CSV | |
summary_data = [] | |
for result in evaluated_results: | |
summary_data.append({ | |
'model': result['model_type'], | |
'description': result['description'], | |
'id': result['id'], | |
'category': result['category'], | |
'vqa_score': result.get('vqa_score', float('nan')), | |
'aesthetic_score': result.get('aesthetic_score', float('nan')), | |
'fidelity_score': result.get('fidelity_score', float('nan')), | |
'generation_time': result.get('generation_time_seconds', float('nan')), | |
'timestamp': result['timestamp'] | |
}) | |
summary_df = pd.DataFrame(summary_data) | |
summary_file = f"{args.results_dir}/summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
summary_df.to_csv(summary_file, index=False) | |
# Print summary | |
logger.info("\nEvaluation Summary:") | |
for result in evaluated_results: | |
logger.info(f"Model: {result['model_type']}") | |
logger.info(f"Description: {result['description']}") | |
logger.info(f"VQA Score: {result.get('vqa_score', 'N/A')}") | |
logger.info(f"Aesthetic Score: {result.get('aesthetic_score', 'N/A')}") | |
logger.info(f"Fidelity Score: {result.get('fidelity_score', 'N/A')}") | |
logger.info(f"Generation Time: {result.get('generation_time_seconds', 'N/A')} seconds") | |
logger.info("---") | |
logger.info(f"Results saved to: {results_file}") | |
logger.info(f"Summary saved to: {summary_file}") | |
logger.info(f"Log file: {log_file}") | |
if __name__ == "__main__": | |
main() | |