conspectum / ui_transcribe.py
macsunmood's picture
update app
6edd739
raw
history blame contribute delete
14.2 kB
import streamlit as st
from streamlit_extras.stylable_container import stylable_container
import os
import time
import pathlib
from datetime import timedelta
import requests
os.environ['STREAMLIT_SERVER_ENABLE_FILE_WATCHER'] = 'false'
import whisper # openai-whisper
import torch # check for GPU availability
# from models.loader import load_model_sst
from transcriber import Transcription
import matplotlib.colors as mcolors
######
# import gdown
# import tempfile
from utils import load_config, get_secret_api
st.session_state.secret_api = get_secret_api()
# # create & close the temp file so it's not locked
# tmp = tempfile.NamedTemporaryFile(delete=False)
# tmp_path = tmp.name
# tmp.close()
# gdown.download(id=load_config()['links']['secret_api_id'], output=tmp_path, quiet=True)
# tmp.seek(0)
# st.session_state.secret_api = tmp.read()#.decode('utf-8')
# os.remove(tmp_path)
# with tempfile.NamedTemporaryFile(delete=False) as tmp:
# gdown.download(id=load_config()['links']['secret_api_id'], output=tmp.name, quiet=True)
# tmp.seek(0)
# st.session_state.secret_api = tmp.read().decode('utf-8')
# tmp_path = tmp.name
# tmp.close()
# os.remove(tmp_path)
######
trash_str = 'Субтитры создавал DimaTorzok'
st.title('🎙️ Step 2: Speech-to-Text (ASR/STT)')
# Check if audio path exists from previous step
if 'audio_path' not in st.session_state or not st.session_state['audio_path'] or not os.path.exists(st.session_state['audio_path']):
st.warning('Audio file not found. Please go back to the "**📤 Upload**" page and process a video first.')
st.stop()
# st.write(f'Audio file to process: `{os.path.basename(audio_path)}`')
st.write(f'Processing audio `{st.session_state.video_input_title}` from video input')
if 'start_time' not in st.session_state:
st.session_state.start_time = 0
# st.audio(audio_path)
# format='audio/wav',
st.audio(st.session_state.audio_path, start_time=st.session_state.start_time)
#
# ==================================================================
#
col_model, col_config = st.columns(2)
# --- Model ---
# with col_model.expander('**MODEL**', expanded=True):
with col_model.container(border=True):
model_option = st.selectbox(
'SST Model:',
['whisper', 'faster-whisper', 'distill-whisper', 'giga'],
index=0
)
# sst_model = load_model_sst(model_option)
# --- Configuration ---
with col_config.expander('**CONFIG**', expanded=True):
# Determine device
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = st.radio(
'Compute device:',
('cuda', 'cpu'),
index=0 if default_device == 'cuda' else 1,
horizontal=True,
disabled=not torch.cuda.is_available()
)
if device == 'cuda' and not torch.cuda.is_available():
st.warning('CUDA selected but not available, falling back to CPU')
device = 'cpu'
whisper_model_option = st.selectbox(
'Whisper model type:',
['tiny', 'base', 'small', 'medium', 'large-v3', 'turbo'],
index=5
)
pauses = st.checkbox('pauses', value=False)
# from models.models_sst import Whisper
# Whisper.config()
##
## --- Transcription ---
##
_, col_button_trancribe, _ = st.columns([2, 1, 2])
if col_button_trancribe.button('Transcribe', type='primary', use_container_width=True):
# if input_files:
# pass
# else:
# st.error("Please select a file")
st.session_state.transcript = None # clear previous transcript
col_info, col_complete, col_next = st.columns(3)
try:
with st.spinner(f'Loading Whisper `{whisper_model_option}` model and transcribing..'):
# #-- Load whisper model
# start = time.time()
# # Let Whisper handle device placement if possible
# model = whisper.load_model(whisper_model_option, device=device)
# # load_time =
# col_info.info(f'Model loaded in {time.time() - start:.2f} seconds.')
#-- Perform transcription
start = time.time()
# print('################################')
# print(st.session_state.audio_path)
# print('################################')
# with open(audio_path, "rb") as audio_file:
# transcript = openai.Audio.transcribe("whisper-1", audio_file)
# st.write(st.session_state.secret_api)
# response = requests.post(
# f'{st.session_state.secret_api}/post',
# f'https://535e-104-196-233-103.ngrok-free/transcribe',
# # params={'username': username, 'filename': uploaded_pdf.name},
# params={'filename': st.session_state.audio_path},
# # files={'uploaded_file': uploaded_pdf.getvalue()}
# # files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)}
# files={'file': 'string'}
# # json={'1': '2'}
# )
# st.write(response)
# # import sys
# # st.write(sys.sizeof(f))
# st.write(response.text)
with open(st.session_state.audio_path, 'rb') as f:
response = requests.post(
# f'{st.session_state.secret_api}/transcribe_faster_whisper',
f'{st.session_state.secret_api}/transcribe',
# params={'filename': st.session_state.audio_path},
# files={'uploaded_file': uploaded_pdf.getvalue()}
# files={'uploaded_file': whisper.load_audio(st.session_state.audio_path)}
# data={'model': whisper_model_option},
params={'model': whisper_model_option},
files={'file': f}
)
st.write(response)
response = response.json()
# st.write(response['inference_time'])
# st.write(response['model_name'])
# st.write(response['form'])
st.session_state['transcript'] = response['output']
# st.session_state['transcript'] = result['text']
st.session_state.transcript = Transcription(st.session_state.audio_path)
# # st.session_state.transcript = Transcription([audio_path])
# # st.session_state.transcript.transcribe(whisper_model_option)
# # st.markdown(model.name)
# st.session_state.transcript.transcribe(model)
# # result = model.transcribe(audio_path, fp16=(device == 'cuda')) # use fp16 on GPU for speed/memory
st.session_state.transcript.output = response['output']
transcribe_time = time.time() - start
# st.session_state['transcript'] = result['text']
# st.session_state['transcript'] = st.session_state.transcript
# Store segments for timestamping/structuring later
# print(len(st.session_state.transcript['segments']))
# st.session_state['transcript_segments'] = st.session_state.transcript['segments']
col_complete.success(f'Transcription complete! (Took {transcribe_time:.2f}s)')
col_next.page_link('ui_video.py', label='Next Step: **🖼️ Analyze Video**', icon='➡️')
except Exception as e:
st.error(f'An error occurred during transcription: {e}')
# Consider unloading model if error occurs to free memory
if 'model' in locals():
del model
if device == 'cuda':
torch.cuda.empty_cache()
if 'transcript' in st.session_state and st.session_state['transcript']:
# --- Video Player ---
with st.expander('**Video Player**', expanded=True):
col_video, col_segments = st.columns(2)
col_video.video(st.session_state.video_path, start_time=st.session_state.start_time)
# --- Display Transcript ---
prev_word_end = -1
text = ''
html_text = ''
# for idx, segment in st.session_state.transcript.output['segments']:
# if trash_str in segment['text'].strip():
# st.session_state.transcript.output['segments'][idx]
output = st.session_state.transcript.output
# doc = docx.Document()
avg_confidence_score = 0
amount_words = 0
save_dir = str(pathlib.Path(__file__).parent.absolute()) + '/transcripts/'
# st.write(output['segments'])
for idx, segment in enumerate(output['segments']):
# segment[idx] = segment.replace(trash_str, '')
for w in segment['words']:
amount_words += 1
avg_confidence_score += w['probability']
# Define the color map
colors = [(0.6, 0, 0), (1, 0.7, 0), (0, 0.6, 0)]
cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
with st.expander('**TRANSCRIPT**', expanded=True):
st.badge(
f'whisper model: **`{whisper_model_option}`** | ' +
f'language: **`{output["language"]}`** | ' +
f'confidence score: **`{round(avg_confidence_score / amount_words, 3)}`**'
)
color_coding = st.checkbox(
'color coding',
value=True,
# key={i},
help='Цветное кодирование слов в зависимости от вероятности правильного распознавания: от зелёного (хорошо) до красного (плохо)'
)
# https://docs.streamlit.io/develop/api-reference/layout/st.container
with st.container(height=300, border=False):
for idx, segment in enumerate(output['segments']):
for w in output['segments'][idx]['words']:
# check for pauses in speech longer than 3s
if pauses and prev_word_end != -1 and w['start'] - prev_word_end >= 3:
pause = w['start'] - prev_word_end
pause_int = int(pause)
html_text += f'{"." * pause_int}{{{pause_int}sec}}'
text += f'{"." * pause_int}{{{pause_int}sec}}'
prev_word_end = w['end']
if (color_coding):
rgba_color = cmap(w['probability'])
rgb_color = tuple(round(x * 255)
for x in rgba_color[:3])
else:
rgb_color = (0, 0, 0)
html_text += f"<span style='color:rgb{rgb_color}'>{w['word']}</span>"
text += w['word']
# insert line break if there is a punctuation mark
if any(c in w['word'] for c in '!?.') and not any(c.isdigit() for c in w['word']):
html_text += '<br><br>'
text += '\n\n'
st.markdown(html_text, unsafe_allow_html=True)
# doc.add_paragraph(text)
# if (translation):
# with st.expander("English translation"):
# st.markdown(output["translation"], unsafe_allow_html=True)
# # save transcript as docx. in local folder
# file_name = output['name'] + "-" + whisper_model + \
# "-" + datetime.today().strftime('%d-%m-%y') + ".docx"
# doc.save(save_dir + file_name)
# bio = io.BytesIO()
# doc.save(bio)
# st.download_button(
# label="Download Transcription",
# data=bio.getvalue(),
# file_name=file_name,
# mime="docx"
# )
# --- Display Segments with timestamps ---
# if 'segments' in st.session_state.transcript:
# with st.expander('Detailed segments (with timestamps)'):
# st.json(st.session_state.transcript['segments'])
format_time = lambda s: str(timedelta(seconds=int(s)))
# st.write(st.session_state.transcript.output['segments'])
# https://discuss.streamlit.io/t/replaying-an-audio-file-with-a-timecode-click/48892/9
# with col_segments.expander('**SEGMENTS**', expanded=True):
# with col_segments.container('**SEGMENTS**', expanded=True):
# https://docs.streamlit.io/develop/api-reference/layout/st.container
st.session_state['transcript_segments'] = ''
with col_segments.container(height=400, border=False):
# Style buttons as links
with stylable_container(
key='link_buttons',
css_styles='''
button {
background: none!important;
border: none;
padding: 0!important;
font-family: arial, sans-serif;
color: #069;
cursor: pointer;
}
''',
):
for i, segment in enumerate(st.session_state.transcript.output['segments']):
start = format_time(segment['start'])
end = format_time(segment['end'])
text = segment['text'].strip()
# 🕒Segment {i + 1}
# st.badge(f'**[{start} - {end}]** {text}', color='gray')
# st.markdown(
# f':violet-badge[**{start} - {end}**] :gray-badge[{text}]'
# )
col_timecode, col_text = st.columns([1, 5], vertical_alignment='center')
# seg_text = f':violet-badge[**{start} - {end}**] :gray-badge[{text}]'
if col_timecode.button(f':violet-badge[**{start}{end}**]', use_container_width=True):
st.session_state['start_time'] = start
st.rerun()
# col_text.markdown(f':gray-badge[`{text}`]')
# col_text.write('#')
# col_text.markdown(f'<div style="text-align: bottom;">:gray-badge[{text}]</div>', unsafe_allow_html=True)
st.session_state.transcript_segments += f'[**{start}{end}**] {text}'
col_text.text(f'{text}')
# col_text.badge(text, color='gray')
if trash_str in st.session_state.transcript_segments:
st.session_state.transcript_segments.replace(trash_str, '')
# else:
# st.info('Transcript has not been generated yet.')