Upload folder using huggingface_hub
Browse files
modeling_llavanext_for_embedding.py
CHANGED
@@ -257,3 +257,74 @@ class LLaVANextForEmbedding(LlavaNextForConditionalGeneration):
|
|
257 |
|
258 |
return outputs
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
return outputs
|
259 |
|
260 |
+
def set_processor(self, model_name):
|
261 |
+
self.processor = LlavaNextProcessor.from_pretrained(model_name)
|
262 |
+
def prepare_text_input(self, image=None, text=None, q_or_c=None, task_instruction=None):
|
263 |
+
task_instruction_example_cir = "Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: "
|
264 |
+
|
265 |
+
assert q_or_c in ["query", "candidate", "q", "c"]
|
266 |
+
|
267 |
+
if "q" in q_or_c:
|
268 |
+
if task_instruction is None:
|
269 |
+
text_input = "[INST] \n <instruct> <query>"
|
270 |
+
print(f"""
|
271 |
+
Warning: For optimal performance, MMRet-MLLM requires the task instruction to be specified in the query.
|
272 |
+
For example, for the composed image retrieval task, you might use a specific instruction like: {task_instruction_example_cir}.
|
273 |
+
Instructions for other tasks can be referenced in the MMEB benchmark.
|
274 |
+
""")
|
275 |
+
elif task_instruction is not None:
|
276 |
+
text_input = f"[INST] \n <instruct> {task_instruction} <query> "
|
277 |
+
|
278 |
+
if text is not None:
|
279 |
+
text_input = f"{text_input} {text} \n"
|
280 |
+
if image is not None:
|
281 |
+
text_input = f"{text_input} <image>"
|
282 |
+
|
283 |
+
text_input = f"{text_input} [/INST]"
|
284 |
+
else:
|
285 |
+
text_input = "[INST] "
|
286 |
+
if text is not None:
|
287 |
+
text_input = f"{text_input} {text} \n"
|
288 |
+
if image is not None:
|
289 |
+
text_input = f"{text_input} <image>"
|
290 |
+
text_input = f"{text_input} [/INST]"
|
291 |
+
|
292 |
+
return text_input
|
293 |
+
|
294 |
+
def data_process(self, images=None, text=None, q_or_c=None, task_instruction=None):
|
295 |
+
if images is not None:
|
296 |
+
_is_list = isinstance(images, list)
|
297 |
+
elif text is not None:
|
298 |
+
_is_list = isinstance(text, list)
|
299 |
+
else:
|
300 |
+
raise ValueError("images and text cannot be both None.")
|
301 |
+
|
302 |
+
assert q_or_c in ["query", "candidate", "q", "c"]
|
303 |
+
|
304 |
+
if not _is_list :
|
305 |
+
text_input = self.prepare_text_input(images, text, q_or_c, task_instruction)
|
306 |
+
text_input = [text_input]
|
307 |
+
|
308 |
+
print(text_input)
|
309 |
+
|
310 |
+
if images is not None:
|
311 |
+
images = Image.open(images).resize((512,512)).convert("RGB")
|
312 |
+
images = [images]
|
313 |
+
inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True)
|
314 |
+
else:
|
315 |
+
inputs = self.processor(text=text_input, return_tensors="pt", padding=True)
|
316 |
+
|
317 |
+
else:
|
318 |
+
text_input = [self.prepare_text_input(_image, _text, q_or_c, task_instruction) for _image, _text in zip(images, text)]
|
319 |
+
|
320 |
+
print(text_input)
|
321 |
+
|
322 |
+
if images is not None:
|
323 |
+
images = [Image.open(_image).resize((512,512)).convert("RGB") for _image in images]
|
324 |
+
inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True)
|
325 |
+
else:
|
326 |
+
inputs = self.processor(text=text_input, return_tensors="pt", padding=True)
|
327 |
+
|
328 |
+
inputs = inputs.to(self.device)
|
329 |
+
|
330 |
+
return inputs
|