import ast import argparse import logging import numpy as np import pandas as pd import json from datetime import datetime import os from PIL import Image from ml import MLModel from dl import DLModel from naive import NaiveModel import cairosvg import io from typing import Dict, Any, List, Tuple from tqdm import tqdm from metric import harmonic_mean, VQAEvaluator, AestheticEvaluator import gc import torch # Setup logging os.makedirs("logs", exist_ok=True) log_file = f"logs/eval_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # Custom JSON encoder to handle NumPy types class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() return super(NumpyEncoder, self).default(obj) def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image: """Converts SVG code to a PNG image. Args: svg_code (str): SVG code to convert size (tuple, optional): Output image size. Defaults to (384, 384). Returns: PIL.Image.Image: The converted PNG image """ try: png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'), output_width=size[0], output_height=size[1]) return Image.open(io.BytesIO(png_data)) except Exception as e: logger.error(f"Error converting SVG to PNG: {e}") # Return a default red circle if conversion fails default_svg = """""" png_data = cairosvg.svg2png(bytestring=default_svg.encode('utf-8'), output_width=size[0], output_height=size[1]) return Image.open(io.BytesIO(png_data)) def load_evaluation_data(eval_csv_path: str, descriptions_csv_path: str, index: int = None) -> Tuple[pd.DataFrame, pd.DataFrame]: """Load evaluation data from CSV files. Args: eval_csv_path (str): Path to the evaluation CSV descriptions_csv_path (str): Path to the descriptions CSV index (int, optional): Specific index to load. Defaults to None (load all). Returns: Tuple[pd.DataFrame, pd.DataFrame]: Loaded evaluation and descriptions dataframes """ logger.info(f"Loading evaluation data from {eval_csv_path} and {descriptions_csv_path}") with tqdm(total=2, desc="Loading data files") as pbar: eval_df = pd.read_csv(eval_csv_path) pbar.update(1) descriptions_df = pd.read_csv(descriptions_csv_path) pbar.update(1) if index is not None: eval_df = eval_df.iloc[[index]] descriptions_df = descriptions_df.iloc[[index]] logger.info(f"Selected description at index {index}: {descriptions_df.iloc[0]['description']}") return eval_df, descriptions_df def generate_svg(model: Any, description: str, eval_data: pd.Series, results_dir: str = "results") -> Dict[str, Any]: """Generate SVG using the model and save it. Args: model (Any): The model to evaluate (MLModel, DLModel, or NaiveModel) description (str): Text description to generate SVG from eval_data (pd.Series): Evaluation data with questions, choices, and answers results_dir (str): Directory to save results to Returns: Dict[str, Any]: Generation results """ # Create output directories os.makedirs(results_dir, exist_ok=True) os.makedirs(f"{results_dir}/svg", exist_ok=True) os.makedirs(f"{results_dir}/png", exist_ok=True) model_name = model.__class__.__name__ results = { "description": description, "model_type": model_name, "id": eval_data.get('id', '0'), "category": description.split(',')[-1] if ',' in description else "unknown", "timestamp": datetime.now().isoformat(), } # Generate SVG logger.info(f"Generating SVG for description: {description}") start_time = datetime.now() svg = model.predict(description) generation_time = (datetime.now() - start_time).total_seconds() results["svg"] = svg results["generation_time_seconds"] = generation_time # Convert SVG to PNG for visual evaluation image = svg_to_png(svg) results["image_width"] = image.width results["image_height"] = image.height # Save the SVG and PNG for inspection output_filename = f"{results['id']}_{model_name}" with open(f"{results_dir}/svg/{output_filename}.svg", "w") as f: f.write(svg) image.save(f"{results_dir}/png/{output_filename}.png") logger.info(f"Generated SVG for model {model_name} in {generation_time:.2f} seconds") return results def evaluate_results(results_list: List[Dict[str, Any]], vqa_evaluator, aesthetic_evaluator, results_dir: str = "results") -> List[Dict[str, Any]]: """Evaluate generated SVGs. Args: results_list (List[Dict[str, Any]]): List of generation results vqa_evaluator: VQA evaluation model aesthetic_evaluator: Aesthetic evaluation model results_dir (str): Directory with saved results Returns: List[Dict[str, Any]]: Evaluation results """ evaluated_results = [] for result in tqdm(results_list, desc="Evaluating results"): model_name = result["model_type"] output_filename = f"{result['id']}_{model_name}" # Load the PNG image image = Image.open(f"{results_dir}/png/{output_filename}.png").convert('RGB') try: # Parse evaluation data questions = result.get("questions") choices = result.get("choices") answers = result.get("answers") if not all([questions, choices, answers]): logger.warning(f"Missing evaluation data for {output_filename}") continue # Calculate scores logger.info(f"Calculating VQA score for model: {model_name}") vqa_score = vqa_evaluator.score(questions, choices, answers, image) logger.info(f"Calculating aesthetic score for model: {model_name}") aesthetic_score = aesthetic_evaluator.score(image) # Calculate final fidelity score using harmonic mean instance_score = harmonic_mean(vqa_score, aesthetic_score, beta=0.5) # Add scores to results result["vqa_score"] = vqa_score result["aesthetic_score"] = aesthetic_score result["fidelity_score"] = instance_score logger.info(f"VQA Score: {vqa_score:.4f}") logger.info(f"Aesthetic Score: {aesthetic_score:.4f}") logger.info(f"Final Fidelity Score: {instance_score:.4f}") except Exception as e: logger.error(f"Error during evaluation: {e}") result["error"] = str(e) evaluated_results.append(result) return evaluated_results def create_model(model_type: str, device: str = "cuda") -> Any: """Create a model instance based on model type. Args: model_type (str): Type of model ('ml', 'dl', or 'naive') device (str, optional): Device to run model on. Defaults to "cuda". Returns: Any: Model instance """ logger.info(f"Creating {model_type.upper()} model on {device}") with tqdm(total=1, desc=f"Loading {model_type.upper()} model") as pbar: if model_type.lower() == 'ml': model = MLModel(device=device) elif model_type.lower() == 'dl': model = DLModel(device=device) elif model_type.lower() == 'naive': model = NaiveModel(device=device) else: raise ValueError(f"Unknown model type: {model_type}") pbar.update(1) return model def main(): parser = argparse.ArgumentParser(description='Evaluate SVG generation models') # dl is not working and takes too long, so we don't evaluate it by default parser.add_argument('--models', nargs='+', choices=['ml', 'dl', 'naive'], default=['ml', 'naive'], help='Models to evaluate (ml, dl, naive)') parser.add_argument('--index', type=int, default=None, help='Index of the description to evaluate (default: None, evaluate all)') parser.add_argument('--device', type=str, default='cuda', help='Device to run models on (default: cuda)') parser.add_argument('--eval-csv', type=str, default='data/eval.csv', help='Path to evaluation CSV (default: data/eval.csv)') parser.add_argument('--descriptions-csv', type=str, default='data/descriptions.csv', help='Path to descriptions CSV (default: data/descriptions.csv)') parser.add_argument('--results-dir', type=str, default='results', help='Directory to save results (default: results)') parser.add_argument('--generate-only', action='store_true', help='Only generate SVGs without evaluation') parser.add_argument('--evaluate-only', action='store_true', help='Only evaluate previously generated SVGs') args = parser.parse_args() # Create results directory os.makedirs(args.results_dir, exist_ok=True) # Load evaluation data eval_df, descriptions_df = load_evaluation_data(args.eval_csv, args.descriptions_csv, args.index) # Load cached results or initialize new results cached_results_file = f"{args.results_dir}/cached_results.json" if os.path.exists(cached_results_file) and args.evaluate_only: with open(cached_results_file, 'r') as f: results = json.load(f) logger.info(f"Loaded {len(results)} cached results from {cached_results_file}") else: results = [] # Step 1: Generate SVGs if not in evaluate-only mode if not args.evaluate_only: # Process one model at a time to avoid loading/unloading models repeatedly for model_type in args.models: logger.info(f"Processing all descriptions with model: {model_type}") model = create_model(model_type, args.device) # Process all descriptions with the current model for idx, (_, desc_row) in enumerate(descriptions_df.iterrows()): description = desc_row['description'] eval_data = eval_df.iloc[idx] logger.info(f"Processing description {idx}: {description}") # Generate SVG and save result = generate_svg(model, description, eval_data, args.results_dir) # Add questions, choices and answers to the result try: result["questions"] = ast.literal_eval(eval_data['question']) result["choices"] = ast.literal_eval(eval_data['choices']) result["answers"] = ast.literal_eval(eval_data['answer']) except Exception as e: logger.error(f"Error parsing evaluation data: {e}") results.append(result) logger.info(f"Completed SVG generation for description {idx}") # Free up memory after processing all descriptions with this model logger.info(f"Completed all SVG generations for model: {model_type}") del model if args.device == 'cuda': torch.cuda.empty_cache() gc.collect() # Save the results for later evaluation with open(cached_results_file, 'w') as f: # Remove image data from results to avoid large JSON files clean_results = [] for result in results: clean_result = {k: v for k, v in result.items() if k not in ['image', 'svg']} clean_results.append(clean_result) json.dump(clean_results, f, indent=2, cls=NumpyEncoder) logger.info(f"Saved {len(results)} results to {cached_results_file}") # Exit if only generating if args.generate_only: logger.info("Generation completed. Skipping evaluation as requested.") return # Step 2: Evaluate the generated SVGs logger.info("Starting evaluation phase") # Initialize evaluators logger.info("Initializing VQA evaluator...") vqa_evaluator = VQAEvaluator() logger.info("Initializing Aesthetic evaluator...") aesthetic_evaluator = AestheticEvaluator() # Evaluate all results evaluated_results = evaluate_results(results, vqa_evaluator, aesthetic_evaluator, args.results_dir) # Save final results results_file = f"{args.results_dir}/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(results_file, 'w') as f: # Remove image data from results to avoid large JSON files clean_results = [] for result in evaluated_results: clean_result = {k: v for k, v in result.items() if k not in ['image', 'svg']} clean_results.append(clean_result) json.dump(clean_results, f, indent=2, cls=NumpyEncoder) # Create a summary CSV summary_data = [] for result in evaluated_results: summary_data.append({ 'model': result['model_type'], 'description': result['description'], 'id': result['id'], 'category': result['category'], 'vqa_score': result.get('vqa_score', float('nan')), 'aesthetic_score': result.get('aesthetic_score', float('nan')), 'fidelity_score': result.get('fidelity_score', float('nan')), 'generation_time': result.get('generation_time_seconds', float('nan')), 'timestamp': result['timestamp'] }) summary_df = pd.DataFrame(summary_data) summary_file = f"{args.results_dir}/summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" summary_df.to_csv(summary_file, index=False) # Print summary logger.info("\nEvaluation Summary:") for result in evaluated_results: logger.info(f"Model: {result['model_type']}") logger.info(f"Description: {result['description']}") logger.info(f"VQA Score: {result.get('vqa_score', 'N/A')}") logger.info(f"Aesthetic Score: {result.get('aesthetic_score', 'N/A')}") logger.info(f"Fidelity Score: {result.get('fidelity_score', 'N/A')}") logger.info(f"Generation Time: {result.get('generation_time_seconds', 'N/A')} seconds") logger.info("---") logger.info(f"Results saved to: {results_file}") logger.info(f"Summary saved to: {summary_file}") logger.info(f"Log file: {log_file}") if __name__ == "__main__": main()