sergey21000 commited on
Commit
53f3ca1
·
verified ·
1 Parent(s): 025a5d0

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +434 -0
  2. config.py +130 -0
  3. requirements.txt +14 -0
  4. utils.py +513 -0
app.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional
2
+
3
+ # this is so that there is no error: exception: access violation reading 0x0000000000000000
4
+ # https://github.com/abetlen/llama-cpp-python/issues/1581
5
+ from llama_cpp import Llama
6
+
7
+ import gradio as gr
8
+ from langchain_core.vectorstores import VectorStore
9
+
10
+ from config import (
11
+ LLM_MODEL_REPOS,
12
+ START_LLM_MODEL_FILE,
13
+ EMBED_MODEL_REPOS,
14
+ SUBTITLES_LANGUAGES,
15
+ GENERATE_KWARGS,
16
+ CONTEXT_TEMPLATE,
17
+ )
18
+
19
+ from utils import (
20
+ load_llm_model,
21
+ load_embed_model,
22
+ load_documents_and_create_db,
23
+ user_message_to_chatbot,
24
+ update_user_message_with_context,
25
+ get_llm_response,
26
+ get_gguf_model_names,
27
+ add_new_model_repo,
28
+ clear_llm_folder,
29
+ clear_embed_folder,
30
+ get_memory_usage,
31
+ )
32
+
33
+
34
+ # ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============
35
+
36
+ def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox:
37
+ value = visible = db is not None
38
+ return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible)
39
+
40
+
41
+ def get_rag_settings(
42
+ rag_mode: bool,
43
+ context_template_value: str,
44
+ render: bool = True,
45
+ ) -> Tuple[gr.component, ...]:
46
+
47
+ k = gr.Radio(
48
+ choices=[1, 2, 3, 4, 5, 'all'],
49
+ value=2,
50
+ label='Number of relevant documents for search',
51
+ visible=rag_mode,
52
+ render=render,
53
+ )
54
+ score_threshold = gr.Slider(
55
+ minimum=0,
56
+ maximum=1,
57
+ value=0.5,
58
+ step=0.05,
59
+ label='relevance_scores_threshold',
60
+ visible=rag_mode,
61
+ render=render,
62
+ )
63
+ context_template = gr.Textbox(
64
+ value=context_template_value,
65
+ label='Context Template',
66
+ lines=len(context_template_value.split('\n')),
67
+ visible=rag_mode,
68
+ render=render,
69
+ )
70
+ return k, score_threshold, context_template
71
+
72
+
73
+ def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component:
74
+ num_lines = len(text.split('\n'))
75
+ max_lines = 10
76
+ num_lines = max_lines if num_lines > max_lines else num_lines
77
+ return gr.Textbox(
78
+ text,
79
+ visible=rag_mode,
80
+ interactive=False,
81
+ label='User Message With Context',
82
+ lines=num_lines,
83
+ )
84
+
85
+
86
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
87
+ value = '' if interactive else 'System prompt is not supported by this model'
88
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
89
+
90
+
91
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
92
+ generate_args = [
93
+ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
94
+ gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
95
+ gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
96
+ gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
97
+ ]
98
+ return generate_args
99
+
100
+
101
+ # ================ LOADING AND INITIALIZING MODELS ========================
102
+
103
+ start_llm_model, start_support_system_role, load_log = load_llm_model(
104
+ model_repo=LLM_MODEL_REPOS[0],
105
+ model_file=START_LLM_MODEL_FILE,
106
+ )
107
+
108
+ if start_llm_model['llm_model'] is None:
109
+ raise Exception(f'LLM model not initialized, status message: {load_log}')
110
+
111
+
112
+ start_embed_model, load_log = load_embed_model(
113
+ model_repo=EMBED_MODEL_REPOS[0],
114
+ )
115
+
116
+ if start_embed_model['embed_model'] is None:
117
+ raise Exception(f'Embed model not initialized, status message: {load_log}')
118
+
119
+
120
+ # ================== APPLICATION WEB INTERFACE ============================
121
+
122
+ css = '''
123
+ .gradio-container {
124
+ width: 70% !important;
125
+ margin: 0 auto !important;
126
+ }
127
+ '''
128
+
129
+ with gr.Blocks(css=css) as interface:
130
+
131
+ # ==================== GRADIO STATES ===============================
132
+
133
+ documents = gr.State([])
134
+ db = gr.State(None)
135
+ user_message_with_context = gr.State('')
136
+ support_system_role = gr.State(start_support_system_role)
137
+ llm_model_repos = gr.State(LLM_MODEL_REPOS)
138
+ embed_model_repos = gr.State(EMBED_MODEL_REPOS)
139
+ llm_model = gr.State(start_llm_model)
140
+ embed_model = gr.State(start_embed_model)
141
+
142
+
143
+
144
+ # ==================== BOT PAGE =================================
145
+
146
+ with gr.Tab(label='Chatbot'):
147
+ with gr.Row():
148
+ with gr.Column(scale=3):
149
+ chatbot = gr.Chatbot(
150
+ type='messages', # new in gradio 5+
151
+ show_copy_button=True,
152
+ height=480,
153
+ )
154
+ user_message = gr.Textbox(label='User')
155
+
156
+ with gr.Row():
157
+ user_message_btn = gr.Button('Send')
158
+ stop_btn = gr.Button('Stop')
159
+ clear_btn = gr.Button('Clear')
160
+
161
+ # ------------- GENERATION PARAMETERS -------------------
162
+
163
+ with gr.Column(scale=1, min_width=80):
164
+ with gr.Group():
165
+ gr.Markdown('History size')
166
+ history_len = gr.Slider(
167
+ minimum=0,
168
+ maximum=5,
169
+ value=0,
170
+ step=1,
171
+ info='Number of previous messages taken into account in history',
172
+ label='history_len',
173
+ show_label=False,
174
+ )
175
+
176
+ with gr.Group():
177
+ gr.Markdown('Generation parameters')
178
+ do_sample = gr.Checkbox(
179
+ value=False,
180
+ label='do_sample',
181
+ info='Activate random sampling',
182
+ )
183
+ generate_args = get_generate_args(do_sample.value)
184
+ do_sample.change(
185
+ fn=get_generate_args,
186
+ inputs=do_sample,
187
+ outputs=generate_args,
188
+ show_progress=False,
189
+ )
190
+
191
+ rag_mode = get_rag_mode_component(db=db.value)
192
+ k, score_threshold, context_template = get_rag_settings(
193
+ rag_mode=rag_mode.value,
194
+ context_template_value=CONTEXT_TEMPLATE,
195
+ render=False,
196
+ )
197
+ rag_mode.change(
198
+ fn=get_rag_settings,
199
+ inputs=[rag_mode, context_template],
200
+ outputs=[k, score_threshold, context_template],
201
+ )
202
+
203
+ with gr.Row():
204
+ k.render()
205
+ score_threshold.render()
206
+
207
+ # ---------------- SYSTEM PROMPT AND USER MESSAGE -----------
208
+
209
+ with gr.Accordion('Prompt', open=True):
210
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
211
+ context_template.render()
212
+ user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value)
213
+
214
+ # ---------------- SEND, CLEAR AND STOP BUTTONS ------------
215
+
216
+ generate_event = gr.on(
217
+ triggers=[user_message.submit, user_message_btn.click],
218
+ fn=user_message_to_chatbot,
219
+ inputs=[user_message, chatbot],
220
+ outputs=[user_message, chatbot],
221
+ # queue=False,
222
+ ).then(
223
+ fn=update_user_message_with_context,
224
+ inputs=[chatbot, rag_mode, db, k, score_threshold, context_template],
225
+ outputs=[user_message_with_context],
226
+ ).then(
227
+ fn=get_user_message_with_context,
228
+ inputs=[user_message_with_context, rag_mode],
229
+ outputs=[user_message_with_context],
230
+ ).then(
231
+ fn=get_llm_response,
232
+ inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt,
233
+ support_system_role, history_len, do_sample, *generate_args],
234
+ outputs=[chatbot],
235
+ )
236
+
237
+ stop_btn.click(
238
+ fn=None,
239
+ inputs=None,
240
+ outputs=None,
241
+ cancels=generate_event,
242
+ queue=False,
243
+ )
244
+
245
+ clear_btn.click(
246
+ fn=lambda: (None, ''),
247
+ inputs=None,
248
+ outputs=[chatbot, user_message_with_context],
249
+ queue=False,
250
+ )
251
+
252
+
253
+
254
+ # ================= FILE DOWNLOAD PAGE =========================
255
+
256
+ with gr.Tab(label='Load documents'):
257
+ with gr.Row(variant='compact'):
258
+ upload_files = gr.File(file_count='multiple', label='Loading text files')
259
+ web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube')
260
+
261
+ with gr.Row(variant='compact'):
262
+ chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size')
263
+ chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap')
264
+
265
+ subtitles_lang = gr.Radio(
266
+ SUBTITLES_LANGUAGES,
267
+ value=SUBTITLES_LANGUAGES[0],
268
+ label='YouTube subtitle language',
269
+ )
270
+
271
+ load_documents_btn = gr.Button(value='Upload documents and initialize database')
272
+ load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False)
273
+
274
+ load_documents_btn.click(
275
+ fn=load_documents_and_create_db,
276
+ inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model],
277
+ outputs=[documents, db, load_docs_log],
278
+ ).success(
279
+ fn=get_rag_mode_component,
280
+ inputs=[db],
281
+ outputs=[rag_mode],
282
+ )
283
+
284
+ gr.HTML("""<h3 style='text-align: center'>
285
+ <a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3>
286
+ """)
287
+
288
+
289
+
290
+ # ================= VIEW PAGE FOR ALL DOCUMENTS =================
291
+
292
+ with gr.Tab(label='View documents'):
293
+ view_documents_btn = gr.Button(value='Show downloaded text chunks')
294
+ view_documents_textbox = gr.Textbox(
295
+ lines=1,
296
+ placeholder='To view chunks, load documents in the Load documents tab',
297
+ label='Uploaded chunks',
298
+ )
299
+ sep = '=' * 20
300
+ view_documents_btn.click(
301
+ lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]),
302
+ inputs=[documents],
303
+ outputs=[view_documents_textbox],
304
+ )
305
+
306
+
307
+ # ============== GGUF MODELS DOWNLOAD PAGE =====================
308
+
309
+ with gr.Tab('Load LLM model'):
310
+ new_llm_model_repo = gr.Textbox(
311
+ value='',
312
+ label='Add repository',
313
+ placeholder='Link to repository of HF models in GGUF format',
314
+ )
315
+ new_llm_model_repo_btn = gr.Button('Add repository')
316
+ curr_llm_model_repo = gr.Dropdown(
317
+ choices=LLM_MODEL_REPOS,
318
+ value=None,
319
+ label='HF Model Repository',
320
+ )
321
+ curr_llm_model_path = gr.Dropdown(
322
+ choices=[],
323
+ value=None,
324
+ label='GGUF model file',
325
+ )
326
+ load_llm_model_btn = gr.Button('Loading and initializing model')
327
+ load_llm_model_log = gr.Textbox(
328
+ value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup',
329
+ label='Model loading status',
330
+ lines=6,
331
+ )
332
+
333
+ with gr.Group():
334
+ gr.Markdown('Free up disk space by deleting all models except the currently selected one')
335
+ clear_llm_folder_btn = gr.Button('Clear folder')
336
+
337
+ new_llm_model_repo_btn.click(
338
+ fn=add_new_model_repo,
339
+ inputs=[new_llm_model_repo, llm_model_repos],
340
+ outputs=[curr_llm_model_repo, load_llm_model_log],
341
+ ).success(
342
+ fn=lambda: '',
343
+ inputs=None,
344
+ outputs=[new_llm_model_repo],
345
+ )
346
+
347
+ curr_llm_model_repo.change(
348
+ fn=get_gguf_model_names,
349
+ inputs=[curr_llm_model_repo],
350
+ outputs=[curr_llm_model_path],
351
+ )
352
+
353
+ load_llm_model_btn.click(
354
+ fn=load_llm_model,
355
+ inputs=[curr_llm_model_repo, curr_llm_model_path],
356
+ outputs=[llm_model, support_system_role, load_llm_model_log],
357
+ ).success(
358
+ fn=lambda log: log + get_memory_usage(),
359
+ inputs=[load_llm_model_log],
360
+ outputs=[load_llm_model_log],
361
+ ).then(
362
+ fn=get_system_prompt_component,
363
+ inputs=[support_system_role],
364
+ outputs=[system_prompt],
365
+ )
366
+
367
+ clear_llm_folder_btn.click(
368
+ fn=clear_llm_folder,
369
+ inputs=[curr_llm_model_path],
370
+ outputs=None,
371
+ ).success(
372
+ fn=lambda model_path: f'Models other than {model_path} removed',
373
+ inputs=[curr_llm_model_path],
374
+ outputs=None,
375
+ )
376
+
377
+
378
+ # ============== EMBEDDING MODELS DOWNLOAD PAGE =============
379
+
380
+ with gr.Tab('Load embed model'):
381
+ new_embed_model_repo = gr.Textbox(
382
+ value='',
383
+ label='Add repository',
384
+ placeholder='Link to HF model repository',
385
+ )
386
+ new_embed_model_repo_btn = gr.Button('Add repository')
387
+ curr_embed_model_repo = gr.Dropdown(
388
+ choices=EMBED_MODEL_REPOS,
389
+ value=None,
390
+ label='HF model repository',
391
+ )
392
+
393
+ load_embed_model_btn = gr.Button('Loading and initializing model')
394
+ load_embed_model_log = gr.Textbox(
395
+ value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup',
396
+ label='Model loading status',
397
+ lines=7,
398
+ )
399
+ with gr.Group():
400
+ gr.Markdown('Free up disk space by deleting all models except the currently selected one')
401
+ clear_embed_folder_btn = gr.Button('Clear folder')
402
+
403
+ new_embed_model_repo_btn.click(
404
+ fn=add_new_model_repo,
405
+ inputs=[new_embed_model_repo, embed_model_repos],
406
+ outputs=[curr_embed_model_repo, load_embed_model_log],
407
+ ).success(
408
+ fn=lambda: '',
409
+ inputs=None,
410
+ outputs=new_embed_model_repo,
411
+ )
412
+
413
+ load_embed_model_btn.click(
414
+ fn=load_embed_model,
415
+ inputs=[curr_embed_model_repo],
416
+ outputs=[embed_model, load_embed_model_log],
417
+ ).success(
418
+ fn=lambda log: log + get_memory_usage(),
419
+ inputs=[load_embed_model_log],
420
+ outputs=[load_embed_model_log],
421
+ )
422
+
423
+ clear_embed_folder_btn.click(
424
+ fn=clear_embed_folder,
425
+ inputs=[curr_embed_model_repo],
426
+ outputs=None,
427
+ ).success(
428
+ fn=lambda model_repo: f'Models other than {model_repo} removed',
429
+ inputs=[curr_embed_model_repo],
430
+ outputs=None,
431
+ )
432
+
433
+
434
+ interface.launch(server_name='0.0.0.0', server_port=7860) # debug=True
config.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from langchain_community.document_loaders import (
5
+ CSVLoader,
6
+ PDFMinerLoader,
7
+ PyPDFLoader,
8
+ TextLoader,
9
+ UnstructuredHTMLLoader,
10
+ UnstructuredMarkdownLoader,
11
+ UnstructuredPowerPointLoader,
12
+ UnstructuredWordDocumentLoader,
13
+ WebBaseLoader,
14
+ YoutubeLoader,
15
+ DirectoryLoader,
16
+ )
17
+
18
+
19
+ # langchain classes for extracting text from various sources
20
+ LOADER_CLASSES = {
21
+ '.csv': CSVLoader,
22
+ '.doc': UnstructuredWordDocumentLoader,
23
+ '.docx': UnstructuredWordDocumentLoader,
24
+ '.html': UnstructuredHTMLLoader,
25
+ '.md': UnstructuredMarkdownLoader,
26
+ '.pdf': PDFMinerLoader,
27
+ '.ppt': UnstructuredPowerPointLoader,
28
+ '.pptx': UnstructuredPowerPointLoader,
29
+ '.txt': TextLoader,
30
+ 'web': WebBaseLoader,
31
+ 'directory': DirectoryLoader,
32
+ 'youtube': YoutubeLoader,
33
+ }
34
+
35
+ # languages ​​for youtube subtitles
36
+ SUBTITLES_LANGUAGES = ['ru', 'en']
37
+
38
+ # prom template subject to context
39
+ CONTEXT_TEMPLATE = '''Ответь на вопрос при условии контекста.
40
+
41
+ Контекст:
42
+ {context}
43
+
44
+ Вопрос:
45
+ {user_message}
46
+
47
+ Ответ:'''
48
+
49
+ # paths to LLM and embeddings models
50
+ LLM_MODELS_PATH = Path('models')
51
+ EMBED_MODELS_PATH = Path('embed_models')
52
+ LLM_MODELS_PATH.mkdir(exist_ok=True)
53
+ EMBED_MODELS_PATH.mkdir(exist_ok=True)
54
+
55
+ # dictionary for text generation config
56
+ GENERATE_KWARGS = dict(
57
+ temperature=0.2,
58
+ top_p=0.95,
59
+ top_k=40,
60
+ repeat_penalty=1.0,
61
+ )
62
+
63
+ # llama-cpp-python model params
64
+ LLAMA_MODEL_KWARGS = dict(
65
+ n_gpu_layers=-1,
66
+ verbose=True,
67
+ n_ctx=4096, # context size, 2048, 4096, ...
68
+ )
69
+
70
+ # models devices
71
+ # EMBED_MODEL_DEVICE = 'cpu'
72
+ EMBED_MODEL_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
73
+ LLM_MODEL_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
74
+
75
+ if LLM_MODEL_DEVICE == 'cpu':
76
+ LLAMA_MODEL_KWARGS['n_gpu_layers'] = 0
77
+
78
+ # available when running the LLM application models in GGUF format
79
+ LLM_MODEL_REPOS = [
80
+ # https://huggingface.co/bartowski/google_gemma-3-1b-it-GGUF
81
+ 'bartowski/google_gemma-3-1b-it-GGUF',
82
+ # https://huggingface.co/bartowski/Qwen2.5-1.5B-Instruct-GGUF
83
+ 'bartowski/Qwen2.5-1.5B-Instruct-GGUF',
84
+ # https://huggingface.co/bartowski/Qwen2.5-3B-Instruct-GGUF
85
+ 'bartowski/Qwen2.5-3B-Instruct-GGUF',
86
+ # https://huggingface.co/bartowski/google_gemma-3-4b-it-GGUF
87
+ 'bartowski/google_gemma-3-4b-it-GGUF',
88
+ # https://huggingface.co/bartowski/google_gemma-3-1b-it-GGUF
89
+ 'https://huggingface.co/bartowski/google_gemma-3-1b-it-GGUF',
90
+ # https://huggingface.co/bartowski/gemma-2-2b-it-GGUF
91
+ 'bartowski/gemma-2-2b-it-GGUF',
92
+ # https://huggingface.co/bartowski/Qwen2.5-1.5B-Instruct-GGUF
93
+ 'bartowski/Qwen2.5-1.5B-Instruct-GGUF',
94
+ ]
95
+
96
+ # GGUF filename to LLM_MODEL_REPOS[0]
97
+ # START_LLM_MODEL_FILE = 'Qwen2.5-1.5B-Instruct-Q8_0.gguf'
98
+ START_LLM_MODEL_FILE = 'google_gemma-3-1b-it-Q8_0.gguf'
99
+
100
+ # Embedding models available at application startup
101
+ EMBED_MODEL_REPOS = [
102
+ # https://huggingface.co/Alibaba-NLP/gte-multilingual-base # 611 MB
103
+ 'Alibaba-NLP/gte-multilingual-base',
104
+ # https://huggingface.co/intfloat/multilingual-e5-small # 471 MB
105
+ 'intfloat/multilingual-e5-small',
106
+ # https://huggingface.co/sergeyzh/rubert-tiny-turbo # 117 MB
107
+ 'sergeyzh/rubert-tiny-turbo',
108
+ # https://huggingface.co/sergeyzh/BERTA # 513 MB
109
+ 'sergeyzh/BERTA',
110
+ # https://huggingface.co/cointegrated/rubert-tiny2 # 118 MB
111
+ 'cointegrated/rubert-tiny2',
112
+ # https://huggingface.co/cointegrated/LaBSE-en-ru # 516 MB
113
+ 'cointegrated/LaBSE-en-ru',
114
+ # https://huggingface.co/sergeyzh/LaBSE-ru-turbo # 513 MB
115
+ 'sergeyzh/LaBSE-ru-turbo',
116
+ # https://huggingface.co/intfloat/multilingual-e5-large # 2.24 GB
117
+ 'intfloat/multilingual-e5-large',
118
+ # https://huggingface.co/intfloat/multilingual-e5-base # 1.11 GB
119
+ 'intfloat/multilingual-e5-base',
120
+ # https://huggingface.co/intfloat/multilingual-e5-large-instruct # 1.12 GB
121
+ 'intfloat/multilingual-e5-large-instruct',
122
+ # https://huggingface.co/sentence-transformers/all-mpnet-base-v2 # 438 MB
123
+ 'sentence-transformers/all-mpnet-base-v2',
124
+ # https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2 # 1.11 GB
125
+ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
126
+ # https://huggingface.co/ai-forever?search_models=ruElectra # 356 MB
127
+ 'ai-forever/ruElectra-medium',
128
+ # https://huggingface.co/ai-forever/sbert_large_nlu_ru # 1.71 GB
129
+ 'ai-forever/sbert_large_nlu_ru',
130
+ ]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
3
+ torch==2.6.0
4
+ # llama-cpp-python==0.3.8
5
+ https://github.com/sergey21000/llama-cpp-python-wheels/releases/download/llama-cpp-python-0.3.8-wheels/llama_cpp_python-0.3.8-cp310-cp310-linux_x86_64.cpu.whl
6
+ gradio==5.25.2
7
+ langchain==0.3.23
8
+ langchain-community==0.3.21
9
+ langchain-huggingface==0.1.2
10
+ pdfminer.six==20250416
11
+ youtube-transcript-api==1.0.3
12
+ psutil==7.0.0
13
+ faiss-cpu==1.10.0
14
+ beautifulsoup4==4.13.4
utils.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ from pathlib import Path
3
+ from shutil import rmtree
4
+ from typing import List, Tuple, Dict, Union, Optional, Any, Iterable
5
+ from tqdm import tqdm
6
+
7
+ from llama_cpp import Llama
8
+
9
+ import psutil
10
+ import requests
11
+ from requests.exceptions import MissingSchema
12
+
13
+ import torch
14
+ import gradio as gr
15
+
16
+ from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound, TranscriptsDisabled
17
+ from huggingface_hub import hf_hub_download, list_repo_tree, list_repo_files, repo_info, repo_exists, snapshot_download
18
+
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
20
+ from langchain_community.vectorstores import FAISS
21
+ from langchain_huggingface import HuggingFaceEmbeddings
22
+
23
+ # imports for annotations
24
+ from langchain.docstore.document import Document
25
+ from langchain_core.embeddings import Embeddings
26
+ from langchain_core.vectorstores import VectorStore
27
+
28
+ from config import (
29
+ LLM_MODELS_PATH,
30
+ EMBED_MODELS_PATH,
31
+ GENERATE_KWARGS,
32
+ LLAMA_MODEL_KWARGS,
33
+ LOADER_CLASSES,
34
+ EMBED_MODEL_DEVICE,
35
+ )
36
+
37
+
38
+ # type annotations
39
+ CHAT_HISTORY = List[Optional[Dict[str, Optional[str]]]]
40
+ LLM_MODEL_DICT = Dict[str, Llama]
41
+ EMBED_MODEL_DICT = Dict[str, Embeddings]
42
+
43
+
44
+ # ===================== ADDITIONAL FUNCS =======================
45
+
46
+ # getting the amount of free memory on disk, CPU and GPU
47
+ def get_memory_usage() -> str:
48
+ print_memory = ''
49
+
50
+ memory_type = 'Disk'
51
+ psutil_stats = psutil.disk_usage('.')
52
+ memory_total = psutil_stats.total / 1024**3
53
+ memory_usage = psutil_stats.used / 1024**3
54
+ print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
55
+
56
+ memory_type = 'CPU'
57
+ psutil_stats = psutil.virtual_memory()
58
+ memory_total = psutil_stats.total / 1024**3
59
+ memory_usage = memory_total - (psutil_stats.available / 1024**3)
60
+ print_memory += f'{memory_type} Menory Usage: {memory_usage:.2f} / {memory_total:.2f} GB\n'
61
+
62
+ if torch.cuda.is_available():
63
+ memory_type = 'GPU'
64
+ memory_free, memory_total = torch.cuda.mem_get_info()
65
+ memory_usage = memory_total - memory_free
66
+ print_memory += f'{memory_type} Menory Usage: {memory_usage / 1024**3:.2f} / {memory_total:.2f} GB\n'
67
+
68
+ print_memory = f'---------------\n{print_memory}---------------'
69
+ return print_memory
70
+
71
+
72
+ # clearing the list of documents
73
+ def clear_documents(documents: Iterable[Document]) -> Iterable[Document]:
74
+ def clear_text(text: str) -> str:
75
+ lines = text.split('\n')
76
+ lines = [line for line in lines if len(line.strip()) > 2]
77
+ text = '\n'.join(lines).strip()
78
+ return text
79
+
80
+ output_documents = []
81
+ for document in documents:
82
+ text = clear_text(document.page_content)
83
+ if len(text) > 10:
84
+ document.page_content = text
85
+ output_documents.append(document)
86
+ return output_documents
87
+
88
+
89
+ # ===================== INTERFACE FUNCS =============================
90
+
91
+
92
+ # ------------- LLM AND EMBEDDING MODELS LOADING ------------------------
93
+
94
+ # downloading file by URL link and displaying progress bars tqdm and gradio
95
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
96
+ response = requests.get(file_url, stream=True)
97
+ if response.status_code != 200:
98
+ raise Exception(f'The file is not available for download at the link: {file_url}')
99
+ total_size = int(response.headers.get('content-length', 0))
100
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
101
+ progress_gradio = gr.Progress()
102
+ completed_size = 0
103
+ with open(file_path, 'wb') as file:
104
+ for data in response.iter_content(chunk_size=4096):
105
+ size = file.write(data)
106
+ progress_tqdm.update(size)
107
+ completed_size += size
108
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
109
+ progress_gradio(completed_size/total_size, desc=desc)
110
+
111
+
112
+ # loading and initializing the GGUF model
113
+ def load_llm_model(model_repo: str, model_file: str) -> Tuple[LLM_MODEL_DICT, str, str]:
114
+ llm_model = None
115
+ load_log = ''
116
+ support_system_role = False
117
+
118
+ if isinstance(model_file, list):
119
+ load_log += 'No model selected\n'
120
+ return {'llm_model': llm_model}, support_system_role, load_log
121
+
122
+ if '(' in model_file:
123
+ model_file = model_file.split('(')[0].rstrip()
124
+
125
+ progress = gr.Progress()
126
+ progress(0.3, desc='Step 1/2: Download the GGUF file')
127
+ model_path = LLM_MODELS_PATH / model_file
128
+
129
+ if model_path.is_file():
130
+ load_log += f'Model {model_file} already loaded, reinitializing\n'
131
+ else:
132
+ try:
133
+ gguf_url = f'https://huggingface.co/{model_repo}/resolve/main/{model_file}'
134
+ download_file(gguf_url, model_path)
135
+ load_log += f'Model {model_file} loaded\n'
136
+ except Exception as ex:
137
+ model_path = ''
138
+ load_log += f'Error downloading model, error code:\n{ex}\n'
139
+
140
+ if model_path:
141
+ progress(0.7, desc='Step 2/2: Initialize the model')
142
+ try:
143
+ llm_model = Llama(model_path=str(model_path), **LLAMA_MODEL_KWARGS)
144
+ support_system_role = 'System role not supported' not in llm_model.metadata['tokenizer.chat_template']
145
+ load_log += f'Model {model_file} initialized, max context size is {llm_model.n_ctx()} tokens\n'
146
+ except Exception as ex:
147
+ load_log += f'Error initializing model, error code:\n{ex}\n'
148
+
149
+ llm_model = {'llm_model': llm_model}
150
+ return llm_model, support_system_role, load_log
151
+
152
+
153
+ # loading and initializing the embedding model
154
+ def load_embed_model(model_repo: str) -> Tuple[Dict[str, HuggingFaceEmbeddings], str]:
155
+ embed_model = None
156
+ load_log = ''
157
+
158
+ if isinstance(model_repo, list):
159
+ load_log = 'No model selected'
160
+ return embed_model, load_log
161
+
162
+ progress = gr.Progress()
163
+ folder_name = model_repo.replace('/', '_')
164
+ folder_path = EMBED_MODELS_PATH / folder_name
165
+ if Path(folder_path).is_dir():
166
+ load_log += f'Reinitializing model {model_repo} \n'
167
+ else:
168
+ progress(0.5, desc='Step 1/2: Download model repository')
169
+ snapshot_download(
170
+ repo_id=model_repo,
171
+ local_dir=folder_path,
172
+ ignore_patterns='*.h5',
173
+ )
174
+ load_log += f'Model {model_repo} loaded\n'
175
+
176
+ progress(0.7, desc='Шаг 2/2: Инициализация модели')
177
+ model_kwargs = dict(
178
+ device=EMBED_MODEL_DEVICE,
179
+ trust_remote_code=True,
180
+ )
181
+ embed_model = HuggingFaceEmbeddings(
182
+ model_name=str(folder_path),
183
+ model_kwargs=model_kwargs,
184
+ # encode_kwargs={'normalize_embeddings': True},
185
+ )
186
+ load_log += f'Embeddings model {model_repo} initialized\n'
187
+ load_log += f'Please upload documents and initialize database again\n'
188
+ embed_model = {'embed_model': embed_model}
189
+ return embed_model, load_log
190
+
191
+
192
+ # adding a new HF repository new_model_repo to the current list of model_repos
193
+ def add_new_model_repo(new_model_repo: str, model_repos: List[str]) -> Tuple[gr.Dropdown, str]:
194
+ load_log = ''
195
+ repo = new_model_repo.strip()
196
+ if repo:
197
+ repo = repo.split('/')[-2:]
198
+ if len(repo) == 2:
199
+ repo = '/'.join(repo).split('?')[0]
200
+ if repo_exists(repo) and repo not in model_repos:
201
+ model_repos.insert(0, repo)
202
+ load_log += f'Model repository {repo} successfully added\n'
203
+ else:
204
+ load_log += 'Invalid HF repository name or model already in the list\n'
205
+ else:
206
+ load_log += 'Invalid link to HF repository\n'
207
+ else:
208
+ load_log += 'Empty line in HF repository field\n'
209
+ model_repo_dropdown = gr.Dropdown(choices=model_repos, value=model_repos[0])
210
+ return model_repo_dropdown, load_log
211
+
212
+
213
+ # get list of GGUF models from HF repository
214
+ def get_gguf_model_names(model_repo: str) -> gr.Dropdown:
215
+ repo_files = list(list_repo_tree(model_repo))
216
+ repo_files = [file for file in repo_files if file.path.endswith('.gguf')]
217
+ model_paths = [f'{file.path} ({file.size / 1000 ** 3:.2f}G)' for file in repo_files]
218
+ model_paths_dropdown = gr.Dropdown(
219
+ choices=model_paths,
220
+ value=model_paths[0],
221
+ label='GGUF model file',
222
+ )
223
+ return model_paths_dropdown
224
+
225
+
226
+ # delete model files and folders to clear space except for the current model gguf_filename
227
+ def clear_llm_folder(gguf_filename: str) -> None:
228
+ if gguf_filename is None:
229
+ gr.Info(f'The name of the model file that does not need to be deleted is not selected.')
230
+ return
231
+ if '(' in gguf_filename:
232
+ gguf_filename = gguf_filename.split('(')[0].rstrip()
233
+ for path in LLM_MODELS_PATH.iterdir():
234
+ if path.name == gguf_filename:
235
+ continue
236
+ if path.is_file():
237
+ path.unlink(missing_ok=True)
238
+ gr.Info(f'All files removed from directory {LLM_MODELS_PATH} except {gguf_filename}')
239
+
240
+
241
+ # delete model folders to clear space except for the current model model_folder_name
242
+ def clear_embed_folder(model_repo: str) -> None:
243
+ if model_repo is None:
244
+ gr.Info(f'The name of the model that does not need to be deleted is not selected.')
245
+ return
246
+ model_folder_name = model_repo.replace('/', '_')
247
+ for path in EMBED_MODELS_PATH.iterdir():
248
+ if path.name == model_folder_name:
249
+ continue
250
+ if path.is_dir():
251
+ rmtree(path, ignore_errors=True)
252
+ gr.Info(f'All directories have been removed from the {EMBED_MODELS_PATH} directory except {model_folder_name}')
253
+
254
+
255
+ # ------------------------ YOUTUBE ------------------------
256
+
257
+ # function to check availability of subtitles, if manual or automatic are available - returns True and logs
258
+ # if subtitles are not available - returns False and logs
259
+ def check_subtitles_available(yt_video_link: str, target_lang: str) -> Tuple[bool, str]:
260
+ video_id = yt_video_link.split('watch?v=')[-1].split('&')[0]
261
+ load_log = ''
262
+ available = True
263
+ try:
264
+ transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
265
+ try:
266
+ transcript = transcript_list.find_transcript([target_lang])
267
+ if transcript.is_generated:
268
+ load_log += f'Automatic subtitles will be loaded, manual ones are not available for video {yt_video_link}\n'
269
+ else:
270
+ load_log += f'Manual subtitles will be downloaded for the video {yt_video_link}\n'
271
+ except NoTranscriptFound:
272
+ load_log += f'Subtitle language {target_lang} is not available for video {yt_video_link}\n'
273
+ available = False
274
+ except TranscriptsDisabled:
275
+ load_log += f'Invalid video url ({yt_video_link}) or current server IP is blocked for YouTube\n'
276
+ available = False
277
+ return available, load_log
278
+
279
+
280
+ # ------------- UPLOADING DOCUMENTS FOR RAG ------------------------
281
+
282
+ # extract documents (in langchain Documents format) from downloaded files
283
+ def load_documents_from_files(upload_files: List[str]) -> Tuple[List[Document], str]:
284
+ load_log = ''
285
+ documents = []
286
+ for upload_file in upload_files:
287
+ file_extension = f".{upload_file.split('.')[-1]}"
288
+ if file_extension in LOADER_CLASSES:
289
+ loader_class = LOADER_CLASSES[file_extension]
290
+ loader_kwargs = {}
291
+ if file_extension == '.csv':
292
+ with open(upload_file) as csvfile:
293
+ delimiter = csv.Sniffer().sniff(csvfile.read(4096)).delimiter
294
+ loader_kwargs = {'csv_args': {'delimiter': delimiter}}
295
+ try:
296
+ load_documents = loader_class(upload_file, **loader_kwargs).load()
297
+ documents.extend(load_documents)
298
+ except Exception as ex:
299
+ load_log += f'Error uploading file {upload_file}\n'
300
+ load_log += f'Error code: {ex}\n'
301
+ continue
302
+ else:
303
+ load_log += f'Unsupported file format {upload_file}\n'
304
+ continue
305
+ return documents, load_log
306
+
307
+
308
+ # extracting documents (in langchain Documents format) from WEB links
309
+ def load_documents_from_links(
310
+ web_links: str,
311
+ subtitles_lang: str,
312
+ ) -> Tuple[List[Document], str]:
313
+
314
+ load_log = ''
315
+ documents = []
316
+ loader_class_kwargs = {}
317
+ web_links = [web_link.strip() for web_link in web_links.split('\n') if web_link.strip()]
318
+ for web_link in web_links:
319
+ if 'youtube.com' in web_link:
320
+ available, log = check_subtitles_available(web_link, subtitles_lang)
321
+ load_log += log
322
+ if not available:
323
+ continue
324
+ loader_class = LOADER_CLASSES['youtube'].from_youtube_url
325
+ loader_class_kwargs = {'language': subtitles_lang}
326
+ else:
327
+ loader_class = LOADER_CLASSES['web']
328
+
329
+ try:
330
+ if requests.get(web_link).status_code != 200:
331
+ load_log += f'Ссылка недоступна для Python requests: {web_link}\n'
332
+ continue
333
+ load_documents = loader_class(web_link, **loader_class_kwargs).load()
334
+ if len(load_documents) == 0:
335
+ load_log += f'No text chunks were found at the link: {web_link}\n'
336
+ continue
337
+ documents.extend(load_documents)
338
+ except MissingSchema:
339
+ load_log += f'Invalid link: {web_link}\n'
340
+ continue
341
+ except Exception as ex:
342
+ load_log += f'Error loading data by web loader at link: {web_link}\n'
343
+ load_log += f'Error code: {ex}\n'
344
+ continue
345
+ return documents, load_log
346
+
347
+
348
+ # uploading files and generating documents and databases
349
+ def load_documents_and_create_db(
350
+ upload_files: Optional[List[str]],
351
+ web_links: str,
352
+ subtitles_lang: str,
353
+ chunk_size: int,
354
+ chunk_overlap: int,
355
+ embed_model_dict: EMBED_MODEL_DICT,
356
+ ) -> Tuple[List[Document], Optional[VectorStore], str]:
357
+
358
+ load_log = ''
359
+ all_documents = []
360
+ db = None
361
+ progress = gr.Progress()
362
+
363
+ embed_model = embed_model_dict.get('embed_model')
364
+ if embed_model is None:
365
+ load_log += 'Embeddings model not initialized, DB cannot be created'
366
+ return all_documents, db, load_log
367
+
368
+ if upload_files is None and not web_links:
369
+ load_log = 'No files or links selected'
370
+ return all_documents, db, load_log
371
+
372
+ if upload_files is not None:
373
+ progress(0.3, desc='Step 1/2: Upload documents from files')
374
+ docs, log = load_documents_from_files(upload_files)
375
+ all_documents.extend(docs)
376
+ load_log += log
377
+
378
+ if web_links:
379
+ progress(0.3 if upload_files is None else 0.5, desc='Step 1/2: Upload documents via links')
380
+ docs, log = load_documents_from_links(web_links, subtitles_lang)
381
+ all_documents.extend(docs)
382
+ load_log += log
383
+
384
+ if len(all_documents) == 0:
385
+ load_log += 'Download was interrupted because no documents were extracted\n'
386
+ load_log += 'RAG mode cannot be activated'
387
+ return all_documents, db, load_log
388
+
389
+ load_log += f'Documents loaded: {len(all_documents)}\n'
390
+ text_splitter = RecursiveCharacterTextSplitter(
391
+ chunk_size=chunk_size,
392
+ chunk_overlap=chunk_overlap,
393
+ )
394
+ documents = text_splitter.split_documents(all_documents)
395
+ documents = clear_documents(documents)
396
+ load_log += f'Documents are divided, number of text chunks: {len(documents)}\n'
397
+
398
+ progress(0.7, desc='Step 2/2: Initialize DB')
399
+ db = FAISS.from_documents(documents=documents, embedding=embed_model)
400
+ load_log += 'DB is initialized, RAG mode is activated and can be activated in the Chatbot tab'
401
+ return documents, db, load_log
402
+
403
+
404
+ # ------------------ ФУНКЦИИ ЧАТ БОТА ------------------------
405
+
406
+ # adding a user message to the chat bot window
407
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
408
+ # chatbot.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
409
+ chatbot.append({'role': 'user', 'content': user_message})
410
+ return '', chatbot
411
+
412
+
413
+ # formatting prompt with adding context if DB is available and RAG mode is enabled
414
+ def update_user_message_with_context(
415
+ chatbot: CHAT_HISTORY,
416
+ rag_mode: bool,
417
+ db: VectorStore,
418
+ k: Union[int, str],
419
+ score_threshold: float,
420
+ context_template: str,
421
+ ) -> Tuple[str, CHAT_HISTORY]:
422
+
423
+ user_message = chatbot[-1]['content']
424
+ user_message_with_context = ''
425
+
426
+ if '{user_message}' not in context_template and '{context}' not in context_template:
427
+ gr.Info('Context template must include {user_message} and {context}')
428
+ return user_message_with_context
429
+
430
+ if db is not None and rag_mode and user_message.strip():
431
+ if k == 'all':
432
+ k = len(db.docstore._dict)
433
+ docs_and_distances = db.similarity_search_with_relevance_scores(
434
+ user_message,
435
+ k=k,
436
+ score_threshold=score_threshold,
437
+ )
438
+ if len(docs_and_distances) > 0:
439
+ retriever_context = '\n\n'.join([doc[0].page_content for doc in docs_and_distances])
440
+ user_message_with_context = context_template.format(
441
+ user_message=user_message,
442
+ context=retriever_context,
443
+ )
444
+ return user_message_with_context
445
+
446
+
447
+ # model response generation
448
+ def get_llm_response(
449
+ chatbot: CHAT_HISTORY,
450
+ llm_model_dict: LLM_MODEL_DICT,
451
+ user_message_with_context: str,
452
+ rag_mode: bool,
453
+ system_prompt: str,
454
+ support_system_role: bool,
455
+ history_len: int,
456
+ do_sample: bool,
457
+ *generate_args,
458
+ ) -> CHAT_HISTORY:
459
+
460
+ llm_model = llm_model_dict.get('llm_model')
461
+ if llm_model is None:
462
+ gr.Info('Model not initialized')
463
+ yield chatbot[:-1]
464
+ return
465
+
466
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
467
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
468
+ if not do_sample:
469
+ gen_kwargs['top_p'] = 0.0
470
+ gen_kwargs['top_k'] = 1
471
+ gen_kwargs['repeat_penalty'] = 1.0
472
+
473
+ user_message = chatbot[-1]['content']
474
+ if not user_message.strip():
475
+ yield chatbot[:-1]
476
+ return
477
+
478
+ if rag_mode:
479
+ if user_message_with_context:
480
+ user_message = user_message_with_context
481
+ else:
482
+ gr.Info((
483
+ 'No documents relevant to the query were found, generation in RAG mode is not possible.\n'
484
+ 'Or Context template is specified incorrectly.\n'
485
+ 'Try reducing searh_score_threshold or disable RAG mode for normal generation'
486
+ ))
487
+ yield chatbot[:-1]
488
+ return
489
+
490
+ messages = []
491
+ if support_system_role and system_prompt:
492
+ messages.append({'role': 'system', 'content': system_prompt})
493
+
494
+ if history_len != 0:
495
+ messages.extend(chatbot[:-1][-(history_len*2):])
496
+
497
+ messages.append({'role': 'user', 'content': user_message})
498
+ stream_response = llm_model.create_chat_completion(
499
+ messages=messages,
500
+ stream=True,
501
+ **gen_kwargs,
502
+ )
503
+ try:
504
+ chatbot.append({'role': 'assistant', 'content': ''})
505
+ for chunk in stream_response:
506
+ token = chunk['choices'][0]['delta'].get('content')
507
+ if token is not None:
508
+ chatbot[-1]['content'] += token
509
+ yield chatbot
510
+ except Exception as ex:
511
+ gr.Info(f'Error generating response, error code: {ex}')
512
+ yield chatbot[:-1]
513
+ return