yunusajib commited on
Commit
971be40
·
verified ·
1 Parent(s): dd2c70a

update model and requirements

Browse files
Files changed (2) hide show
  1. llava_inference.py +82 -7
  2. requirements.txt +5 -0
llava_inference.py CHANGED
@@ -2,24 +2,99 @@ from llava.model.builder import load_pretrained_model
2
  from llava.mm_utils import process_images, tokenizer_image_token
3
  from transformers import AutoTokenizer
4
  import torch
 
 
 
5
 
6
  class LLaVAHelper:
7
  def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
8
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, force_download=True)
9
- self.model, self.image_processor, _ = load_pretrained_model(model_name, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  self.model.eval()
 
 
 
 
 
11
 
12
  def generate_answer(self, image, question):
13
- # Preprocess
14
- image_tensor = process_images([image], self.image_processor, self.model.config)[0].unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  prompt = f"###Human: <image>\n{question}\n###Assistant:"
16
- input_ids = tokenizer_image_token(prompt, self.tokenizer, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
17
-
 
 
 
 
 
 
 
18
  with torch.no_grad():
19
  output_ids = self.model.generate(
20
  input_ids=input_ids.input_ids,
21
  images=image_tensor,
22
- max_new_tokens=512
 
 
 
23
  )
 
 
24
  output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
25
  return output.split("###Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from llava.mm_utils import process_images, tokenizer_image_token
3
  from transformers import AutoTokenizer
4
  import torch
5
+ import requests
6
+ from PIL import Image
7
+ from io import BytesIO
8
 
9
  class LLaVAHelper:
10
  def __init__(self, model_name="llava-hf/llava-1.5-7b-hf"):
11
+ # Use cache_dir to avoid issues with the default cache location
12
+ # and disable force_download to use cached versions when available
13
+ self.tokenizer = AutoTokenizer.from_pretrained(
14
+ model_name,
15
+ cache_dir="./model_cache",
16
+ force_download=False,
17
+ trust_remote_code=True
18
+ )
19
+
20
+ # Load model with same cache directory
21
+ self.model, self.image_processor, _ = load_pretrained_model(
22
+ model_name,
23
+ None,
24
+ cache_dir="./model_cache"
25
+ )
26
  self.model.eval()
27
+
28
+ # Move model to appropriate device
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ self.model.to(self.device)
31
+ print(f"Model loaded on {self.device}")
32
 
33
  def generate_answer(self, image, question):
34
+ """
35
+ Generate a response to a question about an image
36
+
37
+ Args:
38
+ image: PIL Image or path to image
39
+ question: String question about the image
40
+
41
+ Returns:
42
+ String response from the model
43
+ """
44
+ # Handle image input (either PIL Image or path/URL)
45
+ if isinstance(image, str):
46
+ if image.startswith(('http://', 'https://')):
47
+ response = requests.get(image)
48
+ image = Image.open(BytesIO(response.content))
49
+ else:
50
+ image = Image.open(image)
51
+
52
+ # Preprocess image
53
+ image_tensor = process_images(
54
+ [image],
55
+ self.image_processor,
56
+ self.model.config
57
+ )[0].unsqueeze(0).to(self.device)
58
+
59
+ # Format prompt with question
60
  prompt = f"###Human: <image>\n{question}\n###Assistant:"
61
+
62
+ # Tokenize prompt
63
+ input_ids = tokenizer_image_token(
64
+ prompt,
65
+ self.tokenizer,
66
+ return_tensors="pt"
67
+ ).to(self.device)
68
+
69
+ # Generate response
70
  with torch.no_grad():
71
  output_ids = self.model.generate(
72
  input_ids=input_ids.input_ids,
73
  images=image_tensor,
74
+ max_new_tokens=512,
75
+ do_sample=True,
76
+ temperature=0.7,
77
+ top_p=0.9,
78
  )
79
+
80
+ # Decode and extract response
81
  output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
82
  return output.split("###Assistant:")[-1].strip()
83
+
84
+ # Example usage
85
+ if __name__ == "__main__":
86
+ try:
87
+ # Initialize model
88
+ llava = LLaVAHelper()
89
+
90
+ # Example with a local file
91
+ # response = llava.generate_answer("path/to/your/image.jpg", "What's in this image?")
92
+
93
+ # Example with a URL
94
+ # image_url = "https://example.com/image.jpg"
95
+ # response = llava.generate_answer(image_url, "Describe this image in detail.")
96
+
97
+ # print(response)
98
+ print("LLaVA model initialized successfully. Ready to process images.")
99
+ except Exception as e:
100
+ print(f"Error initializing LLaVA: {e}")
requirements.txt CHANGED
@@ -2,4 +2,9 @@ torch>=2.0.0
2
  transformers>=4.30.0
3
  accelerate>=0.20.0
4
  gradio>=3.35.0
 
 
 
 
 
5
  git+https://github.com/haotian-liu/LLaVA.git
 
2
  transformers>=4.30.0
3
  accelerate>=0.20.0
4
  gradio>=3.35.0
5
+ pillow>=9.0.0
6
+ requests>=2.28.0
7
+ tqdm>=4.65.0
8
+ timm>=0.6.13
9
+ sentencepiece>=0.1.97
10
  git+https://github.com/haotian-liu/LLaVA.git