import os import torch import gradio as gr from facenet_pytorch import MTCNN from torchvision import transforms import numpy as np import onnxruntime as ort from PIL import Image from huggingface_hub import hf_hub_download class ArcaneGANProcessor: def __init__(self): self.hf_token = os.getenv('HF_TOKEN') if not self.hf_token: raise ValueError("HF_TOKEN not found in environment variables") print("HF_TOKEN found in environment variables") self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') if self.device.type == 'cpu': print("Warning: Using CPU, performance may be reduced.") self.mtcnn = MTCNN( image_size=256, margin=80, keep_all=True, device=self.device, post_process=True, select_largest=True ) self.img_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def download_model(self, version): """Download model from HuggingFace Hub.""" model_filename = f"ArcaneGAN{version}.onnx" try: model_path = hf_hub_download( repo_id="Arrcttacsrks/ArcaneGanOnnx", filename=model_filename, token=self.hf_token ) return model_path except Exception as e: raise RuntimeError(f"Failed to download model: {str(e)}") def process_image(self, image, version): """Process image through ArcaneGAN.""" if image is None: raise ValueError("Input image is None") if not isinstance(image, Image.Image): image = Image.fromarray(image) if image.mode != "RGB": image = image.convert("RGB") boxes, _ = self.mtcnn.detect(image) if boxes is None: raise ValueError("No face detected in the image") face = self.mtcnn(image) if face is None: raise ValueError("Failed to process face") face = face.unsqueeze(0) # Shape: [1, C, H, W] model_path = self.download_model(version) ort_session = ort.InferenceSession(model_path) # Ensure the input is of the correct shape ort_inputs = {ort_session.get_inputs()[0].name: face.numpy()} ort_output = ort_session.run(None, ort_inputs)[0] output = torch.from_numpy(ort_output) output = output.squeeze(0).permute(1, 2, 0) output = output.clamp(0, 1) * 255 # Use clamp instead of clip output = output.cpu().numpy().astype(np.uint8) return Image.fromarray(output) def create_interface(): """Create Gradio interface.""" processor = ArcaneGANProcessor() with gr.Blocks() as demo: gr.Markdown("# ArcaneGAN Converter") with gr.Row(): input_image = gr.Image(type="numpy", label="Input Image") output_image = gr.Image(type="pil", label="Output Image") version = gr.Radio( choices=["v0.4", "v0.3", "v0.2"], value="v0.4", label="Model Version" ) process_button = gr.Button("Convert") process_button.click( fn=processor.process_image, inputs=[input_image, version], outputs=output_image ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()