Spaces:
Sleeping
Sleeping
File size: 6,039 Bytes
67dd542 110ce35 67dd542 11be554 67dd542 110ce35 c9cdb74 67dd542 c9cdb74 67dd542 c9cdb74 67dd542 c9cdb74 f1b3987 c9cdb74 67dd542 f1b3987 c9cdb74 11be554 67dd542 110ce35 67dd542 110ce35 67dd542 c9cdb74 110ce35 67dd542 110ce35 67dd542 110ce35 67dd542 110ce35 cee4b22 67dd542 110ce35 67dd542 110ce35 67dd542 c9cdb74 67dd542 110ce35 67dd542 c9cdb74 11be554 c9cdb74 67dd542 11be554 67dd542 11be554 67dd542 11be554 c9cdb74 11be554 c9cdb74 11be554 67dd542 11be554 c9cdb74 67dd542 110ce35 67dd542 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List
from ctransformers import AutoModelForCausalLM
import time
import logging
from app.config import MODEL_PATH, MODEL_URL
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Poetry Generator API",
description="An API for generating poetry using a local LLM",
version="1.0.0"
)
# Global model variable
model = None
class PoetryRequest(BaseModel):
prompt: str = Field(..., description="The topic or theme for the poem", min_length=1)
style: str = Field(
default="free verse",
description="Style of the poem to generate"
)
max_length: int = Field(
default=200,
description="Maximum length of the generated poem",
ge=50,
le=500
)
temperature: float = Field(
default=0.7,
description="Temperature for text generation",
ge=0.1,
le=2.0
)
class PoetryResponse(BaseModel):
poem: str
generation_time: float
prompt: str
style: str
class ModelInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
status: str
model_path: str
model_name: str
supported_styles: List[str]
max_context_length: int
def initialize_model():
"""Initialize the model and return it"""
if not MODEL_PATH.exists():
logger.error(f"Model not found at {MODEL_PATH}")
return None
try:
logger.info(f"Loading model from {MODEL_PATH}")
return AutoModelForCausalLM.from_pretrained(
str(MODEL_PATH.parent),
model_file=MODEL_PATH.name,
model_type="llama",
max_new_tokens=512,
context_length=512,
gpu_layers=0 # CPU only
)
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return None
@app.on_event("startup")
async def startup_event():
"""Initialize the model during startup"""
global model
model = initialize_model()
if model is None:
logger.warning("Model failed to load but service will start anyway")
@app.get(
"/health",
response_model=ModelInfo,
status_code=status.HTTP_200_OK,
tags=["Health Check"]
)
async def health_check():
"""Check if the model is loaded and get basic information"""
model_status = "ready" if model is not None else "not_loaded"
return ModelInfo(
status=model_status,
model_name="Llama-2-7B-Chat",
model_path=str(MODEL_PATH),
supported_styles=[
"free verse",
"haiku",
"sonnet",
"limerick",
"tanka"
],
max_context_length=512
)
@app.post(
"/generate",
response_model=PoetryResponse,
status_code=status.HTTP_200_OK,
tags=["Generation"]
)
async def generate_poem(request: PoetryRequest):
"""Generate a poem based on the provided prompt and parameters"""
if model is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model not loaded. Please check /health endpoint for status."
)
try:
start_time = time.time()
prompt_templates = {
"haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n",
"sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n",
"limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n",
"free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n",
"tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n"
}
template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"])
full_prompt = template.format(prompt=request.prompt)
output = model(
full_prompt,
max_new_tokens=request.max_length,
temperature=request.temperature,
top_p=0.95,
repeat_penalty=1.2
)
generation_time = time.time() - start_time
return PoetryResponse(
poem=output.strip(),
generation_time=generation_time,
prompt=request.prompt,
style=request.style
)
except Exception as e:
logger.error(f"Generation error: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate poem: {str(e)}"
)
def download_model():
"""Download the model if it doesn't exist"""
import requests
from tqdm import tqdm
if MODEL_PATH.exists():
logger.info(f"Model already exists at {MODEL_PATH}")
return
logger.info(f"Downloading model to {MODEL_PATH}")
try:
response = requests.get(MODEL_URL, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(MODEL_PATH, 'wb') as file, tqdm(
desc="Downloading",
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
pbar.update(size)
logger.info("Model downloaded successfully")
except Exception as e:
logger.error(f"Error downloading model: {str(e)}")
if MODEL_PATH.exists():
MODEL_PATH.unlink()
raise
if __name__ == "__main__":
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True) |