mabosaimi's picture
Create app.py
bb279b2 verified
from fastapi import FastAPI
from pydantic import BaseModel
import joblib
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
# Load all base models
model_lr_tfidf = joblib.load("model_lr_tfidf.pkl") # Logistic Regression (TF-IDF)
model_lr_bert = joblib.load("model_lr_bert.pkl") # Logistic Regression (Arabic BERT)
model_lgb = joblib.load("model_lgb.pkl") # LightGBM (Arabic BERT)
model_xgb = joblib.load("model_xgb.pkl") # XGBoost (Arabic BERT)
# Load the final meta-model (ensemble model)
meta_model = joblib.load("ensemble_lightgbm.pkl")
# Load TF-IDF Vectorizer
vectorizer = joblib.load("tfidf_vectorizer.pkl")
# βœ… Load Arabic BERT model
model_name = "aubmindlab/bert-base-arabertv02"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to("cpu")
# βœ… FastAPI App
app = FastAPI()
# βœ… Input Schema
class MessageInput(BaseModel):
message: str
sender_number: str
# βœ… Function to Get Embeddings
def get_text_embedding(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = bert_model(**inputs)
return outputs.last_hidden_state[:, 0, :].cpu().numpy()
# βœ… API Route to Predict Spam/Ham
@app.post("/predict/")
def predict_message(data: MessageInput):
message = data.message
sender_number = data.sender_number
# βœ… Extract Features
tfidf_features = vectorizer.transform([message]).toarray() # TF-IDF
bert_embedding = get_text_embedding(message).reshape(1, -1) # Arabic BERT
message_length = np.array([[len(message)]]) # Message Length
# βœ… Ensure Consistent Feature Size
tfidf_features = np.hstack((tfidf_features, message_length)) # TF-IDF (1001)
bert_embedding = np.hstack((bert_embedding, message_length)) # BERT (769)
# βœ… Get Predictions from Base Models
pred_lr_tfidf = model_lr_tfidf.predict(tfidf_features)[0] # TF-IDF model
pred_lr_bert = model_lr_bert.predict(bert_embedding)[0] # BERT model
pred_lgb = model_lgb.predict(bert_embedding)[0] # LightGBM
pred_xgb = model_xgb.predict(bert_embedding)[0] # XGBoost
# βœ… Stack Predictions for Meta Model
X_input = np.array([[pred_lr_tfidf, pred_lr_bert, pred_lgb, pred_xgb]])
# βœ… Final Prediction
prediction = meta_model.predict(X_input)[0]
confidence = meta_model.predict_proba(X_input)[0, 1] # Probability of Spam
return {
"prediction": "Spam" if prediction == 1 else "Ham",
"confidence": round(confidence, 2)
}