harpreetsahota commited on
Commit
575fee0
·
verified ·
1 Parent(s): 93d55ac

Update assistant.py

Browse files
Files changed (1) hide show
  1. assistant.py +44 -49
assistant.py CHANGED
@@ -1,61 +1,56 @@
 
 
1
  from openai import OpenAI
2
- from typing import Optional, Dict, Any
3
 
4
  class AIAssistant:
5
- """
6
- A wrapper class for consistent LLM API interactions.
7
-
8
- This class provides:
9
- - Unified interface for different LLM providers
10
- - Consistent handling of generation parameters
11
- - Support for streaming responses
12
-
13
- Attributes:
14
- client: Initialized API client (OpenAI, Anthropic, etc.)
15
- model: Name of the model to use
16
- """
17
- def __init__(self, client: OpenAI, model: str):
18
  self.client = client
19
  self.model = model
20
 
21
- def generate_response(self,
22
- prompt_template: Any,
23
- generation_params: Optional[Dict] = None,
24
- stream: bool = False,
25
- **kwargs):
 
 
26
  """
27
- Generate LLM response using pthe rovided template and parameters.
28
 
29
  Args:
30
- prompt_template: Template object with format method
31
- generation_params: Optional generation parameters
 
32
  stream: Whether to stream the response
33
- **kwargs: Variables for prompt template
34
 
35
- Returns:
36
- API response object or streamed response
37
-
38
- Example:
39
- assistant.generate_response(
40
- prompt_template=template,
41
- temperature=0.7,
42
- topic="AI safety"
43
- )
44
  """
45
- messages = prompt_template.format(**kwargs)
46
- params = generation_params or {}
47
-
48
- completion = self.client.chat.completions.create(
49
- model=self.model,
50
- messages=messages,
51
- stream=stream,
52
- **params
53
- )
54
-
55
- if stream:
56
- for chunk in completion:
57
- if chunk.choices[0].delta.content is not None:
58
- print(chunk.choices[0].delta.content, end="")
59
- return completion
60
-
61
- return completion
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Any, Union, Generator
2
+ from huggingface_hub import InferenceClient
3
  from openai import OpenAI
4
+ from prompt_template import PromptTemplate
5
 
6
  class AIAssistant:
7
+ def __init__(self, client: Union[OpenAI, InferenceClient], model: str):
 
 
 
 
 
 
 
 
 
 
 
 
8
  self.client = client
9
  self.model = model
10
 
11
+ def generate_response(
12
+ self,
13
+ prompt_template: PromptTemplate,
14
+ messages: list[Dict[str, str]],
15
+ generation_params: Optional[Dict] = None,
16
+ stream: bool = True,
17
+ ) -> Generator[str, None, None]:
18
  """
19
+ Generate LLM response using the provided template and parameters.
20
 
21
  Args:
22
+ prompt_template: PromptTemplate object containing template and parameters
23
+ messages: List of message dictionaries with role and content
24
+ generation_params: Optional generation parameters (overrides template parameters)
25
  stream: Whether to stream the response
 
26
 
27
+ Yields:
28
+ Streamed response text
 
 
 
 
 
 
 
29
  """
30
+ params = generation_params or prompt_template.parameters
31
+
32
+ # Ensure messages are in correct format
33
+ formatted_messages = [
34
+ {"role": msg["role"], "content": str(msg["content"])}
35
+ for msg in messages
36
+ ]
37
+
38
+ try:
39
+ completion = self.client.chat.completions.create(
40
+ model=self.model,
41
+ messages=formatted_messages,
42
+ stream=stream,
43
+ **params
44
+ )
45
+
46
+ if stream:
47
+ response = ""
48
+ for chunk in completion:
49
+ if chunk.choices[0].delta.content is not None:
50
+ response += chunk.choices[0].delta.content
51
+ yield response
52
+ else:
53
+ return completion.choices[0].message.content
54
+
55
+ except Exception as e:
56
+ yield f"Error generating response: {str(e)}"