thanhtantran commited on
Commit
d8fa8a2
·
verified ·
1 Parent(s): 5bbe13d

Clone từ AITeamVN/Vietnamese_Embedding

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 1024,
3
+ "pooling_mode_cls_token": true,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -1,3 +1,90 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - vi
5
+ base_model:
6
+ - BAAI/bge-m3
7
+ pipeline_tag: sentence-similarity
8
+ library_name: sentence-transformers
9
+ tags:
10
+ - Embedding
11
+ ---
12
+
13
+
14
+ ## Model Card: Vietnamese_Embedding
15
+
16
+ Vietnamese_Embedding is an embedding model fine-tuned from the BGE-M3 model (https://huggingface.co/BAAI/bge-m3) to enhance retrieval capabilities for Vietnamese.
17
+
18
+ * The model was trained on approximately 300,000 triplets of queries, positive documents, and negative documents for Vietnamese.
19
+ * The model was trained with a maximum sequence length of 2048.
20
+
21
+ ## Model Details
22
+
23
+ ### Model Description
24
+ - **Model Type:** Sentence Transformer
25
+ - **Base model:** [BAAI/bge-m3](https://huggingface.co/BAAI/bge-m3)
26
+ - **Maximum Sequence Length:** 2048 tokens
27
+ - **Output Dimensionality:** 1024 dimensions
28
+ - **Similarity Function:** Dot product Similarity
29
+ - **Language:** Vietnamese
30
+ - **Licence:** Apache 2.0
31
+
32
+ ## Usage
33
+
34
+ ```python
35
+ from sentence_transformers import SentenceTransformer
36
+ import torch
37
+
38
+ model = SentenceTransformer("AITeamVN/Vietnamese_Embedding")
39
+ model.max_seq_length = 2048
40
+ sentences_1 = ["Trí tuệ nhân tạo là gì", "Lợi ích của giấc ngủ"]
41
+ sentences_2 = ["Trí tuệ nhân tạo là công nghệ giúp máy móc suy nghĩ và học hỏi như con người. Nó hoạt động bằng cách thu thập dữ liệu, nhận diện mẫu và đưa ra quyết định.",
42
+ "Giấc ngủ giúp cơ thể và não bộ nghỉ ngơi, hồi phục năng lượng và cải thiện trí nhớ. Ngủ đủ giấc giúp tinh thần tỉnh táo và làm việc hiệu quả hơn."]
43
+ query_embedding = model.encode(sentences_1)
44
+ doc_embeddings = model.encode(sentences_2)
45
+ similarity = query_embedding @ doc_embeddings.T
46
+ print(similarity)
47
+
48
+ '''
49
+ array([[0.66212064, 0.33066642],
50
+ [0.25866613, 0.5865289 ]], dtype=float32)
51
+ '''
52
+ ```
53
+
54
+
55
+ ### Evaluation:
56
+
57
+ - Dataset: Entire training dataset of Legal Zalo 2021. Our model was not trained on this dataset.
58
+
59
+ | Model | Accuracy@1 | Accuracy@3 | Accuracy@5 | Accuracy@10 | MRR@10 |
60
+ |----------------------|------------|------------|------------|-------------|--------------|
61
+ | Vietnamese_Reranker (Phase 2) | 0.7944 | 0.9324 | 0.9537 | 0.9740 | 0.8672 |
62
+ | Vietnamese_Embedding (Phase 2) | 0.7262 | 0.8927 | 0.9268 | 0.9578 | 0.8149 |
63
+ | Vietnamese_Embedding (public) | 0.7274 | 0.8992 | 0.9305 | 0.9568 | 0.8181 |
64
+ | Vietnamese-bi-encoder (BKAI) | 0.7109 | 0.8680 | 0.9014 | 0.9299 | 0.7951 |
65
+ | BGE-M3 | 0.5682 | 0.7728 | 0.8382 | 0.8921 | 0.6822 |
66
+
67
+ Vietnamese_Reranker (Phase 2) and Vietnamese_Embedding (Phase 2) was trained on 1100000 triplets.
68
+
69
+ Although the score on the legal domain drops a bit on Vietnamese_Embedding (Phase 2), since this phase data is much larger, it is very good for other domains.
70
+
71
+ You can reproduce the evaluation result by running code python evaluation_model.py (data downloaded from Kaggle).
72
+
73
+ ## Contact
74
+
75
76
+
77
+ **Developer**
78
+
79
+ Member: Nguyễn Nho Trung, Nguyễn Nhật Quang
80
+
81
+ ## Citation
82
+
83
+ ```Plaintext
84
+ @misc{Vietnamese_Embedding,
85
+ title={Vietnamese_Embedding: Embedding model in Vietnamese language.},
86
+ author={Nguyen Nho Trung, Nguyen Nhat Quang},
87
+ year={2025},
88
+ publisher={Huggingface},
89
+ }
90
+ ```
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/AITeamVN/bge_vi_2048",
3
+ "architectures": [
4
+ "XLMRobertaModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 8194,
17
+ "model_type": "xlm-roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "output_past": true,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "absolute",
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.49.0",
25
+ "type_vocab_size": 1,
26
+ "use_cache": true,
27
+ "vocab_size": 250002
28
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.6.1",
4
+ "transformers": "4.49.0",
5
+ "pytorch": "2.6.0+cu124"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null
9
+ }
evaluation_model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import json
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from typing import List, Dict, Tuple, Set, Union, Optional
7
+ from langchain.docstore.document import Document
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.vectorstores.faiss import DistanceStrategy
10
+ from langchain_core.embeddings.embeddings import Embeddings
11
+ from FlagEmbedding import BGEM3FlagModel
12
+
13
+ def setup_gpu_info() -> None:
14
+ print(f"Số lượng GPU khả dụng: {torch.cuda.device_count()}")
15
+ print(f"GPU hiện tại: {torch.cuda.current_device()}")
16
+ print(f"Tên GPU: {torch.cuda.get_device_name(0)}")
17
+
18
+ def load_model(model_name: str, use_fp16: bool = False) -> BGEM3FlagModel:
19
+ return BGEM3FlagModel(model_name, use_fp16=use_fp16)
20
+
21
+ def load_json_file(file_path: str) -> dict:
22
+ with open(file_path, 'r', encoding='utf-8') as f:
23
+ return json.load(f)
24
+
25
+ def load_jsonl_file(file_path: str) -> List[Dict]:
26
+ corpus = []
27
+ with open(file_path, "r", encoding="utf-8") as file:
28
+ for line in file:
29
+ data = json.loads(line.strip())
30
+ corpus.append(data)
31
+ return corpus
32
+
33
+ def extract_corpus_from_legal_documents(legal_data: dict) -> List[Dict]:
34
+ corpus = []
35
+ for document in legal_data:
36
+ for article in document['articles']:
37
+ chunk = {
38
+ "law_id": document['law_id'],
39
+ "article_id": article['article_id'],
40
+ "title": article['title'],
41
+ "text": article['title'] + '\n' + article['text']
42
+ }
43
+ corpus.append(chunk)
44
+ return corpus
45
+
46
+ def convert_corpus_to_documents(corpus: List[Dict[str, str]]) -> List[Document]:
47
+ documents = []
48
+ for i in tqdm(range(len(corpus)), desc="Converting corpus to documents"):
49
+ context = corpus[i]['text']
50
+ metadata = {
51
+ 'law_id': corpus[i]['law_id'],
52
+ 'article_id': corpus[i]['article_id'],
53
+ 'title': corpus[i]['title']
54
+ }
55
+ documents.append(Document(page_content=context, metadata=metadata))
56
+ return documents
57
+
58
+ class CustomEmbedding(Embeddings):
59
+ """Custom embedding class that uses the BGEM3FlagModel."""
60
+
61
+ def __init__(self, model: BGEM3FlagModel, batch_size: int = 1):
62
+ self.model = model
63
+ self.batch_size = batch_size
64
+
65
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
66
+ embeddings = []
67
+ for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding documents"):
68
+ batch_texts = texts[i:i+self.batch_size]
69
+ batch_embeddings = self._get_batch_embeddings(batch_texts)
70
+ embeddings.extend(batch_embeddings)
71
+ torch.cuda.empty_cache()
72
+ return np.vstack(embeddings)
73
+
74
+ def embed_query(self, text: str) -> List[float]:
75
+ embedding = self.model.encode(text, max_length=256)['dense_vecs']
76
+ return embedding
77
+
78
+ def _get_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
79
+ with torch.no_grad():
80
+ outputs = self.model.encode(texts, batch_size=self.batch_size, max_length=2048)['dense_vecs']
81
+ batch_embeddings = outputs
82
+ del outputs
83
+ return batch_embeddings
84
+
85
+
86
+ class VectorDB:
87
+ """Vector database for document retrieval."""
88
+
89
+ def __init__(
90
+ self,
91
+ documents: List[Document],
92
+ embedding: Embeddings,
93
+ vector_db=FAISS,
94
+ index_path: Optional[str] = None
95
+ ) -> None:
96
+ self.vector_db = vector_db
97
+ self.embedding = embedding
98
+ self.index_path = index_path
99
+ self.db = self._build_db(documents)
100
+
101
+ def _build_db(self, documents: List[Document]):
102
+ if self.index_path:
103
+ db = self.vector_db.load_local(
104
+ self.index_path,
105
+ self.embedding,
106
+ allow_dangerous_deserialization=True
107
+ )
108
+ else:
109
+ db = self.vector_db.from_documents(
110
+ documents=documents,
111
+ embedding=self.embedding,
112
+ distance_strategy=DistanceStrategy.DOT_PRODUCT
113
+ )
114
+ return db
115
+
116
+ def get_retriever(self, search_type: str = "similarity", search_kwargs: dict = {"k": 10}):
117
+ retriever = self.db.as_retriever(search_type=search_type, search_kwargs=search_kwargs)
118
+ return retriever
119
+
120
+ def save_local(self, folder_path: str) -> None:
121
+ self.db.save_local(folder_path)
122
+
123
+
124
+ def process_sample(sample: dict, retriever) -> List[int]:
125
+ question = sample['question']
126
+ docs = retriever.invoke(question)
127
+ retrieved_article_full_ids = [
128
+ docs[i].metadata['law_id'] + "#" + docs[i].metadata['article_id']
129
+ for i in range(len(docs))
130
+ ]
131
+ indexes = []
132
+ for article in sample['relevant_articles']:
133
+ article_full_id = article['law_id'] + "#" + article['article_id']
134
+ if article_full_id in retrieved_article_full_ids:
135
+ idx = retrieved_article_full_ids.index(article_full_id) + 1
136
+ indexes.append(idx)
137
+ else:
138
+ indexes.append(0)
139
+ return indexes
140
+
141
+ def calculate_metrics(all_indexes: List[List[int]], num_samples: int, selected_keys: Set[str]) -> Dict[str, float]:
142
+ count = [len(indexes) for indexes in all_indexes]
143
+ result = {}
144
+
145
+ for thres in [1, 3, 5, 10, 100]:
146
+ found = [[y for y in x if 0 < y <= thres] for x in all_indexes]
147
+ found_count = [len(x) for x in found]
148
+ acc = sum(1 for i in range(num_samples) if found_count[i] > 0) / num_samples
149
+ rec = sum(found_count[i] / count[i] for i in range(num_samples)) / num_samples
150
+ pre = sum(found_count[i] / thres for i in range(num_samples)) / num_samples
151
+ mrr = sum(1 / min(x) if x else 0 for x in found) / num_samples
152
+
153
+ if f"Accuracy@{thres}" in selected_keys:
154
+ result[f"Accuracy@{thres}"] = acc
155
+ if f"MRR@{thres}" in selected_keys:
156
+ result[f"MRR@{thres}"] = mrr
157
+
158
+ return result
159
+
160
+
161
+ def save_results(result: Dict[str, float], output_path: str) -> None:
162
+ with open(output_path, "w", encoding="utf-8") as f:
163
+ json.dump(result, f, indent=4, ensure_ascii=False)
164
+ print(f"Results saved to {output_path}")
165
+
166
+
167
+ def main():
168
+ setup_gpu_info()
169
+ model = load_model('AITeamVN/Vietnamese_Embedding', use_fp16=False)
170
+ samples = load_json_file('zalo_kaggle/train_question_answer.json')['items']
171
+ legal_data = load_json_file('zalo_kaggle/legal_corpus.json')
172
+
173
+ corpus = extract_corpus_from_legal_documents(legal_data)
174
+ documents = convert_corpus_to_documents(corpus)
175
+ embedding = CustomEmbedding(model, batch_size=1) # Increased batch size for efficiency time
176
+ vectordb = VectorDB(
177
+ documents=documents,
178
+ embedding=embedding,
179
+ vector_db=FAISS,
180
+ index_path=None
181
+ )
182
+ retriever = vectordb.get_retriever(search_type="similarity", search_kwargs={"k": 100})
183
+ all_indexes = []
184
+ for sample in tqdm(samples, desc="Processing samples"):
185
+ all_indexes.append(process_sample(sample, retriever))
186
+ selected_keys = {"Accuracy@1", "Accuracy@3", "Accuracy@5", "Accuracy@10", "MRR@10", "Accuracy@100"}
187
+ result = calculate_metrics(all_indexes, len(samples), selected_keys)
188
+ print(result)
189
+ save_results(result, "zalo_kaggle/Vietnamese_Embedding.json")
190
+ if __name__ == "__main__":
191
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f2debafdf03659e8273022a3e902b94deec73cd20c2b7262ab7e21630163f6d
3
+ size 2271064456
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 8192,
3
+ "do_lower_case": false
4
+ }
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b74659c780d49afad7a7b9799868f75cbd3014fb6c34956e85a793028d38094a
3
+ size 17098251
tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "extra_special_tokens": {},
49
+ "mask_token": "<mask>",
50
+ "model_max_length": 8192,
51
+ "pad_token": "<pad>",
52
+ "sep_token": "</s>",
53
+ "sp_model_kwargs": {},
54
+ "tokenizer_class": "XLMRobertaTokenizer",
55
+ "unk_token": "<unk>"
56
+ }