justREE commited on
Commit
434afff
Β·
verified Β·
1 Parent(s): 0566324

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import io
4
+ import wave
5
+ import re
6
+ import streamlit as st
7
+ from transformers import pipeline, SpeechT5Processor, SpeechT5HifiGan
8
+ from datasets import load_dataset
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch
12
+
13
+ # ─────────────────────────────────────────────────────────────
14
+ # 1) LOAD PIPELINES
15
+ # ─────────────────────────────────────────────────────────────
16
+ @st.cache_resource(show_spinner=False)
17
+ def load_captioner():
18
+ return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device="cpu")
19
+
20
+ @st.cache_resource(show_spinner=False)
21
+ def load_story_generator():
22
+ return pipeline("text-generation", model="microsoft/Phi-4-mini-reasoning", device="cpu")
23
+
24
+ @st.cache_resource(show_spinner=False)
25
+ def load_tts_pipe():
26
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
27
+ model = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
28
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
29
+ speaker_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
30
+ speaker_embedding = torch.tensor(speaker_dataset[7306]["xvector"]).unsqueeze(0)
31
+ return processor, model, vocoder, speaker_embedding
32
+
33
+ # ─────────────────────────────────────────────────────────────
34
+ # 2) PIPELINE FUNCTIONS
35
+ # ─────────────────────────────────────────────────────────────
36
+ def get_caption(image, captioner):
37
+ return captioner(image)[0]['generated_text']
38
+
39
+ def generate_story(caption, generator):
40
+ prompt = f"Write a short, magical story for children aged 3 to 10 based on this scene: {caption}. Keep it under 100 words."
41
+ outputs = generator(
42
+ prompt,
43
+ max_new_tokens=120,
44
+ temperature=0.8,
45
+ top_p=0.95,
46
+ do_sample=True
47
+ )
48
+ story = outputs[0]["generated_text"]
49
+ return clean_story_output(story, prompt)
50
+
51
+ def clean_story_output(story, prompt):
52
+ story = story[len(prompt):].strip() if story.startswith(prompt) else story
53
+ if "." in story:
54
+ story = story[: story.rfind(".") + 1]
55
+ return sentence_case(story)
56
+
57
+ def sentence_case(text):
58
+ parts = re.split(r'([.!?])', text)
59
+ out = []
60
+ for i in range(0, len(parts) - 1, 2):
61
+ sentence = parts[i].strip().capitalize()
62
+ out.append(f"{sentence}{parts[i + 1]}")
63
+ if len(parts) % 2:
64
+ last = parts[-1].strip().capitalize()
65
+ if last:
66
+ out.append(last)
67
+ return " ".join(out)
68
+
69
+ def convert_to_audio(text, processor, tts_pipe, vocoder, speaker_embedding):
70
+ inputs = processor(text=text, return_tensors="pt")
71
+ speech = tts_pipe.model.generate_speech(inputs["input_ids"], speaker_embedding, vocoder=vocoder)
72
+ pcm = (speech.numpy() * 32767).astype(np.int16)
73
+ buffer = io.BytesIO()
74
+ with wave.open(buffer, "wb") as wf:
75
+ wf.setnchannels(1)
76
+ wf.setsampwidth(2)
77
+ wf.setframerate(16000)
78
+ wf.writeframes(pcm.tobytes())
79
+ buffer.seek(0)
80
+ return buffer.read()
81
+
82
+ # ─────────────────────────────────────────────────────────────
83
+ # 3) STREAMLIT APP UI
84
+ # ─────────────────────────────────────────────────────────────
85
+ st.set_page_config(page_title="Magic Storyteller", layout="centered")
86
+ st.title("🧚 Magic Storyteller")
87
+ st.markdown("Upload an image to generate a magical story and hear it read aloud!")
88
+
89
+ uploaded = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
90
+ if uploaded:
91
+ image = Image.open(uploaded)
92
+ st.image(image, caption="Your uploaded image", use_column_width=True)
93
+
94
+ st.subheader("πŸ–ΌοΈ Step 1: Captioning")
95
+ captioner = load_captioner()
96
+ caption = get_caption(image, captioner)
97
+ st.markdown(f"**Caption:** {sentence_case(caption)}")
98
+
99
+ st.subheader("πŸ“– Step 2: Story Generation")
100
+ story_pipe = load_story_generator()
101
+ story = generate_story(caption, story_pipe)
102
+ st.write(story)
103
+
104
+ st.subheader("πŸ”Š Step 3: Listen to the Story")
105
+ processor, tts_pipe, vocoder, speaker_embedding = load_tts_pipe()
106
+ audio_bytes = convert_to_audio(story, processor, tts_pipe, vocoder, speaker_embedding)
107
+ st.audio(audio_bytes, format="audio/wav")
108
+ st.balloons()
109
+ else:
110
+ st.info("Please upload an image to begin.")