File size: 3,745 Bytes
9d4d876
 
 
116b6a0
 
fce60d3
9d4d876
116b6a0
 
 
 
20775db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116b6a0
 
20775db
 
116b6a0
20775db
 
 
116b6a0
 
20775db
9d4d876
20775db
 
 
9d4d876
 
 
 
 
 
 
 
116b6a0
20775db
 
116b6a0
20775db
9d4d876
116b6a0
20775db
 
 
116b6a0
20775db
9d4d876
 
 
 
 
 
116b6a0
9d4d876
 
 
 
 
20775db
 
 
 
 
116b6a0
 
20775db
 
 
116b6a0
 
20775db
9d4d876
 
 
 
116b6a0
9d4d876
20775db
9d4d876
 
116b6a0
20775db
9d4d876
 
 
20775db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from datasets import load_dataset
import random
from groq import Groq

# Load the dataset
dataset = load_dataset("zimhe/pseudo-floor-plan-12k")

# Initialize GroqAPI
client = Groq(api_key="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")

# Function to check dataset structure
def inspect_dataset():
    if "train" not in dataset or not dataset["train"]:
        return "Error: Dataset does not contain a valid 'train' split."
    if "caption" not in dataset["train"].features:
        return "Error: 'caption' field not found in the dataset."
    return None

# Function to select a random floor plan template based on caption
def get_floor_plan_by_caption(caption):
    error = inspect_dataset()
    if error:
        return None, error

    filtered_data = [
        item for item in dataset["train"] 
        if caption.lower() in item["caption"].lower()
    ]
    if not filtered_data:
        return None, "Error: No templates available for the specified caption."
    return random.choice(filtered_data), None

# Function to create a plot for the floor plan
def create_floor_plan_from_template(template):
    try:
        plot_width, plot_height = template["plot_size"]["width"], template["plot_size"]["height"]
    except KeyError:
        return "Error: Template is missing plot size information."

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.set_xlim(0, plot_width)
    ax.set_ylim(0, plot_height)

    # Draw plot boundary
    ax.add_patch(Rectangle((0, 0), plot_width, plot_height, edgecolor="black", fill=None, linewidth=2, label="Plot Boundary"))

    # Draw rooms based on the template
    for room in template["rooms"]:
        x, y, width, height = room["x"], room["y"], room["width"], room["height"]
        ax.add_patch(Rectangle((x, y), width, height, edgecolor="blue", fill=None, linewidth=1))
        ax.text(x + width / 2, y + height / 2, room["name"], ha="center", va="center", fontsize=8)

    # Add additional features like courtyard, parking, and washrooms if provided
    for feature in template["features"]:
        x, y, width, height = feature["x"], feature["y"], feature["width"], feature["height"]
        color = "green" if feature["type"] == "courtyard" else "red" if feature["type"] == "parking" else "purple"
        ax.add_patch(Rectangle((x, y), width, height, edgecolor=color, fill=None, linewidth=1.5))
        ax.text(x + width / 2, y + height / 2, feature["type"].capitalize(), ha="center", va="center", fontsize=8)

    # Finalize layout
    ax.axis("off")
    plt.tight_layout()

    # Save and return the plot
    img_path = "floor_plan_template.png"
    plt.savefig(img_path)
    plt.close(fig)
    return img_path

# Define the Gradio interface
def floor_plan_with_groq(caption):
    # Fetch a template for the given caption
    template, error = get_floor_plan_by_caption(caption)
    if error:
        return error

    # Generate the floor plan using the template
    floor_plan_image = create_floor_plan_from_template(template)
    if isinstance(floor_plan_image, str) and floor_plan_image.startswith("Error"):
        return floor_plan_image

    # Enhance the floor plan using Groq
    enhanced_image = client.enhance_image(floor_plan_image)
    return enhanced_image

# Gradio Interface
interface = gr.Interface(
    fn=floor_plan_with_groq,
    inputs=[
        gr.Textbox(label="Enter Caption for Floor Plan"),
    ],
    outputs="image",
    title="Enhanced Floor Plan Generator with Groq AI",
    description="Generate diverse and realistic floor plans using captions from a pre-trained dataset and Groq's AI capabilities."
)

if __name__ == "__main__":
    interface.launch()