File size: 5,437 Bytes
33803b5
 
 
5645b6a
f05270d
33803b5
 
 
 
8aef8ae
33803b5
8aef8ae
 
33803b5
8aef8ae
 
 
 
 
 
 
33803b5
 
 
 
 
5645b6a
33803b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05270d
 
 
8aef8ae
5645b6a
 
8aef8ae
5645b6a
 
 
 
 
8aef8ae
 
5645b6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33803b5
 
 
 
5645b6a
33803b5
 
5645b6a
 
 
33803b5
 
 
 
 
 
 
8aef8ae
33803b5
 
f05270d
 
 
 
 
33803b5
 
8aef8ae
33803b5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import gradio as gr
import plotly.graph_objects as go
import trimesh
from pathlib import Path

device = torch.device("cpu")
model = torch.jit.load('model_scripted.pt').to(device)


def normalize_vertices(verts):
    # Center the vertices
    center = verts.mean(dim=0)
    verts = verts - center

    # Find the maximum absolute value for each axis to scale them independently
    scale = verts.abs().max(dim=0)[0]  # This finds the max in each dimension independently

    # Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1
    # We avoid division by zero by ensuring scale values are at least a very small number
    scale = torch.where(scale == 0, torch.ones_like(scale), scale)  # Prevent division by zero
    return verts / scale

def plot_3d_results(verts, faces, uv_seam_edge_indices):
    # Convert vertices to NumPy for easier manipulation
    verts_np = verts.cpu().numpy()
    faces_np = faces.cpu().numpy()

    # Prepare the vertex coordinates for the Mesh3d plot
    x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
    i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2]

    # Create the 3D mesh plot
    mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh')

    # Prepare lines for the predicted edges
    edge_x, edge_y, edge_z = [], [], []
    for edge in uv_seam_edge_indices:
        x0, y0, z0 = verts_np[edge[0]]
        x1, y1, z1 = verts_np[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    # Create a trace for edges
    edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2),
                               name='Predicted Edges')

    # Create a figure and add the mesh and edges
    fig = go.Figure(data=[mesh, edges_trace])
    fig.update_layout(scene=dict(
        xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"),
        yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"),
        zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True,
                   zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))),
        title_text='Predicted Edges')

    # return the figure
    return fig


def generate_prediction(file_input, treshold_value=0.5):
    if not file_input:
        return

    # Load and triangulate the mesh
    mesh = trimesh.load_mesh(file_input)


    # For production, we should use a faster method to preprocess the mesh!

    # Convert vertices to a PyTorch tensor
    vertices = torch.tensor(mesh.vertices, dtype=torch.float32)

    vertices = normalize_vertices(vertices)

    # Initialize containers for unique vertices and mapping
    unique_vertices = []
    vertex_mapping = {}
    new_faces = []

    # Populate unique vertices and create new faces with updated indices
    for face in mesh.faces:
        new_face = []
        for orig_index in face:
            vertex = tuple(vertices[orig_index].tolist())  # Convert to tuple (hashable)
            if vertex not in vertex_mapping:
                vertex_mapping[vertex] = len(unique_vertices)
                unique_vertices.append(vertices[orig_index])
            new_face.append(vertex_mapping[vertex])
        new_faces.append(new_face)

    # Create edge set to ensure uniqueness
    edge_set = set()
    for face in new_faces:
        # Unpack the vertex indices
        v1, v2, v3 = face
        # Create undirected edges (use tuple sorting to ensure uniqueness)
        edge_set.add(tuple(sorted((v1, v2))))
        edge_set.add(tuple(sorted((v2, v3))))
        edge_set.add(tuple(sorted((v1, v3))))

    # Convert edges back to tensor
    edges = torch.tensor(list(edge_set), dtype=torch.long)

    # Convert unique vertices and new faces back to tensors
    verts = torch.stack(unique_vertices)
    faces = torch.tensor(new_faces, dtype=torch.long)

    model.eval()

    with torch.no_grad():
        test_outputs_logits = model(verts, edges).to(device)
        test_outputs = torch.sigmoid(test_outputs_logits).to(device)
        test_predictions = (test_outputs > treshold_value).int().cpu()

    uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
    uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist()

    # Return the HTML content generated by plot_3d_results
    return plot_3d_results(verts, faces, uv_seam_edges)


def run_gradio():
    with gr.Blocks() as demo:
        gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.")

        with gr.Row():

            model3d_input = gr.FileExplorer(label="Sphere Prototype Model",
                                            file_count='single',
                                            value='randomSphere_180.obj',
                                            glob='**/*.obj')
            with gr.Column():
                model3d_output = gr.Plot()
                treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold")

        button = gr.Button("Predict")
        button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output)

    demo.launch()


run_gradio()