CraftsMan3D / server.py
wyysf's picture
fix bug
b1311b0
raw
history blame contribute delete
3.99 kB
import argparse
import base64
import os
from datetime import datetime
import traceback
import trimesh
import torch
from craftsman import CraftsManPipeline
CURRENT_DIR = f'/tmp/native3d_server/{os.getpid()}'
os.makedirs(CURRENT_DIR, exist_ok=True)
def parse_parameters():
parser = argparse.ArgumentParser("native3d")
parser.add_argument('--host', default="0.0.0.0", type=str)
parser.add_argument('--port', default=12345, type=int)
return parser.parse_args()
# -------------------- fastapi --------------------
from typing import Optional
from pydantic import BaseModel, Field
class Native3DRequestV1(BaseModel):
image_path: str # input image path
mesh_path: str # output mesh path, support glb or obj in clean dir
class Native3DResponseV1(BaseModel):
pass
class Native3DRequestV2(BaseModel):
image_bytes: str # input image bytes(base64)
mesh_type: str # output mesh type, support glb or obj
class Native3DResponseV2(BaseModel):
mesh_bytes: str # output mesh bytes(base64)
if __name__=="__main__":
parse_args = parse_parameters()
# prepare models
pipeline = CraftsManPipeline.from_pretrained("/home/super/Desktop/8TDisk/weiyu/CraftsMan_gradio/ckpts/craftsman-v1-5", device="cuda:0", torch_dtype=torch.float32)
# -------------------- fastapi --------------------
from fastapi import FastAPI, Request
import requests
app = FastAPI()
@app.post("/native3d_v1", response_model=Native3DResponseV1)
async def native3d(request: Request, image_to_mesh_request: Native3DRequestV1):
try:
print(f"image_to_mesh_request = {image_to_mesh_request}")
mesh = pipeline(image_to_mesh_request.image_path).meshes[0]
os.makedirs(os.path.dirname(os.path.abspath(image_to_mesh_request.mesh_path)), exist_ok=True)
mesh.export(image_to_mesh_request.mesh_path)
except Exception as e:
traceback.print_exc()
print(f"generate_model error: {e}")
return Native3DResponseV1()
@app.post("/native3d_v2", response_model=Native3DResponseV2)
async def native3d(request: Request, image_to_mesh_request: Native3DRequestV2):
try:
# print(f"image_to_mesh_request = {image_to_mesh_request}")
mesh_type = image_to_mesh_request.mesh_type
assert mesh_type in ['obj', 'glb']
task_id = datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f') + '-' + 'native3d'
current_dir = os.path.join(CURRENT_DIR, task_id)
os.makedirs(current_dir, exist_ok=True)
image_path = os.path.join(current_dir, 'input_image.png')
with open(image_path, 'wb') as f:
f.write(base64.b64decode(image_to_mesh_request.image_bytes))
mesh_path = os.path.join(current_dir, f'output_mesh.{mesh_type}')
import time
start = time.time()
# mesh = pipeline(image_path).meshes[0]
# mesh = pipeline(image_path, mc_depth=7, num_inference_steps=50).meshes[0]
mesh = pipeline(image_path).meshes[0]
print(f"Time: {time.time() - start}s")
os.makedirs(os.path.dirname(os.path.abspath(mesh_path)), exist_ok=True)
mesh.visual = trimesh.visual.TextureVisuals(
material=trimesh.visual.material.PBRMaterial(
baseColorFactor=(255, 255, 255), main_color=(255, 255, 255), metallicFactor=0.05, roughnessFactor=1.0
)
)
mesh.export(mesh_path)
with open(mesh_path, 'rb') as f:
mesh_bytes = f.read()
except Exception as e:
traceback.print_exc()
print(f"generate_model error: {e}")
return Native3DResponseV2(mesh_bytes=base64.b64encode(mesh_bytes).decode('utf-8'))
@app.get("/health")
async def health():
return {"status": "OK"}
import uvicorn
uvicorn.run(app, host=parse_args.host, port=parse_args.port)