Spaces:
Paused
Paused
import concurrent | |
import io | |
import logging | |
import re | |
import cairosvg | |
import kagglehub | |
import torch | |
from lxml import etree | |
from unsloth import FastLanguageModel | |
from unsloth.chat_templates import get_chat_template | |
svg_constraints = kagglehub.package_import('metric/svg-constraints') | |
class NaiveModel: | |
def __init__(self, model_name="unsloth/phi-4-unsloth-bnb-4bit", max_seq_length=2048, device="cuda"): | |
self.device = device | |
self.max_seq_length = max_seq_length | |
self.load_in_4bit = True | |
# Load the Unsloth Phi-4 model | |
self.model, self.tokenizer = FastLanguageModel.from_pretrained( | |
model_name=model_name, | |
max_seq_length=self.max_seq_length, | |
load_in_4bit=self.load_in_4bit | |
) | |
# Set up chat template | |
self.tokenizer = get_chat_template( | |
self.tokenizer, | |
chat_template="phi-4", | |
) | |
# Prepare model for inference | |
FastLanguageModel.for_inference(self.model) | |
self.prompt_template = """Generate SVG code to visually represent the following text description, while respecting the given constraints. | |
<constraints> | |
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs` | |
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity` | |
</constraints> | |
Please ensure that the generated SVG code is well-formed, valid, and strictly adheres to these constraints. Focus on a clear and concise representation of the input description within the given limitations. Always give the complete SVG code with nothing omitted. Never use an ellipsis. | |
<description>"A red circle with a blue square inside"</description> | |
```svg | |
<svg viewBox="0 0 256 256" width="256" height="256"> | |
<circle cx="50" cy="50" r="40" fill="red"/> | |
<rect x="30" y="30" width="40" height="40" fill="blue"/> | |
</svg> | |
``` | |
<description>"{}"</description> | |
""" | |
self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>""" | |
self.constraints = svg_constraints.SVGConstraints() | |
self.timeout_seconds = 90 | |
def predict(self, description: str, max_new_tokens=512) -> str: | |
def generate_svg(): | |
try: | |
# Format the prompt | |
prompt = self.prompt_template.format(description) | |
# Create messages in the format expected by the chat template | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
# Tokenize the messages | |
inputs = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
).to(self.device) | |
# Generate the output | |
outputs = self.model.generate( | |
input_ids=inputs, | |
max_new_tokens=max_new_tokens, | |
use_cache=True, | |
temperature=1.0, | |
min_p=0.1, | |
do_sample=True, | |
) | |
# Decode the output | |
output_decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the generated text (skip the prompt) | |
generated_text = output_decoded.split("```svg")[-1].split("```")[0] if "```svg" in output_decoded else "" | |
logging.debug('Output decoded from model: %s', output_decoded) | |
matches = re.findall(r"<svg.*?</svg>", output_decoded, re.DOTALL | re.IGNORECASE) | |
if matches: | |
svg = matches[-1] | |
else: | |
return self.default_svg | |
logging.debug('Unprocessed SVG: %s', svg) | |
svg = self.enforce_constraints(svg) | |
logging.debug('Processed SVG: %s', svg) | |
# Ensure the generated code can be converted by cairosvg | |
cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
return svg | |
except Exception as e: | |
logging.error('Exception during SVG generation: %s', e) | |
return self.default_svg | |
# Execute SVG generation in a new thread to enforce time constraints | |
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: | |
future = executor.submit(generate_svg) | |
try: | |
return future.result(timeout=self.timeout_seconds) | |
except concurrent.futures.TimeoutError: | |
logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds) | |
return self.default_svg | |
except Exception as e: | |
logging.error(f"An unexpected error occurred: {e}") | |
return self.default_svg | |
def enforce_constraints(self, svg_string: str) -> str: | |
"""Enforces constraints on an SVG string, removing disallowed elements | |
and attributes. | |
Parameters | |
---------- | |
svg_string : str | |
The SVG string to process. | |
Returns | |
------- | |
str | |
The processed SVG string, or the default SVG if constraints | |
cannot be satisfied. | |
""" | |
logging.info('Sanitizing SVG...') | |
try: | |
parser = etree.XMLParser(remove_blank_text=True, remove_comments=True) | |
root = etree.fromstring(svg_string, parser=parser) | |
except etree.ParseError as e: | |
logging.error('SVG Parse Error: %s. Returning default SVG.', e) | |
logging.error('SVG string: %s', svg_string) | |
return self.default_svg | |
elements_to_remove = [] | |
for element in root.iter(): | |
tag_name = etree.QName(element.tag).localname | |
# Remove disallowed elements | |
if tag_name not in self.constraints.allowed_elements: | |
elements_to_remove.append(element) | |
continue # Skip attribute checks for removed elements | |
# Remove disallowed attributes | |
attrs_to_remove = [] | |
for attr in element.attrib: | |
attr_name = etree.QName(attr).localname | |
if ( | |
attr_name | |
not in self.constraints.allowed_elements[tag_name] | |
and attr_name | |
not in self.constraints.allowed_elements['common'] | |
): | |
attrs_to_remove.append(attr) | |
for attr in attrs_to_remove: | |
logging.debug( | |
'Attribute "%s" for element "%s" not allowed. Removing.', | |
attr, | |
tag_name, | |
) | |
del element.attrib[attr] | |
# Check and remove invalid href attributes | |
for attr, value in element.attrib.items(): | |
if etree.QName(attr).localname == 'href' and not value.startswith('#'): | |
logging.debug( | |
'Removing invalid href attribute in element "%s".', tag_name | |
) | |
del element.attrib[attr] | |
# Validate path elements to help ensure SVG conversion | |
if tag_name == 'path': | |
d_attribute = element.get('d') | |
if not d_attribute: | |
logging.warning('Path element is missing "d" attribute. Removing path.') | |
elements_to_remove.append(element) | |
continue # Skip further checks for this removed element | |
# Use regex to validate 'd' attribute format | |
path_regex = re.compile( | |
r'^' # Start of string | |
r'(?:' # Non-capturing group for each command + numbers block | |
r'[MmZzLlHhVvCcSsQqTtAa]' # Valid SVG path commands (adjusted to exclude extra letters) | |
r'\s*' # Optional whitespace after command | |
r'(?:' # Non-capturing group for optional numbers | |
r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?' # First number | |
r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*' # Subsequent numbers with mandatory separator(s) | |
r')?' # Numbers are optional (e.g. for Z command) | |
r'\s*' # Optional whitespace after numbers/command block | |
r')+' # One or more command blocks | |
r'\s*' # Optional trailing whitespace | |
r'$' # End of string | |
) | |
if not path_regex.match(d_attribute): | |
logging.warning( | |
'Path element has malformed "d" attribute format. Removing path.' | |
) | |
elements_to_remove.append(element) | |
continue | |
logging.debug('Path element "d" attribute validated (regex check).') | |
# Remove elements marked for removal | |
for element in elements_to_remove: | |
if element.getparent() is not None: | |
element.getparent().remove(element) | |
logging.debug('Removed element: %s', element.tag) | |
try: | |
cleaned_svg_string = etree.tostring(root, encoding='unicode') | |
return cleaned_svg_string | |
except ValueError as e: | |
logging.error( | |
'SVG could not be sanitized to meet constraints: %s', e | |
) | |
return self.default_svg | |
if __name__ == "__main__": | |
model = NaiveModel() | |
svg = model.predict("a purple forest at dusk") | |
# Convert SVG to PNG | |
try: | |
# Create a PNG in memory | |
png_data = cairosvg.svg2png(bytestring=svg.encode('utf-8')) | |
# Save the PNG to a file | |
with open("output.png", "wb") as f: | |
f.write(png_data) | |
print("SVG saved as output.png") | |
except Exception as e: | |
print(f"Error converting SVG to PNG: {e}") | |