import gradio as gr import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from datasets import load_dataset import random from groq import Groq # Load the dataset dataset = load_dataset("zimhe/pseudo-floor-plan-12k") # Initialize GroqAPI client = Groq(api_key="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") # Function to check dataset structure def inspect_dataset(): if "train" not in dataset or not dataset["train"]: return "Error: Dataset does not contain a valid 'train' split." if "caption" not in dataset["train"].features: return "Error: 'caption' field not found in the dataset." return None # Function to select a random floor plan template based on caption def get_floor_plan_by_caption(caption): error = inspect_dataset() if error: return None, error filtered_data = [ item for item in dataset["train"] if caption.lower() in item["caption"].lower() ] if not filtered_data: return None, "Error: No templates available for the specified caption." return random.choice(filtered_data), None # Function to create a plot for the floor plan def create_floor_plan_from_template(template): try: plot_width, plot_height = template["plot_size"]["width"], template["plot_size"]["height"] except KeyError: return "Error: Template is missing plot size information." fig, ax = plt.subplots(figsize=(10, 6)) ax.set_xlim(0, plot_width) ax.set_ylim(0, plot_height) # Draw plot boundary ax.add_patch(Rectangle((0, 0), plot_width, plot_height, edgecolor="black", fill=None, linewidth=2, label="Plot Boundary")) # Draw rooms based on the template for room in template["rooms"]: x, y, width, height = room["x"], room["y"], room["width"], room["height"] ax.add_patch(Rectangle((x, y), width, height, edgecolor="blue", fill=None, linewidth=1)) ax.text(x + width / 2, y + height / 2, room["name"], ha="center", va="center", fontsize=8) # Add additional features like courtyard, parking, and washrooms if provided for feature in template["features"]: x, y, width, height = feature["x"], feature["y"], feature["width"], feature["height"] color = "green" if feature["type"] == "courtyard" else "red" if feature["type"] == "parking" else "purple" ax.add_patch(Rectangle((x, y), width, height, edgecolor=color, fill=None, linewidth=1.5)) ax.text(x + width / 2, y + height / 2, feature["type"].capitalize(), ha="center", va="center", fontsize=8) # Finalize layout ax.axis("off") plt.tight_layout() # Save and return the plot img_path = "floor_plan_template.png" plt.savefig(img_path) plt.close(fig) return img_path # Define the Gradio interface def floor_plan_with_groq(caption): # Fetch a template for the given caption template, error = get_floor_plan_by_caption(caption) if error: return error # Generate the floor plan using the template floor_plan_image = create_floor_plan_from_template(template) if isinstance(floor_plan_image, str) and floor_plan_image.startswith("Error"): return floor_plan_image # Enhance the floor plan using Groq enhanced_image = client.enhance_image(floor_plan_image) return enhanced_image # Gradio Interface interface = gr.Interface( fn=floor_plan_with_groq, inputs=[ gr.Textbox(label="Enter Caption for Floor Plan"), ], outputs="image", title="Enhanced Floor Plan Generator with Groq AI", description="Generate diverse and realistic floor plans using captions from a pre-trained dataset and Groq's AI capabilities." ) if __name__ == "__main__": interface.launch()