textclarity / app.py
ganna217's picture
update
589b738
raw
history blame contribute delete
2.59 kB
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
import os
import uvicorn
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set cache directory to a writable location
cache_dir = "/tmp/hf_cache"
os.environ["HF_HOME"] = cache_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
os.environ["TRANSFORMERS_CACHE"] = cache_dir
# Create the cache directory if it doesn't exist
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Load the zero-shot classification model with explicit cache directory
logger.info("Loading the model...")
try:
classifier = pipeline(
"zero-shot-classification",
model="UBC-NLP/ARBERTv2", # Switch to a better Arabic model
tokenizer="UBC-NLP/ARBERTv2",
cache_dir=cache_dir
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
@app.get("/")
async def index():
logger.info("Serving index.html")
return FileResponse("static/index.html")
@app.post("/classify")
async def classify_text(data: dict):
logger.info(f"Received classify request with data: {data}")
try:
text = data.get("document")
labels = data.get("labels")
if not text or not labels:
logger.warning("Missing text or labels in request")
return {"error": "Please provide both text and labels"}, 400
# Convert labels to list if it's a string
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",") if label.strip()]
logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
result = classifier(text, labels, multi_label=False)
logger.info(f"Classification result: {result}")
return {"labels": result["labels"], "scores": result["scores"]}
except Exception as e:
logger.error(f"Error during classification: {str(e)}")
return {"error": str(e)}, 500
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)