hskwon7 commited on
Commit
3c49d37
·
verified ·
1 Parent(s): 2a4516f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import pipeline
3
  from PIL import Image
4
  import io
5
  from gtts import gTTS
 
6
 
7
  st.title("🖼️ → 📖 Image-to-Story Demo")
8
  st.write("Upload an image and watch as it’s captioned, turned into a short story, and even read aloud!")
@@ -18,19 +19,19 @@ def load_story_gen():
18
  captioner = load_captioner()
19
  story_gen = load_story_gen()
20
 
21
- # 1) Upload
22
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
23
  if uploaded:
24
  img = Image.open(uploaded)
25
  st.image(img, use_column_width=True)
26
 
27
- # 2) Caption
28
  if "caption" not in st.session_state:
29
  with st.spinner("Generating caption…"):
30
- st.session_state.caption = captioner(img)[0]["generated_text"]
 
31
  st.write("**Caption:**", st.session_state.caption)
32
 
33
- # 3) Story
34
  if "story" not in st.session_state:
35
  with st.spinner("Spinning up a story…"):
36
  out = story_gen(
@@ -43,7 +44,7 @@ if uploaded:
43
  st.session_state.story = out[0]["generated_text"]
44
  st.write("**Story:**", st.session_state.story)
45
 
46
- # 4) Pre-generate raw MP3 bytes
47
  if "audio_bytes" not in st.session_state:
48
  with st.spinner("Generating audio…"):
49
  tts = gTTS(text=st.session_state.story, lang="en")
@@ -51,7 +52,13 @@ if uploaded:
51
  tts.write_to_fp(buf)
52
  st.session_state.audio_bytes = buf.getvalue()
53
 
54
- # 5) Play on demand
55
  if st.button("🔊 Play Story Audio"):
56
- audio_buffer = io.BytesIO(st.session_state.audio_bytes)
57
- st.audio(audio_buffer, format="audio/mp3")
 
 
 
 
 
 
 
3
  from PIL import Image
4
  import io
5
  from gtts import gTTS
6
+ import tempfile
7
 
8
  st.title("🖼️ → 📖 Image-to-Story Demo")
9
  st.write("Upload an image and watch as it’s captioned, turned into a short story, and even read aloud!")
 
19
  captioner = load_captioner()
20
  story_gen = load_story_gen()
21
 
 
22
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
23
  if uploaded:
24
  img = Image.open(uploaded)
25
  st.image(img, use_column_width=True)
26
 
27
+ # Caption
28
  if "caption" not in st.session_state:
29
  with st.spinner("Generating caption…"):
30
+ caps = captioner(img)
31
+ st.session_state.caption = caps[0] if isinstance(caps, list) else caps
32
  st.write("**Caption:**", st.session_state.caption)
33
 
34
+ # Story
35
  if "story" not in st.session_state:
36
  with st.spinner("Spinning up a story…"):
37
  out = story_gen(
 
44
  st.session_state.story = out[0]["generated_text"]
45
  st.write("**Story:**", st.session_state.story)
46
 
47
+ # Prepare audio bytes once
48
  if "audio_bytes" not in st.session_state:
49
  with st.spinner("Generating audio…"):
50
  tts = gTTS(text=st.session_state.story, lang="en")
 
52
  tts.write_to_fp(buf)
53
  st.session_state.audio_bytes = buf.getvalue()
54
 
55
+ # Play button
56
  if st.button("🔊 Play Story Audio"):
57
+ # Write to a temp file
58
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
59
+ tmp.write(st.session_state.audio_bytes)
60
+ tmp.flush()
61
+ tmp_path = tmp.name
62
+ tmp.close()
63
+ # Stream it
64
+ st.audio(tmp_path, format="audio/mp3")