Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
from datasets import load_dataset | |
import random | |
from groq import Groq | |
# Load the dataset | |
dataset = load_dataset("zimhe/pseudo-floor-plan-12k") | |
# Initialize GroqAPI | |
client = Groq(api_key="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") | |
# Function to check dataset structure | |
def inspect_dataset(): | |
if "train" not in dataset or not dataset["train"]: | |
return "Error: Dataset does not contain a valid 'train' split." | |
if "caption" not in dataset["train"].features: | |
return "Error: 'caption' field not found in the dataset." | |
return None | |
# Function to select a random floor plan template based on caption | |
def get_floor_plan_by_caption(caption): | |
error = inspect_dataset() | |
if error: | |
return None, error | |
filtered_data = [ | |
item for item in dataset["train"] | |
if caption.lower() in item["caption"].lower() | |
] | |
if not filtered_data: | |
return None, "Error: No templates available for the specified caption." | |
return random.choice(filtered_data), None | |
# Function to create a plot for the floor plan | |
def create_floor_plan_from_template(template): | |
try: | |
plot_width, plot_height = template["plot_size"]["width"], template["plot_size"]["height"] | |
except KeyError: | |
return "Error: Template is missing plot size information." | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
ax.set_xlim(0, plot_width) | |
ax.set_ylim(0, plot_height) | |
# Draw plot boundary | |
ax.add_patch(Rectangle((0, 0), plot_width, plot_height, edgecolor="black", fill=None, linewidth=2, label="Plot Boundary")) | |
# Draw rooms based on the template | |
for room in template["rooms"]: | |
x, y, width, height = room["x"], room["y"], room["width"], room["height"] | |
ax.add_patch(Rectangle((x, y), width, height, edgecolor="blue", fill=None, linewidth=1)) | |
ax.text(x + width / 2, y + height / 2, room["name"], ha="center", va="center", fontsize=8) | |
# Add additional features like courtyard, parking, and washrooms if provided | |
for feature in template["features"]: | |
x, y, width, height = feature["x"], feature["y"], feature["width"], feature["height"] | |
color = "green" if feature["type"] == "courtyard" else "red" if feature["type"] == "parking" else "purple" | |
ax.add_patch(Rectangle((x, y), width, height, edgecolor=color, fill=None, linewidth=1.5)) | |
ax.text(x + width / 2, y + height / 2, feature["type"].capitalize(), ha="center", va="center", fontsize=8) | |
# Finalize layout | |
ax.axis("off") | |
plt.tight_layout() | |
# Save and return the plot | |
img_path = "floor_plan_template.png" | |
plt.savefig(img_path) | |
plt.close(fig) | |
return img_path | |
# Define the Gradio interface | |
def floor_plan_with_groq(caption): | |
# Fetch a template for the given caption | |
template, error = get_floor_plan_by_caption(caption) | |
if error: | |
return error | |
# Generate the floor plan using the template | |
floor_plan_image = create_floor_plan_from_template(template) | |
if isinstance(floor_plan_image, str) and floor_plan_image.startswith("Error"): | |
return floor_plan_image | |
# Enhance the floor plan using Groq | |
enhanced_image = client.enhance_image(floor_plan_image) | |
return enhanced_image | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=floor_plan_with_groq, | |
inputs=[ | |
gr.Textbox(label="Enter Caption for Floor Plan"), | |
], | |
outputs="image", | |
title="Enhanced Floor Plan Generator with Groq AI", | |
description="Generate diverse and realistic floor plans using captions from a pre-trained dataset and Groq's AI capabilities." | |
) | |
if __name__ == "__main__": | |
interface.launch() |