Spaces:
Running
Running
""" | |
Utility functions for OCR processing with Mistral AI. | |
Contains helper functions for working with OCR responses and image handling. | |
""" | |
import json | |
import base64 | |
import io | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union, Any | |
try: | |
from PIL import Image | |
PILLOW_AVAILABLE = True | |
except ImportError: | |
PILLOW_AVAILABLE = False | |
from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk | |
def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: | |
""" | |
Replace image placeholders in markdown with base64-encoded images. | |
Args: | |
markdown_str: Markdown text containing image placeholders | |
images_dict: Dictionary mapping image IDs to base64 strings | |
Returns: | |
Markdown text with images replaced by base64 data | |
""" | |
for img_name, base64_str in images_dict.items(): | |
markdown_str = markdown_str.replace( | |
f"", f"" | |
) | |
return markdown_str | |
def get_combined_markdown(ocr_response) -> str: | |
""" | |
Combine OCR text and images into a single markdown document. | |
Ensures proper spacing between text and images. | |
Args: | |
ocr_response: Response from OCR processing containing text and images | |
See https://docs.mistral.ai/capabilities/document/ for API reference | |
Returns: | |
Combined markdown string with embedded images | |
""" | |
markdowns: list[str] = [] | |
# Extract images from page | |
for page in ocr_response.pages: | |
image_data = {} | |
for img in page.images: | |
image_data[img.id] = img.image_base64 | |
# Replace image placeholders with actual images | |
page_markdown = replace_images_in_markdown(page.markdown, image_data) | |
# Ensure proper spacing between paragraphs and images | |
# Add extra newlines between paragraphs to improve rendering | |
page_markdown = page_markdown.replace("\n", "\n\n") | |
# Add page separator for multi-page documents | |
markdowns.append(page_markdown) | |
# Join pages with clear separators for multi-page documents | |
return "\n\n---\n\n".join(markdowns) | |
def encode_image_for_api(image_path: Union[str, Path]) -> str: | |
""" | |
Encode an image as base64 for API use. | |
Args: | |
image_path: Path to the image file | |
Returns: | |
Base64 data URL for the image | |
""" | |
# Convert to Path object if string | |
image_file = Path(image_path) if isinstance(image_path, str) else image_path | |
# Verify image exists | |
if not image_file.is_file(): | |
raise FileNotFoundError(f"Image file not found: {image_file}") | |
# Encode image as base64 | |
encoded = base64.b64encode(image_file.read_bytes()).decode() | |
return f"data:image/jpeg;base64,{encoded}" | |
def process_image_with_ocr(client, image_path: Union[str, Path], model: str = "mistral-ocr-latest"): | |
""" | |
Process an image with OCR and return the response. | |
Args: | |
client: Mistral AI client | |
image_path: Path to the image file | |
model: OCR model to use | |
Returns: | |
OCR response object | |
""" | |
# Encode image as base64 | |
base64_data_url = encode_image_for_api(image_path) | |
# Process image with OCR | |
image_response = client.ocr.process( | |
document=ImageURLChunk(image_url=base64_data_url), | |
model=model | |
) | |
return image_response | |
def ocr_response_to_json(ocr_response, indent: int = 4) -> str: | |
""" | |
Convert OCR response to a formatted JSON string. | |
Args: | |
ocr_response: OCR response object | |
indent: Indentation level for JSON formatting | |
Returns: | |
Formatted JSON string | |
""" | |
# Convert response to JSON | |
response_dict = json.loads(ocr_response.model_dump_json()) | |
return json.dumps(response_dict, indent=indent) | |
def get_combined_markdown_compressed(ocr_response, max_width: int = 1200, quality: int = 92) -> str: | |
""" | |
Combine OCR text and images into a single markdown document with compressed images. | |
Reduces image sizes to improve performance. | |
Args: | |
ocr_response: Response from OCR processing containing text and images | |
max_width: Maximum width to resize images to (preserves aspect ratio) | |
quality: JPEG quality (0-100) for compression | |
Returns: | |
Combined markdown string with embedded compressed images | |
""" | |
if not PILLOW_AVAILABLE: | |
# Fall back to regular method if PIL is not available | |
return get_combined_markdown(ocr_response) | |
markdowns: list[str] = [] | |
# Process each page | |
for page in ocr_response.pages: | |
image_data = {} | |
# Process and compress each image | |
for img in page.images: | |
try: | |
# Decode base64 image | |
img_bytes = base64.b64decode(img.image_base64.split(',')[1] if ',' in img.image_base64 else img.image_base64) | |
# Open with PIL | |
pil_img = Image.open(io.BytesIO(img_bytes)) | |
# Convert to RGB if not already (to ensure CV_8UC3 format) | |
if pil_img.mode != 'RGB': | |
pil_img = pil_img.convert('RGB') | |
# Resize if needed (maintain aspect ratio) | |
original_width, original_height = pil_img.size | |
if original_width > max_width: | |
ratio = max_width / original_width | |
new_height = int(original_height * ratio) | |
pil_img = pil_img.resize((max_width, new_height), Image.LANCZOS) | |
# Convert to bytes with compression | |
buffer = io.BytesIO() | |
format = pil_img.format if pil_img.format else 'JPEG' | |
if format.upper() == 'JPEG' or format.upper() == 'JPG': | |
pil_img.save(buffer, format=format, quality=quality, optimize=True) | |
else: | |
# For non-JPEG formats (PNG, etc.) | |
pil_img.save(buffer, format=format, optimize=True) | |
# Convert back to base64 | |
compressed_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
mime_type = f"image/{format.lower()}" if format else "image/jpeg" | |
image_data[img.id] = f"data:{mime_type};base64,{compressed_base64}" | |
except Exception as e: | |
# If compression fails, use original image | |
image_data[img.id] = img.image_base64 | |
# Replace image placeholders with compressed images | |
page_markdown = replace_images_in_markdown(page.markdown, image_data) | |
# Ensure proper spacing between paragraphs and images | |
page_markdown = page_markdown.replace("\n", "\n\n") | |
# Add page to list | |
markdowns.append(page_markdown) | |
# Join pages with clear separators | |
return "\n\n---\n\n".join(markdowns) | |
# For display in notebooks | |
try: | |
from IPython.display import Markdown, display | |
def display_ocr_with_images(ocr_response): | |
""" | |
Display OCR response with embedded images in IPython environments. | |
Args: | |
ocr_response: OCR response object | |
""" | |
combined_markdown = get_combined_markdown(ocr_response) | |
display(Markdown(combined_markdown)) | |
except ImportError: | |
# IPython not available | |
pass | |
def create_html_with_images(result_with_pages: dict) -> str: | |
""" | |
Create HTML with embedded images from the OCR result. | |
Args: | |
result_with_pages: OCR result with pages_data containing markdown and images | |
Returns: | |
HTML string with embedded images | |
""" | |
if not result_with_pages.get('has_images', False) or 'pages_data' not in result_with_pages: | |
return "<p>No images available in the document.</p>" | |
# Create HTML document | |
html = """<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Document with Images</title> | |
<style> | |
body { | |
font-family: 'Georgia', serif; | |
line-height: 1.6; | |
margin: 0; | |
padding: 20px; | |
background-color: #f9f9f9; | |
color: #333; | |
} | |
.container { | |
max-width: 1000px; | |
margin: 0 auto; | |
background-color: #fff; | |
padding: 30px; | |
border-radius: 8px; | |
box-shadow: 0 4px 12px rgba(0,0,0,0.1); | |
} | |
h1, h2, h3, h4 { | |
font-family: 'Bookman', 'Georgia', serif; | |
margin-top: 1.5em; | |
margin-bottom: 0.5em; | |
color: #222; | |
} | |
h1 { font-size: 2.2em; border-bottom: 2px solid #e0e0e0; padding-bottom: 10px; } | |
h2 { font-size: 1.8em; border-bottom: 1px solid #e0e0e0; padding-bottom: 6px; } | |
h3 { font-size: 1.5em; } | |
h4 { font-size: 1.2em; } | |
p { margin-bottom: 1.2em; text-align: justify; } | |
img { | |
max-width: 100%; | |
height: auto; | |
margin: 20px 0; | |
border: 1px solid #ddd; | |
border-radius: 6px; | |
box-shadow: 0 3px 6px rgba(0,0,0,0.1); | |
display: block; | |
} | |
.page { | |
margin-bottom: 40px; | |
padding-bottom: 30px; | |
border-bottom: 1px dashed #ccc; | |
} | |
.page:last-child { | |
border-bottom: none; | |
} | |
.page-title { | |
text-align: center; | |
color: #555; | |
font-style: italic; | |
margin: 30px 0; | |
} | |
pre { | |
background-color: #f5f5f5; | |
padding: 15px; | |
border-radius: 5px; | |
overflow-x: auto; | |
font-size: 14px; | |
line-height: 1.4; | |
} | |
blockquote { | |
border-left: 3px solid #ccc; | |
margin: 1.5em 0; | |
padding: 0.5em 1.5em; | |
background-color: #f5f5f5; | |
font-style: italic; | |
} | |
.poem { | |
font-family: 'Baskerville', 'Georgia', serif; | |
margin-left: 2em; | |
line-height: 1.8; | |
white-space: pre-wrap; | |
} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
""" | |
# Process each page | |
pages_data = result_with_pages.get('pages_data', []) | |
for page_idx, page in enumerate(pages_data): | |
page_number = page.get('page_number', page_idx + 1) | |
page_markdown = page.get('markdown', '') | |
page_images = page.get('images', []) | |
# Add page header | |
html += f'<div class="page" id="page-{page_number}">\n' | |
if len(pages_data) > 1: | |
html += f'<div class="page-title">Page {page_number}</div>\n' | |
# Process markdown text and replace image references | |
if page_markdown: | |
# Replace image markers with actual images | |
for img in page_images: | |
img_id = img.get('id', '') | |
img_base64 = img.get('image_base64', '') | |
if img_id and img_base64: | |
# Format image tag | |
img_tag = f'<img src="{img_base64}" alt="Image {img_id}" loading="lazy">' | |
# Replace markdown image reference with HTML image | |
page_markdown = page_markdown.replace(f'', img_tag) | |
# Convert line breaks to <p> tags for proper HTML formatting | |
paragraphs = page_markdown.split('\n\n') | |
for paragraph in paragraphs: | |
if paragraph.strip(): | |
# Check if this looks like a header | |
if paragraph.startswith('# '): | |
header_text = paragraph[2:].strip() | |
html += f'<h1>{header_text}</h1>\n' | |
elif paragraph.startswith('## '): | |
header_text = paragraph[3:].strip() | |
html += f'<h2>{header_text}</h2>\n' | |
elif paragraph.startswith('### '): | |
header_text = paragraph[4:].strip() | |
html += f'<h3>{header_text}</h3>\n' | |
else: | |
html += f'<p>{paragraph}</p>\n' | |
# Add any images that weren't referenced in the markdown | |
referenced_img_ids = [img.get('id') for img in page_images if img.get('id') in page_markdown] | |
for img in page_images: | |
img_id = img.get('id', '') | |
img_base64 = img.get('image_base64', '') | |
if img_id and img_base64 and img_id not in referenced_img_ids: | |
html += f'<img src="{img_base64}" alt="Image {img_id}" loading="lazy">\n' | |
# Close page div | |
html += '</div>\n' | |
# Close main container and document | |
html += """ </div> | |
</body> | |
</html>""" | |
return html |