shashankkandimalla commited on
Commit
68f6e3a
·
verified ·
1 Parent(s): 49583f0

added logging

Browse files
Files changed (1) hide show
  1. app.py +202 -53
app.py CHANGED
@@ -1,32 +1,64 @@
1
  import gradio as gr
2
  import weaviate
3
- from weaviate.embedded import EmbeddedOptions
4
  import os
5
  from openai import AsyncOpenAI
6
  from dotenv import load_dotenv
7
- import textwrap
8
  import asyncio
9
- import aiohttp
10
  from functools import wraps
 
 
11
 
 
 
 
12
  # Load environment variables
13
  load_dotenv()
14
 
15
  # Set up AsyncOpenAI client
16
  openai_client = AsyncOpenAI(api_key=os.getenv('OPENAI_API_KEY'))
17
 
18
- # Connect to Weaviate
19
- client = weaviate.Client(
20
- url=os.getenv('WCS_URL'),
21
- auth_client_secret=weaviate.auth.AuthApiKey(os.getenv('WCS_API_KEY')),
22
- additional_headers={
23
- "X-OpenAI-Api-Key": os.getenv('OPENAI_API_KEY')
24
- }
25
- )
26
 
27
  # Get the collection name from environment variable
28
  COLLECTION_NAME = os.getenv('WEAVIATE_COLLECTION_NAME')
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Async-compatible caching decorator
31
  def async_lru_cache(maxsize=128):
32
  cache = {}
@@ -67,43 +99,41 @@ async def search_multimodal(query: str, limit: int = 30, alpha: float = 0.6):
67
  print(f"An error occurred during the search: {str(e)}")
68
  return []
69
 
70
- async def generate_response(query: str, context: str) -> str:
71
  prompt = f"""
72
  You are an AI assistant with extensive expertise in the semiconductor industry. Your knowledge spans a wide range of companies, technologies, and products, including but not limited to: System-on-Chip (SoC) designs, Field-Programmable Gate Arrays (FPGAs), Microcontrollers, Integrated Circuits (ICs), semiconductor manufacturing processes, and emerging technologies like quantum computing and neuromorphic chips.
73
  Use the following context, your vast knowledge, and the user's question to generate an accurate, comprehensive, and insightful answer. While formulating your response, follow these steps internally:
74
-
75
  Analyze the question to identify the main topic and specific information requested.
76
  Evaluate the provided context and identify relevant information.
77
  Retrieve additional relevant knowledge from your semiconductor industry expertise.
78
  Reason and formulate a response by combining context and knowledge.
79
  Generate a detailed response that covers all aspects of the query.
80
  Review and refine your answer for coherence and accuracy.
81
-
82
  In your output, provide only the final, polished response. Do not include your step-by-step reasoning or mention the process you followed.
83
  IMPORTANT: Ensure your response is grounded in factual information. Do not hallucinate or invent information. If you're unsure about any aspect of the answer or if the necessary information is not available in the provided context or your knowledge base, clearly state this uncertainty. It's better to admit lack of information than to provide inaccurate details.
84
  Your response should be:
85
-
86
  Thorough and directly address all aspects of the user's question
87
  Based solely on factual information from the provided context and your reliable knowledge
88
  Include specific examples, data points, or case studies only when you're certain of their accuracy
89
  Explain technical concepts clearly, considering the user may have varying levels of expertise
90
  Clearly indicate any areas where information is limited or uncertain
91
-
92
  Context: {context}
93
  User Question: {query}
94
  Based on the above context and your extensive knowledge of the semiconductor industry, provide your detailed, accurate, and grounded response below. Remember, only include information you're confident is correct, and clearly state any uncertainties:
95
  """
96
 
97
- response = await openai_client.chat.completions.create(
98
  model="gpt-4o",
99
  messages=[
100
  {"role": "system", "content": "You are an expert Semi Conductor industry analyst"},
101
  {"role": "user", "content": prompt}
102
  ],
103
- temperature=0
104
- )
105
-
106
- return response.choices[0].message.content
 
 
107
 
108
  def process_search_result(item):
109
  if item['content_type'] == 'text':
@@ -114,14 +144,12 @@ def process_search_result(item):
114
  return f"Table Description from {item['source_document']} (Page {item['page_number']}): {item['description']}\n\n"
115
  return ""
116
 
117
- async def esg_analysis(user_query: str):
118
  search_results = await search_multimodal(user_query)
119
 
