JUNJIE99 commited on
Commit
e319d30
·
verified ·
1 Parent(s): 844b6e2

Delete demo_test.py

Browse files
Files changed (1) hide show
  1. demo_test.py +0 -44
demo_test.py DELETED
@@ -1,44 +0,0 @@
1
- # from modeling_llavanext_for_embedding import LLaVANextForEmbedding
2
- # from transformers import LlavaNextProcessor
3
-
4
- # model = LLaVANextForEmbedding.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/test").cuda()
5
- # processor = LlavaNextProcessor.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/test")
6
-
7
- # texts = "find a image of a dog"
8
-
9
- # inputs = processor(texts, return_tensors="pt").to("cuda")
10
- # outputs = model(**inputs)
11
- # print(outputs)
12
-
13
-
14
- import torch
15
- from transformers import LlavaNextProcessor, AutoModel
16
-
17
- model = AutoModel.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM", trust_remote_code=True).cuda()
18
- model = model.eval()
19
- processor = LlavaNextProcessor.from_pretrained("/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM")
20
-
21
- texts = "[INST] \n <instruct> <query> find a image of a dog \n [/INST]"
22
-
23
- inputs = processor(texts, return_tensors="pt").to("cuda")
24
- outputs = model(**inputs)[:, -1, :]
25
- embeddings = torch.nn.functional.normalize(outputs, dim=-1)
26
-
27
- print(embeddings)
28
-
29
-
30
-
31
- from transformers import LlavaNextProcessor, AutoModel
32
- import torch
33
-
34
- model_name = "/share/junjie/code/VISTA2/240920mllmemb/llm_dense_retriever/MMRet-release/MMRet-MLLM"
35
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
36
- model = model.eval()
37
- model.set_processor(model_name)
38
- inputs = model.data_process(text="find a image of a dog", q_or_c="query")
39
-
40
- model_output = model(**inputs, output_hidden_states=True)
41
- embeddings = model_output[:, -1, :]
42
- embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
43
-
44
- print(embeddings)