internvit-fix-bfloat / no_flash_attn_test.py
mknolan's picture
Upload no_flash_attn_test.py with huggingface_hub
0dd9229 verified
raw
history blame contribute delete
4.98 kB
import torch
import os
import sys
import traceback
import gradio as gr
from PIL import Image
from transformers import AutoModel, CLIPImageProcessor
print("=" * 50)
print("INTERNVIT-6B MODEL LOADING TEST (NO FLASH-ATTN)")
print("=" * 50)
# System information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
# Memory info
print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
else:
print("CUDA is not available. This is a critical issue for model loading.")
# Create a function to load and test the model
def load_and_test_model():
try:
# Monkey patch to disable flash attention
import sys
import types
# Create a fake flash_attn module
flash_attn_module = types.ModuleType("flash_attn")
flash_attn_module.__version__ = "0.0.0-disabled"
sys.modules["flash_attn"] = flash_attn_module
print("\nNOTE: Created dummy flash_attn module to avoid dependency error")
print("This is just for testing basic model loading - some functionality may be disabled")
print("\nLoading model with bfloat16 precision and low_cpu_mem_usage=True...")
model = AutoModel.from_pretrained(
"OpenGVLab/InternViT-6B-224px",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True)
if torch.cuda.is_available():
print("Moving model to CUDA...")
model = model.cuda()
model.eval()
print("βœ“ Model loaded successfully!")
# Now try to process a test image
print("\nLoading image processor...")
image_processor = CLIPImageProcessor.from_pretrained("OpenGVLab/InternViT-6B-224px")
print("βœ“ Image processor loaded successfully!")
# Create a simple test image
print("\nCreating test image...")
test_image = Image.new("RGB", (224, 224), color="red")
# Process the test image
print("Processing test image...")
pixel_values = image_processor(images=test_image, return_tensors="pt").pixel_values
# FIXED: Always convert to bfloat16 first, then optionally move to CUDA
print("Converting image tensor to bfloat16 to match model dtype...")
pixel_values = pixel_values.to(torch.bfloat16)
if torch.cuda.is_available():
print("Moving image tensor to CUDA...")
pixel_values = pixel_values.cuda()
# Get model parameters
params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {params:,}")
# Forward pass
print("Running forward pass...")
with torch.no_grad():
outputs = model(pixel_values)
print("βœ“ Forward pass successful!")
print(f"Output shape: {outputs.last_hidden_state.shape}")
return f"SUCCESS: Model loaded and test passed!\nParameters: {params:,}\nOutput shape: {outputs.last_hidden_state.shape}"
except Exception as e:
print(f"\n❌ ERROR: {str(e)}")
traceback.print_exc()
return f"FAILED: Error loading model or processing image\nError: {str(e)}"
# Create a simple Gradio interface
def create_interface():
with gr.Blocks(title="InternViT-6B Test") as demo:
gr.Markdown("# InternViT-6B Model Loading Test (without Flash Attention)")
gr.Markdown("### This version uses a dummy flash-attn implementation to avoid compilation issues")
with gr.Row():
test_btn = gr.Button("Test Model Loading")
output = gr.Textbox(label="Test Results", lines=10)
test_btn.click(fn=load_and_test_model, inputs=[], outputs=output)
return demo
# Main function
if __name__ == "__main__":
# Print environment variables
print("\nEnvironment variables:")
relevant_vars = ["CUDA_VISIBLE_DEVICES", "NVIDIA_VISIBLE_DEVICES",
"TRANSFORMERS_CACHE", "HF_HOME", "PYTORCH_CUDA_ALLOC_CONF"]
for var in relevant_vars:
print(f"{var}: {os.environ.get(var, 'Not set')}")
# Set environment variable for better GPU memory management
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Create and launch the interface
demo = create_interface()
demo.launch(share=False, server_name="0.0.0.0")