Spaces:
Running
Running
Commit
·
94a5e86
1
Parent(s):
9affb79
moderation added
Browse files- .gitignore +3 -1
- app.py +139 -0
.gitignore
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
.env
|
2 |
-
venv/
|
|
|
|
|
|
1 |
.env
|
2 |
+
venv/
|
3 |
+
.gitattributes
|
4 |
+
__pycache__/
|
app.py
CHANGED
@@ -13,6 +13,8 @@ import mimetypes
|
|
13 |
from workflow_handler import WanVideoWorkflow
|
14 |
from video_config import MODEL_FRAME_RATES, calculate_frames
|
15 |
import asyncio
|
|
|
|
|
16 |
|
17 |
dotenv.load_dotenv()
|
18 |
|
@@ -137,6 +139,9 @@ supabase: Client = create_client(
|
|
137 |
|
138 |
)
|
139 |
|
|
|
|
|
|
|
140 |
def initialize_gcs():
|
141 |
"""Initialize Google Cloud Storage client with credentials from environment"""
|
142 |
try:
|
@@ -353,6 +358,122 @@ def poll_generation_status(generation_id):
|
|
353 |
print(f"Error polling generation status: {e}")
|
354 |
raise e
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
async def generate_video(input_image, subject, duration, selected_index, progress=gr.Progress()):
|
357 |
try:
|
358 |
# Initialize workflow handler with explicit paths
|
@@ -451,6 +572,24 @@ async def handle_generation(image_input, subject, duration, selected_index, prog
|
|
451 |
try:
|
452 |
if selected_index is None:
|
453 |
raise gr.Error("You must select a LoRA before proceeding.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
# Generate the video and get generation ID
|
456 |
generation_id = await generate_video(image_input, subject, duration, selected_index)
|
|
|
13 |
from workflow_handler import WanVideoWorkflow
|
14 |
from video_config import MODEL_FRAME_RATES, calculate_frames
|
15 |
import asyncio
|
16 |
+
from openai import OpenAI
|
17 |
+
import base64
|
18 |
|
19 |
dotenv.load_dotenv()
|
20 |
|
|
|
139 |
|
140 |
)
|
141 |
|
142 |
+
# Initialize OpenAI client
|
143 |
+
openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
144 |
+
|
145 |
def initialize_gcs():
|
146 |
"""Initialize Google Cloud Storage client with credentials from environment"""
|
147 |
try:
|
|
|
358 |
print(f"Error polling generation status: {e}")
|
359 |
raise e
|
360 |
|
361 |
+
async def moderate_prompt(prompt: str) -> dict:
|
362 |
+
"""
|
363 |
+
Check if a text prompt contains NSFW content
|
364 |
+
"""
|
365 |
+
try:
|
366 |
+
response = openai_client.moderations.create(input=prompt)
|
367 |
+
result = response.results[0]
|
368 |
+
|
369 |
+
if result.flagged:
|
370 |
+
# Find which categories were flagged
|
371 |
+
flagged_categories = [
|
372 |
+
category for category, flagged in result.categories.model_dump().items()
|
373 |
+
if flagged
|
374 |
+
]
|
375 |
+
|
376 |
+
return {
|
377 |
+
"isNSFW": True,
|
378 |
+
"reason": f"Content flagged for: {', '.join(flagged_categories)}"
|
379 |
+
}
|
380 |
+
|
381 |
+
return {"isNSFW": False, "reason": None}
|
382 |
+
except Exception as e:
|
383 |
+
print(f"Error during prompt moderation: {e}")
|
384 |
+
return {"isNSFW": False, "reason": None}
|
385 |
+
|
386 |
+
async def moderate_image(image_path: str) -> dict:
|
387 |
+
"""
|
388 |
+
Check if an image contains NSFW content using OpenAI's vision capabilities
|
389 |
+
"""
|
390 |
+
try:
|
391 |
+
# Convert image to base64
|
392 |
+
with open(image_path, "rb") as image_file:
|
393 |
+
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
394 |
+
|
395 |
+
response = openai_client.chat.completions.create(
|
396 |
+
model="gpt-4o",
|
397 |
+
messages=[
|
398 |
+
{
|
399 |
+
"role": "system",
|
400 |
+
"content": "You are a content moderation system. Your task is to determine if an image contains NSFW content. Respond with only 'NSFW' if the image contains inappropriate sexual content such as nudity, pornography (especially child pornography), or other explicit material. Otherwise, respond with 'SAFE'."
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"role": "user",
|
404 |
+
"content": [
|
405 |
+
{"type": "text", "text": "Is this image appropriate or does it contain NSFW content?"},
|
406 |
+
{"type": "image_url", "image_url": {
|
407 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
408 |
+
}}
|
409 |
+
]
|
410 |
+
}
|
411 |
+
],
|
412 |
+
max_tokens=10
|
413 |
+
)
|
414 |
+
|
415 |
+
result = response.choices[0].message.content.strip()
|
416 |
+
|
417 |
+
if "NSFW" in result:
|
418 |
+
return {
|
419 |
+
"isNSFW": True,
|
420 |
+
"reason": "Image contains potentially inappropriate content"
|
421 |
+
}
|
422 |
+
|
423 |
+
return {"isNSFW": False, "reason": None}
|
424 |
+
except Exception as e:
|
425 |
+
print(f"Error during image moderation: {e}")
|
426 |
+
return {"isNSFW": False, "reason": None}
|
427 |
+
|
428 |
+
async def moderate_combined(prompt: str, image_path: str) -> dict:
|
429 |
+
"""
|
430 |
+
Check if both a prompt and image contain NSFW content together
|
431 |
+
"""
|
432 |
+
try:
|
433 |
+
# Convert image to base64
|
434 |
+
with open(image_path, "rb") as image_file:
|
435 |
+
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
436 |
+
|
437 |
+
response = openai_client.chat.completions.create(
|
438 |
+
model="gpt-4o",
|
439 |
+
messages=[
|
440 |
+
{
|
441 |
+
"role": "system",
|
442 |
+
"content": "You are a content moderation system. Your task is to determine if an image contains NSFW content. The image + text prompt combined will be generated into a video. Please assess if the generated video would be inappropriate in terms of sexual content only (pornography or child pornography). Respond with 'NSFW: reason' if inappropriate, or 'SAFE' if appropriate."
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"role": "user",
|
446 |
+
"content": [
|
447 |
+
{
|
448 |
+
"type": "text",
|
449 |
+
"text": f'Please moderate this image and prompt combination for an image-to-video generation:\n\nPrompt: "{prompt}"'
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"type": "image_url",
|
453 |
+
"image_url": {
|
454 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
455 |
+
}
|
456 |
+
}
|
457 |
+
]
|
458 |
+
}
|
459 |
+
],
|
460 |
+
max_tokens=150
|
461 |
+
)
|
462 |
+
|
463 |
+
result = response.choices[0].message.content.strip()
|
464 |
+
if result.startswith("NSFW:"):
|
465 |
+
return {
|
466 |
+
"isNSFW": True,
|
467 |
+
"reason": result[5:].strip()
|
468 |
+
}
|
469 |
+
return {
|
470 |
+
"isNSFW": False,
|
471 |
+
"reason": None
|
472 |
+
}
|
473 |
+
except Exception as e:
|
474 |
+
print(f"Error during combined moderation: {e}")
|
475 |
+
return {"isNSFW": False, "reason": None}
|
476 |
+
|
477 |
async def generate_video(input_image, subject, duration, selected_index, progress=gr.Progress()):
|
478 |
try:
|
479 |
# Initialize workflow handler with explicit paths
|
|
|
572 |
try:
|
573 |
if selected_index is None:
|
574 |
raise gr.Error("You must select a LoRA before proceeding.")
|
575 |
+
|
576 |
+
# First, moderate the prompt
|
577 |
+
prompt_moderation = await moderate_prompt(subject)
|
578 |
+
print(f"Prompt moderation result: {prompt_moderation}")
|
579 |
+
if prompt_moderation["isNSFW"]:
|
580 |
+
raise gr.Error(f"Content moderation failed: {prompt_moderation['reason']}")
|
581 |
+
|
582 |
+
# Then, moderate the image
|
583 |
+
image_moderation = await moderate_image(image_input)
|
584 |
+
print(f"Image moderation result: {image_moderation}")
|
585 |
+
if image_moderation["isNSFW"]:
|
586 |
+
raise gr.Error(f"Content moderation failed: {image_moderation['reason']}")
|
587 |
+
|
588 |
+
# Finally, check the combination
|
589 |
+
combined_moderation = await moderate_combined(subject, image_input)
|
590 |
+
print(f"Combined moderation result: {combined_moderation}")
|
591 |
+
if combined_moderation["isNSFW"]:
|
592 |
+
raise gr.Error(f"Content moderation failed: {combined_moderation['reason']}")
|
593 |
|
594 |
# Generate the video and get generation ID
|
595 |
generation_id = await generate_video(image_input, subject, duration, selected_index)
|