from fastapi import FastAPI, HTTPException, status from pydantic import BaseModel, ConfigDict, Field from typing import Optional, List from ctransformers import AutoModelForCausalLM import time import logging from app.config import MODEL_PATH, MODEL_URL # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Poetry Generator API", description="An API for generating poetry using a local LLM", version="1.0.0" ) # Global model variable model = None class PoetryRequest(BaseModel): prompt: str = Field(..., description="The topic or theme for the poem", min_length=1) style: str = Field( default="free verse", description="Style of the poem to generate" ) max_length: int = Field( default=200, description="Maximum length of the generated poem", ge=50, le=500 ) temperature: float = Field( default=0.7, description="Temperature for text generation", ge=0.1, le=2.0 ) class PoetryResponse(BaseModel): poem: str generation_time: float prompt: str style: str class ModelInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) status: str model_path: str model_name: str supported_styles: List[str] max_context_length: int def initialize_model(): """Initialize the model and return it""" if not MODEL_PATH.exists(): logger.error(f"Model not found at {MODEL_PATH}") return None try: logger.info(f"Loading model from {MODEL_PATH}") return AutoModelForCausalLM.from_pretrained( str(MODEL_PATH.parent), model_file=MODEL_PATH.name, model_type="llama", max_new_tokens=512, context_length=512, gpu_layers=0 # CPU only ) except Exception as e: logger.error(f"Error loading model: {str(e)}") return None @app.on_event("startup") async def startup_event(): """Initialize the model during startup""" global model model = initialize_model() if model is None: logger.warning("Model failed to load but service will start anyway") @app.get( "/health", response_model=ModelInfo, status_code=status.HTTP_200_OK, tags=["Health Check"] ) async def health_check(): """Check if the model is loaded and get basic information""" model_status = "ready" if model is not None else "not_loaded" return ModelInfo( status=model_status, model_name="Llama-2-7B-Chat", model_path=str(MODEL_PATH), supported_styles=[ "free verse", "haiku", "sonnet", "limerick", "tanka" ], max_context_length=512 ) @app.post( "/generate", response_model=PoetryResponse, status_code=status.HTTP_200_OK, tags=["Generation"] ) async def generate_poem(request: PoetryRequest): """Generate a poem based on the provided prompt and parameters""" if model is None: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded. Please check /health endpoint for status." ) try: start_time = time.time() prompt_templates = { "haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n", "sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n", "limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n", "free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n", "tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n" } template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"]) full_prompt = template.format(prompt=request.prompt) output = model( full_prompt, max_new_tokens=request.max_length, temperature=request.temperature, top_p=0.95, repeat_penalty=1.2 ) generation_time = time.time() - start_time return PoetryResponse( poem=output.strip(), generation_time=generation_time, prompt=request.prompt, style=request.style ) except Exception as e: logger.error(f"Generation error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to generate poem: {str(e)}" ) def download_model(): """Download the model if it doesn't exist""" import requests from tqdm import tqdm if MODEL_PATH.exists(): logger.info(f"Model already exists at {MODEL_PATH}") return logger.info(f"Downloading model to {MODEL_PATH}") try: response = requests.get(MODEL_URL, stream=True) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) with open(MODEL_PATH, 'wb') as file, tqdm( desc="Downloading", total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(chunk_size=1024): size = file.write(data) pbar.update(size) logger.info("Model downloaded successfully") except Exception as e: logger.error(f"Error downloading model: {str(e)}") if MODEL_PATH.exists(): MODEL_PATH.unlink() raise if __name__ == "__main__": import uvicorn uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)