SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
785e1a7 verified
raw
history blame contribute delete
4.75 kB
import os, re, time, datetime, threading, traceback, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers.utils import logging as hf_logging
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(m):
line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6
MAX_IN, RATE_N, RATE_T = 300, 5, 60
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the friendly digital mascot of "
"SchoolSpirit AI LLC, founded by Charles Norton in 2025. "
"The company installs on‑prem AI chat mascots, fine‑tunes language models, "
"and ships turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Reply in ≤ 4 sentences unless asked for detail.\n"
"• No personal‑data collection; no medical/legal/financial advice.\n"
"• If uncertain, say so and suggest contacting a human.\n"
"• If you can’t answer, politely direct the user to [email protected].\n"
"• Keep language age‑appropriate; avoid profanity, politics, mature themes."
)
WELCOME = "Hi there! I’m SchoolSpirit AI. How can I help?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
hf_logging.set_verbosity_error()
try:
tok = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
low_cpu_mem_usage=True,
)
MODEL_ERR = None
log("Model loaded")
except Exception as e:
MODEL_ERR = f"Model load error: {e}"
log(MODEL_ERR + "\n" + traceback.format_exc())
VISITS = {}
def allowed(ip):
now = time.time()
VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T]
if len(VISITS[ip]) >= RATE_N:
return False
VISITS[ip].append(now)
return True
def build_prompt(raw):
def render(m):
if m["role"] == "system":
return m["content"]
return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}"
sys, convo = raw[0], raw[1:]
while True:
parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"]
if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2:
return "\n".join(parts)
convo = convo[2:]
def chat_fn(user_msg, hist, state, request: gr.Request):
ip = request.client.host if request else "anon"
if not allowed(ip):
hist.append((user_msg, "Rate limit exceeded — please wait a minute."))
return hist, state, ""
user_msg = strip(user_msg or "")
if not user_msg:
return hist, state, ""
if len(user_msg) > MAX_IN:
hist.append((user_msg, f"Input >{MAX_IN} chars."))
return hist, state, ""
if MODEL_ERR:
hist.append((user_msg, MODEL_ERR))
return hist, state, ""
hist.append((user_msg, ""))
state["raw"].append({"role": "user", "content": user_msg})
prompt = build_prompt(state["raw"])
ids = tok(prompt, return_tensors="pt").to(model.device).input_ids
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
threading.Thread(
target=model.generate,
kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer),
).start()
partial = ""
for piece in streamer:
partial += piece
if "User:" in partial or "\nAI:" in partial:
partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
break
hist[-1] = (user_msg, partial)
yield hist, state, ""
reply = strip(partial)
hist[-1] = (user_msg, reply)
state["raw"].append({"role": "assistant", "content": reply})
yield hist, state, ""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
bot = gr.Chatbot(value=[("", WELCOME)], height=480)
st = gr.State({
"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME},
]
})
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
send = gr.Button("Send", variant="primary")
send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
demo.launch()