File size: 6,039 Bytes
67dd542
110ce35
67dd542
 
11be554
67dd542
110ce35
c9cdb74
67dd542
 
 
 
 
 
c9cdb74
67dd542
 
 
 
 
c9cdb74
 
67dd542
c9cdb74
f1b3987
c9cdb74
67dd542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1b3987
c9cdb74
 
11be554
67dd542
 
 
 
110ce35
 
67dd542
 
110ce35
67dd542
 
c9cdb74
110ce35
 
 
 
 
67dd542
110ce35
67dd542
110ce35
67dd542
 
 
 
 
 
 
 
110ce35
 
 
 
 
 
 
 
 
 
cee4b22
67dd542
 
 
 
 
 
 
 
110ce35
67dd542
 
110ce35
67dd542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9cdb74
67dd542
 
 
 
110ce35
67dd542
 
c9cdb74
11be554
c9cdb74
67dd542
 
 
 
 
 
 
11be554
67dd542
 
11be554
 
 
 
67dd542
11be554
 
c9cdb74
11be554
 
c9cdb74
11be554
67dd542
 
 
 
11be554
 
c9cdb74
67dd542
 
 
 
 
110ce35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67dd542
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List
from ctransformers import AutoModelForCausalLM
import time
import logging
from app.config import MODEL_PATH, MODEL_URL

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="Poetry Generator API",
    description="An API for generating poetry using a local LLM",
    version="1.0.0"
)

# Global model variable
model = None

class PoetryRequest(BaseModel):
    prompt: str = Field(..., description="The topic or theme for the poem", min_length=1)
    style: str = Field(
        default="free verse",
        description="Style of the poem to generate"
    )
    max_length: int = Field(
        default=200,
        description="Maximum length of the generated poem",
        ge=50,
        le=500
    )
    temperature: float = Field(
        default=0.7,
        description="Temperature for text generation",
        ge=0.1,
        le=2.0
    )

class PoetryResponse(BaseModel):
    poem: str
    generation_time: float
    prompt: str
    style: str

class ModelInfo(BaseModel):
    model_config = ConfigDict(protected_namespaces=())
    
    status: str
    model_path: str
    model_name: str
    supported_styles: List[str]
    max_context_length: int

def initialize_model():
    """Initialize the model and return it"""
    if not MODEL_PATH.exists():
        logger.error(f"Model not found at {MODEL_PATH}")
        return None
        
    try:
        logger.info(f"Loading model from {MODEL_PATH}")
        return AutoModelForCausalLM.from_pretrained(
            str(MODEL_PATH.parent),
            model_file=MODEL_PATH.name,
            model_type="llama",
            max_new_tokens=512,
            context_length=512,
            gpu_layers=0  # CPU only
        )
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return None

@app.on_event("startup")
async def startup_event():
    """Initialize the model during startup"""
    global model
    model = initialize_model()
    if model is None:
        logger.warning("Model failed to load but service will start anyway")

@app.get(
    "/health",
    response_model=ModelInfo,
    status_code=status.HTTP_200_OK,
    tags=["Health Check"]
)
async def health_check():
    """Check if the model is loaded and get basic information"""
    model_status = "ready" if model is not None else "not_loaded"
    
    return ModelInfo(
        status=model_status,
        model_name="Llama-2-7B-Chat",
        model_path=str(MODEL_PATH),
        supported_styles=[
            "free verse",
            "haiku",
            "sonnet",
            "limerick",
            "tanka"
        ],
        max_context_length=512
    )

@app.post(
    "/generate",
    response_model=PoetryResponse,
    status_code=status.HTTP_200_OK,
    tags=["Generation"]
)
async def generate_poem(request: PoetryRequest):
    """Generate a poem based on the provided prompt and parameters"""
    if model is None:
        raise HTTPException(
            status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            detail="Model not loaded. Please check /health endpoint for status."
        )

    try:
        start_time = time.time()
        
        prompt_templates = {
            "haiku": "Write a haiku about {prompt}. Follow the 5-7-5 syllable pattern:\n\n",
            "sonnet": "Write a Shakespearean sonnet about {prompt}. Follow the traditional 14-line format with rhyme scheme ABAB CDCD EFEF GG:\n\n",
            "limerick": "Write a limerick about {prompt}. Follow the AABBA rhyme scheme:\n\n",
            "free verse": "Write a free verse poem about {prompt}. Make it creative and meaningful:\n\n",
            "tanka": "Write a tanka about {prompt}. Follow the 5-7-5-7-7 syllable pattern:\n\n"
        }
        
        template = prompt_templates.get(request.style.lower(), prompt_templates["free verse"])
        full_prompt = template.format(prompt=request.prompt)

        output = model(
            full_prompt,
            max_new_tokens=request.max_length,
            temperature=request.temperature,
            top_p=0.95,
            repeat_penalty=1.2
        )

        generation_time = time.time() - start_time
        
        return PoetryResponse(
            poem=output.strip(),
            generation_time=generation_time,
            prompt=request.prompt,
            style=request.style
        )

    except Exception as e:
        logger.error(f"Generation error: {str(e)}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to generate poem: {str(e)}"
        )
    
def download_model():
    """Download the model if it doesn't exist"""
    import requests
    from tqdm import tqdm
    
    if MODEL_PATH.exists():
        logger.info(f"Model already exists at {MODEL_PATH}")
        return
    
    logger.info(f"Downloading model to {MODEL_PATH}")
    try:
        response = requests.get(MODEL_URL, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        
        with open(MODEL_PATH, 'wb') as file, tqdm(
            desc="Downloading",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                pbar.update(size)
                
        logger.info("Model downloaded successfully")
    except Exception as e:
        logger.error(f"Error downloading model: {str(e)}")
        if MODEL_PATH.exists():
            MODEL_PATH.unlink()
        raise

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)