Spaces:
Sleeping
Sleeping
Update model configuration and enhance initialization logic; adjust BASE_DIR for container, implement model download functionality, and improve health check response
110ce35
from fastapi import FastAPI, HTTPException, status | |
from pydantic import BaseModel, ConfigDict, Field | |
from typing import Optional, List | |
from ctransformers import AutoModelForCausalLM | |
import time | |
import logging | |
from app.config import MODEL_PATH, MODEL_URL | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Poetry Generator API", | |
description="An API for generating poetry using a local LLM", | |
version="1.0.0" | |
) | |
# Global model variable | |
model = None | |
class PoetryRequest(BaseModel): | |
prompt: str = Field(..., description="The topic or theme for the poem", min_length=1) | |
style: str = Field( | |
default="free verse", | |
description="Style of the poem to generate" | |
) | |
max_length: int = Field( | |
default=200, | |
description="Maximum length of the generated poem", | |
ge=50, | |
le=500 | |
) | |
temperature: float = Field( | |
default=0.7, | |
description="Temperature for text generation", | |
ge=0.1, | |
le=2.0 | |
) | |
class PoetryResponse(BaseModel): | |
poem: str | |
generation_time: float | |
prompt: str | |
style: str | |
class ModelInfo(BaseModel): | |
model_config = ConfigDict(protected_namespaces=()) | |
status: str | |
model_path: str | |
model_name: str | |
supported_styles: List[str] | |
max_context_length: int | |
def initialize_model(): | |
"""Initialize the model and return it""" | |
if not MODEL_PATH.exists(): | |
logger.error(f"Model not found at {MODEL_PATH}") | |
return None | |
try: | |
logger.info(f"Loading model from {MODEL_PATH}") | |
return AutoModelForCausalLM.from_pretrained( | |
str(MODEL_PATH.parent), | |
model_file=MODEL_PATH.name, | |
model_type="llama", | |
max_new_tokens=512, | |
context_length=512, | |
gpu_layers=0 # CPU only | |
) | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return None | |
async def startup_event(): | |
"""Initialize the model during startup""" | |
global model | |
model = initialize_model() | |
if model is None: | |
logger.warning("Model failed to load but service will start anyway") | |
async def health_check(): | |
"""Check if the model is loaded and get basic information""" | |
model_status = "ready" if model is not None else "not_loaded" | |
return ModelInfo( | |
status=model_status, | |
model_name="Llama-2-7B-Chat", | |
model_path=str(MODEL_PATH), | |
supported_styles=[ | |
"free verse", | |
"haiku", | |
"sonnet", | |
"limerick", | |
"tanka" | |
], | |
max_context_length=512 | |
) | |
async def generate_poem(request: PoetryRequest): | |
"""Generate a poem based on the provided prompt and parameters""" | |
if model is None: | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Model not loaded. Please check /health endpoint for status." | |
) | |
try: | |
start_time = time.time() | |
prompt_templates = { | |
"haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n", | |
"sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n", | |
"limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n", | |
"free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n", | |
"tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n" | |
} | |
template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"]) | |
full_prompt = template.format(prompt=request.prompt) | |
output = model( | |
full_prompt, | |
max_new_tokens=request.max_length, | |
temperature=request.temperature, | |
top_p=0.95, | |
repeat_penalty=1.2 | |
) | |
generation_time = time.time() - start_time | |
return PoetryResponse( | |
poem=output.strip(), | |
generation_time=generation_time, | |
prompt=request.prompt, | |
style=request.style | |
) | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Failed to generate poem: {str(e)}" | |
) | |
def download_model(): | |
"""Download the model if it doesn't exist""" | |
import requests | |
from tqdm import tqdm | |
if MODEL_PATH.exists(): | |
logger.info(f"Model already exists at {MODEL_PATH}") | |
return | |
logger.info(f"Downloading model to {MODEL_PATH}") | |
try: | |
response = requests.get(MODEL_URL, stream=True) | |
response.raise_for_status() | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(MODEL_PATH, 'wb') as file, tqdm( | |
desc="Downloading", | |
total=total_size, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=1024, | |
) as pbar: | |
for data in response.iter_content(chunk_size=1024): | |
size = file.write(data) | |
pbar.update(size) | |
logger.info("Model downloaded successfully") | |
except Exception as e: | |
logger.error(f"Error downloading model: {str(e)}") | |
if MODEL_PATH.exists(): | |
MODEL_PATH.unlink() | |
raise | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) |