alex-remade commited on
Commit
94a5e86
·
1 Parent(s): 9affb79

moderation added

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. app.py +139 -0
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
- venv/
 
 
 
1
  .env
2
+ venv/
3
+ .gitattributes
4
+ __pycache__/
app.py CHANGED
@@ -13,6 +13,8 @@ import mimetypes
13
  from workflow_handler import WanVideoWorkflow
14
  from video_config import MODEL_FRAME_RATES, calculate_frames
15
  import asyncio
 
 
16
 
17
  dotenv.load_dotenv()
18
 
@@ -137,6 +139,9 @@ supabase: Client = create_client(
137
 
138
  )
139
 
 
 
 
140
  def initialize_gcs():
141
  """Initialize Google Cloud Storage client with credentials from environment"""
142
  try:
@@ -353,6 +358,122 @@ def poll_generation_status(generation_id):
353
  print(f"Error polling generation status: {e}")
354
  raise e
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  async def generate_video(input_image, subject, duration, selected_index, progress=gr.Progress()):
357
  try:
358
  # Initialize workflow handler with explicit paths
@@ -451,6 +572,24 @@ async def handle_generation(image_input, subject, duration, selected_index, prog
451
  try:
452
  if selected_index is None:
453
  raise gr.Error("You must select a LoRA before proceeding.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  # Generate the video and get generation ID
456
  generation_id = await generate_video(image_input, subject, duration, selected_index)
 
13
  from workflow_handler import WanVideoWorkflow
14
  from video_config import MODEL_FRAME_RATES, calculate_frames
15
  import asyncio
16
+ from openai import OpenAI
17
+ import base64
18
 
19
  dotenv.load_dotenv()
20
 
 
139
 
140
  )
141
 
142
+ # Initialize OpenAI client
143
+ openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
144
+
145
  def initialize_gcs():
146
  """Initialize Google Cloud Storage client with credentials from environment"""
147
  try:
 
358
  print(f"Error polling generation status: {e}")
359
  raise e
360
 
361
+ async def moderate_prompt(prompt: str) -> dict:
362
+ """
363
+ Check if a text prompt contains NSFW content
364
+ """
365
+ try:
366
+ response = openai_client.moderations.create(input=prompt)
367
+ result = response.results[0]
368
+
369
+ if result.flagged:
370
+ # Find which categories were flagged
371
+ flagged_categories = [
372
+ category for category, flagged in result.categories.model_dump().items()
373
+ if flagged
374
+ ]
375
+
376
+ return {
377
+ "isNSFW": True,
378
+ "reason": f"Content flagged for: {', '.join(flagged_categories)}"
379
+ }
380
+
381
+ return {"isNSFW": False, "reason": None}
382
+ except Exception as e:
383
+ print(f"Error during prompt moderation: {e}")
384
+ return {"isNSFW": False, "reason": None}
385
+
386
+ async def moderate_image(image_path: str) -> dict:
387
+ """
388
+ Check if an image contains NSFW content using OpenAI's vision capabilities
389
+ """
390
+ try:
391
+ # Convert image to base64
392
+ with open(image_path, "rb") as image_file:
393
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
394
+
395
+ response = openai_client.chat.completions.create(
396
+ model="gpt-4o",
397
+ messages=[
398
+ {
399
+ "role": "system",
400
+ "content": "You are a content moderation system. Your task is to determine if an image contains NSFW content. Respond with only 'NSFW' if the image contains inappropriate sexual content such as nudity, pornography (especially child pornography), or other explicit material. Otherwise, respond with 'SAFE'."
401
+ },
402
+ {
403
+ "role": "user",
404
+ "content": [
405
+ {"type": "text", "text": "Is this image appropriate or does it contain NSFW content?"},
406
+ {"type": "image_url", "image_url": {
407
+ "url": f"data:image/jpeg;base64,{base64_image}"
408
+ }}
409
+ ]
410
+ }
411
+ ],
412
+ max_tokens=10
413
+ )
414
+
415
+ result = response.choices[0].message.content.strip()
416
+
417
+ if "NSFW" in result:
418
+ return {
419
+ "isNSFW": True,
420
+ "reason": "Image contains potentially inappropriate content"
421
+ }
422
+
423
+ return {"isNSFW": False, "reason": None}
424
+ except Exception as e:
425
+ print(f"Error during image moderation: {e}")
426
+ return {"isNSFW": False, "reason": None}
427
+
428
+ async def moderate_combined(prompt: str, image_path: str) -> dict:
429
+ """
430
+ Check if both a prompt and image contain NSFW content together
431
+ """
432
+ try:
433
+ # Convert image to base64
434
+ with open(image_path, "rb") as image_file:
435
+ base64_image = base64.b64encode(image_file.read()).decode('utf-8')
436
+
437
+ response = openai_client.chat.completions.create(
438
+ model="gpt-4o",
439
+ messages=[
440
+ {
441
+ "role": "system",
442
+ "content": "You are a content moderation system. Your task is to determine if an image contains NSFW content. The image + text prompt combined will be generated into a video. Please assess if the generated video would be inappropriate in terms of sexual content only (pornography or child pornography). Respond with 'NSFW: reason' if inappropriate, or 'SAFE' if appropriate."
443
+ },
444
+ {
445
+ "role": "user",
446
+ "content": [
447
+ {
448
+ "type": "text",
449
+ "text": f'Please moderate this image and prompt combination for an image-to-video generation:\n\nPrompt: "{prompt}"'
450
+ },
451
+ {
452
+ "type": "image_url",
453
+ "image_url": {
454
+ "url": f"data:image/jpeg;base64,{base64_image}"
455
+ }
456
+ }
457
+ ]
458
+ }
459
+ ],
460
+ max_tokens=150
461
+ )
462
+
463
+ result = response.choices[0].message.content.strip()
464
+ if result.startswith("NSFW:"):
465
+ return {
466
+ "isNSFW": True,
467
+ "reason": result[5:].strip()
468
+ }
469
+ return {
470
+ "isNSFW": False,
471
+ "reason": None
472
+ }
473
+ except Exception as e:
474
+ print(f"Error during combined moderation: {e}")
475
+ return {"isNSFW": False, "reason": None}
476
+
477
  async def generate_video(input_image, subject, duration, selected_index, progress=gr.Progress()):
478
  try:
479
  # Initialize workflow handler with explicit paths
 
572
  try:
573
  if selected_index is None:
574
  raise gr.Error("You must select a LoRA before proceeding.")
575
+
576
+ # First, moderate the prompt
577
+ prompt_moderation = await moderate_prompt(subject)
578
+ print(f"Prompt moderation result: {prompt_moderation}")
579
+ if prompt_moderation["isNSFW"]:
580
+ raise gr.Error(f"Content moderation failed: {prompt_moderation['reason']}")
581
+
582
+ # Then, moderate the image
583
+ image_moderation = await moderate_image(image_input)
584
+ print(f"Image moderation result: {image_moderation}")
585
+ if image_moderation["isNSFW"]:
586
+ raise gr.Error(f"Content moderation failed: {image_moderation['reason']}")
587
+
588
+ # Finally, check the combination
589
+ combined_moderation = await moderate_combined(subject, image_input)
590
+ print(f"Combined moderation result: {combined_moderation}")
591
+ if combined_moderation["isNSFW"]:
592
+ raise gr.Error(f"Content moderation failed: {combined_moderation['reason']}")
593
 
594
  # Generate the video and get generation ID
595
  generation_id = await generate_video(image_input, subject, duration, selected_index)