Spaces:
Paused
Paused
import ast | |
import io | |
import math | |
import statistics | |
import string | |
import cairosvg | |
import clip | |
import cv2 | |
import kagglehub | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
from more_itertools import chunked | |
from PIL import Image, ImageFilter | |
from transformers import ( | |
AutoProcessor, | |
BitsAndBytesConfig, | |
PaliGemmaForConditionalGeneration, | |
) | |
svg_constraints = kagglehub.package_import('metric/svg-constraints') | |
class ParticipantVisibleError(Exception): | |
pass | |
def score( | |
solution: pd.DataFrame, submission: pd.DataFrame, row_id_column_name: str, random_seed: int = 0 | |
) -> float: | |
"""Calculates a fidelity score by comparing generated SVG images to target text descriptions. | |
Parameters | |
---------- | |
solution : pd.DataFrame | |
A DataFrame containing target questions, choices, and answers about an SVG image. | |
submission : pd.DataFrame | |
A DataFrame containing generated SVG strings. Must have a column named 'svg'. | |
row_id_column_name : str | |
The name of the column containing row identifiers. This column is removed before scoring. | |
random_seed : int | |
A seed to set the random state. | |
Returns | |
------- | |
float | |
The mean fidelity score (a value between 0 and 1) representing the average similarity between the generated SVGs and their descriptions. | |
A higher score indicates better fidelity. | |
Raises | |
------ | |
ParticipantVisibleError | |
If the 'svg' column in the submission DataFrame is not of string type or if validation of the SVG fails. | |
Examples | |
-------- | |
>>> import pandas as pd | |
>>> solution = pd.DataFrame({ | |
... 'id': ["abcde"], | |
... 'question': ['["Is there a red circle?", "What shape is present?"]'], | |
... 'choices': ['[["yes", "no"], ["square", "circle", "triangle", "hexagon"]]'], | |
... 'answer': ['["yes", "circle"]'], | |
... }) | |
>>> submission = pd.DataFrame({ | |
... 'id': ["abcde"], | |
... 'svg': ['<svg viewBox="0 0 100 100"><circle cx="50" cy="50" r="40" fill="red"/></svg>'], | |
... }) | |
>>> score(solution, submission, 'row_id', random_seed=42) | |
0... | |
""" | |
# Convert solution fields to list dtypes and expand | |
for colname in ['question', 'choices', 'answer']: | |
solution[colname] = solution[colname].apply(ast.literal_eval) | |
solution = solution.explode(['question', 'choices', 'answer']) | |
# Validate | |
if not pd.api.types.is_string_dtype(submission.loc[:, 'svg']): | |
raise ParticipantVisibleError('svg must be a string.') | |
# Check that SVG code meets defined constraints | |
constraints = svg_constraints.SVGConstraints() | |
try: | |
for svg in submission.loc[:, 'svg']: | |
constraints.validate_svg(svg) | |
except: | |
raise ParticipantVisibleError('SVG code violates constraints.') | |
# Score | |
vqa_evaluator = VQAEvaluator() | |
aesthetic_evaluator = AestheticEvaluator() | |
results = [] | |
rng = np.random.RandomState(random_seed) | |
try: | |
df = solution.merge(submission, on='id') | |
for i, (_, group) in enumerate(df.loc[ | |
:, ['id', 'question', 'choices', 'answer', 'svg'] | |
].groupby('id')): | |
questions, choices, answers, svg = [ | |
group[col_name].to_list() | |
for col_name in group.drop('id', axis=1).columns | |
] | |
svg = svg[0] # unpack singleton from list | |
group_seed = rng.randint(0, np.iinfo(np.int32).max) | |
image_processor = ImageProcessor(image=svg_to_png(svg), seed=group_seed).apply() | |
image = image_processor.image.copy() | |
aesthetic_score = aesthetic_evaluator.score(image) | |
vqa_score = vqa_evaluator.score(questions, choices, answers, image) | |
image_processor.reset().apply_random_crop_resize().apply_jpeg_compression(quality=90) | |
ocr_score = vqa_evaluator.ocr(image_processor.image) | |
instance_score = ( | |
harmonic_mean(vqa_score, aesthetic_score, beta=0.5) * ocr_score | |
) | |
results.append(instance_score) | |
except: | |
raise ParticipantVisibleError('SVG failed to score.') | |
fidelity = statistics.mean(results) | |
return float(fidelity) | |
class VQAEvaluator: | |
"""Evaluates images based on their similarity to a given text description using multiple choice questions.""" | |
def __init__(self): | |
self.quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type='nf4', | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
self.letters = string.ascii_uppercase | |
self.model_path = kagglehub.model_download( | |
'google/paligemma-2/transformers/paligemma2-10b-mix-448' | |
) | |
self.processor = AutoProcessor.from_pretrained(self.model_path) | |
self.model = PaliGemmaForConditionalGeneration.from_pretrained( | |
self.model_path, | |
low_cpu_mem_usage=True, | |
quantization_config=self.quantization_config, | |
).to('cuda') | |
def score(self, questions, choices, answers, image, n=4): | |
scores = [] | |
batches = (chunked(qs, n) for qs in [questions, choices, answers]) | |
for question_batch, choice_batch, answer_batch in zip(*batches, strict=True): | |
scores.extend( | |
self.score_batch( | |
image, | |
question_batch, | |
choice_batch, | |
answer_batch, | |
) | |
) | |
return statistics.mean(scores) | |
def score_batch( | |
self, | |
image: Image.Image, | |
questions: list[str], | |
choices_list: list[list[str]], | |
answers: list[str], | |
) -> list[float]: | |
"""Evaluates the image based on multiple choice questions and answers. | |
Parameters | |
---------- | |
image : PIL.Image.Image | |
The image to evaluate. | |
questions : list[str] | |
List of questions about the image. | |
choices_list : list[list[str]] | |
List of lists of possible answer choices, corresponding to each question. | |
answers : list[str] | |
List of correct answers from the choices, corresponding to each question. | |
Returns | |
------- | |
list[float] | |
List of scores (values between 0 and 1) representing the probability of the correct answer for each question. | |
""" | |
prompts = [ | |
self.format_prompt(question, choices) | |
for question, choices in zip(questions, choices_list, strict=True) | |
] | |
batched_choice_probabilities = self.get_choice_probability( | |
image, prompts, choices_list | |
) | |
scores = [] | |
for i, _ in enumerate(questions): | |
choice_probabilities = batched_choice_probabilities[i] | |
answer = answers[i] | |
answer_probability = 0.0 | |
for choice, prob in choice_probabilities.items(): | |
if choice == answer: | |
answer_probability = prob | |
break | |
scores.append(answer_probability) | |
return scores | |
def format_prompt(self, question: str, choices: list[str]) -> str: | |
prompt = f'<image>answer en Question: {question}\nChoices:\n' | |
for i, choice in enumerate(choices): | |
prompt += f'{self.letters[i]}. {choice}\n' | |
return prompt | |
def mask_choices(self, logits, choices_list): | |
"""Masks logits for the first token of each choice letter for each question in the batch.""" | |
batch_size = logits.shape[0] | |
masked_logits = torch.full_like(logits, float('-inf')) | |
for batch_idx in range(batch_size): | |
choices = choices_list[batch_idx] | |
for i in range(len(choices)): | |
letter_token = self.letters[i] | |
first_token = self.processor.tokenizer.encode( | |
letter_token, add_special_tokens=False | |
)[0] | |
first_token_with_space = self.processor.tokenizer.encode( | |
' ' + letter_token, add_special_tokens=False | |
)[0] | |
if isinstance(first_token, int): | |
masked_logits[batch_idx, first_token] = logits[ | |
batch_idx, first_token | |
] | |
if isinstance(first_token_with_space, int): | |
masked_logits[batch_idx, first_token_with_space] = logits[ | |
batch_idx, first_token_with_space | |
] | |
return masked_logits | |
def get_choice_probability(self, image, prompts, choices_list) -> list[dict]: | |
inputs = self.processor( | |
images=[image] * len(prompts), | |
text=prompts, | |
return_tensors='pt', | |
padding='longest', | |
).to('cuda') | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
logits = outputs.logits[:, -1, :] # Logits for the last (predicted) token | |
masked_logits = self.mask_choices(logits, choices_list) | |
probabilities = torch.softmax(masked_logits, dim=-1) | |
batched_choice_probabilities = [] | |
for batch_idx in range(len(prompts)): | |
choice_probabilities = {} | |
choices = choices_list[batch_idx] | |
for i, choice in enumerate(choices): | |
letter_token = self.letters[i] | |
first_token = self.processor.tokenizer.encode( | |
letter_token, add_special_tokens=False | |
)[0] | |
first_token_with_space = self.processor.tokenizer.encode( | |
' ' + letter_token, add_special_tokens=False | |
)[0] | |
prob = 0.0 | |
if isinstance(first_token, int): | |
prob += probabilities[batch_idx, first_token].item() | |
if isinstance(first_token_with_space, int): | |
prob += probabilities[batch_idx, first_token_with_space].item() | |
choice_probabilities[choice] = prob | |
# Renormalize probabilities for each question | |
total_prob = sum(choice_probabilities.values()) | |
if total_prob > 0: | |
renormalized_probabilities = { | |
choice: prob / total_prob | |
for choice, prob in choice_probabilities.items() | |
} | |
else: | |
renormalized_probabilities = ( | |
choice_probabilities # Avoid division by zero if total_prob is 0 | |
) | |
batched_choice_probabilities.append(renormalized_probabilities) | |
return batched_choice_probabilities | |
def ocr(self, image, free_chars=4): | |
inputs = ( | |
self.processor( | |
text='<image>ocr\n', | |
images=image, | |
return_tensors='pt', | |
) | |
.to(torch.float16) | |
.to(self.model.device) | |
) | |
input_len = inputs['input_ids'].shape[-1] | |
with torch.inference_mode(): | |
outputs = self.model.generate(**inputs, max_new_tokens=32, do_sample=False) | |
outputs = outputs[0][input_len:] | |
decoded = self.processor.decode(outputs, skip_special_tokens=True) | |
num_char = len(decoded) | |
# Exponentially decreasing towards 0.0 if more than free_chars detected | |
return min(1.0, math.exp(-num_char + free_chars)) | |
class AestheticPredictor(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.input_size = input_size | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 1024), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
nn.Linear(16, 1), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class AestheticEvaluator: | |
def __init__(self): | |
self.model_path = 'improved-aesthetic-predictor/sac+logos+ava1-l14-linearMSE.pth' | |
self.clip_model_path = 'ViT-L/14' | |
self.predictor, self.clip_model, self.preprocessor = self.load() | |
def load(self): | |
"""Loads the aesthetic predictor model and CLIP model.""" | |
state_dict = torch.load(self.model_path, weights_only=True, map_location='cuda') | |
# CLIP embedding dim is 768 for CLIP ViT L 14 | |
predictor = AestheticPredictor(768) | |
predictor.load_state_dict(state_dict) | |
predictor.to('cuda') | |
predictor.eval() | |
clip_model, preprocessor = clip.load(self.clip_model_path, device='cuda') | |
return predictor, clip_model, preprocessor | |
def score(self, image: Image.Image) -> float: | |
"""Predicts the CLIP aesthetic score of an image.""" | |
image = self.preprocessor(image).unsqueeze(0).to('cuda') | |
with torch.no_grad(): | |
image_features = self.clip_model.encode_image(image) | |
# l2 normalize | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
image_features = image_features.cpu().detach().numpy() | |
score = self.predictor(torch.from_numpy(image_features).to('cuda').float()) | |
return score.item() / 10.0 # scale to [0, 1] | |
def harmonic_mean(a: float, b: float, beta: float = 1.0) -> float: | |
""" | |
Calculate the harmonic mean of two values, weighted using a beta parameter. | |
Args: | |
a: First value (e.g., precision) | |
b: Second value (e.g., recall) | |
beta: Weighting parameter | |
Returns: | |
Weighted harmonic mean | |
""" | |
# Handle zero values to prevent division by zero | |
if a <= 0 or b <= 0: | |
return 0.0 | |
return (1 + beta**2) * (a * b) / (beta**2 * a + b) | |
def svg_to_png(svg_code: str, size: tuple = (384, 384)) -> Image.Image: | |
""" | |
Converts an SVG string to a PNG image using CairoSVG. | |
If the SVG does not define a `viewBox`, it will add one using the provided size. | |
Parameters | |
---------- | |
svg_code : str | |
The SVG string to convert. | |
size : tuple[int, int], default=(384, 384) | |
The desired size of the output PNG image (width, height). | |
Returns | |
------- | |
PIL.Image.Image | |
The generated PNG image. | |
""" | |
# Ensure SVG has proper size attributes | |
if 'viewBox' not in svg_code: | |
svg_code = svg_code.replace('<svg', f'<svg viewBox="0 0 {size[0]} {size[1]}"') | |
# Convert SVG to PNG | |
png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8')) | |
return Image.open(io.BytesIO(png_data)).convert('RGB').resize(size) | |
class ImageProcessor: | |
def __init__(self, image: Image.Image, seed=None): | |
"""Initialize with either a path to an image or a PIL Image object.""" | |
self.image = image | |
self.original_image = self.image.copy() | |
if seed is not None: | |
self.rng = np.random.RandomState(seed) | |
else: | |
self.rng = np.random | |
def reset(self): | |
self.image = self.original_image.copy() | |
return self | |
def visualize_comparison( | |
self, | |
original_name='Original', | |
processed_name='Processed', | |
figsize=(10, 5), | |
show=True, | |
): | |
"""Display original and processed images side by side.""" | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize) | |
ax1.imshow(np.asarray(self.original_image)) | |
ax1.set_title(original_name) | |
ax1.axis('off') | |
ax2.imshow(np.asarray(self.image)) | |
ax2.set_title(processed_name) | |
ax2.axis('off') | |
title = f'{original_name} vs {processed_name}' | |
fig.suptitle(title) | |
fig.tight_layout() | |
if show: | |
plt.show() | |
return fig | |
def apply_median_filter(self, size=3): | |
"""Apply median filter to remove outlier pixel values. | |
Args: | |
size: Size of the median filter window. | |
""" | |
self.image = self.image.filter(ImageFilter.MedianFilter(size=size)) | |
return self | |
def apply_bilateral_filter(self, d=9, sigma_color=75, sigma_space=75): | |
"""Apply bilateral filter to smooth while preserving edges. | |
Args: | |
d: Diameter of each pixel neighborhood | |
sigma_color: Filter sigma in the color space | |
sigma_space: Filter sigma in the coordinate space | |
""" | |
# Convert PIL Image to numpy array for OpenCV | |
img_array = np.asarray(self.image) | |
# Apply bilateral filter | |
filtered = cv2.bilateralFilter(img_array, d, sigma_color, sigma_space) | |
# Convert back to PIL Image | |
self.image = Image.fromarray(filtered) | |
return self | |
def apply_fft_low_pass(self, cutoff_frequency=0.5): | |
"""Apply low-pass filter in the frequency domain using FFT. | |
Args: | |
cutoff_frequency: Normalized cutoff frequency (0-1). | |
Lower values remove more high frequencies. | |
""" | |
# Convert to numpy array, ensuring float32 for FFT | |
img_array = np.array(self.image, dtype=np.float32) | |
# Process each color channel separately | |
result = np.zeros_like(img_array) | |
for i in range(3): # For RGB channels | |
# Apply FFT | |
f = np.fft.fft2(img_array[:, :, i]) | |
fshift = np.fft.fftshift(f) | |
# Create a low-pass filter mask | |
rows, cols = img_array[:, :, i].shape | |
crow, ccol = rows // 2, cols // 2 | |
mask = np.zeros((rows, cols), np.float32) | |
r = int(min(crow, ccol) * cutoff_frequency) | |
center = [crow, ccol] | |
x, y = np.ogrid[:rows, :cols] | |
mask_area = (x - center[0]) ** 2 + (y - center[1]) ** 2 <= r * r | |
mask[mask_area] = 1 | |
# Apply mask and inverse FFT | |
fshift_filtered = fshift * mask | |
f_ishift = np.fft.ifftshift(fshift_filtered) | |
img_back = np.fft.ifft2(f_ishift) | |
img_back = np.real(img_back) | |
result[:, :, i] = img_back | |
# Clip to 0-255 range and convert to uint8 after processing all channels | |
result = np.clip(result, 0, 255).astype(np.uint8) | |
# Convert back to PIL Image | |
self.image = Image.fromarray(result) | |
return self | |
def apply_jpeg_compression(self, quality=85): | |
"""Apply JPEG compression. | |
Args: | |
quality: JPEG quality (0-95). Lower values increase compression. | |
""" | |
buffer = io.BytesIO() | |
self.image.save(buffer, format='JPEG', quality=quality) | |
buffer.seek(0) | |
self.image = Image.open(buffer) | |
return self | |
def apply_random_crop_resize(self, crop_percent=0.05): | |
"""Randomly crop and resize back to original dimensions. | |
Args: | |
crop_percent: Percentage of image to crop (0-0.4). | |
""" | |
width, height = self.image.size | |
crop_pixels_w = int(width * crop_percent) | |
crop_pixels_h = int(height * crop_percent) | |
left = self.rng.randint(0, crop_pixels_w + 1) | |
top = self.rng.randint(0, crop_pixels_h + 1) | |
right = width - self.rng.randint(0, crop_pixels_w + 1) | |
bottom = height - self.rng.randint(0, crop_pixels_h + 1) | |
self.image = self.image.crop((left, top, right, bottom)) | |
self.image = self.image.resize((width, height), Image.BILINEAR) | |
return self | |
def apply(self): | |
"""Apply an ensemble of defenses.""" | |
return ( | |
self.apply_random_crop_resize(crop_percent=0.03) | |
.apply_jpeg_compression(quality=95) | |
.apply_median_filter(size=9) | |
.apply_fft_low_pass(cutoff_frequency=0.5) | |
.apply_bilateral_filter(d=5, sigma_color=75, sigma_space=75) | |
.apply_jpeg_compression(quality=92) | |
) |