padmanabhbosamia's picture
Uploaded Necessary Files
029c2c9 verified
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
CLIPVisionModel,
CLIPImageProcessor
)
from peft import PeftModel
import gradio as gr
from PIL import Image
import os
from prompt_templates import PROMPT_TEMPLATES, RESPONSE_TEMPLATES, ANALYSIS_TEMPLATES
import numpy as np
import cv2
class VLMInference:
def __init__(self):
# Initialize vision model
self.vision_model = CLIPVisionModel.from_pretrained(
"openai/clip-vit-base-patch32",
device_map="auto",
torch_dtype=torch.float16
)
self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Initialize language model
self.language_model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True
)
# Find the most recent trained model
model_dirs = [d for d in os.listdir(".") if d.startswith("best_vlm_model")]
if model_dirs:
# Sort directories by timestamp if available, otherwise by name
def get_timestamp(d):
try:
return int(d.split("_")[-1])
except ValueError:
return 0 # For directories without timestamps
latest_model = sorted(model_dirs, key=get_timestamp)[-1]
model_path = latest_model
print(f"Loading trained model from: {model_path}")
# Load the trained LoRA weights
self.language_model = PeftModel.from_pretrained(
self.language_model,
model_path
)
print("Successfully loaded trained LoRA weights")
else:
print("No trained model found. Using base model.")
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
self.tokenizer.pad_token = self.tokenizer.eos_token
# Projection layer
self.projection = torch.nn.Linear(
self.vision_model.config.hidden_size,
self.language_model.config.hidden_size
).half()
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.vision_model.to(self.device)
self.language_model.to(self.device)
self.projection.to(self.device)
def analyze_image(self, image):
# Convert PIL Image to numpy array
if isinstance(image, torch.Tensor):
# If image is a tensor, move to CPU and convert to numpy
img_np = image.cpu().numpy().transpose(1, 2, 0)
# Denormalize if needed (assuming image was normalized)
img_np = (img_np * 255).astype(np.uint8)
else:
# If image is PIL Image, convert directly
img_np = np.array(image)
# Basic color statistics
r, g, b = np.mean(img_np, axis=(0,1))
r_std, g_std, b_std = np.std(img_np, axis=(0,1))
# Convert to grayscale for brightness and edge analysis
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
brightness = np.mean(gray)
contrast = np.std(gray)
# Edge detection
edges = cv2.Canny(gray, 100, 200)
edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1]) * 100
# Regional brightness analysis
h, w = gray.shape
top = np.mean(gray[:h//2, :])
bottom = np.mean(gray[h//2:, :])
left = np.mean(gray[:, :w//2])
right = np.mean(gray[:, w//2:])
return {
"color": ANALYSIS_TEMPLATES["color"].format(
r=int(r), g=int(g), b=int(b),
r_std=int(r_std), g_std=int(g_std), b_std=int(b_std)
),
"brightness": ANALYSIS_TEMPLATES["brightness"].format(
brightness=int(brightness),
contrast=int(contrast)
),
"edges": ANALYSIS_TEMPLATES["edges"].format(
edge_density=int(edge_density)
),
"regions": ANALYSIS_TEMPLATES["regions"].format(
top=int(top), bottom=int(bottom),
left=int(left), right=int(right)
)
}
def process_image(self, image):
# Process image
image = self.image_processor(image, return_tensors="pt")["pixel_values"][0].to(self.device)
# Get vision features
with torch.no_grad():
vision_outputs = self.vision_model(image.unsqueeze(0))
vision_features = vision_outputs.last_hidden_state.mean(dim=1)
vision_features = self.projection(vision_features)
return vision_features
def generate_response(self, image, prompt_type, custom_prompt=None):
# Process image
image = self.image_processor(image, return_tensors="pt")["pixel_values"][0].to(self.device)
# Get vision features
with torch.no_grad():
vision_outputs = self.vision_model(image.unsqueeze(0))
vision_features = vision_outputs.last_hidden_state.mean(dim=1)
vision_features = self.projection(vision_features)
# Analyze image
analysis = self.analyze_image(image)
# Format prompt based on type
if custom_prompt:
prompt = custom_prompt
else:
prompt = np.random.choice(PROMPT_TEMPLATES[prompt_type])
# Format full prompt with analysis
full_prompt = f"### Instruction: {prompt}\n\nImage Analysis:\n"
for key, value in analysis.items():
full_prompt += f"{value}\n"
full_prompt += "\n### Response:"
# Tokenize
inputs = self.tokenizer(
full_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(self.device)
# Generate response
with torch.no_grad():
# Generate using the base model
outputs = self.language_model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=512,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("### Response:")[1].strip()
return response, analysis
def create_interface():
# Initialize model
model = VLMInference()
def process_image_and_prompt(image, prompt_type, custom_prompt):
try:
response, analysis = model.generate_response(image, prompt_type, custom_prompt)
# Format the output
output = f"Response:\n{response}\n\nImage Analysis:\n"
for key, value in analysis.items():
output += f"{value}\n"
return output
except Exception as e:
return f"Error: {str(e)}"
# Load sample images from enhanced CIFAR10 dataset
sample_images = []
sample_labels = []
dataset_dir = "enhanced_cifar10_dataset"
if os.path.exists(dataset_dir):
for filename in os.listdir(dataset_dir):
if filename.startswith("enhanced_cifar10_") and filename.endswith(".png"):
class_name = filename.replace("enhanced_cifar10_", "").replace(".png", "")
image_path = os.path.join(dataset_dir, filename)
try:
# Load and verify the image
img = Image.open(image_path)
img.verify() # Verify it's a valid image
sample_images.append(image_path)
sample_labels.append(class_name)
except Exception as e:
print(f"Error loading image {image_path}: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="Vision-Language Model Demo") as interface:
gr.Markdown("# Vision-Language Model Demo")
gr.Markdown("Select a sample image from the enhanced CIFAR10 dataset or upload your own image.")
with gr.Row():
with gr.Column():
# Sample images gallery
if sample_images:
gr.Markdown("### Sample Images from Enhanced CIFAR10 Dataset")
sample_gallery = gr.Gallery(
value=[(img, label) for img, label in zip(sample_images, sample_labels)],
label="Select a sample image",
columns=5,
height="auto",
object_fit="contain"
)
else:
gr.Markdown("No sample images found in the enhanced CIFAR10 dataset.")
# Image input
image_input = gr.Image(type="pil", label="Upload Image")
# Prompt selection
prompt_type = gr.Dropdown(
choices=list(PROMPT_TEMPLATES.keys()),
value="basic",
label="Select Prompt Type"
)
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter your own prompt here..."
)
submit_btn = gr.Button("Generate Response")
with gr.Column():
output_text = gr.Textbox(
label="Model Response and Analysis",
lines=15
)
# Add click event for sample gallery
if sample_images:
def load_selected_image(evt: gr.SelectData):
if evt.index < len(sample_images):
return Image.open(sample_images[evt.index])
return None
sample_gallery.select(
fn=load_selected_image,
inputs=[],
outputs=[image_input]
)
submit_btn.click(
fn=process_image_and_prompt,
inputs=[image_input, prompt_type, custom_prompt],
outputs=[output_text]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch(share=True)