120
  context_parts = await asyncio.gather(*[asyncio.to_thread(process_search_result, item) for item in search_results])
121
  context = "".join(context_parts)
122
 
123
- response = await generate_response(user_query, context)
124
-
125
  sources = []
126
  for item in search_results[:5]: # Limit to top 5 sources
127
  source = {
@@ -135,22 +163,9 @@ async def esg_analysis(user_query: str):
135
  source["image_path"] = item.get("image_path", "N/A")
136
  sources.append(source)
137
 
138
- return response, sources
139
-
140
- def wrap_text(text, width=120):
141
- return textwrap.fill(text, width=width)
142
-
143
- async def gradio_interface(user_question):
144
- ai_response, sources = await esg_analysis(user_question)
145
-
146
- # Format AI response
147
- formatted_response = f"""
148
- ## AI Response
149
 
150
- {ai_response}
151
- """
152
-
153
- # Format sources
154
  source_text = "## Top 5 Sources\n\n"
155
  for i, source in enumerate(sources, 1):
156
  source_text += f"### Source {i}\n"
@@ -162,20 +177,154 @@ async def gradio_interface(user_question):
162
  if 'image_path' in source:
163
  source_text += f"- **Image Path:** {source['image_path']}\n"
164
  source_text += "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- return formatted_response, source_text
167
-
168
- iface = gr.Interface(
169
- fn=lambda user_question: asyncio.run(gradio_interface(user_question)),
170
- inputs=gr.Textbox(lines=2, placeholder="Enter your question about the semiconductor industry..."),
171
- outputs=[
172
- gr.Markdown(label="AI Response"),
173
- gr.Markdown(label="Sources")
174
- ],
175
- title="Semiconductor Industry ESG Analysis",
176
- description="Ask questions about the semiconductor industry and get AI-powered answers with sources.",
177
- flagging_dir="/app/flagged" # Specify the flagging directory
178
- )
179
 
180
  if __name__ == "__main__":
181
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
2
  import weaviate
 
3
  import os
4
  from openai import AsyncOpenAI
5
  from dotenv import load_dotenv
 
6
  import asyncio
 
7
  from functools import wraps
8
+ import logging
9
+ import time
10
 
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
  # Load environment variables
15
  load_dotenv()
16
 
17
  # Set up AsyncOpenAI client
18
  openai_client = AsyncOpenAI(api_key=os.getenv('OPENAI_API_KEY'))
19
 
20
+ # Initialize client as None
21
+ client = None
 
 
 
 
 
 
22
 
23
  # Get the collection name from environment variable
24
  COLLECTION_NAME = os.getenv('WEAVIATE_COLLECTION_NAME')
25
 
26
+ # Global variable to track connection status
27
+ # Global variable to track connection status
28
+ connection_status = {"status": "Disconnected", "color": "red"}
29
+
30
+ # Function to initialize the Weaviate client
31
+ async def initialize_weaviate_client(max_retries=3, retry_delay=5):
32
+ global client, connection_status
33
+ retries = 0
34
+ while retries < max_retries:
35
+ connection_status = {"status": "Connecting...", "color": "orange"}
36
+ try:
37
+ logger.info(f"Attempting to connect to Weaviate (Attempt {retries + 1}/{max_retries})")
38
+ client = weaviate.Client(
39
+ url=os.getenv('WCS_URL'),
40
+ auth_client_secret=weaviate.auth.AuthApiKey(os.getenv('WCS_API_KEY')),
41
+ additional_headers={
42
+ "X-OpenAI-Api-Key": os.getenv('OPENAI_API_KEY')
43
+ }
44
+ )
45
+ # Test the connection
46
+ await asyncio.to_thread(client.schema.get)
47
+ connection_status = {"status": "Connected", "color": "green"}
48
+ logger.info("Successfully connected to Weaviate")
49
+ return connection_status
50
+ except Exception as e:
51
+ logger.error(f"Error connecting to Weaviate: {str(e)}")
52
+ connection_status = {"status": f"Error: {str(e)}", "color": "red"}
53
+ retries += 1
54
+ if retries < max_retries:
55
+ logger.info(f"Retrying in {retry_delay} seconds...")
56
+ await asyncio.sleep(retry_delay)
57
+ else:
58
+ logger.error("Max retries reached. Could not connect to Weaviate.")
59
+ return connection_status
60
+
61
+
62
  # Async-compatible caching decorator
63
  def async_lru_cache(maxsize=128):
64
  cache = {}
 
99
  print(f"An error occurred during the search: {str(e)}")
100
  return []
101
 
102
+ async def generate_response_stream(query: str, context: str):
103
  prompt = f"""
104
  You are an AI assistant with extensive expertise in the semiconductor industry. Your knowledge spans a wide range of companies, technologies, and products, including but not limited to: System-on-Chip (SoC) designs, Field-Programmable Gate Arrays (FPGAs), Microcontrollers, Integrated Circuits (ICs), semiconductor manufacturing processes, and emerging technologies like quantum computing and neuromorphic chips.
105
  Use the following context, your vast knowledge, and the user's question to generate an accurate, comprehensive, and insightful answer. While formulating your response, follow these steps internally:
 
106
  Analyze the question to identify the main topic and specific information requested.
107
  Evaluate the provided context and identify relevant information.
108
  Retrieve additional relevant knowledge from your semiconductor industry expertise.
109
  Reason and formulate a response by combining context and knowledge.
110
  Generate a detailed response that covers all aspects of the query.
111
  Review and refine your answer for coherence and accuracy.
 
112
  In your output, provide only the final, polished response. Do not include your step-by-step reasoning or mention the process you followed.
113
  IMPORTANT: Ensure your response is grounded in factual information. Do not hallucinate or invent information. If you're unsure about any aspect of the answer or if the necessary information is not available in the provided context or your knowledge base, clearly state this uncertainty. It's better to admit lack of information than to provide inaccurate details.
114
  Your response should be:
 
115
  Thorough and directly address all aspects of the user's question
116
  Based solely on factual information from the provided context and your reliable knowledge
117
  Include specific examples, data points, or case studies only when you're certain of their accuracy
118
  Explain technical concepts clearly, considering the user may have varying levels of expertise
119
  Clearly indicate any areas where information is limited or uncertain
 
120
  Context: {context}
121
  User Question: {query}
122
  Based on the above context and your extensive knowledge of the semiconductor industry, provide your detailed, accurate, and grounded response below. Remember, only include information you're confident is correct, and clearly state any uncertainties:
123
  """
124
 
125
+ async for chunk in await openai_client.chat.completions.create(
126
  model="gpt-4o",
127
  messages=[
128
  {"role": "system", "content": "You are an expert Semi Conductor industry analyst"},
129
  {"role": "user", "content": prompt}
130
  ],
131
+ temperature=0,
132
+ stream=True
133
+ ):
134
+ content = chunk.choices[0].delta.content
135
+ if content is not None:
136
+ yield content
137
 
138
  def process_search_result(item):
139
  if item['content_type'] == 'text':
 
144
  return f"Table Description from {item['source_document']} (Page {item['page_number']}): {item['description']}\n\n"
145
  return ""
146
 
147
+ async def esg_analysis_stream(user_query: str):
148
  search_results = await search_multimodal(user_query)
149
 
150
  context_parts = await asyncio.gather(*[asyncio.to_thread(process_search_result, item) for item in search_results])
151
  context = "".join(context_parts)
152
 
 
 
153
  sources = []
154
  for item in search_results[:5]: # Limit to top 5 sources
155
  source = {
 
163
  source["image_path"] = item.get("image_path", "N/A")
164
  sources.append(source)
165
 
166
+ return generate_response_stream(user_query, context), sources
 
 
 
 
 
 
 
 
 
 
167
 
168
+ def format_sources(sources):
 
 
 
169
  source_text = "## Top 5 Sources\n\n"
170
  for i, source in enumerate(sources, 1):
171
  source_text += f"### Source {i}\n"
 
177
  if 'image_path' in source:
178
  source_text += f"- **Image Path:** {source['image_path']}\n"
179
  source_text += "\n"
180
+ return source_text
181
+
182
+ # Custom CSS for the status box
183
+ custom_css = """
184
+ #status-box {
185
+ position: absolute;
186
+ top: 10px;
187
+ right: 10px;
188
+ background-color: white;
189
+ padding: 5px 10px;
190
+ border-radius: 5px;
191
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1);
192
+ z-index: 1000;
193
+ display: flex;
194
+ align-items: center;
195
+ }
196
+ #status-light {
197
+ width: 10px;
198
+ height: 10px;
199
+ border-radius: 50%;
200
+ display: inline-block;
201
+ margin-right: 5px;
202
+ }
203
+ #status-text {
204
+ font-size: 14px;
205
+ font-weight: bold;
206
+ }
207
+ """
208
+
209
+ def get_connection_status():
210
+ status = connection_status["status"]
211
+ color = connection_status["color"]
212
+ return f'<div id="status-box"><div id="status-light" style="background-color: {color};"></div><span id="status-text">{status}</span></div>'
213
+
214
+ async def check_connection():
215
+ global connection_status
216
+ try:
217
+ if client:
218
+ await asyncio.to_thread(client.schema.get)
219
+ return {"status": "Connected", "color": "green"}
220
+ else:
221
+ return {"status": "Disconnected", "color": "red"}
222
+ except Exception:
223
+ return {"status": "Disconnected", "color": "red"}
224
+
225
+ async def update_status():
226
+ global connection_status
227
+ while True:
228
+ new_status = await check_connection()
229
+ if new_status != connection_status:
230
+ connection_status = new_status
231
+ yield new_status
232
+ await asyncio.sleep(5) # Check every 5 seconds
233
+
234
+ async def gradio_interface(user_question):
235
+ if connection_status["status"] != "Connected":
236
+ return "Error: Database not connected. Please wait for the connection to be established.", ""
237
+
238
+ response_generator, sources = await esg_analysis_stream(user_question)
239
+ formatted_sources = format_sources(sources)
240
+
241
+ full_response = ""
242
+ async for response_chunk in response_generator:
243
+ full_response += response_chunk
244
+
245
+ return full_response, formatted_sources
246
+
247
+ with gr.Blocks(css=custom_css) as iface:
248
+ status_indicator = gr.HTML(get_connection_status())
249
+
250
+ with gr.Row():
251
+ gr.Markdown("# Semiconductor Industry Analysis")
252
+
253
+ gr.Markdown("Ask questions about the semiconductor industry and get AI-powered answers with sources.")
254
+
255
+ user_question = gr.Textbox(lines=2, placeholder="Enter your question about the semiconductor industry...", interactive=False)
256
+ ai_response = gr.Markdown(label="AI Response")
257
+ sources_output = gr.Markdown(label="Sources")
258
+
259
+ submit_btn = gr.Button("Submit", interactive=False)
260
+
261
+ submit_btn.click(
262
+ fn=gradio_interface,
263
+ inputs=user_question,
264
+ outputs=[ai_response, sources_output],
265
+ )
266
+
267
+ # Update status
268
+ def update_status_indicator(status):
269
+ return get_connection_status() # Return the HTML string directly
270
+
271
+ def update_input_state(status):
272
+ is_connected = status["status"] == "Connected"
273
+ return gr.update(interactive=is_connected), gr.update(interactive=is_connected)
274
+
275
+ status_updater = gr.State(connection_status)
276
+
277
+ iface.load(
278
+ lambda: connection_status,
279
+ outputs=[status_updater],
280
+ every=1,
281
+ )
282
+
283
+ status_updater.change(
284
+ fn=update_status_indicator,
285
+ inputs=[status_updater],
286
+ outputs=[status_indicator],
287
+ )
288
+
289
+ status_updater.change(
290
+ fn=update_input_state,
291
+ inputs=[status_updater],
292
+ outputs=[user_question, submit_btn],
293
+ )
294
+
295
+ status_updater = gr.State(connection_status)
296
+
297
+ iface.load(
298
+ lambda: connection_status,
299
+ outputs=[status_updater],
300
+ every=1,
301
+ )
302
+
303
+ status_updater.change(
304
+ fn=update_status_indicator,
305
+ inputs=[status_updater],
306
+ outputs=[status_indicator],
307
+ )
308
+
309
+ status_updater.change(
310
+ fn=update_input_state,
311
+ inputs=[status_updater],
312
+ outputs=[user_question, submit_btn],
313
+ )
314
+
315
+ async def main():
316
+ # Check environment variables
317
+ required_env_vars = ['WCS_URL', 'WCS_API_KEY', 'OPENAI_API_KEY', 'WEAVIATE_COLLECTION_NAME']
318
+ for var in required_env_vars:
319
+ if not os.getenv(var):
320
+ logger.error(f"Environment variable {var} is not set!")
321
+ return
322
+
323
+ # Initialize the client before launching the interface
324
+ await initialize_weaviate_client()
325
 
326
+ # Launch the interface regardless of connection status
327
+ await iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  if __name__ == "__main__":
330
+ asyncio.run(main())