Spaces:
Running
Running
File size: 3,877 Bytes
3dbb28b de084a6 3dbb28b 5c67dfe 7b31b4e 3dbb28b de084a6 3dbb28b de084a6 3dbb28b 5c67dfe 3dbb28b 5c67dfe 3dbb28b a9e8285 de084a6 a9e8285 3dbb28b 9633067 3dbb28b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import base64
import json
import os
from pathlib import Path
import gradio as gr
import huggingface_hub
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
get_stt_model,
get_twilio_turn_credentials,
)
from gradio.utils import get_space
from pydantic import BaseModel
load_dotenv()
curr_dir = Path(__file__).parent
client = huggingface_hub.InferenceClient(
api_key=os.environ.get("SAMBANOVA_API_KEY"),
provider="sambanova",
)
stt_model = get_stt_model()
def response(
audio: tuple[int, np.ndarray],
gradio_chatbot: list[dict] | None = None,
conversation_state: list[dict] | None = None,
):
gradio_chatbot = gradio_chatbot or []
conversation_state = conversation_state or []
text = stt_model.stt(audio)
sample_rate, array = audio
gradio_chatbot.append(
{"role": "user", "content": gr.Audio((sample_rate, array.squeeze()))}
)
yield AdditionalOutputs(gradio_chatbot, conversation_state)
conversation_state.append({"role": "user", "content": text})
request = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=conversation_state, # type: ignore
temperature=0.1,
top_p=0.1,
)
response = {"role": "assistant", "content": request.choices[0].message.content}
conversation_state.append(response)
gradio_chatbot.append(response)
yield AdditionalOutputs(gradio_chatbot, conversation_state)
chatbot = gr.Chatbot(type="messages", value=[])
state = gr.State(value=[])
stream = Stream(
ReplyOnPause(
response, # type: ignore
input_sample_rate=16000,
),
mode="send",
modality="audio",
additional_inputs=[chatbot, state],
additional_outputs=[chatbot, state],
additional_outputs_handler=lambda *a: (a[2], a[3]),
concurrency_limit=20 if get_space() else None,
rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
)
app = FastAPI()
stream.mount(app)
class Message(BaseModel):
role: str
content: str
class InputData(BaseModel):
webrtc_id: str
chatbot: list[Message]
state: list[Message]
@app.get("/")
async def _():
rtc_config = get_twilio_turn_credentials() if get_space() else None
html_content = (curr_dir / "index.html").read_text()
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
return HTMLResponse(content=html_content)
@app.post("/input_hook")
async def _(data: InputData):
body = data.model_dump()
stream.set_input(data.webrtc_id, body["chatbot"], body["state"])
def audio_to_base64(file_path):
audio_format = "wav"
with open(file_path, "rb") as audio_file:
encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
return f"data:audio/{audio_format};base64,{encoded_audio}"
@app.get("/outputs")
async def _(webrtc_id: str):
async def output_stream():
async for output in stream.output_stream(webrtc_id):
chatbot = output.args[0]
state = output.args[1]
data = {
"message": state[-1],
"audio": audio_to_base64(chatbot[-1]["content"].value["path"])
if chatbot[-1]["role"] == "user"
else None,
}
yield f"event: output\ndata: {json.dumps(data)}\n\n"
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import os
if (mode := os.getenv("MODE")) == "UI":
stream.ui.launch(server_port=7860)
elif mode == "PHONE":
raise ValueError("Phone mode not supported")
else:
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|