Spaces:
Running
Running
import streamlit as st | |
from pathlib import Path | |
import torch | |
from transformers import ( | |
GPT2LMHeadModel, | |
GPT2Tokenizer, | |
pipeline, | |
AutoModelForCausalLM, | |
AutoTokenizer | |
) | |
from moviepy.editor import ( | |
VideoFileClip, | |
TextClip, | |
CompositeVideoClip, | |
AudioFileClip, | |
ColorClip, | |
vfx, | |
concatenate_videoclips | |
) | |
import requests | |
import json | |
from typing import Dict, List, Optional, Union | |
import tempfile | |
import os | |
from dotenv import load_dotenv | |
import time | |
from datetime import datetime | |
import nltk | |
from tqdm import tqdm | |
import pyttsx3 | |
from gtts import gTTS | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageFont | |
import io | |
import random | |
# Set timeout for model downloads | |
import socket | |
socket.setdefaulttimeout(30) # 30 second timeout | |
# Configure transformers to use smaller models and cache | |
from transformers import logging | |
logging.set_verbosity_error() # Reduce logging noise | |
os.environ['TOKENIZERS_PARALLELISM'] = 'true' | |
# Download NLTK data at startup | |
try: | |
nltk.data.find('tokenizers/punkt') | |
nltk.data.find('taggers/averaged_perceptron_tagger') | |
except LookupError: | |
nltk.download('punkt', quiet=True) | |
nltk.download('averaged_perceptron_tagger', quiet=True) | |
# Load environment variables | |
load_dotenv() | |
def create_progress_tracker(): | |
"""Create a progress tracking system""" | |
progress = st.progress(0) | |
status = st.empty() | |
def update(percentage: int, message: str): | |
progress.progress(percentage) | |
status.text(message) | |
return update | |
# Initialize models at module level for caching | |
def load_models(): | |
"""Load all AI models with better error handling and timeouts""" | |
models = {} | |
try: | |
# GPT-2 (smaller version) | |
models['gpt2_model'] = GPT2LMHeadModel.from_pretrained('gpt2', | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32) | |
models['gpt2_tokenizer'] = GPT2Tokenizer.from_pretrained('gpt2') | |
# Use smaller BLOOM model | |
models['bloom_model'] = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.float32) | |
models['bloom_tokenizer'] = AutoTokenizer.from_pretrained("bigscience/bloom-560m") | |
# Use smaller sentiment model | |
models['sentiment_pipeline'] = pipeline('sentiment-analysis', | |
model='distilbert-base-uncased-finetuned-sst-2-english', | |
device=-1) # Force CPU usage | |
return models | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
return None | |
class ContentStyle: | |
def __init__( | |
self, | |
font_size: int = 40, | |
font_color: str = "#FFFFFF", | |
background_effect: str = "None", | |
transition_effect: str = "fade", | |
text_animation: str = "slide", | |
theme: str = "modern", | |
layout: str = "centered" | |
): | |
self.font_size = font_size | |
self.font_color = font_color | |
self.background_effect = background_effect | |
self.transition_effect = transition_effect | |
self.text_animation = text_animation | |
self.theme = theme | |
self.layout = layout | |
class AIContentEngine: | |
def __init__(self, update_progress): | |
self.update_progress = update_progress | |
# Use only GPT-2 for faster generation | |
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
self.model = GPT2LMHeadModel.from_pretrained('gpt2') | |
# Move model to CPU explicitly | |
self.model = self.model.to('cpu') | |
# Set model to evaluation mode | |
self.model.eval() | |
def generate_with_timeout(self, prompt: str, max_length: int = 100, timeout: int = 30) -> str: | |
"""Generate content with timeout""" | |
try: | |
# Set up inputs | |
inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=max_length) | |
# Generate with basic parameters for speed | |
with torch.no_grad(): # Disable gradient calculation | |
outputs = self.model.generate( | |
inputs['input_ids'], | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_k=40, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
attention_mask=inputs['attention_mask'] | |
) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
return f"{prompt} [Error: {str(e)}]" | |
def initialize_models(self): | |
"""Initialize models with retries""" | |
max_retries = 3 | |
retry_count = 0 | |
while retry_count < max_retries: | |
try: | |
self.update_progress(5 + retry_count * 5, f"Loading AI models (attempt {retry_count + 1})...") | |
self.models = load_models() | |
if self.models is not None: | |
return | |
retry_count += 1 | |
time.sleep(2) # Wait before retrying | |
except Exception as e: | |
retry_count += 1 | |
if retry_count == max_retries: | |
raise Exception(f"Failed to load models after {max_retries} attempts: {str(e)}") | |
time.sleep(2) | |
def generate_dynamic_prompt(self, preferences: Dict) -> str: | |
"""Generate context-aware prompt based on detailed user preferences""" | |
prompt_template = f""" | |
Create {preferences['style']} content about {preferences['topic']} | |
with a {preferences['tone']} tone. | |
Target audience: {preferences['audience']}. | |
Purpose: {preferences['purpose']}. | |
Key message: {preferences['message']}. | |
Content format: {preferences.get('format', 'General')}. | |
Include specific examples and actionable insights. | |
Make it engaging and memorable. | |
""" | |
return prompt_template.strip() | |
def analyze_content_quality(self, content: str) -> float: | |
"""Advanced content quality analysis with multiple metrics""" | |
score = 0.0 | |
# Sentiment analysis | |
sentiment = self.sentiment_pipeline(content)[0] | |
score += sentiment['score'] if sentiment['label'] == 'POSITIVE' else 0 | |
# Length analysis | |
words = content.split() | |
word_count = len(words) | |
if 20 <= word_count <= 50: | |
score += 0.3 | |
# Complexity and readability | |
tokens = nltk.word_tokenize(content) | |
pos_tags = nltk.pos_tag(tokens) | |
# Vocabulary diversity | |
unique_words = len(set(words)) / len(words) | |
score += unique_words * 0.3 | |
# Sentence structure variety | |
sentences = nltk.sent_tokenize(content) | |
avg_sentence_length = np.mean([len(nltk.word_tokenize(sent)) for sent in sentences]) | |
if 10 <= avg_sentence_length <= 20: | |
score += 0.2 | |
return score | |
def generate_content_package(self, preferences: Dict) -> Dict[str, str]: | |
"""Generate content package with faster processing""" | |
content_package = {} | |
try: | |
# Simplified prompt templates for faster generation | |
templates = { | |
'main_content': f"Write a short post about {preferences['topic']} that is {preferences['tone']}.", | |
'quote': f"A short quote about {preferences['topic']}:", | |
'tips': f"Three quick tips about {preferences['topic']}:", | |
'call_to_action': f"Call to action for {preferences['topic']}:", | |
'hashtags': f"Trending hashtags for {preferences['topic']}:" | |
} | |
# Generate content with progress updates | |
for i, (key, prompt) in enumerate(templates.items()): | |
progress = 30 + (i * 10) # Progress from 30% to 80% | |
self.update_progress(progress, f"Generating {key.replace('_', ' ')}...") | |
# Generate with timeout | |
content = self.generate_with_timeout(prompt, | |
max_length=100 if key != 'main_content' else 200) | |
content_package[key] = content | |
# Add small delay between generations | |
time.sleep(0.1) | |
return content_package | |
except Exception as e: | |
st.error(f"Error in content generation: {str(e)}") | |
return content_package if content_package else None | |
def generate_with_models(self, prompt: str, max_length: int = 100) -> str: | |
"""Generate content with fallback options""" | |
try: | |
# Try GPT-2 first | |
if 'gpt2_model' in self.models: | |
inputs = self.models['gpt2_tokenizer'](prompt, return_tensors='pt', truncation=True) | |
outputs = self.models['gpt2_model'].generate( | |
inputs['input_ids'], | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
do_sample=True | |
) | |
return self.models['gpt2_tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
# Fallback to BLOOM if GPT-2 fails | |
elif 'bloom_model' in self.models: | |
inputs = self.models['bloom_tokenizer'](prompt, return_tensors="pt") | |
outputs = self.models['bloom_model'].generate( | |
inputs['input_ids'], | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
do_sample=True | |
) | |
return self.models['bloom_tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
else: | |
raise Exception("No language models available") | |
except Exception as e: | |
st.warning(f"Error in content generation: {str(e)}") | |
# Return a basic response if all else fails | |
return prompt + " [Content generation failed, please try again]" | |
class VideoGenerator: | |
def __init__(self, update_progress): | |
self.update_progress = update_progress | |
self.temp_dir = tempfile.mkdtemp() | |
self.pixabay_api_key = os.getenv("PIXABAY_API_KEY") | |
def fetch_background_video(self, query: str) -> Optional[str]: | |
"""Fetch background video from Pixabay with error handling""" | |
try: | |
url = f"https://pixabay.com/api/videos/?key={self.pixabay_api_key}&q={query}" | |
response = requests.get(url) | |
data = response.json() | |
if data.get("hits"): | |
video = random.choice(data["hits"]) | |
video_url = video["videos"]["medium"]["url"] | |
video_path = os.path.join(self.temp_dir, "background.mp4") | |
response = requests.get(video_url) | |
with open(video_path, "wb") as f: | |
f.write(response.content) | |
return video_path | |
return None | |
except Exception as e: | |
st.error(f"Error fetching video: {str(e)}") | |
return None | |
def generate_voiceover(self, text: str, voice_type: str = "gtts") -> Optional[str]: | |
"""Generate voiceover with multiple options and error handling""" | |
try: | |
audio_path = os.path.join(self.temp_dir, "voiceover.mp3") | |
if voice_type == "gtts": | |
tts = gTTS(text=text, lang='en', tld='com') | |
tts.save(audio_path) | |
else: # pyttsx3 | |
engine = pyttsx3.init() | |
engine.save_to_file(text, audio_path) | |
engine.runAndWait() | |
return audio_path | |
except Exception as e: | |
st.error(f"Error generating voiceover: {str(e)}") | |
return None | |
def create_video(self, content: Dict[str, str], preferences: Dict, style: ContentStyle) -> Optional[str]: | |
"""Create video with simpler processing""" | |
try: | |
# Basic video creation | |
self.update_progress(85, "Creating video...") | |
# Create a colored background instead of downloading video | |
background = ColorClip((1080, 1920), color=(0, 0, 128)) # Navy blue background | |
background = background.set_duration(15) | |
# Create text clips | |
text_clips = [] | |
# Main content text | |
main_text = TextClip( | |
txt=content['main_content'][:200], # Limit text length | |
fontsize=40, | |
color='white', | |
size=(1000, None), | |
method='label' | |
).set_duration(15) | |
# Position text in center | |
main_text = main_text.set_position('center') | |
text_clips.append(main_text) | |
# Compose final video | |
final = CompositeVideoClip([background] + text_clips) | |
# Save video | |
self.update_progress(95, "Saving video...") | |
output_path = os.path.join(self.temp_dir, f"output_{int(time.time())}.mp4") | |
final.write_videofile(output_path, fps=24, codec='libx264', audio=False) | |
return output_path | |
except Exception as e: | |
st.error(f"Error creating video: {str(e)}") | |
return None | |
def main(): | |
st.set_page_config(page_title="Professional Content Generator", layout="wide") | |
st.title("🎬 AI Content Generator") | |
# Advanced User Preferences | |
with st.form("content_preferences"): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
topic = st.text_input("Main Topic", help="What's your content about?") | |
style = st.selectbox( | |
"Content Style", | |
["Motivational", "Educational", "Business", "Personal Development", | |
"Technical", "Creative", "Storytelling", "News", "Tutorial"] | |
) | |
tone = st.selectbox( | |
"Tone", | |
["Inspiring", "Professional", "Casual", "Intense", "Friendly", | |
"Authoritative", "Empathetic", "Humorous", "Serious"] | |
) | |
with col2: | |
audience = st.selectbox( | |
"Target Audience", | |
["Entrepreneurs", "Students", "Professionals", "General Public", | |
"Technical", "Creative", "Business Leaders", "Educators"] | |
) | |
purpose = st.selectbox( | |
"Content Purpose", | |
["Inspire", "Educate", "Motivate", "Transform", "Inform", | |
"Entertain", "Persuade", "Guide", "Analyze"] | |
) | |
message = st.text_input("Key Message", help="Core message to convey") | |
with col3: | |
mood = st.selectbox( | |
"Visual Mood", | |
["Energetic", "Calm", "Professional", "Creative", "Modern", | |
"Traditional", "Minimalist", "Bold", "Subtle"] | |
) | |
voice_type = st.selectbox( | |
"Voice Type", | |
["Natural", "Professional", "Friendly", "Authoritative", | |
"Casual", "Energetic", "Calm", "Dynamic"] | |
) | |
# Advanced options in expandable section | |
with st.expander("Advanced Options"): | |
col4, col5 = st.columns(2) | |
with col4: | |
font_size = st.slider("Font Size", 30, 70, 40) | |
font_color = st.color_picker("Font Color", "#FFFFFF") | |
text_animation = st.selectbox( | |
"Text Animation", | |
["Fade", "Slide", "Static", "Bounce", "Zoom"] | |
) | |
with col5: | |
background_effect = st.selectbox( | |
"Background Effect", | |
["None", "Zoom", "Blur", "Bright", "Contrast"] | |
) | |
transition_effect = st.selectbox( | |
"Transition Effect", | |
["None", "Fade", "Slide", "Dissolve", "Wipe"] | |
) | |
theme = st.selectbox( | |
"Visual Theme", | |
["Modern", "Classic", "Minimal", "Bold", "Corporate", "Creative"] | |
) | |
layout = st.selectbox( | |
"Content Layout", | |
["Centered", "Split", "Grid", "Dynamic", "Minimal"] | |
) | |
submit_button = st.form_submit_button("Generate Content") | |
if submit_button: | |
if not topic or not message: | |
st.error("Please fill in at least the topic and key message fields.") | |
return | |
# Initialize style configuration | |
style_config = ContentStyle( | |
font_size=font_size, | |
font_color=font_color, | |
background_effect=background_effect, | |
transition_effect=transition_effect, | |
text_animation=text_animation, | |
theme=theme, | |
layout=layout | |
) | |
# Collect preferences before trying to use them | |
preferences = { | |
'topic': topic, | |
'style': style, | |
'tone': tone, | |
'audience': audience, | |
'purpose': purpose, | |
'message': message, | |
'mood': mood, | |
'voice_type': voice_type | |
} | |
try: | |
# Create progress tracker | |
update_progress = create_progress_tracker() | |
# Initialize engines with progress tracker | |
update_progress(5, "Initializing AI engine...") | |
ai_engine = AIContentEngine(update_progress) | |
update_progress(15, "Initializing video generator...") | |
video_gen = VideoGenerator(update_progress) | |
# Generate content | |
update_progress(25, "Starting content generation...") | |
content_package = ai_engine.generate_content_package(preferences) | |
# Create video | |
update_progress(65, "Starting video creation...") | |
video_path = video_gen.create_video(content_package, preferences, style_config) | |
# Display results | |
update_progress(100, "Complete!") | |
if video_path and os.path.exists(video_path): | |
try: | |
video_file = open(video_path, 'rb') | |
video_bytes = video_file.read() | |
st.video(video_bytes) | |
video_file.close() | |
# Download button | |
st.download_button( | |
label="Download Video", | |
data=video_bytes, | |
file_name=f"content_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp4", | |
mime="video/mp4" | |
) | |
finally: | |
if 'video_file' in locals(): | |
video_file.close() | |
if os.path.exists(video_path): | |
os.remove(video_path) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |