masterblaster22 commited on
Commit
5645b6a
·
verified ·
1 Parent(s): ac37724

working proto

Browse files
Files changed (4) hide show
  1. app.py +46 -13
  2. model_scripted.pt +1 -1
  3. requirements.txt +2 -3
  4. sphere.obj +0 -0
app.py CHANGED
@@ -1,8 +1,7 @@
1
- from pytorch3d.io import load_obj
2
- from pytorch3d.structures import Meshes
3
  import torch
4
  import gradio as gr
5
  import plotly.graph_objects as go
 
6
 
7
  device = torch.device("cpu")
8
  model = torch.jit.load('model_scripted.pt').to(device)
@@ -14,11 +13,10 @@ def normalize_vertices(verts):
14
  scale = max(verts.abs().max(0)[0])
15
  return verts / scale
16
 
17
-
18
  def plot_3d_results(verts, faces, uv_seam_edge_indices):
19
  # Convert vertices to NumPy for easier manipulation
20
  verts_np = verts.cpu().numpy()
21
- faces_np = faces.verts_idx.cpu().numpy()
22
 
23
  # Prepare the vertex coordinates for the Mesh3d plot
24
  x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
@@ -56,21 +54,56 @@ def plot_3d_results(verts, faces, uv_seam_edge_indices):
56
 
57
 
58
  def generate_prediction(file_input, treshold_value=0.5):
59
- verts, faces, aux = load_obj(file_input)
60
- verts = normalize_vertices(verts)
61
- mesh = Meshes(verts=[verts.to(device)], faces=[faces.verts_idx.to(device)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  model.eval()
64
 
65
- test_verts = mesh.verts_packed().to(device)
66
- test_edges = mesh.edges_packed().to(device)
67
-
68
  with torch.no_grad():
69
- test_outputs_logits = model(test_verts, test_edges).to(device)
70
  test_outputs = torch.sigmoid(test_outputs_logits).to(device)
71
  test_predictions = (test_outputs > treshold_value).int().cpu()
72
- uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
73
- uv_seam_edges = test_edges[uv_seam_edges_mask].cpu().tolist()
 
74
 
75
  # Return the HTML content generated by plot_3d_results
76
  return plot_3d_results(verts, faces, uv_seam_edges)
 
 
 
1
  import torch
2
  import gradio as gr
3
  import plotly.graph_objects as go
4
+ import trimesh
5
 
6
  device = torch.device("cpu")
7
  model = torch.jit.load('model_scripted.pt').to(device)
 
13
  scale = max(verts.abs().max(0)[0])
14
  return verts / scale
15
 
 
16
  def plot_3d_results(verts, faces, uv_seam_edge_indices):
17
  # Convert vertices to NumPy for easier manipulation
18
  verts_np = verts.cpu().numpy()
19
+ faces_np = faces.cpu().numpy()
20
 
21
  # Prepare the vertex coordinates for the Mesh3d plot
22
  x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
 
54
 
55
 
56
  def generate_prediction(file_input, treshold_value=0.5):
57
+ # Load the triangle mesh
58
+ mesh = trimesh.load_mesh(file_input)
59
+
60
+ # For production, we should use a faster method to preprocess the mesh!
61
+
62
+ # Convert vertices to a PyTorch tensor
63
+ vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
64
+
65
+ # Initialize containers for unique vertices and mapping
66
+ unique_vertices = []
67
+ vertex_mapping = {}
68
+ new_faces = []
69
+
70
+ # Populate unique vertices and create new faces with updated indices
71
+ for face in mesh.faces:
72
+ new_face = []
73
+ for orig_index in face:
74
+ vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable)
75
+ if vertex not in vertex_mapping:
76
+ vertex_mapping[vertex] = len(unique_vertices)
77
+ unique_vertices.append(vertices[orig_index])
78
+ new_face.append(vertex_mapping[vertex])
79
+ new_faces.append(new_face)
80
+
81
+ # Create edge set to ensure uniqueness
82
+ edge_set = set()
83
+ for face in new_faces:
84
+ # Unpack the vertex indices
85
+ v1, v2, v3 = face
86
+ # Create undirected edges (use tuple sorting to ensure uniqueness)
87
+ edge_set.add(tuple(sorted((v1, v2))))
88
+ edge_set.add(tuple(sorted((v2, v3))))
89
+ edge_set.add(tuple(sorted((v1, v3))))
90
+
91
+ # Convert edges back to tensor
92
+ edges = torch.tensor(list(edge_set), dtype=torch.long)
93
+
94
+ # Convert unique vertices and new faces back to tensors
95
+ verts = torch.stack(unique_vertices)
96
+ faces = torch.tensor(new_faces, dtype=torch.long)
97
 
98
  model.eval()
99
 
 
 
 
100
  with torch.no_grad():
101
+ test_outputs_logits = model(verts, edges).to(device)
102
  test_outputs = torch.sigmoid(test_outputs_logits).to(device)
103
  test_predictions = (test_outputs > treshold_value).int().cpu()
104
+
105
+ uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
106
+ uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist()
107
 
108
  # Return the HTML content generated by plot_3d_results
109
  return plot_3d_results(verts, faces, uv_seam_edges)
model_scripted.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5475053a74a596b223abc768bf8d926247bf94f826c1ed41b469f593a11828c8
3
  size 255324
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f0d5dcd806540ba8061e3d70ed17eac539e560ae0932f04679f706898126588
3
  size 255324
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- git+https://github.com/facebookresearch/pytorch3d.git
2
  torch
3
- pytorch3d
4
- plotly
 
 
1
  torch
2
+ plotly
3
+ trimesh
sphere.obj CHANGED
The diff for this file is too large to render. See raw diff