smart-cropper / app.py
mostlycached's picture
Update app.py
d81e58f verified
raw
history blame contribute delete
10.1 kB
import os
import gradio as gr
import numpy as np
import torch
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SamModel, SamProcessor
import warnings
warnings.filterwarnings("ignore")
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load SAM model and processor
model_id = "facebook/sam-vit-base"
processor = SamProcessor.from_pretrained(model_id)
model = SamModel.from_pretrained(model_id).to(device)
def get_sam_mask(image, points=None):
"""
Generate mask from SAM model based on the entire image
"""
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Process image with SAM
if points is None:
# Generate automatic masks for the whole image
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get the best mask (highest IoU)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0]
# Convert to binary mask and return the largest mask
masks = masks.numpy()
if masks.shape[0] > 0:
# Calculate area of each mask and get the largest one
areas = [np.sum(mask) for mask in masks]
largest_mask_idx = np.argmax(areas)
return masks[largest_mask_idx].astype(np.uint8) * 255
else:
# If no masks found, return full image mask
return np.ones((image.height, image.width), dtype=np.uint8) * 255
else:
# Use the provided points to generate a mask
inputs = processor(images=image, input_points=[points], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
# Get the mask
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)[0][0]
return masks[0].numpy().astype(np.uint8) * 255
def find_optimal_crop(image, mask, target_aspect_ratio):
"""
Find the optimal crop that preserves important content based on the mask
"""
# Convert PIL image to numpy array
image_np = np.array(image)
h, w = mask.shape
# Find the bounding box of the important content
# First, find where the mask is non-zero (important content)
y_indices, x_indices = np.where(mask > 0)
if len(y_indices) == 0 or len(x_indices) == 0:
# Fallback if no mask is found
content_box = (0, 0, w, h)
else:
# Get the bounding box of important content
min_x, max_x = np.min(x_indices), np.max(x_indices)
min_y, max_y = np.min(y_indices), np.max(y_indices)
content_width = max_x - min_x + 1
content_height = max_y - min_y + 1
content_box = (min_x, min_y, content_width, content_height)
# Calculate target dimensions based on the original image
if target_aspect_ratio > w / h:
# Target is wider than original
target_h = int(w / target_aspect_ratio)
target_w = w
else:
# Target is taller than original
target_h = h
target_w = int(h * target_aspect_ratio)
# Calculate the center of the important content
content_center_x = content_box[0] + content_box[2] // 2
content_center_y = content_box[1] + content_box[3] // 2
# Try to center the crop on the important content
x = max(0, min(content_center_x - target_w // 2, w - target_w))
y = max(0, min(content_center_y - target_h // 2, h - target_h))
# Check if the important content fits within this crop
min_x, min_y, content_width, content_height = content_box
max_x = min_x + content_width
max_y = min_y + content_height
# If the content doesn't fit in the crop, adjust the crop
if target_w >= content_width and target_h >= content_height:
# If the crop is large enough to include all content, center it
x = max(0, min(content_center_x - target_w // 2, w - target_w))
y = max(0, min(content_center_y - target_h // 2, h - target_h))
else:
# If crop isn't large enough for all content, maximize visible content
# and prioritize centering the crop on the content
x = max(0, min(min_x, w - target_w))
y = max(0, min(min_y, h - target_h))
# If we still can't fit width, center the crop horizontally
if content_width > target_w:
x = max(0, min(content_center_x - target_w // 2, w - target_w))
# If we still can't fit height, center the crop vertically
if content_height > target_h:
y = max(0, min(content_center_y - target_h // 2, h - target_h))
return (x, y, x + target_w, y + target_h)
def smart_crop(input_image, target_aspect_ratio, point_x=None, point_y=None):
"""
Main function to perform smart cropping
"""
if input_image is None:
return None
# Open image and convert to RGB
pil_image = Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image
if pil_image.mode != "RGB":
pil_image = pil_image.convert("RGB")
# Generate mask using SAM
points = None
if point_x is not None and point_y is not None and point_x > 0 and point_y > 0:
points = [[point_x, point_y]]
mask = get_sam_mask(pil_image, points)
# Calculate the best crop
crop_box = find_optimal_crop(pil_image, mask, target_aspect_ratio)
# Crop the image
cropped_img = pil_image.crop(crop_box)
# Visualize the process
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
ax[0].imshow(pil_image)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(mask, cmap='gray')
ax[1].set_title("SAM Segmentation Mask")
ax[1].axis("off")
ax[2].imshow(cropped_img)
ax[2].set_title(f"Smart Cropped ({target_aspect_ratio:.2f})")
ax[2].axis("off")
plt.tight_layout()
# Create a temporary file for visualization
vis_path = "visualization.png"
plt.savefig(vis_path)
plt.close()
return cropped_img, vis_path
def aspect_ratio_options(choice):
"""Map aspect ratio choices to actual values"""
options = {
"16:9 (Landscape)": 16/9,
"9:16 (Portrait)": 9/16,
"4:3 (Standard)": 4/3,
"3:4 (Portrait)": 3/4,
"1:1 (Square)": 1/1,
"21:9 (Ultrawide)": 21/9,
"2:3 (Portrait)": 2/3,
"3:2 (Landscape)": 3/2,
}
return options.get(choice, 16/9)
def process_image(input_image, aspect_ratio_choice, point_x=None, point_y=None):
if input_image is None:
return None, None
# Get the actual aspect ratio value
target_aspect_ratio = aspect_ratio_options(aspect_ratio_choice)
# Process the image
result_img, vis_path = smart_crop(input_image, target_aspect_ratio, point_x, point_y)
return result_img, vis_path
def create_app():
with gr.Blocks(title="Smart Image Cropper using SAM") as app:
gr.Markdown("# Smart Image Cropper using Segment Anything Model (SAM)")
gr.Markdown("""
Upload an image and choose your target aspect ratio. The app will use the Segment Anything Model (SAM)
to identify important content and crop intelligently to preserve it.
Optionally, you can click on the uploaded image to specify a point of interest.
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
aspect_ratio = gr.Dropdown(
choices=[
"16:9 (Landscape)",
"9:16 (Portrait)",
"4:3 (Standard)",
"3:4 (Portrait)",
"1:1 (Square)",
"21:9 (Ultrawide)",
"2:3 (Portrait)",
"3:2 (Landscape)"
],
value="16:9 (Landscape)",
label="Target Aspect Ratio"
)
point_coords = gr.State(value=[None, None])
def update_coords(img, evt: gr.SelectData):
return [evt.index[0], evt.index[1]]
input_image.select(update_coords, inputs=[input_image], outputs=[point_coords])
process_btn = gr.Button("Process Image")
with gr.Column(scale=2):
output_image = gr.Image(type="pil", label="Cropped Result")
visualization = gr.Image(type="filepath", label="Process Visualization")
process_btn.click(
fn=lambda img, ratio, coords: process_image(img, ratio, coords[0], coords[1]),
inputs=[input_image, aspect_ratio, point_coords],
outputs=[output_image, visualization]
)
gr.Markdown("""
## How It Works
1. The Segment Anything Model (SAM) analyzes your image to identify the important content
2. The app finds the optimal crop window that maximizes the preservation of that content
3. The image is cropped to your desired aspect ratio while keeping the important parts
## Tips
- For better results with specific subjects, click on the important object in the image
- Try different aspect ratios to see how the model adapts the cropping
""")
return app
# Create and launch the app
demo = create_app()
# For local testing
if __name__ == "__main__":
demo.launch()
else:
# For Hugging Face Spaces
demo.launch()