ArcaneGanX / app.py
Arrcttacsrks's picture
Update app.py
9a0bf56 verified
import os
import torch
import gradio as gr
from facenet_pytorch import MTCNN
from torchvision import transforms
import numpy as np
import onnxruntime as ort
from PIL import Image
from huggingface_hub import hf_hub_download
class ArcaneGANProcessor:
def __init__(self):
self.hf_token = os.getenv('HF_TOKEN')
if not self.hf_token:
raise ValueError("HF_TOKEN not found in environment variables")
print("HF_TOKEN found in environment variables")
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if self.device.type == 'cpu':
print("Warning: Using CPU, performance may be reduced.")
self.mtcnn = MTCNN(
image_size=256,
margin=80,
keep_all=True,
device=self.device,
post_process=True,
select_largest=True
)
self.img_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def download_model(self, version):
"""Download model from HuggingFace Hub."""
model_filename = f"ArcaneGAN{version}.onnx"
try:
model_path = hf_hub_download(
repo_id="Arrcttacsrks/ArcaneGanOnnx",
filename=model_filename,
token=self.hf_token
)
return model_path
except Exception as e:
raise RuntimeError(f"Failed to download model: {str(e)}")
def process_image(self, image, version):
"""Process image through ArcaneGAN."""
if image is None:
raise ValueError("Input image is None")
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
if image.mode != "RGB":
image = image.convert("RGB")
boxes, _ = self.mtcnn.detect(image)
if boxes is None:
raise ValueError("No face detected in the image")
face = self.mtcnn(image)
if face is None:
raise ValueError("Failed to process face")
face = face.unsqueeze(0) # Shape: [1, C, H, W]
model_path = self.download_model(version)
ort_session = ort.InferenceSession(model_path)
# Ensure the input is of the correct shape
ort_inputs = {ort_session.get_inputs()[0].name: face.numpy()}
ort_output = ort_session.run(None, ort_inputs)[0]
output = torch.from_numpy(ort_output)
output = output.squeeze(0).permute(1, 2, 0)
output = output.clamp(0, 1) * 255 # Use clamp instead of clip
output = output.cpu().numpy().astype(np.uint8)
return Image.fromarray(output)
def create_interface():
"""Create Gradio interface."""
processor = ArcaneGANProcessor()
with gr.Blocks() as demo:
gr.Markdown("# ArcaneGAN Converter")
with gr.Row():
input_image = gr.Image(type="numpy", label="Input Image")
output_image = gr.Image(type="pil", label="Output Image")
version = gr.Radio(
choices=["v0.4", "v0.3", "v0.2"],
value="v0.4",
label="Model Version"
)
process_button = gr.Button("Convert")
process_button.click(
fn=processor.process_image,
inputs=[input_image, version],
outputs=output_image
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()