Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import pipeline | |
import os | |
import uvicorn | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set cache directory to a writable location | |
cache_dir = "/tmp/hf_cache" | |
os.environ["HF_HOME"] = cache_dir | |
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
# Create the cache directory if it doesn't exist | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir, exist_ok=True) | |
app = FastAPI() | |
# Add CORS middleware to allow frontend requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Load the zero-shot classification model with explicit cache directory | |
logger.info("Loading the model...") | |
try: | |
classifier = pipeline( | |
"zero-shot-classification", | |
model="UBC-NLP/ARBERTv2", # Switch to a better Arabic model | |
tokenizer="UBC-NLP/ARBERTv2", | |
cache_dir=cache_dir | |
) | |
logger.info("Model loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
raise | |
async def index(): | |
logger.info("Serving index.html") | |
return FileResponse("static/index.html") | |
async def classify_text(data: dict): | |
logger.info(f"Received classify request with data: {data}") | |
try: | |
text = data.get("document") | |
labels = data.get("labels") | |
if not text or not labels: | |
logger.warning("Missing text or labels in request") | |
return {"error": "Please provide both text and labels"}, 400 | |
# Convert labels to list if it's a string | |
if isinstance(labels, str): | |
labels = [label.strip() for label in labels.split(",") if label.strip()] | |
logger.info(f"Classifying text: {text[:50]}... with labels: {labels}") | |
result = classifier(text, labels, multi_label=False) | |
logger.info(f"Classification result: {result}") | |
return {"labels": result["labels"], "scores": result["scores"]} | |
except Exception as e: | |
logger.error(f"Error during classification: {str(e)}") | |
return {"error": str(e)}, 500 | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |