Spaces:
Sleeping
Sleeping
File size: 3,745 Bytes
9d4d876 116b6a0 fce60d3 9d4d876 116b6a0 20775db 116b6a0 20775db 116b6a0 20775db 116b6a0 20775db 9d4d876 20775db 9d4d876 116b6a0 20775db 116b6a0 20775db 9d4d876 116b6a0 20775db 116b6a0 20775db 9d4d876 116b6a0 9d4d876 20775db 116b6a0 20775db 116b6a0 20775db 9d4d876 116b6a0 9d4d876 20775db 9d4d876 116b6a0 20775db 9d4d876 20775db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from datasets import load_dataset
import random
from groq import Groq
# Load the dataset
dataset = load_dataset("zimhe/pseudo-floor-plan-12k")
# Initialize GroqAPI
client = Groq(api_key="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")
# Function to check dataset structure
def inspect_dataset():
if "train" not in dataset or not dataset["train"]:
return "Error: Dataset does not contain a valid 'train' split."
if "caption" not in dataset["train"].features:
return "Error: 'caption' field not found in the dataset."
return None
# Function to select a random floor plan template based on caption
def get_floor_plan_by_caption(caption):
error = inspect_dataset()
if error:
return None, error
filtered_data = [
item for item in dataset["train"]
if caption.lower() in item["caption"].lower()
]
if not filtered_data:
return None, "Error: No templates available for the specified caption."
return random.choice(filtered_data), None
# Function to create a plot for the floor plan
def create_floor_plan_from_template(template):
try:
plot_width, plot_height = template["plot_size"]["width"], template["plot_size"]["height"]
except KeyError:
return "Error: Template is missing plot size information."
fig, ax = plt.subplots(figsize=(10, 6))
ax.set_xlim(0, plot_width)
ax.set_ylim(0, plot_height)
# Draw plot boundary
ax.add_patch(Rectangle((0, 0), plot_width, plot_height, edgecolor="black", fill=None, linewidth=2, label="Plot Boundary"))
# Draw rooms based on the template
for room in template["rooms"]:
x, y, width, height = room["x"], room["y"], room["width"], room["height"]
ax.add_patch(Rectangle((x, y), width, height, edgecolor="blue", fill=None, linewidth=1))
ax.text(x + width / 2, y + height / 2, room["name"], ha="center", va="center", fontsize=8)
# Add additional features like courtyard, parking, and washrooms if provided
for feature in template["features"]:
x, y, width, height = feature["x"], feature["y"], feature["width"], feature["height"]
color = "green" if feature["type"] == "courtyard" else "red" if feature["type"] == "parking" else "purple"
ax.add_patch(Rectangle((x, y), width, height, edgecolor=color, fill=None, linewidth=1.5))
ax.text(x + width / 2, y + height / 2, feature["type"].capitalize(), ha="center", va="center", fontsize=8)
# Finalize layout
ax.axis("off")
plt.tight_layout()
# Save and return the plot
img_path = "floor_plan_template.png"
plt.savefig(img_path)
plt.close(fig)
return img_path
# Define the Gradio interface
def floor_plan_with_groq(caption):
# Fetch a template for the given caption
template, error = get_floor_plan_by_caption(caption)
if error:
return error
# Generate the floor plan using the template
floor_plan_image = create_floor_plan_from_template(template)
if isinstance(floor_plan_image, str) and floor_plan_image.startswith("Error"):
return floor_plan_image
# Enhance the floor plan using Groq
enhanced_image = client.enhance_image(floor_plan_image)
return enhanced_image
# Gradio Interface
interface = gr.Interface(
fn=floor_plan_with_groq,
inputs=[
gr.Textbox(label="Enter Caption for Floor Plan"),
],
outputs="image",
title="Enhanced Floor Plan Generator with Groq AI",
description="Generate diverse and realistic floor plans using captions from a pre-trained dataset and Groq's AI capabilities."
)
if __name__ == "__main__":
interface.launch() |