Spaces:
Sleeping
Sleeping
File size: 5,437 Bytes
33803b5 5645b6a f05270d 33803b5 8aef8ae 33803b5 8aef8ae 33803b5 8aef8ae 33803b5 5645b6a 33803b5 f05270d 8aef8ae 5645b6a 8aef8ae 5645b6a 8aef8ae 5645b6a 33803b5 5645b6a 33803b5 5645b6a 33803b5 8aef8ae 33803b5 f05270d 33803b5 8aef8ae 33803b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import gradio as gr
import plotly.graph_objects as go
import trimesh
from pathlib import Path
device = torch.device("cpu")
model = torch.jit.load('model_scripted.pt').to(device)
def normalize_vertices(verts):
# Center the vertices
center = verts.mean(dim=0)
verts = verts - center
# Find the maximum absolute value for each axis to scale them independently
scale = verts.abs().max(dim=0)[0] # This finds the max in each dimension independently
# Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1
# We avoid division by zero by ensuring scale values are at least a very small number
scale = torch.where(scale == 0, torch.ones_like(scale), scale) # Prevent division by zero
return verts / scale
def plot_3d_results(verts, faces, uv_seam_edge_indices):
# Convert vertices to NumPy for easier manipulation
verts_np = verts.cpu().numpy()
faces_np = faces.cpu().numpy()
# Prepare the vertex coordinates for the Mesh3d plot
x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2]
# Create the 3D mesh plot
mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh')
# Prepare lines for the predicted edges
edge_x, edge_y, edge_z = [], [], []
for edge in uv_seam_edge_indices:
x0, y0, z0 = verts_np[edge[0]]
x1, y1, z1 = verts_np[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
# Create a trace for edges
edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2),
name='Predicted Edges')
# Create a figure and add the mesh and edges
fig = go.Figure(data=[mesh, edges_trace])
fig.update_layout(scene=dict(
xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True,
zerolinecolor="white"),
yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True,
zerolinecolor="white"),
zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True,
zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))),
title_text='Predicted Edges')
# return the figure
return fig
def generate_prediction(file_input, treshold_value=0.5):
if not file_input:
return
# Load and triangulate the mesh
mesh = trimesh.load_mesh(file_input)
# For production, we should use a faster method to preprocess the mesh!
# Convert vertices to a PyTorch tensor
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
vertices = normalize_vertices(vertices)
# Initialize containers for unique vertices and mapping
unique_vertices = []
vertex_mapping = {}
new_faces = []
# Populate unique vertices and create new faces with updated indices
for face in mesh.faces:
new_face = []
for orig_index in face:
vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable)
if vertex not in vertex_mapping:
vertex_mapping[vertex] = len(unique_vertices)
unique_vertices.append(vertices[orig_index])
new_face.append(vertex_mapping[vertex])
new_faces.append(new_face)
# Create edge set to ensure uniqueness
edge_set = set()
for face in new_faces:
# Unpack the vertex indices
v1, v2, v3 = face
# Create undirected edges (use tuple sorting to ensure uniqueness)
edge_set.add(tuple(sorted((v1, v2))))
edge_set.add(tuple(sorted((v2, v3))))
edge_set.add(tuple(sorted((v1, v3))))
# Convert edges back to tensor
edges = torch.tensor(list(edge_set), dtype=torch.long)
# Convert unique vertices and new faces back to tensors
verts = torch.stack(unique_vertices)
faces = torch.tensor(new_faces, dtype=torch.long)
model.eval()
with torch.no_grad():
test_outputs_logits = model(verts, edges).to(device)
test_outputs = torch.sigmoid(test_outputs_logits).to(device)
test_predictions = (test_outputs > treshold_value).int().cpu()
uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist()
# Return the HTML content generated by plot_3d_results
return plot_3d_results(verts, faces, uv_seam_edges)
def run_gradio():
with gr.Blocks() as demo:
gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.")
with gr.Row():
model3d_input = gr.FileExplorer(label="Sphere Prototype Model",
file_count='single',
value='randomSphere_180.obj',
glob='**/*.obj')
with gr.Column():
model3d_output = gr.Plot()
treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold")
button = gr.Button("Predict")
button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output)
demo.launch()
run_gradio()
|