poetica / main.py
abhisheksan's picture
Update model configuration and enhance initialization logic; adjust BASE_DIR for container, implement model download functionality, and improve health check response
110ce35
raw
history blame
6.04 kB
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List
from ctransformers import AutoModelForCausalLM
import time
import logging
from app.config import MODEL_PATH, MODEL_URL
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Poetry Generator API",
description="An API for generating poetry using a local LLM",
version="1.0.0"
)
# Global model variable
model = None
class PoetryRequest(BaseModel):
prompt: str = Field(..., description="The topic or theme for the poem", min_length=1)
style: str = Field(
default="free verse",
description="Style of the poem to generate"
)
max_length: int = Field(
default=200,
description="Maximum length of the generated poem",
ge=50,
le=500
)
temperature: float = Field(
default=0.7,
description="Temperature for text generation",
ge=0.1,
le=2.0
)
class PoetryResponse(BaseModel):
poem: str
generation_time: float
prompt: str
style: str
class ModelInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
status: str
model_path: str
model_name: str
supported_styles: List[str]
max_context_length: int
def initialize_model():
"""Initialize the model and return it"""
if not MODEL_PATH.exists():
logger.error(f"Model not found at {MODEL_PATH}")
return None
try:
logger.info(f"Loading model from {MODEL_PATH}")
return AutoModelForCausalLM.from_pretrained(
str(MODEL_PATH.parent),
model_file=MODEL_PATH.name,
model_type="llama",
max_new_tokens=512,
context_length=512,
gpu_layers=0 # CPU only
)
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return None
@app.on_event("startup")
async def startup_event():
"""Initialize the model during startup"""
global model
model = initialize_model()
if model is None:
logger.warning("Model failed to load but service will start anyway")
@app.get(
"/health",
response_model=ModelInfo,
status_code=status.HTTP_200_OK,
tags=["Health Check"]
)
async def health_check():
"""Check if the model is loaded and get basic information"""
model_status = "ready" if model is not None else "not_loaded"
return ModelInfo(
status=model_status,
model_name="Llama-2-7B-Chat",
model_path=str(MODEL_PATH),
supported_styles=[
"free verse",
"haiku",
"sonnet",
"limerick",
"tanka"
],
max_context_length=512
)
@app.post(
"/generate",
response_model=PoetryResponse,
status_code=status.HTTP_200_OK,
tags=["Generation"]
)
async def generate_poem(request: PoetryRequest):
"""Generate a poem based on the provided prompt and parameters"""
if model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded. Please check /health endpoint for status."
)
try:
start_time = time.time()
prompt_templates = {
"haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n",
"sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n",
"limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n",
"free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n",
"tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n"
}
template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"])
full_prompt = template.format(prompt=request.prompt)
output = model(
full_prompt,
max_new_tokens=request.max_length,
temperature=request.temperature,
top_p=0.95,
repeat_penalty=1.2
)
generation_time = time.time() - start_time
return PoetryResponse(
poem=output.strip(),
generation_time=generation_time,
prompt=request.prompt,
style=request.style
)
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate poem: {str(e)}"
)
def download_model():
"""Download the model if it doesn't exist"""
import requests
from tqdm import tqdm
if MODEL_PATH.exists():
logger.info(f"Model already exists at {MODEL_PATH}")
return
logger.info(f"Downloading model to {MODEL_PATH}")
try:
response = requests.get(MODEL_URL, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(MODEL_PATH, 'wb') as file, tqdm(
desc="Downloading",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
pbar.update(size)
logger.info("Model downloaded successfully")
except Exception as e:
logger.error(f"Error downloading model: {str(e)}")
if MODEL_PATH.exists():
MODEL_PATH.unlink()
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)