Shyamnath commited on
Commit
469eae6
·
1 Parent(s): f526ba5

Push core package and essential files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +26 -0
  2. litellm/__init__.py +1084 -0
  3. litellm/_logging.py +167 -0
  4. litellm/_redis.py +333 -0
  5. litellm/_service_logger.py +311 -0
  6. litellm/_version.py +6 -0
  7. litellm/anthropic_interface/__init__.py +6 -0
  8. litellm/anthropic_interface/messages/__init__.py +117 -0
  9. litellm/anthropic_interface/readme.md +116 -0
  10. litellm/assistants/main.py +1484 -0
  11. litellm/assistants/utils.py +161 -0
  12. litellm/batch_completion/Readme.md +11 -0
  13. litellm/batch_completion/main.py +253 -0
  14. litellm/batches/batch_utils.py +182 -0
  15. litellm/batches/main.py +796 -0
  16. litellm/budget_manager.py +230 -0
  17. litellm/caching/Readme.md +40 -0
  18. litellm/caching/__init__.py +9 -0
  19. litellm/caching/_internal_lru_cache.py +30 -0
  20. litellm/caching/base_cache.py +55 -0
  21. litellm/caching/caching.py +818 -0
  22. litellm/caching/caching_handler.py +938 -0
  23. litellm/caching/disk_cache.py +88 -0
  24. litellm/caching/dual_cache.py +434 -0
  25. litellm/caching/in_memory_cache.py +203 -0
  26. litellm/caching/llm_caching_handler.py +39 -0
  27. litellm/caching/qdrant_semantic_cache.py +442 -0
  28. litellm/caching/redis_cache.py +1162 -0
  29. litellm/caching/redis_cluster_cache.py +59 -0
  30. litellm/caching/redis_semantic_cache.py +450 -0
  31. litellm/caching/s3_cache.py +159 -0
  32. litellm/constants.py +543 -0
  33. litellm/cost.json +5 -0
  34. litellm/cost_calculator.py +1378 -0
  35. litellm/exceptions.py +809 -0
  36. litellm/experimental_mcp_client/Readme.md +6 -0
  37. litellm/experimental_mcp_client/__init__.py +3 -0
  38. litellm/experimental_mcp_client/client.py +0 -0
  39. litellm/experimental_mcp_client/tools.py +111 -0
  40. litellm/files/main.py +891 -0
  41. litellm/fine_tuning/main.py +761 -0
  42. litellm/integrations/Readme.md +5 -0
  43. litellm/integrations/SlackAlerting/Readme.md +13 -0
  44. litellm/integrations/SlackAlerting/batching_handler.py +81 -0
  45. litellm/integrations/SlackAlerting/slack_alerting.py +1825 -0
  46. litellm/integrations/SlackAlerting/utils.py +92 -0
  47. litellm/integrations/__init__.py +1 -0
  48. litellm/integrations/_types/open_inference.py +389 -0
  49. litellm/integrations/additional_logging_utils.py +36 -0
  50. litellm/integrations/agentops/__init__.py +3 -0
LICENSE ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Portions of this software are licensed as follows:
2
+
3
+ * All content that resides under the "enterprise/" directory of this repository, if that directory exists, is licensed under the license defined in "enterprise/LICENSE".
4
+ * Content outside of the above mentioned directories or restrictions above is available under the MIT license as defined below.
5
+ ---
6
+ MIT License
7
+
8
+ Copyright (c) 2023 Berri AI
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
litellm/__init__.py ADDED
@@ -0,0 +1,1084 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Hide pydantic namespace conflict warnings globally ###
2
+ import warnings
3
+
4
+ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
5
+ ### INIT VARIABLES ###########
6
+ import threading
7
+ import os
8
+ from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
9
+ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
10
+ from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
11
+ from litellm.caching.llm_caching_handler import LLMClientCache
12
+ from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
13
+ from litellm.types.utils import (
14
+ ImageObject,
15
+ BudgetConfig,
16
+ all_litellm_params,
17
+ all_litellm_params as _litellm_completion_params,
18
+ CredentialItem,
19
+ ) # maintain backwards compatibility for root param
20
+ from litellm._logging import (
21
+ set_verbose,
22
+ _turn_on_debug,
23
+ verbose_logger,
24
+ json_logs,
25
+ _turn_on_json,
26
+ log_level,
27
+ )
28
+ import re
29
+ from litellm.constants import (
30
+ DEFAULT_BATCH_SIZE,
31
+ DEFAULT_FLUSH_INTERVAL_SECONDS,
32
+ ROUTER_MAX_FALLBACKS,
33
+ DEFAULT_MAX_RETRIES,
34
+ DEFAULT_REPLICATE_POLLING_RETRIES,
35
+ DEFAULT_REPLICATE_POLLING_DELAY_SECONDS,
36
+ LITELLM_CHAT_PROVIDERS,
37
+ HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
38
+ OPENAI_CHAT_COMPLETION_PARAMS,
39
+ OPENAI_CHAT_COMPLETION_PARAMS as _openai_completion_params, # backwards compatibility
40
+ OPENAI_FINISH_REASONS,
41
+ OPENAI_FINISH_REASONS as _openai_finish_reasons, # backwards compatibility
42
+ openai_compatible_endpoints,
43
+ openai_compatible_providers,
44
+ openai_text_completion_compatible_providers,
45
+ _openai_like_providers,
46
+ replicate_models,
47
+ clarifai_models,
48
+ huggingface_models,
49
+ empower_models,
50
+ together_ai_models,
51
+ baseten_models,
52
+ REPEATED_STREAMING_CHUNK_LIMIT,
53
+ request_timeout,
54
+ open_ai_embedding_models,
55
+ cohere_embedding_models,
56
+ bedrock_embedding_models,
57
+ known_tokenizer_config,
58
+ BEDROCK_INVOKE_PROVIDERS_LITERAL,
59
+ DEFAULT_MAX_TOKENS,
60
+ DEFAULT_SOFT_BUDGET,
61
+ DEFAULT_ALLOWED_FAILS,
62
+ )
63
+ from litellm.types.guardrails import GuardrailItem
64
+ from litellm.proxy._types import (
65
+ KeyManagementSystem,
66
+ KeyManagementSettings,
67
+ LiteLLM_UpperboundKeyGenerateParams,
68
+ )
69
+ from litellm.types.proxy.management_endpoints.ui_sso import DefaultTeamSSOParams
70
+ from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
71
+ from litellm.integrations.custom_logger import CustomLogger
72
+ from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
73
+ import httpx
74
+ import dotenv
75
+
76
+ litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
77
+ if litellm_mode == "DEV":
78
+ dotenv.load_dotenv()
79
+ ################################################
80
+ if set_verbose == True:
81
+ _turn_on_debug()
82
+ ################################################
83
+ ### Callbacks /Logging / Success / Failure Handlers #####
84
+ CALLBACK_TYPES = Union[str, Callable, CustomLogger]
85
+ input_callback: List[CALLBACK_TYPES] = []
86
+ success_callback: List[CALLBACK_TYPES] = []
87
+ failure_callback: List[CALLBACK_TYPES] = []
88
+ service_callback: List[CALLBACK_TYPES] = []
89
+ logging_callback_manager = LoggingCallbackManager()
90
+ _custom_logger_compatible_callbacks_literal = Literal[
91
+ "lago",
92
+ "openmeter",
93
+ "logfire",
94
+ "literalai",
95
+ "dynamic_rate_limiter",
96
+ "langsmith",
97
+ "prometheus",
98
+ "otel",
99
+ "datadog",
100
+ "datadog_llm_observability",
101
+ "galileo",
102
+ "braintrust",
103
+ "arize",
104
+ "arize_phoenix",
105
+ "langtrace",
106
+ "gcs_bucket",
107
+ "azure_storage",
108
+ "opik",
109
+ "argilla",
110
+ "mlflow",
111
+ "langfuse",
112
+ "pagerduty",
113
+ "humanloop",
114
+ "gcs_pubsub",
115
+ "agentops",
116
+ "anthropic_cache_control_hook",
117
+ "bedrock_knowledgebase_hook",
118
+ ]
119
+ logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
120
+ _known_custom_logger_compatible_callbacks: List = list(
121
+ get_args(_custom_logger_compatible_callbacks_literal)
122
+ )
123
+ callbacks: List[
124
+ Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
125
+ ] = []
126
+ langfuse_default_tags: Optional[List[str]] = None
127
+ langsmith_batch_size: Optional[int] = None
128
+ prometheus_initialize_budget_metrics: Optional[bool] = False
129
+ require_auth_for_metrics_endpoint: Optional[bool] = False
130
+ argilla_batch_size: Optional[int] = None
131
+ datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
132
+ gcs_pub_sub_use_v1: Optional[bool] = (
133
+ False # if you want to use v1 gcs pubsub logged payload
134
+ )
135
+ argilla_transformation_object: Optional[Dict[str, Any]] = None
136
+ _async_input_callback: List[Union[str, Callable, CustomLogger]] = (
137
+ []
138
+ ) # internal variable - async custom callbacks are routed here.
139
+ _async_success_callback: List[Union[str, Callable, CustomLogger]] = (
140
+ []
141
+ ) # internal variable - async custom callbacks are routed here.
142
+ _async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
143
+ []
144
+ ) # internal variable - async custom callbacks are routed here.
145
+ pre_call_rules: List[Callable] = []
146
+ post_call_rules: List[Callable] = []
147
+ turn_off_message_logging: Optional[bool] = False
148
+ log_raw_request_response: bool = False
149
+ redact_messages_in_exceptions: Optional[bool] = False
150
+ redact_user_api_key_info: Optional[bool] = False
151
+ filter_invalid_headers: Optional[bool] = False
152
+ add_user_information_to_llm_headers: Optional[bool] = (
153
+ None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
154
+ )
155
+ store_audit_logs = False # Enterprise feature, allow users to see audit logs
156
+ ### end of callbacks #############
157
+
158
+ email: Optional[str] = (
159
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
160
+ )
161
+ token: Optional[str] = (
162
+ None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
163
+ )
164
+ telemetry = True
165
+ max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
166
+ drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
167
+ modify_params = bool(os.getenv("LITELLM_MODIFY_PARAMS", False))
168
+ retry = True
169
+ ### AUTH ###
170
+ api_key: Optional[str] = None
171
+ openai_key: Optional[str] = None
172
+ groq_key: Optional[str] = None
173
+ databricks_key: Optional[str] = None
174
+ openai_like_key: Optional[str] = None
175
+ azure_key: Optional[str] = None
176
+ anthropic_key: Optional[str] = None
177
+ replicate_key: Optional[str] = None
178
+ cohere_key: Optional[str] = None
179
+ infinity_key: Optional[str] = None
180
+ clarifai_key: Optional[str] = None
181
+ maritalk_key: Optional[str] = None
182
+ ai21_key: Optional[str] = None
183
+ ollama_key: Optional[str] = None
184
+ openrouter_key: Optional[str] = None
185
+ predibase_key: Optional[str] = None
186
+ huggingface_key: Optional[str] = None
187
+ vertex_project: Optional[str] = None
188
+ vertex_location: Optional[str] = None
189
+ predibase_tenant_id: Optional[str] = None
190
+ togetherai_api_key: Optional[str] = None
191
+ cloudflare_api_key: Optional[str] = None
192
+ baseten_key: Optional[str] = None
193
+ aleph_alpha_key: Optional[str] = None
194
+ nlp_cloud_key: Optional[str] = None
195
+ snowflake_key: Optional[str] = None
196
+ common_cloud_provider_auth_params: dict = {
197
+ "params": ["project", "region_name", "token"],
198
+ "providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
199
+ }
200
+ use_client: bool = False
201
+ ssl_verify: Union[str, bool] = True
202
+ ssl_certificate: Optional[str] = None
203
+ disable_streaming_logging: bool = False
204
+ disable_add_transform_inline_image_block: bool = False
205
+ in_memory_llm_clients_cache: LLMClientCache = LLMClientCache()
206
+ safe_memory_mode: bool = False
207
+ enable_azure_ad_token_refresh: Optional[bool] = False
208
+ ### DEFAULT AZURE API VERSION ###
209
+ AZURE_DEFAULT_API_VERSION = "2025-02-01-preview" # this is updated to the latest
210
+ ### DEFAULT WATSONX API VERSION ###
211
+ WATSONX_DEFAULT_API_VERSION = "2024-03-13"
212
+ ### COHERE EMBEDDINGS DEFAULT TYPE ###
213
+ COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
214
+ ### CREDENTIALS ###
215
+ credential_list: List[CredentialItem] = []
216
+ ### GUARDRAILS ###
217
+ llamaguard_model_name: Optional[str] = None
218
+ openai_moderations_model_name: Optional[str] = None
219
+ presidio_ad_hoc_recognizers: Optional[str] = None
220
+ google_moderation_confidence_threshold: Optional[float] = None
221
+ llamaguard_unsafe_content_categories: Optional[str] = None
222
+ blocked_user_list: Optional[Union[str, List]] = None
223
+ banned_keywords_list: Optional[Union[str, List]] = None
224
+ llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
225
+ guardrail_name_config_map: Dict[str, GuardrailItem] = {}
226
+ ##################
227
+ ### PREVIEW FEATURES ###
228
+ enable_preview_features: bool = False
229
+ return_response_headers: bool = (
230
+ False # get response headers from LLM Api providers - example x-remaining-requests,
231
+ )
232
+ enable_json_schema_validation: bool = False
233
+ ##################
234
+ logging: bool = True
235
+ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
236
+ enable_caching_on_provider_specific_optional_params: bool = (
237
+ False # feature-flag for caching on optional params - e.g. 'top_k'
238
+ )
239
+ caching: bool = (
240
+ False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
241
+ )
242
+ caching_with_models: bool = (
243
+ False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
244
+ )
245
+ cache: Optional[Cache] = (
246
+ None # cache object <- use this - https://docs.litellm.ai/docs/caching
247
+ )
248
+ default_in_memory_ttl: Optional[float] = None
249
+ default_redis_ttl: Optional[float] = None
250
+ default_redis_batch_cache_expiry: Optional[float] = None
251
+ model_alias_map: Dict[str, str] = {}
252
+ model_group_alias_map: Dict[str, str] = {}
253
+ max_budget: float = 0.0 # set the max budget across all providers
254
+ budget_duration: Optional[str] = (
255
+ None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
256
+ )
257
+ default_soft_budget: float = (
258
+ DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
259
+ )
260
+ forward_traceparent_to_llm_provider: bool = False
261
+
262
+
263
+ _current_cost = 0.0 # private variable, used if max budget is set
264
+ error_logs: Dict = {}
265
+ add_function_to_prompt: bool = (
266
+ False # if function calling not supported by api, append function call details to system prompt
267
+ )
268
+ client_session: Optional[httpx.Client] = None
269
+ aclient_session: Optional[httpx.AsyncClient] = None
270
+ model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
271
+ model_cost_map_url: str = (
272
+ "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
273
+ )
274
+ suppress_debug_info = False
275
+ dynamodb_table_name: Optional[str] = None
276
+ s3_callback_params: Optional[Dict] = None
277
+ generic_logger_headers: Optional[Dict] = None
278
+ default_key_generate_params: Optional[Dict] = None
279
+ upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
280
+ key_generation_settings: Optional[StandardKeyGenerationConfig] = None
281
+ default_internal_user_params: Optional[Dict] = None
282
+ default_team_params: Optional[Union[DefaultTeamSSOParams, Dict]] = None
283
+ default_team_settings: Optional[List] = None
284
+ max_user_budget: Optional[float] = None
285
+ default_max_internal_user_budget: Optional[float] = None
286
+ max_internal_user_budget: Optional[float] = None
287
+ max_ui_session_budget: Optional[float] = 10 # $10 USD budgets for UI Chat sessions
288
+ internal_user_budget_duration: Optional[str] = None
289
+ tag_budget_config: Optional[Dict[str, BudgetConfig]] = None
290
+ max_end_user_budget: Optional[float] = None
291
+ disable_end_user_cost_tracking: Optional[bool] = None
292
+ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
293
+ custom_prometheus_metadata_labels: List[str] = []
294
+ #### REQUEST PRIORITIZATION ####
295
+ priority_reservation: Optional[Dict[str, float]] = None
296
+ force_ipv4: bool = (
297
+ False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
298
+ )
299
+ module_level_aclient = AsyncHTTPHandler(
300
+ timeout=request_timeout, client_alias="module level aclient"
301
+ )
302
+ module_level_client = HTTPHandler(timeout=request_timeout)
303
+
304
+ #### RETRIES ####
305
+ num_retries: Optional[int] = None # per model endpoint
306
+ max_fallbacks: Optional[int] = None
307
+ default_fallbacks: Optional[List] = None
308
+ fallbacks: Optional[List] = None
309
+ context_window_fallbacks: Optional[List] = None
310
+ content_policy_fallbacks: Optional[List] = None
311
+ allowed_fails: int = 3
312
+ num_retries_per_request: Optional[int] = (
313
+ None # for the request overall (incl. fallbacks + model retries)
314
+ )
315
+ ####### SECRET MANAGERS #####################
316
+ secret_manager_client: Optional[Any] = (
317
+ None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
318
+ )
319
+ _google_kms_resource_name: Optional[str] = None
320
+ _key_management_system: Optional[KeyManagementSystem] = None
321
+ _key_management_settings: KeyManagementSettings = KeyManagementSettings()
322
+ #### PII MASKING ####
323
+ output_parse_pii: bool = False
324
+ #############################################
325
+ from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map
326
+
327
+ model_cost = get_model_cost_map(url=model_cost_map_url)
328
+ custom_prompt_dict: Dict[str, dict] = {}
329
+ check_provider_endpoint = False
330
+
331
+
332
+ ####### THREAD-SPECIFIC DATA ####################
333
+ class MyLocal(threading.local):
334
+ def __init__(self):
335
+ self.user = "Hello World"
336
+
337
+
338
+ _thread_context = MyLocal()
339
+
340
+
341
+ def identify(event_details):
342
+ # Store user in thread local data
343
+ if "user" in event_details:
344
+ _thread_context.user = event_details["user"]
345
+
346
+
347
+ ####### ADDITIONAL PARAMS ################### configurable params if you use proxy models like Helicone, map spend to org id, etc.
348
+ api_base: Optional[str] = None
349
+ headers = None
350
+ api_version = None
351
+ organization = None
352
+ project = None
353
+ config_path = None
354
+ vertex_ai_safety_settings: Optional[dict] = None
355
+ BEDROCK_CONVERSE_MODELS = [
356
+ "anthropic.claude-3-7-sonnet-20250219-v1:0",
357
+ "anthropic.claude-3-5-haiku-20241022-v1:0",
358
+ "anthropic.claude-3-5-sonnet-20241022-v2:0",
359
+ "anthropic.claude-3-5-sonnet-20240620-v1:0",
360
+ "anthropic.claude-3-opus-20240229-v1:0",
361
+ "anthropic.claude-3-sonnet-20240229-v1:0",
362
+ "anthropic.claude-3-haiku-20240307-v1:0",
363
+ "anthropic.claude-v2",
364
+ "anthropic.claude-v2:1",
365
+ "anthropic.claude-v1",
366
+ "anthropic.claude-instant-v1",
367
+ "ai21.jamba-instruct-v1:0",
368
+ "meta.llama3-70b-instruct-v1:0",
369
+ "meta.llama3-8b-instruct-v1:0",
370
+ "meta.llama3-1-8b-instruct-v1:0",
371
+ "meta.llama3-1-70b-instruct-v1:0",
372
+ "meta.llama3-1-405b-instruct-v1:0",
373
+ "meta.llama3-70b-instruct-v1:0",
374
+ "mistral.mistral-large-2407-v1:0",
375
+ "mistral.mistral-large-2402-v1:0",
376
+ "meta.llama3-2-1b-instruct-v1:0",
377
+ "meta.llama3-2-3b-instruct-v1:0",
378
+ "meta.llama3-2-11b-instruct-v1:0",
379
+ "meta.llama3-2-90b-instruct-v1:0",
380
+ ]
381
+
382
+ ####### COMPLETION MODELS ###################
383
+ open_ai_chat_completion_models: List = []
384
+ open_ai_text_completion_models: List = []
385
+ cohere_models: List = []
386
+ cohere_chat_models: List = []
387
+ mistral_chat_models: List = []
388
+ text_completion_codestral_models: List = []
389
+ anthropic_models: List = []
390
+ openrouter_models: List = []
391
+ vertex_language_models: List = []
392
+ vertex_vision_models: List = []
393
+ vertex_chat_models: List = []
394
+ vertex_code_chat_models: List = []
395
+ vertex_ai_image_models: List = []
396
+ vertex_text_models: List = []
397
+ vertex_code_text_models: List = []
398
+ vertex_embedding_models: List = []
399
+ vertex_anthropic_models: List = []
400
+ vertex_llama3_models: List = []
401
+ vertex_ai_ai21_models: List = []
402
+ vertex_mistral_models: List = []
403
+ ai21_models: List = []
404
+ ai21_chat_models: List = []
405
+ nlp_cloud_models: List = []
406
+ aleph_alpha_models: List = []
407
+ bedrock_models: List = []
408
+ bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS
409
+ fireworks_ai_models: List = []
410
+ fireworks_ai_embedding_models: List = []
411
+ deepinfra_models: List = []
412
+ perplexity_models: List = []
413
+ watsonx_models: List = []
414
+ gemini_models: List = []
415
+ xai_models: List = []
416
+ deepseek_models: List = []
417
+ azure_ai_models: List = []
418
+ jina_ai_models: List = []
419
+ voyage_models: List = []
420
+ infinity_models: List = []
421
+ databricks_models: List = []
422
+ cloudflare_models: List = []
423
+ codestral_models: List = []
424
+ friendliai_models: List = []
425
+ palm_models: List = []
426
+ groq_models: List = []
427
+ azure_models: List = []
428
+ azure_text_models: List = []
429
+ anyscale_models: List = []
430
+ cerebras_models: List = []
431
+ galadriel_models: List = []
432
+ sambanova_models: List = []
433
+ assemblyai_models: List = []
434
+ snowflake_models: List = []
435
+
436
+
437
+ def is_bedrock_pricing_only_model(key: str) -> bool:
438
+ """
439
+ Excludes keys with the pattern 'bedrock/<region>/<model>'. These are in the model_prices_and_context_window.json file for pricing purposes only.
440
+
441
+ Args:
442
+ key (str): A key to filter.
443
+
444
+ Returns:
445
+ bool: True if the key matches the Bedrock pattern, False otherwise.
446
+ """
447
+ # Regex to match 'bedrock/<region>/<model>'
448
+ bedrock_pattern = re.compile(r"^bedrock/[a-zA-Z0-9_-]+/.+$")
449
+
450
+ if "month-commitment" in key:
451
+ return True
452
+
453
+ is_match = bedrock_pattern.match(key)
454
+ return is_match is not None
455
+
456
+
457
+ def is_openai_finetune_model(key: str) -> bool:
458
+ """
459
+ Excludes model cost keys with the pattern 'ft:<model>'. These are in the model_prices_and_context_window.json file for pricing purposes only.
460
+
461
+ Args:
462
+ key (str): A key to filter.
463
+
464
+ Returns:
465
+ bool: True if the key matches the OpenAI finetune pattern, False otherwise.
466
+ """
467
+ return key.startswith("ft:") and not key.count(":") > 1
468
+
469
+
470
+ def add_known_models():
471
+ for key, value in model_cost.items():
472
+ if value.get("litellm_provider") == "openai" and not is_openai_finetune_model(
473
+ key
474
+ ):
475
+ open_ai_chat_completion_models.append(key)
476
+ elif value.get("litellm_provider") == "text-completion-openai":
477
+ open_ai_text_completion_models.append(key)
478
+ elif value.get("litellm_provider") == "azure_text":
479
+ azure_text_models.append(key)
480
+ elif value.get("litellm_provider") == "cohere":
481
+ cohere_models.append(key)
482
+ elif value.get("litellm_provider") == "cohere_chat":
483
+ cohere_chat_models.append(key)
484
+ elif value.get("litellm_provider") == "mistral":
485
+ mistral_chat_models.append(key)
486
+ elif value.get("litellm_provider") == "anthropic":
487
+ anthropic_models.append(key)
488
+ elif value.get("litellm_provider") == "empower":
489
+ empower_models.append(key)
490
+ elif value.get("litellm_provider") == "openrouter":
491
+ openrouter_models.append(key)
492
+ elif value.get("litellm_provider") == "vertex_ai-text-models":
493
+ vertex_text_models.append(key)
494
+ elif value.get("litellm_provider") == "vertex_ai-code-text-models":
495
+ vertex_code_text_models.append(key)
496
+ elif value.get("litellm_provider") == "vertex_ai-language-models":
497
+ vertex_language_models.append(key)
498
+ elif value.get("litellm_provider") == "vertex_ai-vision-models":
499
+ vertex_vision_models.append(key)
500
+ elif value.get("litellm_provider") == "vertex_ai-chat-models":
501
+ vertex_chat_models.append(key)
502
+ elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
503
+ vertex_code_chat_models.append(key)
504
+ elif value.get("litellm_provider") == "vertex_ai-embedding-models":
505
+ vertex_embedding_models.append(key)
506
+ elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
507
+ key = key.replace("vertex_ai/", "")
508
+ vertex_anthropic_models.append(key)
509
+ elif value.get("litellm_provider") == "vertex_ai-llama_models":
510
+ key = key.replace("vertex_ai/", "")
511
+ vertex_llama3_models.append(key)
512
+ elif value.get("litellm_provider") == "vertex_ai-mistral_models":
513
+ key = key.replace("vertex_ai/", "")
514
+ vertex_mistral_models.append(key)
515
+ elif value.get("litellm_provider") == "vertex_ai-ai21_models":
516
+ key = key.replace("vertex_ai/", "")
517
+ vertex_ai_ai21_models.append(key)
518
+ elif value.get("litellm_provider") == "vertex_ai-image-models":
519
+ key = key.replace("vertex_ai/", "")
520
+ vertex_ai_image_models.append(key)
521
+ elif value.get("litellm_provider") == "ai21":
522
+ if value.get("mode") == "chat":
523
+ ai21_chat_models.append(key)
524
+ else:
525
+ ai21_models.append(key)
526
+ elif value.get("litellm_provider") == "nlp_cloud":
527
+ nlp_cloud_models.append(key)
528
+ elif value.get("litellm_provider") == "aleph_alpha":
529
+ aleph_alpha_models.append(key)
530
+ elif value.get(
531
+ "litellm_provider"
532
+ ) == "bedrock" and not is_bedrock_pricing_only_model(key):
533
+ bedrock_models.append(key)
534
+ elif value.get("litellm_provider") == "bedrock_converse":
535
+ bedrock_converse_models.append(key)
536
+ elif value.get("litellm_provider") == "deepinfra":
537
+ deepinfra_models.append(key)
538
+ elif value.get("litellm_provider") == "perplexity":
539
+ perplexity_models.append(key)
540
+ elif value.get("litellm_provider") == "watsonx":
541
+ watsonx_models.append(key)
542
+ elif value.get("litellm_provider") == "gemini":
543
+ gemini_models.append(key)
544
+ elif value.get("litellm_provider") == "fireworks_ai":
545
+ # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
546
+ if "-to-" not in key and "fireworks-ai-default" not in key:
547
+ fireworks_ai_models.append(key)
548
+ elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
549
+ # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
550
+ if "-to-" not in key:
551
+ fireworks_ai_embedding_models.append(key)
552
+ elif value.get("litellm_provider") == "text-completion-codestral":
553
+ text_completion_codestral_models.append(key)
554
+ elif value.get("litellm_provider") == "xai":
555
+ xai_models.append(key)
556
+ elif value.get("litellm_provider") == "deepseek":
557
+ deepseek_models.append(key)
558
+ elif value.get("litellm_provider") == "azure_ai":
559
+ azure_ai_models.append(key)
560
+ elif value.get("litellm_provider") == "voyage":
561
+ voyage_models.append(key)
562
+ elif value.get("litellm_provider") == "infinity":
563
+ infinity_models.append(key)
564
+ elif value.get("litellm_provider") == "databricks":
565
+ databricks_models.append(key)
566
+ elif value.get("litellm_provider") == "cloudflare":
567
+ cloudflare_models.append(key)
568
+ elif value.get("litellm_provider") == "codestral":
569
+ codestral_models.append(key)
570
+ elif value.get("litellm_provider") == "friendliai":
571
+ friendliai_models.append(key)
572
+ elif value.get("litellm_provider") == "palm":
573
+ palm_models.append(key)
574
+ elif value.get("litellm_provider") == "groq":
575
+ groq_models.append(key)
576
+ elif value.get("litellm_provider") == "azure":
577
+ azure_models.append(key)
578
+ elif value.get("litellm_provider") == "anyscale":
579
+ anyscale_models.append(key)
580
+ elif value.get("litellm_provider") == "cerebras":
581
+ cerebras_models.append(key)
582
+ elif value.get("litellm_provider") == "galadriel":
583
+ galadriel_models.append(key)
584
+ elif value.get("litellm_provider") == "sambanova_models":
585
+ sambanova_models.append(key)
586
+ elif value.get("litellm_provider") == "assemblyai":
587
+ assemblyai_models.append(key)
588
+ elif value.get("litellm_provider") == "jina_ai":
589
+ jina_ai_models.append(key)
590
+ elif value.get("litellm_provider") == "snowflake":
591
+ snowflake_models.append(key)
592
+
593
+
594
+ add_known_models()
595
+ # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
596
+
597
+ # this is maintained for Exception Mapping
598
+
599
+
600
+ # used for Cost Tracking & Token counting
601
+ # https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
602
+ # Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
603
+ azure_llms = {
604
+ "gpt-35-turbo": "azure/gpt-35-turbo",
605
+ "gpt-35-turbo-16k": "azure/gpt-35-turbo-16k",
606
+ "gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
607
+ }
608
+
609
+ azure_embedding_models = {
610
+ "ada": "azure/ada",
611
+ }
612
+
613
+ petals_models = [
614
+ "petals-team/StableBeluga2",
615
+ ]
616
+
617
+ ollama_models = ["llama2"]
618
+
619
+ maritalk_models = ["maritalk"]
620
+
621
+
622
+ model_list = (
623
+ open_ai_chat_completion_models
624
+ + open_ai_text_completion_models
625
+ + cohere_models
626
+ + cohere_chat_models
627
+ + anthropic_models
628
+ + replicate_models
629
+ + openrouter_models
630
+ + huggingface_models
631
+ + vertex_chat_models
632
+ + vertex_text_models
633
+ + ai21_models
634
+ + ai21_chat_models
635
+ + together_ai_models
636
+ + baseten_models
637
+ + aleph_alpha_models
638
+ + nlp_cloud_models
639
+ + ollama_models
640
+ + bedrock_models
641
+ + deepinfra_models
642
+ + perplexity_models
643
+ + maritalk_models
644
+ + vertex_language_models
645
+ + watsonx_models
646
+ + gemini_models
647
+ + text_completion_codestral_models
648
+ + xai_models
649
+ + deepseek_models
650
+ + azure_ai_models
651
+ + voyage_models
652
+ + infinity_models
653
+ + databricks_models
654
+ + cloudflare_models
655
+ + codestral_models
656
+ + friendliai_models
657
+ + palm_models
658
+ + groq_models
659
+ + azure_models
660
+ + anyscale_models
661
+ + cerebras_models
662
+ + galadriel_models
663
+ + sambanova_models
664
+ + azure_text_models
665
+ + assemblyai_models
666
+ + jina_ai_models
667
+ + snowflake_models
668
+ )
669
+
670
+ model_list_set = set(model_list)
671
+
672
+ provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
673
+
674
+
675
+ models_by_provider: dict = {
676
+ "openai": open_ai_chat_completion_models + open_ai_text_completion_models,
677
+ "text-completion-openai": open_ai_text_completion_models,
678
+ "cohere": cohere_models + cohere_chat_models,
679
+ "cohere_chat": cohere_chat_models,
680
+ "anthropic": anthropic_models,
681
+ "replicate": replicate_models,
682
+ "huggingface": huggingface_models,
683
+ "together_ai": together_ai_models,
684
+ "baseten": baseten_models,
685
+ "openrouter": openrouter_models,
686
+ "vertex_ai": vertex_chat_models
687
+ + vertex_text_models
688
+ + vertex_anthropic_models
689
+ + vertex_vision_models
690
+ + vertex_language_models,
691
+ "ai21": ai21_models,
692
+ "bedrock": bedrock_models + bedrock_converse_models,
693
+ "petals": petals_models,
694
+ "ollama": ollama_models,
695
+ "deepinfra": deepinfra_models,
696
+ "perplexity": perplexity_models,
697
+ "maritalk": maritalk_models,
698
+ "watsonx": watsonx_models,
699
+ "gemini": gemini_models,
700
+ "fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
701
+ "aleph_alpha": aleph_alpha_models,
702
+ "text-completion-codestral": text_completion_codestral_models,
703
+ "xai": xai_models,
704
+ "deepseek": deepseek_models,
705
+ "mistral": mistral_chat_models,
706
+ "azure_ai": azure_ai_models,
707
+ "voyage": voyage_models,
708
+ "infinity": infinity_models,
709
+ "databricks": databricks_models,
710
+ "cloudflare": cloudflare_models,
711
+ "codestral": codestral_models,
712
+ "nlp_cloud": nlp_cloud_models,
713
+ "friendliai": friendliai_models,
714
+ "palm": palm_models,
715
+ "groq": groq_models,
716
+ "azure": azure_models + azure_text_models,
717
+ "azure_text": azure_text_models,
718
+ "anyscale": anyscale_models,
719
+ "cerebras": cerebras_models,
720
+ "galadriel": galadriel_models,
721
+ "sambanova": sambanova_models,
722
+ "assemblyai": assemblyai_models,
723
+ "jina_ai": jina_ai_models,
724
+ "snowflake": snowflake_models,
725
+ }
726
+
727
+ # mapping for those models which have larger equivalents
728
+ longer_context_model_fallback_dict: dict = {
729
+ # openai chat completion models
730
+ "gpt-3.5-turbo": "gpt-3.5-turbo-16k",
731
+ "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
732
+ "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
733
+ "gpt-4": "gpt-4-32k",
734
+ "gpt-4-0314": "gpt-4-32k-0314",
735
+ "gpt-4-0613": "gpt-4-32k-0613",
736
+ # anthropic
737
+ "claude-instant-1": "claude-2",
738
+ "claude-instant-1.2": "claude-2",
739
+ # vertexai
740
+ "chat-bison": "chat-bison-32k",
741
+ "chat-bison@001": "chat-bison-32k",
742
+ "codechat-bison": "codechat-bison-32k",
743
+ "codechat-bison@001": "codechat-bison-32k",
744
+ # openrouter
745
+ "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
746
+ "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
747
+ }
748
+
749
+ ####### EMBEDDING MODELS ###################
750
+
751
+ all_embedding_models = (
752
+ open_ai_embedding_models
753
+ + cohere_embedding_models
754
+ + bedrock_embedding_models
755
+ + vertex_embedding_models
756
+ + fireworks_ai_embedding_models
757
+ )
758
+
759
+ ####### IMAGE GENERATION MODELS ###################
760
+ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
761
+
762
+ from .timeout import timeout
763
+ from .cost_calculator import completion_cost
764
+ from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
765
+ from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
766
+ from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
767
+ from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
768
+ from .utils import (
769
+ client,
770
+ exception_type,
771
+ get_optional_params,
772
+ get_response_string,
773
+ token_counter,
774
+ create_pretrained_tokenizer,
775
+ create_tokenizer,
776
+ supports_function_calling,
777
+ supports_web_search,
778
+ supports_response_schema,
779
+ supports_parallel_function_calling,
780
+ supports_vision,
781
+ supports_audio_input,
782
+ supports_audio_output,
783
+ supports_system_messages,
784
+ supports_reasoning,
785
+ get_litellm_params,
786
+ acreate,
787
+ get_max_tokens,
788
+ get_model_info,
789
+ register_prompt_template,
790
+ validate_environment,
791
+ check_valid_key,
792
+ register_model,
793
+ encode,
794
+ decode,
795
+ _calculate_retry_after,
796
+ _should_retry,
797
+ get_supported_openai_params,
798
+ get_api_base,
799
+ get_first_chars_messages,
800
+ ModelResponse,
801
+ ModelResponseStream,
802
+ EmbeddingResponse,
803
+ ImageResponse,
804
+ TranscriptionResponse,
805
+ TextCompletionResponse,
806
+ get_provider_fields,
807
+ ModelResponseListIterator,
808
+ )
809
+
810
+ ALL_LITELLM_RESPONSE_TYPES = [
811
+ ModelResponse,
812
+ EmbeddingResponse,
813
+ ImageResponse,
814
+ TranscriptionResponse,
815
+ TextCompletionResponse,
816
+ ]
817
+
818
+ from .llms.custom_llm import CustomLLM
819
+ from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
820
+ from .llms.openai_like.chat.handler import OpenAILikeChatConfig
821
+ from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
822
+ from .llms.galadriel.chat.transformation import GaladrielChatConfig
823
+ from .llms.github.chat.transformation import GithubChatConfig
824
+ from .llms.empower.chat.transformation import EmpowerChatConfig
825
+ from .llms.huggingface.chat.transformation import HuggingFaceChatConfig
826
+ from .llms.huggingface.embedding.transformation import HuggingFaceEmbeddingConfig
827
+ from .llms.oobabooga.chat.transformation import OobaboogaConfig
828
+ from .llms.maritalk import MaritalkConfig
829
+ from .llms.openrouter.chat.transformation import OpenrouterConfig
830
+ from .llms.anthropic.chat.transformation import AnthropicConfig
831
+ from .llms.anthropic.common_utils import AnthropicModelInfo
832
+ from .llms.groq.stt.transformation import GroqSTTConfig
833
+ from .llms.anthropic.completion.transformation import AnthropicTextConfig
834
+ from .llms.triton.completion.transformation import TritonConfig
835
+ from .llms.triton.completion.transformation import TritonGenerateConfig
836
+ from .llms.triton.completion.transformation import TritonInferConfig
837
+ from .llms.triton.embedding.transformation import TritonEmbeddingConfig
838
+ from .llms.databricks.chat.transformation import DatabricksConfig
839
+ from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
840
+ from .llms.predibase.chat.transformation import PredibaseConfig
841
+ from .llms.replicate.chat.transformation import ReplicateConfig
842
+ from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
843
+ from .llms.snowflake.chat.transformation import SnowflakeConfig
844
+ from .llms.cohere.rerank.transformation import CohereRerankConfig
845
+ from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
846
+ from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
847
+ from .llms.infinity.rerank.transformation import InfinityRerankConfig
848
+ from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
849
+ from .llms.clarifai.chat.transformation import ClarifaiConfig
850
+ from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
851
+ from .llms.anthropic.experimental_pass_through.messages.transformation import (
852
+ AnthropicMessagesConfig,
853
+ )
854
+ from .llms.together_ai.chat import TogetherAIConfig
855
+ from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
856
+ from .llms.cloudflare.chat.transformation import CloudflareChatConfig
857
+ from .llms.deprecated_providers.palm import (
858
+ PalmConfig,
859
+ ) # here to prevent breaking changes
860
+ from .llms.nlp_cloud.chat.handler import NLPCloudConfig
861
+ from .llms.petals.completion.transformation import PetalsConfig
862
+ from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
863
+ from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
864
+ VertexGeminiConfig,
865
+ VertexGeminiConfig as VertexAIConfig,
866
+ )
867
+ from .llms.gemini.common_utils import GeminiModelInfo
868
+ from .llms.gemini.chat.transformation import (
869
+ GoogleAIStudioGeminiConfig,
870
+ GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
871
+ )
872
+
873
+
874
+ from .llms.vertex_ai.vertex_embeddings.transformation import (
875
+ VertexAITextEmbeddingConfig,
876
+ )
877
+
878
+ vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
879
+
880
+ from .llms.vertex_ai.vertex_ai_partner_models.anthropic.transformation import (
881
+ VertexAIAnthropicConfig,
882
+ )
883
+ from .llms.vertex_ai.vertex_ai_partner_models.llama3.transformation import (
884
+ VertexAILlama3Config,
885
+ )
886
+ from .llms.vertex_ai.vertex_ai_partner_models.ai21.transformation import (
887
+ VertexAIAi21Config,
888
+ )
889
+
890
+ from .llms.ollama.completion.transformation import OllamaConfig
891
+ from .llms.sagemaker.completion.transformation import SagemakerConfig
892
+ from .llms.sagemaker.chat.transformation import SagemakerChatConfig
893
+ from .llms.ollama_chat import OllamaChatConfig
894
+ from .llms.bedrock.chat.invoke_handler import (
895
+ AmazonCohereChatConfig,
896
+ bedrock_tool_name_mappings,
897
+ )
898
+
899
+ from .llms.bedrock.common_utils import (
900
+ AmazonBedrockGlobalConfig,
901
+ )
902
+ from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
903
+ AmazonAI21Config,
904
+ )
905
+ from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
906
+ AmazonInvokeNovaConfig,
907
+ )
908
+ from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
909
+ AmazonAnthropicConfig,
910
+ )
911
+ from .llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import (
912
+ AmazonAnthropicClaude3Config,
913
+ )
914
+ from .llms.bedrock.chat.invoke_transformations.amazon_cohere_transformation import (
915
+ AmazonCohereConfig,
916
+ )
917
+ from .llms.bedrock.chat.invoke_transformations.amazon_llama_transformation import (
918
+ AmazonLlamaConfig,
919
+ )
920
+ from .llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
921
+ AmazonDeepSeekR1Config,
922
+ )
923
+ from .llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import (
924
+ AmazonMistralConfig,
925
+ )
926
+ from .llms.bedrock.chat.invoke_transformations.amazon_titan_transformation import (
927
+ AmazonTitanConfig,
928
+ )
929
+ from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
930
+ AmazonInvokeConfig,
931
+ )
932
+
933
+ from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
934
+ from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
935
+ from .llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig
936
+ from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
937
+ from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
938
+ AmazonTitanMultimodalEmbeddingG1Config,
939
+ )
940
+ from .llms.bedrock.embed.amazon_titan_v2_transformation import (
941
+ AmazonTitanV2Config,
942
+ )
943
+ from .llms.cohere.chat.transformation import CohereChatConfig
944
+ from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
945
+ from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
946
+ from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig
947
+ from .llms.deepinfra.chat.transformation import DeepInfraConfig
948
+ from .llms.deepgram.audio_transcription.transformation import (
949
+ DeepgramAudioTranscriptionConfig,
950
+ )
951
+ from .llms.topaz.common_utils import TopazModelInfo
952
+ from .llms.topaz.image_variations.transformation import TopazImageVariationConfig
953
+ from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
954
+ from .llms.groq.chat.transformation import GroqChatConfig
955
+ from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
956
+ from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
957
+ from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
958
+ from .llms.mistral.mistral_chat_transformation import MistralConfig
959
+ from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
960
+ from .llms.azure.responses.transformation import AzureOpenAIResponsesAPIConfig
961
+ from .llms.openai.chat.o_series_transformation import (
962
+ OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
963
+ OpenAIOSeriesConfig,
964
+ )
965
+
966
+ from .llms.snowflake.chat.transformation import SnowflakeConfig
967
+
968
+ openaiOSeriesConfig = OpenAIOSeriesConfig()
969
+ from .llms.openai.chat.gpt_transformation import (
970
+ OpenAIGPTConfig,
971
+ )
972
+ from .llms.openai.transcriptions.whisper_transformation import (
973
+ OpenAIWhisperAudioTranscriptionConfig,
974
+ )
975
+ from .llms.openai.transcriptions.gpt_transformation import (
976
+ OpenAIGPTAudioTranscriptionConfig,
977
+ )
978
+
979
+ openAIGPTConfig = OpenAIGPTConfig()
980
+ from .llms.openai.chat.gpt_audio_transformation import (
981
+ OpenAIGPTAudioConfig,
982
+ )
983
+
984
+ openAIGPTAudioConfig = OpenAIGPTAudioConfig()
985
+
986
+ from .llms.nvidia_nim.chat import NvidiaNimConfig
987
+ from .llms.nvidia_nim.embed import NvidiaNimEmbeddingConfig
988
+
989
+ nvidiaNimConfig = NvidiaNimConfig()
990
+ nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
991
+
992
+ from .llms.cerebras.chat import CerebrasConfig
993
+ from .llms.sambanova.chat import SambanovaConfig
994
+ from .llms.ai21.chat.transformation import AI21ChatConfig
995
+ from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
996
+ from .llms.fireworks_ai.completion.transformation import FireworksAITextCompletionConfig
997
+ from .llms.fireworks_ai.audio_transcription.transformation import (
998
+ FireworksAIAudioTranscriptionConfig,
999
+ )
1000
+ from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
1001
+ FireworksAIEmbeddingConfig,
1002
+ )
1003
+ from .llms.friendliai.chat.transformation import FriendliaiChatConfig
1004
+ from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
1005
+ from .llms.xai.chat.transformation import XAIChatConfig
1006
+ from .llms.xai.common_utils import XAIModelInfo
1007
+ from .llms.volcengine import VolcEngineConfig
1008
+ from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
1009
+ from .llms.azure.azure import (
1010
+ AzureOpenAIError,
1011
+ AzureOpenAIAssistantsAPIConfig,
1012
+ )
1013
+
1014
+ from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
1015
+ from .llms.azure.completion.transformation import AzureOpenAITextConfig
1016
+ from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
1017
+ from .llms.llamafile.chat.transformation import LlamafileChatConfig
1018
+ from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
1019
+ from .llms.vllm.completion.transformation import VLLMConfig
1020
+ from .llms.deepseek.chat.transformation import DeepSeekChatConfig
1021
+ from .llms.lm_studio.chat.transformation import LMStudioChatConfig
1022
+ from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
1023
+ from .llms.perplexity.chat.transformation import PerplexityChatConfig
1024
+ from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
1025
+ from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
1026
+ from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
1027
+ from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
1028
+ from .main import * # type: ignore
1029
+ from .integrations import *
1030
+ from .exceptions import (
1031
+ AuthenticationError,
1032
+ InvalidRequestError,
1033
+ BadRequestError,
1034
+ NotFoundError,
1035
+ RateLimitError,
1036
+ ServiceUnavailableError,
1037
+ OpenAIError,
1038
+ ContextWindowExceededError,
1039
+ ContentPolicyViolationError,
1040
+ BudgetExceededError,
1041
+ APIError,
1042
+ Timeout,
1043
+ APIConnectionError,
1044
+ UnsupportedParamsError,
1045
+ APIResponseValidationError,
1046
+ UnprocessableEntityError,
1047
+ InternalServerError,
1048
+ JSONSchemaValidationError,
1049
+ LITELLM_EXCEPTION_TYPES,
1050
+ MockException,
1051
+ )
1052
+ from .budget_manager import BudgetManager
1053
+ from .proxy.proxy_cli import run_server
1054
+ from .router import Router
1055
+ from .assistants.main import *
1056
+ from .batches.main import *
1057
+ from .batch_completion.main import * # type: ignore
1058
+ from .rerank_api.main import *
1059
+ from .llms.anthropic.experimental_pass_through.messages.handler import *
1060
+ from .responses.main import *
1061
+ from .realtime_api.main import _arealtime
1062
+ from .fine_tuning.main import *
1063
+ from .files.main import *
1064
+ from .scheduler import *
1065
+ from .cost_calculator import response_cost_calculator, cost_per_token
1066
+
1067
+ ### ADAPTERS ###
1068
+ from .types.adapter import AdapterItem
1069
+ import litellm.anthropic_interface as anthropic
1070
+
1071
+ adapters: List[AdapterItem] = []
1072
+
1073
+ ### CUSTOM LLMs ###
1074
+ from .types.llms.custom_llm import CustomLLMItem
1075
+ from .types.utils import GenericStreamingChunk
1076
+
1077
+ custom_provider_map: List[CustomLLMItem] = []
1078
+ _custom_providers: List[str] = (
1079
+ []
1080
+ ) # internal helper util, used to track names of custom providers
1081
+ disable_hf_tokenizer_download: Optional[bool] = (
1082
+ None # disable huggingface tokenizer download. Defaults to openai clk100
1083
+ )
1084
+ global_disable_no_log_param: bool = False
litellm/_logging.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ from logging import Formatter
7
+
8
+ set_verbose = False
9
+
10
+ if set_verbose is True:
11
+ logging.warning(
12
+ "`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
13
+ )
14
+ json_logs = bool(os.getenv("JSON_LOGS", False))
15
+ # Create a handler for the logger (you may need to adapt this based on your needs)
16
+ log_level = os.getenv("LITELLM_LOG", "DEBUG")
17
+ numeric_level: str = getattr(logging, log_level.upper())
18
+ handler = logging.StreamHandler()
19
+ handler.setLevel(numeric_level)
20
+
21
+
22
+ class JsonFormatter(Formatter):
23
+ def __init__(self):
24
+ super(JsonFormatter, self).__init__()
25
+
26
+ def formatTime(self, record, datefmt=None):
27
+ # Use datetime to format the timestamp in ISO 8601 format
28
+ dt = datetime.fromtimestamp(record.created)
29
+ return dt.isoformat()
30
+
31
+ def format(self, record):
32
+ json_record = {
33
+ "message": record.getMessage(),
34
+ "level": record.levelname,
35
+ "timestamp": self.formatTime(record),
36
+ }
37
+
38
+ if record.exc_info:
39
+ json_record["stacktrace"] = self.formatException(record.exc_info)
40
+
41
+ return json.dumps(json_record)
42
+
43
+
44
+ # Function to set up exception handlers for JSON logging
45
+ def _setup_json_exception_handlers(formatter):
46
+ # Create a handler with JSON formatting for exceptions
47
+ error_handler = logging.StreamHandler()
48
+ error_handler.setFormatter(formatter)
49
+
50
+ # Setup excepthook for uncaught exceptions
51
+ def json_excepthook(exc_type, exc_value, exc_traceback):
52
+ record = logging.LogRecord(
53
+ name="LiteLLM",
54
+ level=logging.ERROR,
55
+ pathname="",
56
+ lineno=0,
57
+ msg=str(exc_value),
58
+ args=(),
59
+ exc_info=(exc_type, exc_value, exc_traceback),
60
+ )
61
+ error_handler.handle(record)
62
+
63
+ sys.excepthook = json_excepthook
64
+
65
+ # Configure asyncio exception handler if possible
66
+ try:
67
+ import asyncio
68
+
69
+ def async_json_exception_handler(loop, context):
70
+ exception = context.get("exception")
71
+ if exception:
72
+ record = logging.LogRecord(
73
+ name="LiteLLM",
74
+ level=logging.ERROR,
75
+ pathname="",
76
+ lineno=0,
77
+ msg=str(exception),
78
+ args=(),
79
+ exc_info=None,
80
+ )
81
+ error_handler.handle(record)
82
+ else:
83
+ loop.default_exception_handler(context)
84
+
85
+ asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
86
+ except Exception:
87
+ pass
88
+
89
+
90
+ # Create a formatter and set it for the handler
91
+ if json_logs:
92
+ handler.setFormatter(JsonFormatter())
93
+ _setup_json_exception_handlers(JsonFormatter())
94
+ else:
95
+ formatter = logging.Formatter(
96
+ "\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
97
+ datefmt="%H:%M:%S",
98
+ )
99
+
100
+ handler.setFormatter(formatter)
101
+
102
+ verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
103
+ verbose_router_logger = logging.getLogger("LiteLLM Router")
104
+ verbose_logger = logging.getLogger("LiteLLM")
105
+
106
+ # Add the handler to the logger
107
+ verbose_router_logger.addHandler(handler)
108
+ verbose_proxy_logger.addHandler(handler)
109
+ verbose_logger.addHandler(handler)
110
+
111
+
112
+ def _turn_on_json():
113
+ handler = logging.StreamHandler()
114
+ handler.setFormatter(JsonFormatter())
115
+
116
+ # Define all loggers to update, including root logger
117
+ loggers = [logging.getLogger()] + [
118
+ verbose_router_logger,
119
+ verbose_proxy_logger,
120
+ verbose_logger,
121
+ ]
122
+
123
+ # Iterate through each logger and update its handlers
124
+ for logger in loggers:
125
+ # Remove all existing handlers
126
+ for h in logger.handlers[:]:
127
+ logger.removeHandler(h)
128
+ # Add the new handler
129
+ logger.addHandler(handler)
130
+
131
+ # Set up exception handlers
132
+ _setup_json_exception_handlers(JsonFormatter())
133
+
134
+
135
+ def _turn_on_debug():
136
+ verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
137
+ verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
138
+ verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
139
+
140
+
141
+ def _disable_debugging():
142
+ verbose_logger.disabled = True
143
+ verbose_router_logger.disabled = True
144
+ verbose_proxy_logger.disabled = True
145
+
146
+
147
+ def _enable_debugging():
148
+ verbose_logger.disabled = False
149
+ verbose_router_logger.disabled = False
150
+ verbose_proxy_logger.disabled = False
151
+
152
+
153
+ def print_verbose(print_statement):
154
+ try:
155
+ if set_verbose:
156
+ print(print_statement) # noqa
157
+ except Exception:
158
+ pass
159
+
160
+
161
+ def _is_debugging_on() -> bool:
162
+ """
163
+ Returns True if debugging is on
164
+ """
165
+ if verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True:
166
+ return True
167
+ return False
litellm/_redis.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ import inspect
11
+ import json
12
+
13
+ # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
14
+ import os
15
+ from typing import List, Optional, Union
16
+
17
+ import redis # type: ignore
18
+ import redis.asyncio as async_redis # type: ignore
19
+
20
+ from litellm import get_secret, get_secret_str
21
+ from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
22
+
23
+ from ._logging import verbose_logger
24
+
25
+
26
+ def _get_redis_kwargs():
27
+ arg_spec = inspect.getfullargspec(redis.Redis)
28
+
29
+ # Only allow primitive arguments
30
+ exclude_args = {
31
+ "self",
32
+ "connection_pool",
33
+ "retry",
34
+ }
35
+
36
+ include_args = ["url"]
37
+
38
+ available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
39
+
40
+ return available_args
41
+
42
+
43
+ def _get_redis_url_kwargs(client=None):
44
+ if client is None:
45
+ client = redis.Redis.from_url
46
+ arg_spec = inspect.getfullargspec(redis.Redis.from_url)
47
+
48
+ # Only allow primitive arguments
49
+ exclude_args = {
50
+ "self",
51
+ "connection_pool",
52
+ "retry",
53
+ }
54
+
55
+ include_args = ["url"]
56
+
57
+ available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
58
+
59
+ return available_args
60
+
61
+
62
+ def _get_redis_cluster_kwargs(client=None):
63
+ if client is None:
64
+ client = redis.Redis.from_url
65
+ arg_spec = inspect.getfullargspec(redis.RedisCluster)
66
+
67
+ # Only allow primitive arguments
68
+ exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
69
+
70
+ available_args = [x for x in arg_spec.args if x not in exclude_args]
71
+ available_args.append("password")
72
+ available_args.append("username")
73
+ available_args.append("ssl")
74
+
75
+ return available_args
76
+
77
+
78
+ def _get_redis_env_kwarg_mapping():
79
+ PREFIX = "REDIS_"
80
+
81
+ return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
82
+
83
+
84
+ def _redis_kwargs_from_environment():
85
+ mapping = _get_redis_env_kwarg_mapping()
86
+
87
+ return_dict = {}
88
+ for k, v in mapping.items():
89
+ value = get_secret(k, default_value=None) # type: ignore
90
+ if value is not None:
91
+ return_dict[v] = value
92
+ return return_dict
93
+
94
+
95
+ def get_redis_url_from_environment():
96
+ if "REDIS_URL" in os.environ:
97
+ return os.environ["REDIS_URL"]
98
+
99
+ if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
100
+ raise ValueError(
101
+ "Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
102
+ )
103
+
104
+ if "REDIS_PASSWORD" in os.environ:
105
+ redis_password = f":{os.environ['REDIS_PASSWORD']}@"
106
+ else:
107
+ redis_password = ""
108
+
109
+ return (
110
+ f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
111
+ )
112
+
113
+
114
+ def _get_redis_client_logic(**env_overrides):
115
+ """
116
+ Common functionality across sync + async redis client implementations
117
+ """
118
+ ### check if "os.environ/<key-name>" passed in
119
+ for k, v in env_overrides.items():
120
+ if isinstance(v, str) and v.startswith("os.environ/"):
121
+ v = v.replace("os.environ/", "")
122
+ value = get_secret(v) # type: ignore
123
+ env_overrides[k] = value
124
+
125
+ redis_kwargs = {
126
+ **_redis_kwargs_from_environment(),
127
+ **env_overrides,
128
+ }
129
+
130
+ _startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
131
+ "REDIS_CLUSTER_NODES"
132
+ )
133
+
134
+ if _startup_nodes is not None and isinstance(_startup_nodes, str):
135
+ redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
136
+
137
+ _sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
138
+ "REDIS_SENTINEL_NODES"
139
+ )
140
+
141
+ if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
142
+ redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
143
+
144
+ _sentinel_password: Optional[str] = redis_kwargs.get(
145
+ "sentinel_password", None
146
+ ) or get_secret_str("REDIS_SENTINEL_PASSWORD")
147
+
148
+ if _sentinel_password is not None:
149
+ redis_kwargs["sentinel_password"] = _sentinel_password
150
+
151
+ _service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
152
+ "REDIS_SERVICE_NAME"
153
+ )
154
+
155
+ if _service_name is not None:
156
+ redis_kwargs["service_name"] = _service_name
157
+
158
+ if "url" in redis_kwargs and redis_kwargs["url"] is not None:
159
+ redis_kwargs.pop("host", None)
160
+ redis_kwargs.pop("port", None)
161
+ redis_kwargs.pop("db", None)
162
+ redis_kwargs.pop("password", None)
163
+ elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
164
+ pass
165
+ elif (
166
+ "sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
167
+ ):
168
+ pass
169
+ elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
170
+ raise ValueError("Either 'host' or 'url' must be specified for redis.")
171
+
172
+ # litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
173
+ return redis_kwargs
174
+
175
+
176
+ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
177
+ _redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
178
+ if _redis_cluster_nodes_in_env is not None:
179
+ try:
180
+ redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
181
+ except json.JSONDecodeError:
182
+ raise ValueError(
183
+ "REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
184
+ )
185
+
186
+ verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
187
+ from redis.cluster import ClusterNode
188
+
189
+ args = _get_redis_cluster_kwargs()
190
+ cluster_kwargs = {}
191
+ for arg in redis_kwargs:
192
+ if arg in args:
193
+ cluster_kwargs[arg] = redis_kwargs[arg]
194
+
195
+ new_startup_nodes: List[ClusterNode] = []
196
+
197
+ for item in redis_kwargs["startup_nodes"]:
198
+ new_startup_nodes.append(ClusterNode(**item))
199
+
200
+ redis_kwargs.pop("startup_nodes")
201
+ return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
202
+
203
+
204
+ def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
205
+ sentinel_nodes = redis_kwargs.get("sentinel_nodes")
206
+ sentinel_password = redis_kwargs.get("sentinel_password")
207
+ service_name = redis_kwargs.get("service_name")
208
+
209
+ if not sentinel_nodes or not service_name:
210
+ raise ValueError(
211
+ "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
212
+ )
213
+
214
+ verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
215
+
216
+ # Set up the Sentinel client
217
+ sentinel = redis.Sentinel(
218
+ sentinel_nodes,
219
+ socket_timeout=REDIS_SOCKET_TIMEOUT,
220
+ password=sentinel_password,
221
+ )
222
+
223
+ # Return the master instance for the given service
224
+
225
+ return sentinel.master_for(service_name)
226
+
227
+
228
+ def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
229
+ sentinel_nodes = redis_kwargs.get("sentinel_nodes")
230
+ sentinel_password = redis_kwargs.get("sentinel_password")
231
+ service_name = redis_kwargs.get("service_name")
232
+
233
+ if not sentinel_nodes or not service_name:
234
+ raise ValueError(
235
+ "Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
236
+ )
237
+
238
+ verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
239
+
240
+ # Set up the Sentinel client
241
+ sentinel = async_redis.Sentinel(
242
+ sentinel_nodes,
243
+ socket_timeout=REDIS_SOCKET_TIMEOUT,
244
+ password=sentinel_password,
245
+ )
246
+
247
+ # Return the master instance for the given service
248
+
249
+ return sentinel.master_for(service_name)
250
+
251
+
252
+ def get_redis_client(**env_overrides):
253
+ redis_kwargs = _get_redis_client_logic(**env_overrides)
254
+ if "url" in redis_kwargs and redis_kwargs["url"] is not None:
255
+ args = _get_redis_url_kwargs()
256
+ url_kwargs = {}
257
+ for arg in redis_kwargs:
258
+ if arg in args:
259
+ url_kwargs[arg] = redis_kwargs[arg]
260
+
261
+ return redis.Redis.from_url(**url_kwargs)
262
+
263
+ if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
264
+ return init_redis_cluster(redis_kwargs)
265
+
266
+ # Check for Redis Sentinel
267
+ if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
268
+ return _init_redis_sentinel(redis_kwargs)
269
+
270
+ return redis.Redis(**redis_kwargs)
271
+
272
+
273
+ def get_redis_async_client(
274
+ **env_overrides,
275
+ ) -> async_redis.Redis:
276
+ redis_kwargs = _get_redis_client_logic(**env_overrides)
277
+ if "url" in redis_kwargs and redis_kwargs["url"] is not None:
278
+ args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
279
+ url_kwargs = {}
280
+ for arg in redis_kwargs:
281
+ if arg in args:
282
+ url_kwargs[arg] = redis_kwargs[arg]
283
+ else:
284
+ verbose_logger.debug(
285
+ "REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
286
+ arg
287
+ )
288
+ )
289
+ return async_redis.Redis.from_url(**url_kwargs)
290
+
291
+ if "startup_nodes" in redis_kwargs:
292
+ from redis.cluster import ClusterNode
293
+
294
+ args = _get_redis_cluster_kwargs()
295
+ cluster_kwargs = {}
296
+ for arg in redis_kwargs:
297
+ if arg in args:
298
+ cluster_kwargs[arg] = redis_kwargs[arg]
299
+
300
+ new_startup_nodes: List[ClusterNode] = []
301
+
302
+ for item in redis_kwargs["startup_nodes"]:
303
+ new_startup_nodes.append(ClusterNode(**item))
304
+ redis_kwargs.pop("startup_nodes")
305
+ return async_redis.RedisCluster(
306
+ startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
307
+ )
308
+
309
+ # Check for Redis Sentinel
310
+ if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
311
+ return _init_async_redis_sentinel(redis_kwargs)
312
+
313
+ return async_redis.Redis(
314
+ **redis_kwargs,
315
+ )
316
+
317
+
318
+ def get_redis_connection_pool(**env_overrides):
319
+ redis_kwargs = _get_redis_client_logic(**env_overrides)
320
+ verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
321
+ if "url" in redis_kwargs and redis_kwargs["url"] is not None:
322
+ return async_redis.BlockingConnectionPool.from_url(
323
+ timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
324
+ )
325
+ connection_class = async_redis.Connection
326
+ if "ssl" in redis_kwargs:
327
+ connection_class = async_redis.SSLConnection
328
+ redis_kwargs.pop("ssl", None)
329
+ redis_kwargs["connection_class"] = connection_class
330
+ redis_kwargs.pop("startup_nodes", None)
331
+ return async_redis.BlockingConnectionPool(
332
+ timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
333
+ )
litellm/_service_logger.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from datetime import datetime, timedelta
3
+ from typing import TYPE_CHECKING, Any, Optional, Union
4
+
5
+ import litellm
6
+ from litellm._logging import verbose_logger
7
+ from litellm.proxy._types import UserAPIKeyAuth
8
+
9
+ from .integrations.custom_logger import CustomLogger
10
+ from .integrations.datadog.datadog import DataDogLogger
11
+ from .integrations.opentelemetry import OpenTelemetry
12
+ from .integrations.prometheus_services import PrometheusServicesLogger
13
+ from .types.services import ServiceLoggerPayload, ServiceTypes
14
+
15
+ if TYPE_CHECKING:
16
+ from opentelemetry.trace import Span as _Span
17
+
18
+ Span = Union[_Span, Any]
19
+ OTELClass = OpenTelemetry
20
+ else:
21
+ Span = Any
22
+ OTELClass = Any
23
+
24
+
25
+ class ServiceLogging(CustomLogger):
26
+ """
27
+ Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
28
+ """
29
+
30
+ def __init__(self, mock_testing: bool = False) -> None:
31
+ self.mock_testing = mock_testing
32
+ self.mock_testing_sync_success_hook = 0
33
+ self.mock_testing_async_success_hook = 0
34
+ self.mock_testing_sync_failure_hook = 0
35
+ self.mock_testing_async_failure_hook = 0
36
+ if "prometheus_system" in litellm.service_callback:
37
+ self.prometheusServicesLogger = PrometheusServicesLogger()
38
+
39
+ def service_success_hook(
40
+ self,
41
+ service: ServiceTypes,
42
+ duration: float,
43
+ call_type: str,
44
+ parent_otel_span: Optional[Span] = None,
45
+ start_time: Optional[Union[datetime, float]] = None,
46
+ end_time: Optional[Union[float, datetime]] = None,
47
+ ):
48
+ """
49
+ Handles both sync and async monitoring by checking for existing event loop.
50
+ """
51
+
52
+ if self.mock_testing:
53
+ self.mock_testing_sync_success_hook += 1
54
+
55
+ try:
56
+ # Try to get the current event loop
57
+ loop = asyncio.get_event_loop()
58
+ # Check if the loop is running
59
+ if loop.is_running():
60
+ # If we're in a running loop, create a task
61
+ loop.create_task(
62
+ self.async_service_success_hook(
63
+ service=service,
64
+ duration=duration,
65
+ call_type=call_type,
66
+ parent_otel_span=parent_otel_span,
67
+ start_time=start_time,
68
+ end_time=end_time,
69
+ )
70
+ )
71
+ else:
72
+ # Loop exists but not running, we can use run_until_complete
73
+ loop.run_until_complete(
74
+ self.async_service_success_hook(
75
+ service=service,
76
+ duration=duration,
77
+ call_type=call_type,
78
+ parent_otel_span=parent_otel_span,
79
+ start_time=start_time,
80
+ end_time=end_time,
81
+ )
82
+ )
83
+ except RuntimeError:
84
+ # No event loop exists, create a new one and run
85
+ asyncio.run(
86
+ self.async_service_success_hook(
87
+ service=service,
88
+ duration=duration,
89
+ call_type=call_type,
90
+ parent_otel_span=parent_otel_span,
91
+ start_time=start_time,
92
+ end_time=end_time,
93
+ )
94
+ )
95
+
96
+ def service_failure_hook(
97
+ self, service: ServiceTypes, duration: float, error: Exception, call_type: str
98
+ ):
99
+ """
100
+ [TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
101
+ """
102
+ if self.mock_testing:
103
+ self.mock_testing_sync_failure_hook += 1
104
+
105
+ async def async_service_success_hook(
106
+ self,
107
+ service: ServiceTypes,
108
+ call_type: str,
109
+ duration: float,
110
+ parent_otel_span: Optional[Span] = None,
111
+ start_time: Optional[Union[datetime, float]] = None,
112
+ end_time: Optional[Union[datetime, float]] = None,
113
+ event_metadata: Optional[dict] = None,
114
+ ):
115
+ """
116
+ - For counting if the redis, postgres call is successful
117
+ """
118
+ if self.mock_testing:
119
+ self.mock_testing_async_success_hook += 1
120
+
121
+ payload = ServiceLoggerPayload(
122
+ is_error=False,
123
+ error=None,
124
+ service=service,
125
+ duration=duration,
126
+ call_type=call_type,
127
+ event_metadata=event_metadata,
128
+ )
129
+
130
+ for callback in litellm.service_callback:
131
+ if callback == "prometheus_system":
132
+ await self.init_prometheus_services_logger_if_none()
133
+ await self.prometheusServicesLogger.async_service_success_hook(
134
+ payload=payload
135
+ )
136
+ elif callback == "datadog" or isinstance(callback, DataDogLogger):
137
+ await self.init_datadog_logger_if_none()
138
+ await self.dd_logger.async_service_success_hook(
139
+ payload=payload,
140
+ parent_otel_span=parent_otel_span,
141
+ start_time=start_time,
142
+ end_time=end_time,
143
+ event_metadata=event_metadata,
144
+ )
145
+ elif callback == "otel" or isinstance(callback, OpenTelemetry):
146
+ from litellm.proxy.proxy_server import open_telemetry_logger
147
+
148
+ await self.init_otel_logger_if_none()
149
+
150
+ if (
151
+ parent_otel_span is not None
152
+ and open_telemetry_logger is not None
153
+ and isinstance(open_telemetry_logger, OpenTelemetry)
154
+ ):
155
+ await self.otel_logger.async_service_success_hook(
156
+ payload=payload,
157
+ parent_otel_span=parent_otel_span,
158
+ start_time=start_time,
159
+ end_time=end_time,
160
+ event_metadata=event_metadata,
161
+ )
162
+
163
+ async def init_prometheus_services_logger_if_none(self):
164
+ """
165
+ initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
166
+
167
+ """
168
+ if not hasattr(self, "prometheusServicesLogger"):
169
+ self.prometheusServicesLogger = PrometheusServicesLogger()
170
+ elif self.prometheusServicesLogger is None:
171
+ self.prometheusServicesLogger = self.prometheusServicesLogger()
172
+ return
173
+
174
+ async def init_datadog_logger_if_none(self):
175
+ """
176
+ initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
177
+
178
+ """
179
+ from litellm.integrations.datadog.datadog import DataDogLogger
180
+
181
+ if not hasattr(self, "dd_logger"):
182
+ self.dd_logger: DataDogLogger = DataDogLogger()
183
+
184
+ return
185
+
186
+ async def init_otel_logger_if_none(self):
187
+ """
188
+ initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
189
+
190
+ """
191
+ from litellm.proxy.proxy_server import open_telemetry_logger
192
+
193
+ if not hasattr(self, "otel_logger"):
194
+ if open_telemetry_logger is not None and isinstance(
195
+ open_telemetry_logger, OpenTelemetry
196
+ ):
197
+ self.otel_logger: OpenTelemetry = open_telemetry_logger
198
+ else:
199
+ verbose_logger.warning(
200
+ "ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
201
+ )
202
+ return
203
+
204
+ async def async_service_failure_hook(
205
+ self,
206
+ service: ServiceTypes,
207
+ duration: float,
208
+ error: Union[str, Exception],
209
+ call_type: str,
210
+ parent_otel_span: Optional[Span] = None,
211
+ start_time: Optional[Union[datetime, float]] = None,
212
+ end_time: Optional[Union[float, datetime]] = None,
213
+ event_metadata: Optional[dict] = None,
214
+ ):
215
+ """
216
+ - For counting if the redis, postgres call is unsuccessful
217
+ """
218
+ if self.mock_testing:
219
+ self.mock_testing_async_failure_hook += 1
220
+
221
+ error_message = ""
222
+ if isinstance(error, Exception):
223
+ error_message = str(error)
224
+ elif isinstance(error, str):
225
+ error_message = error
226
+
227
+ payload = ServiceLoggerPayload(
228
+ is_error=True,
229
+ error=error_message,
230
+ service=service,
231
+ duration=duration,
232
+ call_type=call_type,
233
+ event_metadata=event_metadata,
234
+ )
235
+
236
+ for callback in litellm.service_callback:
237
+ if callback == "prometheus_system":
238
+ await self.init_prometheus_services_logger_if_none()
239
+ await self.prometheusServicesLogger.async_service_failure_hook(
240
+ payload=payload,
241
+ error=error,
242
+ )
243
+ elif callback == "datadog" or isinstance(callback, DataDogLogger):
244
+ await self.init_datadog_logger_if_none()
245
+ await self.dd_logger.async_service_failure_hook(
246
+ payload=payload,
247
+ error=error_message,
248
+ parent_otel_span=parent_otel_span,
249
+ start_time=start_time,
250
+ end_time=end_time,
251
+ event_metadata=event_metadata,
252
+ )
253
+ elif callback == "otel" or isinstance(callback, OpenTelemetry):
254
+ from litellm.proxy.proxy_server import open_telemetry_logger
255
+
256
+ await self.init_otel_logger_if_none()
257
+
258
+ if not isinstance(error, str):
259
+ error = str(error)
260
+
261
+ if (
262
+ parent_otel_span is not None
263
+ and open_telemetry_logger is not None
264
+ and isinstance(open_telemetry_logger, OpenTelemetry)
265
+ ):
266
+ await self.otel_logger.async_service_success_hook(
267
+ payload=payload,
268
+ parent_otel_span=parent_otel_span,
269
+ start_time=start_time,
270
+ end_time=end_time,
271
+ event_metadata=event_metadata,
272
+ )
273
+
274
+ async def async_post_call_failure_hook(
275
+ self,
276
+ request_data: dict,
277
+ original_exception: Exception,
278
+ user_api_key_dict: UserAPIKeyAuth,
279
+ ):
280
+ """
281
+ Hook to track failed litellm-service calls
282
+ """
283
+ return await super().async_post_call_failure_hook(
284
+ request_data,
285
+ original_exception,
286
+ user_api_key_dict,
287
+ )
288
+
289
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
290
+ """
291
+ Hook to track latency for litellm proxy llm api calls
292
+ """
293
+ try:
294
+ _duration = end_time - start_time
295
+ if isinstance(_duration, timedelta):
296
+ _duration = _duration.total_seconds()
297
+ elif isinstance(_duration, float):
298
+ pass
299
+ else:
300
+ raise Exception(
301
+ "Duration={} is not a float or timedelta object. type={}".format(
302
+ _duration, type(_duration)
303
+ )
304
+ ) # invalid _duration value
305
+ await self.async_service_success_hook(
306
+ service=ServiceTypes.LITELLM,
307
+ duration=_duration,
308
+ call_type=kwargs["call_type"],
309
+ )
310
+ except Exception as e:
311
+ raise e
litellm/_version.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import importlib_metadata
2
+
3
+ try:
4
+ version = importlib_metadata.version("litellm")
5
+ except Exception:
6
+ version = "unknown"
litellm/anthropic_interface/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic module for LiteLLM
3
+ """
4
+ from .messages import acreate, create
5
+
6
+ __all__ = ["acreate", "create"]
litellm/anthropic_interface/messages/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Interface for Anthropic's messages API
3
+
4
+ Use this to call LLMs in Anthropic /messages Request/Response format
5
+
6
+ This is an __init__.py file to allow the following interface
7
+
8
+ - litellm.messages.acreate
9
+ - litellm.messages.create
10
+
11
+ """
12
+
13
+ from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
14
+
15
+ from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
16
+ anthropic_messages as _async_anthropic_messages,
17
+ )
18
+ from litellm.types.llms.anthropic_messages.anthropic_response import (
19
+ AnthropicMessagesResponse,
20
+ )
21
+
22
+
23
+ async def acreate(
24
+ max_tokens: int,
25
+ messages: List[Dict],
26
+ model: str,
27
+ metadata: Optional[Dict] = None,
28
+ stop_sequences: Optional[List[str]] = None,
29
+ stream: Optional[bool] = False,
30
+ system: Optional[str] = None,
31
+ temperature: Optional[float] = 1.0,
32
+ thinking: Optional[Dict] = None,
33
+ tool_choice: Optional[Dict] = None,
34
+ tools: Optional[List[Dict]] = None,
35
+ top_k: Optional[int] = None,
36
+ top_p: Optional[float] = None,
37
+ **kwargs
38
+ ) -> Union[AnthropicMessagesResponse, AsyncIterator]:
39
+ """
40
+ Async wrapper for Anthropic's messages API
41
+
42
+ Args:
43
+ max_tokens (int): Maximum tokens to generate (required)
44
+ messages (List[Dict]): List of message objects with role and content (required)
45
+ model (str): Model name to use (required)
46
+ metadata (Dict, optional): Request metadata
47
+ stop_sequences (List[str], optional): Custom stop sequences
48
+ stream (bool, optional): Whether to stream the response
49
+ system (str, optional): System prompt
50
+ temperature (float, optional): Sampling temperature (0.0 to 1.0)
51
+ thinking (Dict, optional): Extended thinking configuration
52
+ tool_choice (Dict, optional): Tool choice configuration
53
+ tools (List[Dict], optional): List of tool definitions
54
+ top_k (int, optional): Top K sampling parameter
55
+ top_p (float, optional): Nucleus sampling parameter
56
+ **kwargs: Additional arguments
57
+
58
+ Returns:
59
+ Dict: Response from the API
60
+ """
61
+ return await _async_anthropic_messages(
62
+ max_tokens=max_tokens,
63
+ messages=messages,
64
+ model=model,
65
+ metadata=metadata,
66
+ stop_sequences=stop_sequences,
67
+ stream=stream,
68
+ system=system,
69
+ temperature=temperature,
70
+ thinking=thinking,
71
+ tool_choice=tool_choice,
72
+ tools=tools,
73
+ top_k=top_k,
74
+ top_p=top_p,
75
+ **kwargs,
76
+ )
77
+
78
+
79
+ async def create(
80
+ max_tokens: int,
81
+ messages: List[Dict],
82
+ model: str,
83
+ metadata: Optional[Dict] = None,
84
+ stop_sequences: Optional[List[str]] = None,
85
+ stream: Optional[bool] = False,
86
+ system: Optional[str] = None,
87
+ temperature: Optional[float] = 1.0,
88
+ thinking: Optional[Dict] = None,
89
+ tool_choice: Optional[Dict] = None,
90
+ tools: Optional[List[Dict]] = None,
91
+ top_k: Optional[int] = None,
92
+ top_p: Optional[float] = None,
93
+ **kwargs
94
+ ) -> Union[AnthropicMessagesResponse, Iterator]:
95
+ """
96
+ Async wrapper for Anthropic's messages API
97
+
98
+ Args:
99
+ max_tokens (int): Maximum tokens to generate (required)
100
+ messages (List[Dict]): List of message objects with role and content (required)
101
+ model (str): Model name to use (required)
102
+ metadata (Dict, optional): Request metadata
103
+ stop_sequences (List[str], optional): Custom stop sequences
104
+ stream (bool, optional): Whether to stream the response
105
+ system (str, optional): System prompt
106
+ temperature (float, optional): Sampling temperature (0.0 to 1.0)
107
+ thinking (Dict, optional): Extended thinking configuration
108
+ tool_choice (Dict, optional): Tool choice configuration
109
+ tools (List[Dict], optional): List of tool definitions
110
+ top_k (int, optional): Top K sampling parameter
111
+ top_p (float, optional): Nucleus sampling parameter
112
+ **kwargs: Additional arguments
113
+
114
+ Returns:
115
+ Dict: Response from the API
116
+ """
117
+ raise NotImplementedError("This function is not implemented")
litellm/anthropic_interface/readme.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Use LLM API endpoints in Anthropic Interface
2
+
3
+ Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
4
+
5
+
6
+ ## Usage
7
+ ---
8
+
9
+ ### LiteLLM Python SDK
10
+
11
+ #### Non-streaming example
12
+ ```python showLineNumbers title="Example using LiteLLM Python SDK"
13
+ import litellm
14
+ response = await litellm.anthropic.messages.acreate(
15
+ messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
16
+ api_key=api_key,
17
+ model="anthropic/claude-3-haiku-20240307",
18
+ max_tokens=100,
19
+ )
20
+ ```
21
+
22
+ Example response:
23
+ ```json
24
+ {
25
+ "content": [
26
+ {
27
+ "text": "Hi! this is a very short joke",
28
+ "type": "text"
29
+ }
30
+ ],
31
+ "id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
32
+ "model": "claude-3-7-sonnet-20250219",
33
+ "role": "assistant",
34
+ "stop_reason": "end_turn",
35
+ "stop_sequence": null,
36
+ "type": "message",
37
+ "usage": {
38
+ "input_tokens": 2095,
39
+ "output_tokens": 503,
40
+ "cache_creation_input_tokens": 2095,
41
+ "cache_read_input_tokens": 0
42
+ }
43
+ }
44
+ ```
45
+
46
+ #### Streaming example
47
+ ```python showLineNumbers title="Example using LiteLLM Python SDK"
48
+ import litellm
49
+ response = await litellm.anthropic.messages.acreate(
50
+ messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
51
+ api_key=api_key,
52
+ model="anthropic/claude-3-haiku-20240307",
53
+ max_tokens=100,
54
+ stream=True,
55
+ )
56
+ async for chunk in response:
57
+ print(chunk)
58
+ ```
59
+
60
+ ### LiteLLM Proxy Server
61
+
62
+
63
+ 1. Setup config.yaml
64
+
65
+ ```yaml
66
+ model_list:
67
+ - model_name: anthropic-claude
68
+ litellm_params:
69
+ model: claude-3-7-sonnet-latest
70
+ ```
71
+
72
+ 2. Start proxy
73
+
74
+ ```bash
75
+ litellm --config /path/to/config.yaml
76
+ ```
77
+
78
+ 3. Test it!
79
+
80
+ <Tabs>
81
+ <TabItem label="Anthropic Python SDK" value="python">
82
+
83
+ ```python showLineNumbers title="Example using LiteLLM Proxy Server"
84
+ import anthropic
85
+
86
+ # point anthropic sdk to litellm proxy
87
+ client = anthropic.Anthropic(
88
+ base_url="http://0.0.0.0:4000",
89
+ api_key="sk-1234",
90
+ )
91
+
92
+ response = client.messages.create(
93
+ messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
94
+ model="anthropic/claude-3-haiku-20240307",
95
+ max_tokens=100,
96
+ )
97
+ ```
98
+ </TabItem>
99
+ <TabItem label="curl" value="curl">
100
+
101
+ ```bash showLineNumbers title="Example using LiteLLM Proxy Server"
102
+ curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
103
+ -H 'content-type: application/json' \
104
+ -H 'x-api-key: $LITELLM_API_KEY' \
105
+ -H 'anthropic-version: 2023-06-01' \
106
+ -d '{
107
+ "model": "anthropic-claude",
108
+ "messages": [
109
+ {
110
+ "role": "user",
111
+ "content": "Hello, can you tell me a short joke?"
112
+ }
113
+ ],
114
+ "max_tokens": 100
115
+ }'
116
+ ```
litellm/assistants/main.py ADDED
@@ -0,0 +1,1484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What is this?
2
+ ## Main file for assistants API logic
3
+ import asyncio
4
+ import contextvars
5
+ import os
6
+ from functools import partial
7
+ from typing import Any, Coroutine, Dict, Iterable, List, Literal, Optional, Union
8
+
9
+ import httpx
10
+ from openai import AsyncOpenAI, OpenAI
11
+ from openai.types.beta.assistant import Assistant
12
+ from openai.types.beta.assistant_deleted import AssistantDeleted
13
+
14
+ import litellm
15
+ from litellm.types.router import GenericLiteLLMParams
16
+ from litellm.utils import (
17
+ exception_type,
18
+ get_litellm_params,
19
+ get_llm_provider,
20
+ get_secret,
21
+ supports_httpx_timeout,
22
+ )
23
+
24
+ from ..llms.azure.assistants import AzureAssistantsAPI
25
+ from ..llms.openai.openai import OpenAIAssistantsAPI
26
+ from ..types.llms.openai import *
27
+ from ..types.router import *
28
+ from .utils import get_optional_params_add_message
29
+
30
+ ####### ENVIRONMENT VARIABLES ###################
31
+ openai_assistants_api = OpenAIAssistantsAPI()
32
+ azure_assistants_api = AzureAssistantsAPI()
33
+
34
+ ### ASSISTANTS ###
35
+
36
+
37
+ async def aget_assistants(
38
+ custom_llm_provider: Literal["openai", "azure"],
39
+ client: Optional[AsyncOpenAI] = None,
40
+ **kwargs,
41
+ ) -> AsyncCursorPage[Assistant]:
42
+ loop = asyncio.get_event_loop()
43
+ ### PASS ARGS TO GET ASSISTANTS ###
44
+ kwargs["aget_assistants"] = True
45
+ try:
46
+ # Use a partial function to pass your keyword arguments
47
+ func = partial(get_assistants, custom_llm_provider, client, **kwargs)
48
+
49
+ # Add the context to the function
50
+ ctx = contextvars.copy_context()
51
+ func_with_context = partial(ctx.run, func)
52
+
53
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
54
+ model="", custom_llm_provider=custom_llm_provider
55
+ ) # type: ignore
56
+
57
+ # Await normally
58
+ init_response = await loop.run_in_executor(None, func_with_context)
59
+ if asyncio.iscoroutine(init_response):
60
+ response = await init_response
61
+ else:
62
+ response = init_response
63
+ return response # type: ignore
64
+ except Exception as e:
65
+ raise exception_type(
66
+ model="",
67
+ custom_llm_provider=custom_llm_provider,
68
+ original_exception=e,
69
+ completion_kwargs={},
70
+ extra_kwargs=kwargs,
71
+ )
72
+
73
+
74
+ def get_assistants(
75
+ custom_llm_provider: Literal["openai", "azure"],
76
+ client: Optional[Any] = None,
77
+ api_key: Optional[str] = None,
78
+ api_base: Optional[str] = None,
79
+ api_version: Optional[str] = None,
80
+ **kwargs,
81
+ ) -> SyncCursorPage[Assistant]:
82
+ aget_assistants: Optional[bool] = kwargs.pop("aget_assistants", None)
83
+ if aget_assistants is not None and not isinstance(aget_assistants, bool):
84
+ raise Exception(
85
+ "Invalid value passed in for aget_assistants. Only bool or None allowed"
86
+ )
87
+ optional_params = GenericLiteLLMParams(
88
+ api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
89
+ )
90
+ litellm_params_dict = get_litellm_params(**kwargs)
91
+
92
+ ### TIMEOUT LOGIC ###
93
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
94
+ # set timeout for 10 minutes by default
95
+
96
+ if (
97
+ timeout is not None
98
+ and isinstance(timeout, httpx.Timeout)
99
+ and supports_httpx_timeout(custom_llm_provider) is False
100
+ ):
101
+ read_timeout = timeout.read or 600
102
+ timeout = read_timeout # default 10 min timeout
103
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
104
+ timeout = float(timeout) # type: ignore
105
+ elif timeout is None:
106
+ timeout = 600.0
107
+
108
+ response: Optional[SyncCursorPage[Assistant]] = None
109
+ if custom_llm_provider == "openai":
110
+ api_base = (
111
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
112
+ or litellm.api_base
113
+ or os.getenv("OPENAI_BASE_URL")
114
+ or os.getenv("OPENAI_API_BASE")
115
+ or "https://api.openai.com/v1"
116
+ )
117
+ organization = (
118
+ optional_params.organization
119
+ or litellm.organization
120
+ or os.getenv("OPENAI_ORGANIZATION", None)
121
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
122
+ )
123
+ # set API KEY
124
+ api_key = (
125
+ optional_params.api_key
126
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
127
+ or litellm.openai_key
128
+ or os.getenv("OPENAI_API_KEY")
129
+ )
130
+
131
+ response = openai_assistants_api.get_assistants(
132
+ api_base=api_base,
133
+ api_key=api_key,
134
+ timeout=timeout,
135
+ max_retries=optional_params.max_retries,
136
+ organization=organization,
137
+ client=client,
138
+ aget_assistants=aget_assistants, # type: ignore
139
+ ) # type: ignore
140
+ elif custom_llm_provider == "azure":
141
+ api_base = (
142
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
143
+ ) # type: ignore
144
+
145
+ api_version = (
146
+ optional_params.api_version
147
+ or litellm.api_version
148
+ or get_secret("AZURE_API_VERSION")
149
+ ) # type: ignore
150
+
151
+ api_key = (
152
+ optional_params.api_key
153
+ or litellm.api_key
154
+ or litellm.azure_key
155
+ or get_secret("AZURE_OPENAI_API_KEY")
156
+ or get_secret("AZURE_API_KEY")
157
+ ) # type: ignore
158
+
159
+ extra_body = optional_params.get("extra_body", {})
160
+ azure_ad_token: Optional[str] = None
161
+ if extra_body is not None:
162
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
163
+ else:
164
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
165
+
166
+ response = azure_assistants_api.get_assistants(
167
+ api_base=api_base,
168
+ api_key=api_key,
169
+ api_version=api_version,
170
+ azure_ad_token=azure_ad_token,
171
+ timeout=timeout,
172
+ max_retries=optional_params.max_retries,
173
+ client=client,
174
+ aget_assistants=aget_assistants, # type: ignore
175
+ litellm_params=litellm_params_dict,
176
+ )
177
+ else:
178
+ raise litellm.exceptions.BadRequestError(
179
+ message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
180
+ custom_llm_provider
181
+ ),
182
+ model="n/a",
183
+ llm_provider=custom_llm_provider,
184
+ response=httpx.Response(
185
+ status_code=400,
186
+ content="Unsupported provider",
187
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
188
+ ),
189
+ )
190
+
191
+ if response is None:
192
+ raise litellm.exceptions.BadRequestError(
193
+ message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
194
+ custom_llm_provider
195
+ ),
196
+ model="n/a",
197
+ llm_provider=custom_llm_provider,
198
+ response=httpx.Response(
199
+ status_code=400,
200
+ content="Unsupported provider",
201
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
202
+ ),
203
+ )
204
+
205
+ return response
206
+
207
+
208
+ async def acreate_assistants(
209
+ custom_llm_provider: Literal["openai", "azure"],
210
+ client: Optional[AsyncOpenAI] = None,
211
+ **kwargs,
212
+ ) -> Assistant:
213
+ loop = asyncio.get_event_loop()
214
+ ### PASS ARGS TO GET ASSISTANTS ###
215
+ kwargs["async_create_assistants"] = True
216
+ model = kwargs.pop("model", None)
217
+ try:
218
+ kwargs["client"] = client
219
+ # Use a partial function to pass your keyword arguments
220
+ func = partial(create_assistants, custom_llm_provider, model, **kwargs)
221
+
222
+ # Add the context to the function
223
+ ctx = contextvars.copy_context()
224
+ func_with_context = partial(ctx.run, func)
225
+
226
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
227
+ model=model, custom_llm_provider=custom_llm_provider
228
+ ) # type: ignore
229
+
230
+ # Await normally
231
+ init_response = await loop.run_in_executor(None, func_with_context)
232
+ if asyncio.iscoroutine(init_response):
233
+ response = await init_response
234
+ else:
235
+ response = init_response
236
+ return response # type: ignore
237
+ except Exception as e:
238
+ raise exception_type(
239
+ model=model,
240
+ custom_llm_provider=custom_llm_provider,
241
+ original_exception=e,
242
+ completion_kwargs={},
243
+ extra_kwargs=kwargs,
244
+ )
245
+
246
+
247
+ def create_assistants(
248
+ custom_llm_provider: Literal["openai", "azure"],
249
+ model: str,
250
+ name: Optional[str] = None,
251
+ description: Optional[str] = None,
252
+ instructions: Optional[str] = None,
253
+ tools: Optional[List[Dict[str, Any]]] = None,
254
+ tool_resources: Optional[Dict[str, Any]] = None,
255
+ metadata: Optional[Dict[str, str]] = None,
256
+ temperature: Optional[float] = None,
257
+ top_p: Optional[float] = None,
258
+ response_format: Optional[Union[str, Dict[str, str]]] = None,
259
+ client: Optional[Any] = None,
260
+ api_key: Optional[str] = None,
261
+ api_base: Optional[str] = None,
262
+ api_version: Optional[str] = None,
263
+ **kwargs,
264
+ ) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
265
+ async_create_assistants: Optional[bool] = kwargs.pop(
266
+ "async_create_assistants", None
267
+ )
268
+ if async_create_assistants is not None and not isinstance(
269
+ async_create_assistants, bool
270
+ ):
271
+ raise ValueError(
272
+ "Invalid value passed in for async_create_assistants. Only bool or None allowed"
273
+ )
274
+ optional_params = GenericLiteLLMParams(
275
+ api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
276
+ )
277
+ litellm_params_dict = get_litellm_params(**kwargs)
278
+
279
+ ### TIMEOUT LOGIC ###
280
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
281
+ # set timeout for 10 minutes by default
282
+
283
+ if (
284
+ timeout is not None
285
+ and isinstance(timeout, httpx.Timeout)
286
+ and supports_httpx_timeout(custom_llm_provider) is False
287
+ ):
288
+ read_timeout = timeout.read or 600
289
+ timeout = read_timeout # default 10 min timeout
290
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
291
+ timeout = float(timeout) # type: ignore
292
+ elif timeout is None:
293
+ timeout = 600.0
294
+
295
+ create_assistant_data = {
296
+ "model": model,
297
+ "name": name,
298
+ "description": description,
299
+ "instructions": instructions,
300
+ "tools": tools,
301
+ "tool_resources": tool_resources,
302
+ "metadata": metadata,
303
+ "temperature": temperature,
304
+ "top_p": top_p,
305
+ "response_format": response_format,
306
+ }
307
+
308
+ # only send params that are not None
309
+ create_assistant_data = {
310
+ k: v for k, v in create_assistant_data.items() if v is not None
311
+ }
312
+
313
+ response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
314
+ if custom_llm_provider == "openai":
315
+ api_base = (
316
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
317
+ or litellm.api_base
318
+ or os.getenv("OPENAI_BASE_URL")
319
+ or os.getenv("OPENAI_API_BASE")
320
+ or "https://api.openai.com/v1"
321
+ )
322
+ organization = (
323
+ optional_params.organization
324
+ or litellm.organization
325
+ or os.getenv("OPENAI_ORGANIZATION", None)
326
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
327
+ )
328
+ # set API KEY
329
+ api_key = (
330
+ optional_params.api_key
331
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
332
+ or litellm.openai_key
333
+ or os.getenv("OPENAI_API_KEY")
334
+ )
335
+
336
+ response = openai_assistants_api.create_assistants(
337
+ api_base=api_base,
338
+ api_key=api_key,
339
+ timeout=timeout,
340
+ max_retries=optional_params.max_retries,
341
+ organization=organization,
342
+ create_assistant_data=create_assistant_data,
343
+ client=client,
344
+ async_create_assistants=async_create_assistants, # type: ignore
345
+ ) # type: ignore
346
+ elif custom_llm_provider == "azure":
347
+ api_base = (
348
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
349
+ ) # type: ignore
350
+
351
+ api_version = (
352
+ optional_params.api_version
353
+ or litellm.api_version
354
+ or get_secret("AZURE_API_VERSION")
355
+ ) # type: ignore
356
+
357
+ api_key = (
358
+ optional_params.api_key
359
+ or litellm.api_key
360
+ or litellm.azure_key
361
+ or get_secret("AZURE_OPENAI_API_KEY")
362
+ or get_secret("AZURE_API_KEY")
363
+ ) # type: ignore
364
+
365
+ extra_body = optional_params.get("extra_body", {})
366
+ azure_ad_token: Optional[str] = None
367
+ if extra_body is not None:
368
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
369
+ else:
370
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
371
+
372
+ if isinstance(client, OpenAI):
373
+ client = None # only pass client if it's AzureOpenAI
374
+
375
+ response = azure_assistants_api.create_assistants(
376
+ api_base=api_base,
377
+ api_key=api_key,
378
+ azure_ad_token=azure_ad_token,
379
+ api_version=api_version,
380
+ timeout=timeout,
381
+ max_retries=optional_params.max_retries,
382
+ client=client,
383
+ async_create_assistants=async_create_assistants,
384
+ create_assistant_data=create_assistant_data,
385
+ litellm_params=litellm_params_dict,
386
+ )
387
+ else:
388
+ raise litellm.exceptions.BadRequestError(
389
+ message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
390
+ custom_llm_provider
391
+ ),
392
+ model="n/a",
393
+ llm_provider=custom_llm_provider,
394
+ response=httpx.Response(
395
+ status_code=400,
396
+ content="Unsupported provider",
397
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
398
+ ),
399
+ )
400
+ if response is None:
401
+ raise litellm.exceptions.InternalServerError(
402
+ message="No response returned from 'create_assistants'",
403
+ model=model,
404
+ llm_provider=custom_llm_provider,
405
+ )
406
+ return response
407
+
408
+
409
+ async def adelete_assistant(
410
+ custom_llm_provider: Literal["openai", "azure"],
411
+ client: Optional[AsyncOpenAI] = None,
412
+ **kwargs,
413
+ ) -> AssistantDeleted:
414
+ loop = asyncio.get_event_loop()
415
+ ### PASS ARGS TO GET ASSISTANTS ###
416
+ kwargs["async_delete_assistants"] = True
417
+ try:
418
+ kwargs["client"] = client
419
+ # Use a partial function to pass your keyword arguments
420
+ func = partial(delete_assistant, custom_llm_provider, **kwargs)
421
+
422
+ # Add the context to the function
423
+ ctx = contextvars.copy_context()
424
+ func_with_context = partial(ctx.run, func)
425
+
426
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
427
+ model="", custom_llm_provider=custom_llm_provider
428
+ ) # type: ignore
429
+
430
+ # Await normally
431
+ init_response = await loop.run_in_executor(None, func_with_context)
432
+ if asyncio.iscoroutine(init_response):
433
+ response = await init_response
434
+ else:
435
+ response = init_response
436
+ return response # type: ignore
437
+ except Exception as e:
438
+ raise exception_type(
439
+ model="",
440
+ custom_llm_provider=custom_llm_provider,
441
+ original_exception=e,
442
+ completion_kwargs={},
443
+ extra_kwargs=kwargs,
444
+ )
445
+
446
+
447
+ def delete_assistant(
448
+ custom_llm_provider: Literal["openai", "azure"],
449
+ assistant_id: str,
450
+ client: Optional[Any] = None,
451
+ api_key: Optional[str] = None,
452
+ api_base: Optional[str] = None,
453
+ api_version: Optional[str] = None,
454
+ **kwargs,
455
+ ) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
456
+ optional_params = GenericLiteLLMParams(
457
+ api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
458
+ )
459
+
460
+ litellm_params_dict = get_litellm_params(**kwargs)
461
+
462
+ async_delete_assistants: Optional[bool] = kwargs.pop(
463
+ "async_delete_assistants", None
464
+ )
465
+ if async_delete_assistants is not None and not isinstance(
466
+ async_delete_assistants, bool
467
+ ):
468
+ raise ValueError(
469
+ "Invalid value passed in for async_delete_assistants. Only bool or None allowed"
470
+ )
471
+
472
+ ### TIMEOUT LOGIC ###
473
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
474
+ # set timeout for 10 minutes by default
475
+
476
+ if (
477
+ timeout is not None
478
+ and isinstance(timeout, httpx.Timeout)
479
+ and supports_httpx_timeout(custom_llm_provider) is False
480
+ ):
481
+ read_timeout = timeout.read or 600
482
+ timeout = read_timeout # default 10 min timeout
483
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
484
+ timeout = float(timeout) # type: ignore
485
+ elif timeout is None:
486
+ timeout = 600.0
487
+
488
+ response: Optional[
489
+ Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
490
+ ] = None
491
+ if custom_llm_provider == "openai":
492
+ api_base = (
493
+ optional_params.api_base
494
+ or litellm.api_base
495
+ or os.getenv("OPENAI_BASE_URL")
496
+ or os.getenv("OPENAI_API_BASE")
497
+ or "https://api.openai.com/v1"
498
+ )
499
+ organization = (
500
+ optional_params.organization
501
+ or litellm.organization
502
+ or os.getenv("OPENAI_ORGANIZATION", None)
503
+ or None
504
+ )
505
+ # set API KEY
506
+ api_key = (
507
+ optional_params.api_key
508
+ or litellm.api_key
509
+ or litellm.openai_key
510
+ or os.getenv("OPENAI_API_KEY")
511
+ )
512
+
513
+ response = openai_assistants_api.delete_assistant(
514
+ api_base=api_base,
515
+ api_key=api_key,
516
+ timeout=timeout,
517
+ max_retries=optional_params.max_retries,
518
+ organization=organization,
519
+ assistant_id=assistant_id,
520
+ client=client,
521
+ async_delete_assistants=async_delete_assistants,
522
+ )
523
+ elif custom_llm_provider == "azure":
524
+ api_base = (
525
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
526
+ ) # type: ignore
527
+
528
+ api_version = (
529
+ optional_params.api_version
530
+ or litellm.api_version
531
+ or get_secret("AZURE_API_VERSION")
532
+ ) # type: ignore
533
+
534
+ api_key = (
535
+ optional_params.api_key
536
+ or litellm.api_key
537
+ or litellm.azure_key
538
+ or get_secret("AZURE_OPENAI_API_KEY")
539
+ or get_secret("AZURE_API_KEY")
540
+ ) # type: ignore
541
+
542
+ extra_body = optional_params.get("extra_body", {})
543
+ azure_ad_token: Optional[str] = None
544
+ if extra_body is not None:
545
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
546
+ else:
547
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
548
+
549
+ if isinstance(client, OpenAI):
550
+ client = None # only pass client if it's AzureOpenAI
551
+
552
+ response = azure_assistants_api.delete_assistant(
553
+ assistant_id=assistant_id,
554
+ api_base=api_base,
555
+ api_key=api_key,
556
+ azure_ad_token=azure_ad_token,
557
+ api_version=api_version,
558
+ timeout=timeout,
559
+ max_retries=optional_params.max_retries,
560
+ client=client,
561
+ async_delete_assistants=async_delete_assistants,
562
+ litellm_params=litellm_params_dict,
563
+ )
564
+ else:
565
+ raise litellm.exceptions.BadRequestError(
566
+ message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
567
+ custom_llm_provider
568
+ ),
569
+ model="n/a",
570
+ llm_provider=custom_llm_provider,
571
+ response=httpx.Response(
572
+ status_code=400,
573
+ content="Unsupported provider",
574
+ request=httpx.Request(
575
+ method="delete_assistant", url="https://github.com/BerriAI/litellm"
576
+ ),
577
+ ),
578
+ )
579
+ if response is None:
580
+ raise litellm.exceptions.InternalServerError(
581
+ message="No response returned from 'delete_assistant'",
582
+ model="n/a",
583
+ llm_provider=custom_llm_provider,
584
+ )
585
+ return response
586
+
587
+
588
+ ### THREADS ###
589
+
590
+
591
+ async def acreate_thread(
592
+ custom_llm_provider: Literal["openai", "azure"], **kwargs
593
+ ) -> Thread:
594
+ loop = asyncio.get_event_loop()
595
+ ### PASS ARGS TO GET ASSISTANTS ###
596
+ kwargs["acreate_thread"] = True
597
+ try:
598
+ # Use a partial function to pass your keyword arguments
599
+ func = partial(create_thread, custom_llm_provider, **kwargs)
600
+
601
+ # Add the context to the function
602
+ ctx = contextvars.copy_context()
603
+ func_with_context = partial(ctx.run, func)
604
+
605
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
606
+ model="", custom_llm_provider=custom_llm_provider
607
+ ) # type: ignore
608
+
609
+ # Await normally
610
+ init_response = await loop.run_in_executor(None, func_with_context)
611
+ if asyncio.iscoroutine(init_response):
612
+ response = await init_response
613
+ else:
614
+ response = init_response
615
+ return response # type: ignore
616
+ except Exception as e:
617
+ raise exception_type(
618
+ model="",
619
+ custom_llm_provider=custom_llm_provider,
620
+ original_exception=e,
621
+ completion_kwargs={},
622
+ extra_kwargs=kwargs,
623
+ )
624
+
625
+
626
+ def create_thread(
627
+ custom_llm_provider: Literal["openai", "azure"],
628
+ messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
629
+ metadata: Optional[dict] = None,
630
+ tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
631
+ client: Optional[OpenAI] = None,
632
+ **kwargs,
633
+ ) -> Thread:
634
+ """
635
+ - get the llm provider
636
+ - if openai - route it there
637
+ - pass through relevant params
638
+
639
+ ```
640
+ from litellm import create_thread
641
+
642
+ create_thread(
643
+ custom_llm_provider="openai",
644
+ ### OPTIONAL ###
645
+ messages = {
646
+ "role": "user",
647
+ "content": "Hello, what is AI?"
648
+ },
649
+ {
650
+ "role": "user",
651
+ "content": "How does AI work? Explain it in simple terms."
652
+ }]
653
+ )
654
+ ```
655
+ """
656
+ acreate_thread = kwargs.get("acreate_thread", None)
657
+ optional_params = GenericLiteLLMParams(**kwargs)
658
+ litellm_params_dict = get_litellm_params(**kwargs)
659
+
660
+ ### TIMEOUT LOGIC ###
661
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
662
+ # set timeout for 10 minutes by default
663
+
664
+ if (
665
+ timeout is not None
666
+ and isinstance(timeout, httpx.Timeout)
667
+ and supports_httpx_timeout(custom_llm_provider) is False
668
+ ):
669
+ read_timeout = timeout.read or 600
670
+ timeout = read_timeout # default 10 min timeout
671
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
672
+ timeout = float(timeout) # type: ignore
673
+ elif timeout is None:
674
+ timeout = 600.0
675
+
676
+ api_base: Optional[str] = None
677
+ api_key: Optional[str] = None
678
+
679
+ response: Optional[Thread] = None
680
+ if custom_llm_provider == "openai":
681
+ api_base = (
682
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
683
+ or litellm.api_base
684
+ or os.getenv("OPENAI_BASE_URL")
685
+ or os.getenv("OPENAI_API_BASE")
686
+ or "https://api.openai.com/v1"
687
+ )
688
+ organization = (
689
+ optional_params.organization
690
+ or litellm.organization
691
+ or os.getenv("OPENAI_ORGANIZATION", None)
692
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
693
+ )
694
+ # set API KEY
695
+ api_key = (
696
+ optional_params.api_key
697
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
698
+ or litellm.openai_key
699
+ or os.getenv("OPENAI_API_KEY")
700
+ )
701
+ response = openai_assistants_api.create_thread(
702
+ messages=messages,
703
+ metadata=metadata,
704
+ api_base=api_base,
705
+ api_key=api_key,
706
+ timeout=timeout,
707
+ max_retries=optional_params.max_retries,
708
+ organization=organization,
709
+ client=client,
710
+ acreate_thread=acreate_thread,
711
+ )
712
+ elif custom_llm_provider == "azure":
713
+ api_base = (
714
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
715
+ ) # type: ignore
716
+
717
+ api_key = (
718
+ optional_params.api_key
719
+ or litellm.api_key
720
+ or litellm.azure_key
721
+ or get_secret("AZURE_OPENAI_API_KEY")
722
+ or get_secret("AZURE_API_KEY")
723
+ ) # type: ignore
724
+
725
+ api_version: Optional[str] = (
726
+ optional_params.api_version
727
+ or litellm.api_version
728
+ or get_secret("AZURE_API_VERSION")
729
+ ) # type: ignore
730
+
731
+ extra_body = optional_params.get("extra_body", {})
732
+ azure_ad_token: Optional[str] = None
733
+ if extra_body is not None:
734
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
735
+ else:
736
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
737
+
738
+ if isinstance(client, OpenAI):
739
+ client = None # only pass client if it's AzureOpenAI
740
+
741
+ response = azure_assistants_api.create_thread(
742
+ messages=messages,
743
+ metadata=metadata,
744
+ api_base=api_base,
745
+ api_key=api_key,
746
+ azure_ad_token=azure_ad_token,
747
+ api_version=api_version,
748
+ timeout=timeout,
749
+ max_retries=optional_params.max_retries,
750
+ client=client,
751
+ acreate_thread=acreate_thread,
752
+ litellm_params=litellm_params_dict,
753
+ )
754
+ else:
755
+ raise litellm.exceptions.BadRequestError(
756
+ message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
757
+ custom_llm_provider
758
+ ),
759
+ model="n/a",
760
+ llm_provider=custom_llm_provider,
761
+ response=httpx.Response(
762
+ status_code=400,
763
+ content="Unsupported provider",
764
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
765
+ ),
766
+ )
767
+ return response # type: ignore
768
+
769
+
770
+ async def aget_thread(
771
+ custom_llm_provider: Literal["openai", "azure"],
772
+ thread_id: str,
773
+ client: Optional[AsyncOpenAI] = None,
774
+ **kwargs,
775
+ ) -> Thread:
776
+ loop = asyncio.get_event_loop()
777
+ ### PASS ARGS TO GET ASSISTANTS ###
778
+ kwargs["aget_thread"] = True
779
+ try:
780
+ # Use a partial function to pass your keyword arguments
781
+ func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs)
782
+
783
+ # Add the context to the function
784
+ ctx = contextvars.copy_context()
785
+ func_with_context = partial(ctx.run, func)
786
+
787
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
788
+ model="", custom_llm_provider=custom_llm_provider
789
+ ) # type: ignore
790
+
791
+ # Await normally
792
+ init_response = await loop.run_in_executor(None, func_with_context)
793
+ if asyncio.iscoroutine(init_response):
794
+ response = await init_response
795
+ else:
796
+ response = init_response
797
+ return response # type: ignore
798
+ except Exception as e:
799
+ raise exception_type(
800
+ model="",
801
+ custom_llm_provider=custom_llm_provider,
802
+ original_exception=e,
803
+ completion_kwargs={},
804
+ extra_kwargs=kwargs,
805
+ )
806
+
807
+
808
+ def get_thread(
809
+ custom_llm_provider: Literal["openai", "azure"],
810
+ thread_id: str,
811
+ client=None,
812
+ **kwargs,
813
+ ) -> Thread:
814
+ """Get the thread object, given a thread_id"""
815
+ aget_thread = kwargs.pop("aget_thread", None)
816
+ optional_params = GenericLiteLLMParams(**kwargs)
817
+ litellm_params_dict = get_litellm_params(**kwargs)
818
+ ### TIMEOUT LOGIC ###
819
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
820
+ # set timeout for 10 minutes by default
821
+
822
+ if (
823
+ timeout is not None
824
+ and isinstance(timeout, httpx.Timeout)
825
+ and supports_httpx_timeout(custom_llm_provider) is False
826
+ ):
827
+ read_timeout = timeout.read or 600
828
+ timeout = read_timeout # default 10 min timeout
829
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
830
+ timeout = float(timeout) # type: ignore
831
+ elif timeout is None:
832
+ timeout = 600.0
833
+ api_base: Optional[str] = None
834
+ api_key: Optional[str] = None
835
+ response: Optional[Thread] = None
836
+ if custom_llm_provider == "openai":
837
+ api_base = (
838
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
839
+ or litellm.api_base
840
+ or os.getenv("OPENAI_BASE_URL")
841
+ or os.getenv("OPENAI_API_BASE")
842
+ or "https://api.openai.com/v1"
843
+ )
844
+ organization = (
845
+ optional_params.organization
846
+ or litellm.organization
847
+ or os.getenv("OPENAI_ORGANIZATION", None)
848
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
849
+ )
850
+ # set API KEY
851
+ api_key = (
852
+ optional_params.api_key
853
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
854
+ or litellm.openai_key
855
+ or os.getenv("OPENAI_API_KEY")
856
+ )
857
+
858
+ response = openai_assistants_api.get_thread(
859
+ thread_id=thread_id,
860
+ api_base=api_base,
861
+ api_key=api_key,
862
+ timeout=timeout,
863
+ max_retries=optional_params.max_retries,
864
+ organization=organization,
865
+ client=client,
866
+ aget_thread=aget_thread,
867
+ )
868
+ elif custom_llm_provider == "azure":
869
+ api_base = (
870
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
871
+ ) # type: ignore
872
+
873
+ api_version: Optional[str] = (
874
+ optional_params.api_version
875
+ or litellm.api_version
876
+ or get_secret("AZURE_API_VERSION")
877
+ ) # type: ignore
878
+
879
+ api_key = (
880
+ optional_params.api_key
881
+ or litellm.api_key
882
+ or litellm.azure_key
883
+ or get_secret("AZURE_OPENAI_API_KEY")
884
+ or get_secret("AZURE_API_KEY")
885
+ ) # type: ignore
886
+
887
+ extra_body = optional_params.get("extra_body", {})
888
+ azure_ad_token: Optional[str] = None
889
+ if extra_body is not None:
890
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
891
+ else:
892
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
893
+
894
+ if isinstance(client, OpenAI):
895
+ client = None # only pass client if it's AzureOpenAI
896
+
897
+ response = azure_assistants_api.get_thread(
898
+ thread_id=thread_id,
899
+ api_base=api_base,
900
+ api_key=api_key,
901
+ azure_ad_token=azure_ad_token,
902
+ api_version=api_version,
903
+ timeout=timeout,
904
+ max_retries=optional_params.max_retries,
905
+ client=client,
906
+ aget_thread=aget_thread,
907
+ litellm_params=litellm_params_dict,
908
+ )
909
+ else:
910
+ raise litellm.exceptions.BadRequestError(
911
+ message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
912
+ custom_llm_provider
913
+ ),
914
+ model="n/a",
915
+ llm_provider=custom_llm_provider,
916
+ response=httpx.Response(
917
+ status_code=400,
918
+ content="Unsupported provider",
919
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
920
+ ),
921
+ )
922
+ return response # type: ignore
923
+
924
+
925
+ ### MESSAGES ###
926
+
927
+
928
+ async def a_add_message(
929
+ custom_llm_provider: Literal["openai", "azure"],
930
+ thread_id: str,
931
+ role: Literal["user", "assistant"],
932
+ content: str,
933
+ attachments: Optional[List[Attachment]] = None,
934
+ metadata: Optional[dict] = None,
935
+ client=None,
936
+ **kwargs,
937
+ ) -> OpenAIMessage:
938
+ loop = asyncio.get_event_loop()
939
+ ### PASS ARGS TO GET ASSISTANTS ###
940
+ kwargs["a_add_message"] = True
941
+ try:
942
+ # Use a partial function to pass your keyword arguments
943
+ func = partial(
944
+ add_message,
945
+ custom_llm_provider,
946
+ thread_id,
947
+ role,
948
+ content,
949
+ attachments,
950
+ metadata,
951
+ client,
952
+ **kwargs,
953
+ )
954
+
955
+ # Add the context to the function
956
+ ctx = contextvars.copy_context()
957
+ func_with_context = partial(ctx.run, func)
958
+
959
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
960
+ model="", custom_llm_provider=custom_llm_provider
961
+ ) # type: ignore
962
+
963
+ # Await normally
964
+ init_response = await loop.run_in_executor(None, func_with_context)
965
+ if asyncio.iscoroutine(init_response):
966
+ response = await init_response
967
+ else:
968
+ # Call the synchronous function using run_in_executor
969
+ response = init_response
970
+ return response # type: ignore
971
+ except Exception as e:
972
+ raise exception_type(
973
+ model="",
974
+ custom_llm_provider=custom_llm_provider,
975
+ original_exception=e,
976
+ completion_kwargs={},
977
+ extra_kwargs=kwargs,
978
+ )
979
+
980
+
981
+ def add_message(
982
+ custom_llm_provider: Literal["openai", "azure"],
983
+ thread_id: str,
984
+ role: Literal["user", "assistant"],
985
+ content: str,
986
+ attachments: Optional[List[Attachment]] = None,
987
+ metadata: Optional[dict] = None,
988
+ client=None,
989
+ **kwargs,
990
+ ) -> OpenAIMessage:
991
+ ### COMMON OBJECTS ###
992
+ a_add_message = kwargs.pop("a_add_message", None)
993
+ _message_data = MessageData(
994
+ role=role, content=content, attachments=attachments, metadata=metadata
995
+ )
996
+ litellm_params_dict = get_litellm_params(**kwargs)
997
+ optional_params = GenericLiteLLMParams(**kwargs)
998
+
999
+ message_data = get_optional_params_add_message(
1000
+ role=_message_data["role"],
1001
+ content=_message_data["content"],
1002
+ attachments=_message_data["attachments"],
1003
+ metadata=_message_data["metadata"],
1004
+ custom_llm_provider=custom_llm_provider,
1005
+ )
1006
+
1007
+ ### TIMEOUT LOGIC ###
1008
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
1009
+ # set timeout for 10 minutes by default
1010
+
1011
+ if (
1012
+ timeout is not None
1013
+ and isinstance(timeout, httpx.Timeout)
1014
+ and supports_httpx_timeout(custom_llm_provider) is False
1015
+ ):
1016
+ read_timeout = timeout.read or 600
1017
+ timeout = read_timeout # default 10 min timeout
1018
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
1019
+ timeout = float(timeout) # type: ignore
1020
+ elif timeout is None:
1021
+ timeout = 600.0
1022
+ api_key: Optional[str] = None
1023
+ api_base: Optional[str] = None
1024
+ response: Optional[OpenAIMessage] = None
1025
+ if custom_llm_provider == "openai":
1026
+ api_base = (
1027
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
1028
+ or litellm.api_base
1029
+ or os.getenv("OPENAI_BASE_URL")
1030
+ or os.getenv("OPENAI_API_BASE")
1031
+ or "https://api.openai.com/v1"
1032
+ )
1033
+ organization = (
1034
+ optional_params.organization
1035
+ or litellm.organization
1036
+ or os.getenv("OPENAI_ORGANIZATION", None)
1037
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
1038
+ )
1039
+ # set API KEY
1040
+ api_key = (
1041
+ optional_params.api_key
1042
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
1043
+ or litellm.openai_key
1044
+ or os.getenv("OPENAI_API_KEY")
1045
+ )
1046
+ response = openai_assistants_api.add_message(
1047
+ thread_id=thread_id,
1048
+ message_data=message_data,
1049
+ api_base=api_base,
1050
+ api_key=api_key,
1051
+ timeout=timeout,
1052
+ max_retries=optional_params.max_retries,
1053
+ organization=organization,
1054
+ client=client,
1055
+ a_add_message=a_add_message,
1056
+ )
1057
+ elif custom_llm_provider == "azure":
1058
+ api_base = (
1059
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
1060
+ ) # type: ignore
1061
+
1062
+ api_version: Optional[str] = (
1063
+ optional_params.api_version
1064
+ or litellm.api_version
1065
+ or get_secret("AZURE_API_VERSION")
1066
+ ) # type: ignore
1067
+
1068
+ api_key = (
1069
+ optional_params.api_key
1070
+ or litellm.api_key
1071
+ or litellm.azure_key
1072
+ or get_secret("AZURE_OPENAI_API_KEY")
1073
+ or get_secret("AZURE_API_KEY")
1074
+ ) # type: ignore
1075
+
1076
+ extra_body = optional_params.get("extra_body", {})
1077
+ azure_ad_token: Optional[str] = None
1078
+ if extra_body is not None:
1079
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
1080
+ else:
1081
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
1082
+
1083
+ response = azure_assistants_api.add_message(
1084
+ thread_id=thread_id,
1085
+ message_data=message_data,
1086
+ api_base=api_base,
1087
+ api_key=api_key,
1088
+ api_version=api_version,
1089
+ azure_ad_token=azure_ad_token,
1090
+ timeout=timeout,
1091
+ max_retries=optional_params.max_retries,
1092
+ client=client,
1093
+ a_add_message=a_add_message,
1094
+ litellm_params=litellm_params_dict,
1095
+ )
1096
+ else:
1097
+ raise litellm.exceptions.BadRequestError(
1098
+ message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
1099
+ custom_llm_provider
1100
+ ),
1101
+ model="n/a",
1102
+ llm_provider=custom_llm_provider,
1103
+ response=httpx.Response(
1104
+ status_code=400,
1105
+ content="Unsupported provider",
1106
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
1107
+ ),
1108
+ )
1109
+
1110
+ return response # type: ignore
1111
+
1112
+
1113
+ async def aget_messages(
1114
+ custom_llm_provider: Literal["openai", "azure"],
1115
+ thread_id: str,
1116
+ client: Optional[AsyncOpenAI] = None,
1117
+ **kwargs,
1118
+ ) -> AsyncCursorPage[OpenAIMessage]:
1119
+ loop = asyncio.get_event_loop()
1120
+ ### PASS ARGS TO GET ASSISTANTS ###
1121
+ kwargs["aget_messages"] = True
1122
+ try:
1123
+ # Use a partial function to pass your keyword arguments
1124
+ func = partial(
1125
+ get_messages,
1126
+ custom_llm_provider,
1127
+ thread_id,
1128
+ client,
1129
+ **kwargs,
1130
+ )
1131
+
1132
+ # Add the context to the function
1133
+ ctx = contextvars.copy_context()
1134
+ func_with_context = partial(ctx.run, func)
1135
+
1136
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
1137
+ model="", custom_llm_provider=custom_llm_provider
1138
+ ) # type: ignore
1139
+
1140
+ # Await normally
1141
+ init_response = await loop.run_in_executor(None, func_with_context)
1142
+ if asyncio.iscoroutine(init_response):
1143
+ response = await init_response
1144
+ else:
1145
+ # Call the synchronous function using run_in_executor
1146
+ response = init_response
1147
+ return response # type: ignore
1148
+ except Exception as e:
1149
+ raise exception_type(
1150
+ model="",
1151
+ custom_llm_provider=custom_llm_provider,
1152
+ original_exception=e,
1153
+ completion_kwargs={},
1154
+ extra_kwargs=kwargs,
1155
+ )
1156
+
1157
+
1158
+ def get_messages(
1159
+ custom_llm_provider: Literal["openai", "azure"],
1160
+ thread_id: str,
1161
+ client: Optional[Any] = None,
1162
+ **kwargs,
1163
+ ) -> SyncCursorPage[OpenAIMessage]:
1164
+ aget_messages = kwargs.pop("aget_messages", None)
1165
+ optional_params = GenericLiteLLMParams(**kwargs)
1166
+ litellm_params_dict = get_litellm_params(**kwargs)
1167
+
1168
+ ### TIMEOUT LOGIC ###
1169
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
1170
+ # set timeout for 10 minutes by default
1171
+
1172
+ if (
1173
+ timeout is not None
1174
+ and isinstance(timeout, httpx.Timeout)
1175
+ and supports_httpx_timeout(custom_llm_provider) is False
1176
+ ):
1177
+ read_timeout = timeout.read or 600
1178
+ timeout = read_timeout # default 10 min timeout
1179
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
1180
+ timeout = float(timeout) # type: ignore
1181
+ elif timeout is None:
1182
+ timeout = 600.0
1183
+
1184
+ response: Optional[SyncCursorPage[OpenAIMessage]] = None
1185
+ api_key: Optional[str] = None
1186
+ api_base: Optional[str] = None
1187
+ if custom_llm_provider == "openai":
1188
+ api_base = (
1189
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
1190
+ or litellm.api_base
1191
+ or os.getenv("OPENAI_BASE_URL")
1192
+ or os.getenv("OPENAI_API_BASE")
1193
+ or "https://api.openai.com/v1"
1194
+ )
1195
+ organization = (
1196
+ optional_params.organization
1197
+ or litellm.organization
1198
+ or os.getenv("OPENAI_ORGANIZATION", None)
1199
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
1200
+ )
1201
+ # set API KEY
1202
+ api_key = (
1203
+ optional_params.api_key
1204
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
1205
+ or litellm.openai_key
1206
+ or os.getenv("OPENAI_API_KEY")
1207
+ )
1208
+ response = openai_assistants_api.get_messages(
1209
+ thread_id=thread_id,
1210
+ api_base=api_base,
1211
+ api_key=api_key,
1212
+ timeout=timeout,
1213
+ max_retries=optional_params.max_retries,
1214
+ organization=organization,
1215
+ client=client,
1216
+ aget_messages=aget_messages,
1217
+ )
1218
+ elif custom_llm_provider == "azure":
1219
+ api_base = (
1220
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
1221
+ ) # type: ignore
1222
+
1223
+ api_version: Optional[str] = (
1224
+ optional_params.api_version
1225
+ or litellm.api_version
1226
+ or get_secret("AZURE_API_VERSION")
1227
+ ) # type: ignore
1228
+
1229
+ api_key = (
1230
+ optional_params.api_key
1231
+ or litellm.api_key
1232
+ or litellm.azure_key
1233
+ or get_secret("AZURE_OPENAI_API_KEY")
1234
+ or get_secret("AZURE_API_KEY")
1235
+ ) # type: ignore
1236
+
1237
+ extra_body = optional_params.get("extra_body", {})
1238
+ azure_ad_token: Optional[str] = None
1239
+ if extra_body is not None:
1240
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
1241
+ else:
1242
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
1243
+
1244
+ response = azure_assistants_api.get_messages(
1245
+ thread_id=thread_id,
1246
+ api_base=api_base,
1247
+ api_key=api_key,
1248
+ api_version=api_version,
1249
+ azure_ad_token=azure_ad_token,
1250
+ timeout=timeout,
1251
+ max_retries=optional_params.max_retries,
1252
+ client=client,
1253
+ aget_messages=aget_messages,
1254
+ litellm_params=litellm_params_dict,
1255
+ )
1256
+ else:
1257
+ raise litellm.exceptions.BadRequestError(
1258
+ message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
1259
+ custom_llm_provider
1260
+ ),
1261
+ model="n/a",
1262
+ llm_provider=custom_llm_provider,
1263
+ response=httpx.Response(
1264
+ status_code=400,
1265
+ content="Unsupported provider",
1266
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
1267
+ ),
1268
+ )
1269
+
1270
+ return response # type: ignore
1271
+
1272
+
1273
+ ### RUNS ###
1274
+ def arun_thread_stream(
1275
+ *,
1276
+ event_handler: Optional[AssistantEventHandler] = None,
1277
+ **kwargs,
1278
+ ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
1279
+ kwargs["arun_thread"] = True
1280
+ return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
1281
+
1282
+
1283
+ async def arun_thread(
1284
+ custom_llm_provider: Literal["openai", "azure"],
1285
+ thread_id: str,
1286
+ assistant_id: str,
1287
+ additional_instructions: Optional[str] = None,
1288
+ instructions: Optional[str] = None,
1289
+ metadata: Optional[dict] = None,
1290
+ model: Optional[str] = None,
1291
+ stream: Optional[bool] = None,
1292
+ tools: Optional[Iterable[AssistantToolParam]] = None,
1293
+ client: Optional[Any] = None,
1294
+ **kwargs,
1295
+ ) -> Run:
1296
+ loop = asyncio.get_event_loop()
1297
+ ### PASS ARGS TO GET ASSISTANTS ###
1298
+ kwargs["arun_thread"] = True
1299
+ try:
1300
+ # Use a partial function to pass your keyword arguments
1301
+ func = partial(
1302
+ run_thread,
1303
+ custom_llm_provider,
1304
+ thread_id,
1305
+ assistant_id,
1306
+ additional_instructions,
1307
+ instructions,
1308
+ metadata,
1309
+ model,
1310
+ stream,
1311
+ tools,
1312
+ client,
1313
+ **kwargs,
1314
+ )
1315
+
1316
+ # Add the context to the function
1317
+ ctx = contextvars.copy_context()
1318
+ func_with_context = partial(ctx.run, func)
1319
+
1320
+ _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
1321
+ model="", custom_llm_provider=custom_llm_provider
1322
+ ) # type: ignore
1323
+
1324
+ # Await normally
1325
+ init_response = await loop.run_in_executor(None, func_with_context)
1326
+ if asyncio.iscoroutine(init_response):
1327
+ response = await init_response
1328
+ else:
1329
+ # Call the synchronous function using run_in_executor
1330
+ response = init_response
1331
+ return response # type: ignore
1332
+ except Exception as e:
1333
+ raise exception_type(
1334
+ model="",
1335
+ custom_llm_provider=custom_llm_provider,
1336
+ original_exception=e,
1337
+ completion_kwargs={},
1338
+ extra_kwargs=kwargs,
1339
+ )
1340
+
1341
+
1342
+ def run_thread_stream(
1343
+ *,
1344
+ event_handler: Optional[AssistantEventHandler] = None,
1345
+ **kwargs,
1346
+ ) -> AssistantStreamManager[AssistantEventHandler]:
1347
+ return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
1348
+
1349
+
1350
+ def run_thread(
1351
+ custom_llm_provider: Literal["openai", "azure"],
1352
+ thread_id: str,
1353
+ assistant_id: str,
1354
+ additional_instructions: Optional[str] = None,
1355
+ instructions: Optional[str] = None,
1356
+ metadata: Optional[dict] = None,
1357
+ model: Optional[str] = None,
1358
+ stream: Optional[bool] = None,
1359
+ tools: Optional[Iterable[AssistantToolParam]] = None,
1360
+ client: Optional[Any] = None,
1361
+ event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
1362
+ **kwargs,
1363
+ ) -> Run:
1364
+ """Run a given thread + assistant."""
1365
+ arun_thread = kwargs.pop("arun_thread", None)
1366
+ optional_params = GenericLiteLLMParams(**kwargs)
1367
+ litellm_params_dict = get_litellm_params(**kwargs)
1368
+
1369
+ ### TIMEOUT LOGIC ###
1370
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
1371
+ # set timeout for 10 minutes by default
1372
+
1373
+ if (
1374
+ timeout is not None
1375
+ and isinstance(timeout, httpx.Timeout)
1376
+ and supports_httpx_timeout(custom_llm_provider) is False
1377
+ ):
1378
+ read_timeout = timeout.read or 600
1379
+ timeout = read_timeout # default 10 min timeout
1380
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
1381
+ timeout = float(timeout) # type: ignore
1382
+ elif timeout is None:
1383
+ timeout = 600.0
1384
+
1385
+ response: Optional[Run] = None
1386
+ if custom_llm_provider == "openai":
1387
+ api_base = (
1388
+ optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
1389
+ or litellm.api_base
1390
+ or os.getenv("OPENAI_BASE_URL")
1391
+ or os.getenv("OPENAI_API_BASE")
1392
+ or "https://api.openai.com/v1"
1393
+ )
1394
+ organization = (
1395
+ optional_params.organization
1396
+ or litellm.organization
1397
+ or os.getenv("OPENAI_ORGANIZATION", None)
1398
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
1399
+ )
1400
+ # set API KEY
1401
+ api_key = (
1402
+ optional_params.api_key
1403
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
1404
+ or litellm.openai_key
1405
+ or os.getenv("OPENAI_API_KEY")
1406
+ )
1407
+
1408
+ response = openai_assistants_api.run_thread(
1409
+ thread_id=thread_id,
1410
+ assistant_id=assistant_id,
1411
+ additional_instructions=additional_instructions,
1412
+ instructions=instructions,
1413
+ metadata=metadata,
1414
+ model=model,
1415
+ stream=stream,
1416
+ tools=tools,
1417
+ api_base=api_base,
1418
+ api_key=api_key,
1419
+ timeout=timeout,
1420
+ max_retries=optional_params.max_retries,
1421
+ organization=organization,
1422
+ client=client,
1423
+ arun_thread=arun_thread,
1424
+ event_handler=event_handler,
1425
+ )
1426
+ elif custom_llm_provider == "azure":
1427
+ api_base = (
1428
+ optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
1429
+ ) # type: ignore
1430
+
1431
+ api_version = (
1432
+ optional_params.api_version
1433
+ or litellm.api_version
1434
+ or get_secret("AZURE_API_VERSION")
1435
+ ) # type: ignore
1436
+
1437
+ api_key = (
1438
+ optional_params.api_key
1439
+ or litellm.api_key
1440
+ or litellm.azure_key
1441
+ or get_secret("AZURE_OPENAI_API_KEY")
1442
+ or get_secret("AZURE_API_KEY")
1443
+ ) # type: ignore
1444
+
1445
+ extra_body = optional_params.get("extra_body", {})
1446
+ azure_ad_token = None
1447
+ if extra_body is not None:
1448
+ azure_ad_token = extra_body.pop("azure_ad_token", None)
1449
+ else:
1450
+ azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
1451
+
1452
+ response = azure_assistants_api.run_thread(
1453
+ thread_id=thread_id,
1454
+ assistant_id=assistant_id,
1455
+ additional_instructions=additional_instructions,
1456
+ instructions=instructions,
1457
+ metadata=metadata,
1458
+ model=model,
1459
+ stream=stream,
1460
+ tools=tools,
1461
+ api_base=str(api_base) if api_base is not None else None,
1462
+ api_key=str(api_key) if api_key is not None else None,
1463
+ api_version=str(api_version) if api_version is not None else None,
1464
+ azure_ad_token=str(azure_ad_token) if azure_ad_token is not None else None,
1465
+ timeout=timeout,
1466
+ max_retries=optional_params.max_retries,
1467
+ client=client,
1468
+ arun_thread=arun_thread,
1469
+ litellm_params=litellm_params_dict,
1470
+ ) # type: ignore
1471
+ else:
1472
+ raise litellm.exceptions.BadRequestError(
1473
+ message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
1474
+ custom_llm_provider
1475
+ ),
1476
+ model="n/a",
1477
+ llm_provider=custom_llm_provider,
1478
+ response=httpx.Response(
1479
+ status_code=400,
1480
+ content="Unsupported provider",
1481
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
1482
+ ),
1483
+ )
1484
+ return response # type: ignore
litellm/assistants/utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import litellm
4
+
5
+ from ..exceptions import UnsupportedParamsError
6
+ from ..types.llms.openai import *
7
+
8
+
9
+ def get_optional_params_add_message(
10
+ role: Optional[str],
11
+ content: Optional[
12
+ Union[
13
+ str,
14
+ List[
15
+ Union[
16
+ MessageContentTextObject,
17
+ MessageContentImageFileObject,
18
+ MessageContentImageURLObject,
19
+ ]
20
+ ],
21
+ ]
22
+ ],
23
+ attachments: Optional[List[Attachment]],
24
+ metadata: Optional[dict],
25
+ custom_llm_provider: str,
26
+ **kwargs,
27
+ ):
28
+ """
29
+ Azure doesn't support 'attachments' for creating a message
30
+
31
+ Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
32
+ """
33
+ passed_params = locals()
34
+ custom_llm_provider = passed_params.pop("custom_llm_provider")
35
+ special_params = passed_params.pop("kwargs")
36
+ for k, v in special_params.items():
37
+ passed_params[k] = v
38
+
39
+ default_params = {
40
+ "role": None,
41
+ "content": None,
42
+ "attachments": None,
43
+ "metadata": None,
44
+ }
45
+
46
+ non_default_params = {
47
+ k: v
48
+ for k, v in passed_params.items()
49
+ if (k in default_params and v != default_params[k])
50
+ }
51
+ optional_params = {}
52
+
53
+ ## raise exception if non-default value passed for non-openai/azure embedding calls
54
+ def _check_valid_arg(supported_params):
55
+ if len(non_default_params.keys()) > 0:
56
+ keys = list(non_default_params.keys())
57
+ for k in keys:
58
+ if (
59
+ litellm.drop_params is True and k not in supported_params
60
+ ): # drop the unsupported non-default values
61
+ non_default_params.pop(k, None)
62
+ elif k not in supported_params:
63
+ raise litellm.utils.UnsupportedParamsError(
64
+ status_code=500,
65
+ message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
66
+ k, custom_llm_provider, supported_params
67
+ ),
68
+ )
69
+ return non_default_params
70
+
71
+ if custom_llm_provider == "openai":
72
+ optional_params = non_default_params
73
+ elif custom_llm_provider == "azure":
74
+ supported_params = (
75
+ litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
76
+ )
77
+ _check_valid_arg(supported_params=supported_params)
78
+ optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
79
+ non_default_params=non_default_params, optional_params=optional_params
80
+ )
81
+ for k in passed_params.keys():
82
+ if k not in default_params.keys():
83
+ optional_params[k] = passed_params[k]
84
+ return optional_params
85
+
86
+
87
+ def get_optional_params_image_gen(
88
+ n: Optional[int] = None,
89
+ quality: Optional[str] = None,
90
+ response_format: Optional[str] = None,
91
+ size: Optional[str] = None,
92
+ style: Optional[str] = None,
93
+ user: Optional[str] = None,
94
+ custom_llm_provider: Optional[str] = None,
95
+ **kwargs,
96
+ ):
97
+ # retrieve all parameters passed to the function
98
+ passed_params = locals()
99
+ custom_llm_provider = passed_params.pop("custom_llm_provider")
100
+ special_params = passed_params.pop("kwargs")
101
+ for k, v in special_params.items():
102
+ passed_params[k] = v
103
+
104
+ default_params = {
105
+ "n": None,
106
+ "quality": None,
107
+ "response_format": None,
108
+ "size": None,
109
+ "style": None,
110
+ "user": None,
111
+ }
112
+
113
+ non_default_params = {
114
+ k: v
115
+ for k, v in passed_params.items()
116
+ if (k in default_params and v != default_params[k])
117
+ }
118
+ optional_params = {}
119
+
120
+ ## raise exception if non-default value passed for non-openai/azure embedding calls
121
+ def _check_valid_arg(supported_params):
122
+ if len(non_default_params.keys()) > 0:
123
+ keys = list(non_default_params.keys())
124
+ for k in keys:
125
+ if (
126
+ litellm.drop_params is True and k not in supported_params
127
+ ): # drop the unsupported non-default values
128
+ non_default_params.pop(k, None)
129
+ elif k not in supported_params:
130
+ raise UnsupportedParamsError(
131
+ status_code=500,
132
+ message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
133
+ )
134
+ return non_default_params
135
+
136
+ if (
137
+ custom_llm_provider == "openai"
138
+ or custom_llm_provider == "azure"
139
+ or custom_llm_provider in litellm.openai_compatible_providers
140
+ ):
141
+ optional_params = non_default_params
142
+ elif custom_llm_provider == "bedrock":
143
+ supported_params = ["size"]
144
+ _check_valid_arg(supported_params=supported_params)
145
+ if size is not None:
146
+ width, height = size.split("x")
147
+ optional_params["width"] = int(width)
148
+ optional_params["height"] = int(height)
149
+ elif custom_llm_provider == "vertex_ai":
150
+ supported_params = ["n"]
151
+ """
152
+ All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
153
+ """
154
+ _check_valid_arg(supported_params=supported_params)
155
+ if n is not None:
156
+ optional_params["sampleCount"] = int(n)
157
+
158
+ for k in passed_params.keys():
159
+ if k not in default_params.keys():
160
+ optional_params[k] = passed_params[k]
161
+ return optional_params
litellm/batch_completion/Readme.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
2
+
3
+ Doc: https://docs.litellm.ai/docs/completion/batching
4
+
5
+
6
+ LiteLLM Python SDK allows you to:
7
+ 1. `litellm.batch_completion` Batch litellm.completion function for a given model.
8
+ 2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
9
+ as soon as one of the models responds.
10
+ 3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
11
+ from all models that respond.
litellm/batch_completion/main.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
2
+ from typing import List, Optional
3
+
4
+ import litellm
5
+ from litellm._logging import print_verbose
6
+ from litellm.utils import get_optional_params
7
+
8
+ from ..llms.vllm.completion import handler as vllm_handler
9
+
10
+
11
+ def batch_completion(
12
+ model: str,
13
+ # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
14
+ messages: List = [],
15
+ functions: Optional[List] = None,
16
+ function_call: Optional[str] = None,
17
+ temperature: Optional[float] = None,
18
+ top_p: Optional[float] = None,
19
+ n: Optional[int] = None,
20
+ stream: Optional[bool] = None,
21
+ stop=None,
22
+ max_tokens: Optional[int] = None,
23
+ presence_penalty: Optional[float] = None,
24
+ frequency_penalty: Optional[float] = None,
25
+ logit_bias: Optional[dict] = None,
26
+ user: Optional[str] = None,
27
+ deployment_id=None,
28
+ request_timeout: Optional[int] = None,
29
+ timeout: Optional[int] = 600,
30
+ max_workers: Optional[int] = 100,
31
+ # Optional liteLLM function params
32
+ **kwargs,
33
+ ):
34
+ """
35
+ Batch litellm.completion function for a given model.
36
+
37
+ Args:
38
+ model (str): The model to use for generating completions.
39
+ messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
40
+ functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
41
+ function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
42
+ temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
43
+ top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
44
+ n (int, optional): The number of completions to generate. Defaults to None.
45
+ stream (bool, optional): Whether to stream completions or not. Defaults to None.
46
+ stop (optional): The stop parameter for generating completions. Defaults to None.
47
+ max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
48
+ presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
49
+ frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
50
+ logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
51
+ user (str, optional): The user string for generating completions. Defaults to "".
52
+ deployment_id (optional): The deployment ID for generating completions. Defaults to None.
53
+ request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
54
+ max_workers (int,optional): The maximum number of threads to use for parallel processing.
55
+
56
+ Returns:
57
+ list: A list of completion results.
58
+ """
59
+ args = locals()
60
+
61
+ batch_messages = messages
62
+ completions = []
63
+ model = model
64
+ custom_llm_provider = None
65
+ if model.split("/", 1)[0] in litellm.provider_list:
66
+ custom_llm_provider = model.split("/", 1)[0]
67
+ model = model.split("/", 1)[1]
68
+ if custom_llm_provider == "vllm":
69
+ optional_params = get_optional_params(
70
+ functions=functions,
71
+ function_call=function_call,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ n=n,
75
+ stream=stream or False,
76
+ stop=stop,
77
+ max_tokens=max_tokens,
78
+ presence_penalty=presence_penalty,
79
+ frequency_penalty=frequency_penalty,
80
+ logit_bias=logit_bias,
81
+ user=user,
82
+ # params to identify the model
83
+ model=model,
84
+ custom_llm_provider=custom_llm_provider,
85
+ )
86
+ results = vllm_handler.batch_completions(
87
+ model=model,
88
+ messages=batch_messages,
89
+ custom_prompt_dict=litellm.custom_prompt_dict,
90
+ optional_params=optional_params,
91
+ )
92
+ # all non VLLM models for batch completion models
93
+ else:
94
+
95
+ def chunks(lst, n):
96
+ """Yield successive n-sized chunks from lst."""
97
+ for i in range(0, len(lst), n):
98
+ yield lst[i : i + n]
99
+
100
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
101
+ for sub_batch in chunks(batch_messages, 100):
102
+ for message_list in sub_batch:
103
+ kwargs_modified = args.copy()
104
+ kwargs_modified.pop("max_workers")
105
+ kwargs_modified["messages"] = message_list
106
+ original_kwargs = {}
107
+ if "kwargs" in kwargs_modified:
108
+ original_kwargs = kwargs_modified.pop("kwargs")
109
+ future = executor.submit(
110
+ litellm.completion, **kwargs_modified, **original_kwargs
111
+ )
112
+ completions.append(future)
113
+
114
+ # Retrieve the results from the futures
115
+ # results = [future.result() for future in completions]
116
+ # return exceptions if any
117
+ results = []
118
+ for future in completions:
119
+ try:
120
+ results.append(future.result())
121
+ except Exception as exc:
122
+ results.append(exc)
123
+
124
+ return results
125
+
126
+
127
+ # send one request to multiple models
128
+ # return as soon as one of the llms responds
129
+ def batch_completion_models(*args, **kwargs):
130
+ """
131
+ Send a request to multiple language models concurrently and return the response
132
+ as soon as one of the models responds.
133
+
134
+ Args:
135
+ *args: Variable-length positional arguments passed to the completion function.
136
+ **kwargs: Additional keyword arguments:
137
+ - models (str or list of str): The language models to send requests to.
138
+ - Other keyword arguments to be passed to the completion function.
139
+
140
+ Returns:
141
+ str or None: The response from one of the language models, or None if no response is received.
142
+
143
+ Note:
144
+ This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
145
+ It sends requests concurrently and returns the response from the first model that responds.
146
+ """
147
+
148
+ if "model" in kwargs:
149
+ kwargs.pop("model")
150
+ if "models" in kwargs:
151
+ models = kwargs["models"]
152
+ kwargs.pop("models")
153
+ futures = {}
154
+ with ThreadPoolExecutor(max_workers=len(models)) as executor:
155
+ for model in models:
156
+ futures[model] = executor.submit(
157
+ litellm.completion, *args, model=model, **kwargs
158
+ )
159
+
160
+ for model, future in sorted(
161
+ futures.items(), key=lambda x: models.index(x[0])
162
+ ):
163
+ if future.result() is not None:
164
+ return future.result()
165
+ elif "deployments" in kwargs:
166
+ deployments = kwargs["deployments"]
167
+ kwargs.pop("deployments")
168
+ kwargs.pop("model_list")
169
+ nested_kwargs = kwargs.pop("kwargs", {})
170
+ futures = {}
171
+ with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
172
+ for deployment in deployments:
173
+ for key in kwargs.keys():
174
+ if (
175
+ key not in deployment
176
+ ): # don't override deployment values e.g. model name, api base, etc.
177
+ deployment[key] = kwargs[key]
178
+ kwargs = {**deployment, **nested_kwargs}
179
+ futures[deployment["model"]] = executor.submit(
180
+ litellm.completion, **kwargs
181
+ )
182
+
183
+ while futures:
184
+ # wait for the first returned future
185
+ print_verbose("\n\n waiting for next result\n\n")
186
+ done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
187
+ print_verbose(f"done list\n{done}")
188
+ for future in done:
189
+ try:
190
+ result = future.result()
191
+ return result
192
+ except Exception:
193
+ # if model 1 fails, continue with response from model 2, model3
194
+ print_verbose(
195
+ "\n\ngot an exception, ignoring, removing from futures"
196
+ )
197
+ print_verbose(futures)
198
+ new_futures = {}
199
+ for key, value in futures.items():
200
+ if future == value:
201
+ print_verbose(f"removing key{key}")
202
+ continue
203
+ else:
204
+ new_futures[key] = value
205
+ futures = new_futures
206
+ print_verbose(f"new futures{futures}")
207
+ continue
208
+
209
+ print_verbose("\n\ndone looping through futures\n\n")
210
+ print_verbose(futures)
211
+
212
+ return None # If no response is received from any model
213
+
214
+
215
+ def batch_completion_models_all_responses(*args, **kwargs):
216
+ """
217
+ Send a request to multiple language models concurrently and return a list of responses
218
+ from all models that respond.
219
+
220
+ Args:
221
+ *args: Variable-length positional arguments passed to the completion function.
222
+ **kwargs: Additional keyword arguments:
223
+ - models (str or list of str): The language models to send requests to.
224
+ - Other keyword arguments to be passed to the completion function.
225
+
226
+ Returns:
227
+ list: A list of responses from the language models that responded.
228
+
229
+ Note:
230
+ This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
231
+ It sends requests concurrently and collects responses from all models that respond.
232
+ """
233
+ import concurrent.futures
234
+
235
+ # ANSI escape codes for colored output
236
+
237
+ if "model" in kwargs:
238
+ kwargs.pop("model")
239
+ if "models" in kwargs:
240
+ models = kwargs["models"]
241
+ kwargs.pop("models")
242
+ else:
243
+ raise Exception("'models' param not in kwargs")
244
+
245
+ responses = []
246
+
247
+ with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
248
+ for idx, model in enumerate(models):
249
+ future = executor.submit(litellm.completion, *args, model=model, **kwargs)
250
+ if future.result() is not None:
251
+ responses.append(future.result())
252
+
253
+ return responses
litellm/batches/batch_utils.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any, List, Literal, Tuple
3
+
4
+ import litellm
5
+ from litellm._logging import verbose_logger
6
+ from litellm.types.llms.openai import Batch
7
+ from litellm.types.utils import CallTypes, Usage
8
+
9
+
10
+ async def _handle_completed_batch(
11
+ batch: Batch,
12
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
13
+ ) -> Tuple[float, Usage, List[str]]:
14
+ """Helper function to process a completed batch and handle logging"""
15
+ # Get batch results
16
+ file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
17
+ batch, custom_llm_provider
18
+ )
19
+
20
+ # Calculate costs and usage
21
+ batch_cost = await _batch_cost_calculator(
22
+ custom_llm_provider=custom_llm_provider,
23
+ file_content_dictionary=file_content_dictionary,
24
+ )
25
+ batch_usage = _get_batch_job_total_usage_from_file_content(
26
+ file_content_dictionary=file_content_dictionary,
27
+ custom_llm_provider=custom_llm_provider,
28
+ )
29
+
30
+ batch_models = _get_batch_models_from_file_content(file_content_dictionary)
31
+
32
+ return batch_cost, batch_usage, batch_models
33
+
34
+
35
+ def _get_batch_models_from_file_content(
36
+ file_content_dictionary: List[dict],
37
+ ) -> List[str]:
38
+ """
39
+ Get the models from the file content
40
+ """
41
+ batch_models = []
42
+ for _item in file_content_dictionary:
43
+ if _batch_response_was_successful(_item):
44
+ _response_body = _get_response_from_batch_job_output_file(_item)
45
+ _model = _response_body.get("model")
46
+ if _model:
47
+ batch_models.append(_model)
48
+ return batch_models
49
+
50
+
51
+ async def _batch_cost_calculator(
52
+ file_content_dictionary: List[dict],
53
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
54
+ ) -> float:
55
+ """
56
+ Calculate the cost of a batch based on the output file id
57
+ """
58
+ if custom_llm_provider == "vertex_ai":
59
+ raise ValueError("Vertex AI does not support file content retrieval")
60
+ total_cost = _get_batch_job_cost_from_file_content(
61
+ file_content_dictionary=file_content_dictionary,
62
+ custom_llm_provider=custom_llm_provider,
63
+ )
64
+ verbose_logger.debug("total_cost=%s", total_cost)
65
+ return total_cost
66
+
67
+
68
+ async def _get_batch_output_file_content_as_dictionary(
69
+ batch: Batch,
70
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
71
+ ) -> List[dict]:
72
+ """
73
+ Get the batch output file content as a list of dictionaries
74
+ """
75
+ from litellm.files.main import afile_content
76
+
77
+ if custom_llm_provider == "vertex_ai":
78
+ raise ValueError("Vertex AI does not support file content retrieval")
79
+
80
+ if batch.output_file_id is None:
81
+ raise ValueError("Output file id is None cannot retrieve file content")
82
+
83
+ _file_content = await afile_content(
84
+ file_id=batch.output_file_id,
85
+ custom_llm_provider=custom_llm_provider,
86
+ )
87
+ return _get_file_content_as_dictionary(_file_content.content)
88
+
89
+
90
+ def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
91
+ """
92
+ Get the file content as a list of dictionaries from JSON Lines format
93
+ """
94
+ try:
95
+ _file_content_str = file_content.decode("utf-8")
96
+ # Split by newlines and parse each line as a separate JSON object
97
+ json_objects = []
98
+ for line in _file_content_str.strip().split("\n"):
99
+ if line: # Skip empty lines
100
+ json_objects.append(json.loads(line))
101
+ verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
102
+ return json_objects
103
+ except Exception as e:
104
+ raise e
105
+
106
+
107
+ def _get_batch_job_cost_from_file_content(
108
+ file_content_dictionary: List[dict],
109
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
110
+ ) -> float:
111
+ """
112
+ Get the cost of a batch job from the file content
113
+ """
114
+ try:
115
+ total_cost: float = 0.0
116
+ # parse the file content as json
117
+ verbose_logger.debug(
118
+ "file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
119
+ )
120
+ for _item in file_content_dictionary:
121
+ if _batch_response_was_successful(_item):
122
+ _response_body = _get_response_from_batch_job_output_file(_item)
123
+ total_cost += litellm.completion_cost(
124
+ completion_response=_response_body,
125
+ custom_llm_provider=custom_llm_provider,
126
+ call_type=CallTypes.aretrieve_batch.value,
127
+ )
128
+ verbose_logger.debug("total_cost=%s", total_cost)
129
+ return total_cost
130
+ except Exception as e:
131
+ verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
132
+ raise e
133
+
134
+
135
+ def _get_batch_job_total_usage_from_file_content(
136
+ file_content_dictionary: List[dict],
137
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
138
+ ) -> Usage:
139
+ """
140
+ Get the tokens of a batch job from the file content
141
+ """
142
+ total_tokens: int = 0
143
+ prompt_tokens: int = 0
144
+ completion_tokens: int = 0
145
+ for _item in file_content_dictionary:
146
+ if _batch_response_was_successful(_item):
147
+ _response_body = _get_response_from_batch_job_output_file(_item)
148
+ usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
149
+ total_tokens += usage.total_tokens
150
+ prompt_tokens += usage.prompt_tokens
151
+ completion_tokens += usage.completion_tokens
152
+ return Usage(
153
+ total_tokens=total_tokens,
154
+ prompt_tokens=prompt_tokens,
155
+ completion_tokens=completion_tokens,
156
+ )
157
+
158
+
159
+ def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
160
+ """
161
+ Get the tokens of a batch job from the response body
162
+ """
163
+ _usage_dict = response_body.get("usage", None) or {}
164
+ usage: Usage = Usage(**_usage_dict)
165
+ return usage
166
+
167
+
168
+ def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
169
+ """
170
+ Get the response from the batch job output file
171
+ """
172
+ _response: dict = batch_job_output_file.get("response", None) or {}
173
+ _response_body = _response.get("body", None) or {}
174
+ return _response_body
175
+
176
+
177
+ def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
178
+ """
179
+ Check if the batch job response status == 200
180
+ """
181
+ _response: dict = batch_job_output_file.get("response", None) or {}
182
+ return _response.get("status_code", None) == 200
litellm/batches/main.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main File for Batches API implementation
3
+
4
+ https://platform.openai.com/docs/api-reference/batch
5
+
6
+ - create_batch()
7
+ - retrieve_batch()
8
+ - cancel_batch()
9
+ - list_batch()
10
+
11
+ """
12
+
13
+ import asyncio
14
+ import contextvars
15
+ import os
16
+ from functools import partial
17
+ from typing import Any, Coroutine, Dict, Literal, Optional, Union
18
+
19
+ import httpx
20
+
21
+ import litellm
22
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
23
+ from litellm.llms.azure.batches.handler import AzureBatchesAPI
24
+ from litellm.llms.openai.openai import OpenAIBatchesAPI
25
+ from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
26
+ from litellm.secret_managers.main import get_secret_str
27
+ from litellm.types.llms.openai import (
28
+ Batch,
29
+ CancelBatchRequest,
30
+ CreateBatchRequest,
31
+ RetrieveBatchRequest,
32
+ )
33
+ from litellm.types.router import GenericLiteLLMParams
34
+ from litellm.types.utils import LiteLLMBatch
35
+ from litellm.utils import client, get_litellm_params, supports_httpx_timeout
36
+
37
+ ####### ENVIRONMENT VARIABLES ###################
38
+ openai_batches_instance = OpenAIBatchesAPI()
39
+ azure_batches_instance = AzureBatchesAPI()
40
+ vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
41
+ #################################################
42
+
43
+
44
+ @client
45
+ async def acreate_batch(
46
+ completion_window: Literal["24h"],
47
+ endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
48
+ input_file_id: str,
49
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
50
+ metadata: Optional[Dict[str, str]] = None,
51
+ extra_headers: Optional[Dict[str, str]] = None,
52
+ extra_body: Optional[Dict[str, str]] = None,
53
+ **kwargs,
54
+ ) -> Batch:
55
+ """
56
+ Async: Creates and executes a batch from an uploaded file of request
57
+
58
+ LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
59
+ """
60
+ try:
61
+ loop = asyncio.get_event_loop()
62
+ kwargs["acreate_batch"] = True
63
+
64
+ # Use a partial function to pass your keyword arguments
65
+ func = partial(
66
+ create_batch,
67
+ completion_window,
68
+ endpoint,
69
+ input_file_id,
70
+ custom_llm_provider,
71
+ metadata,
72
+ extra_headers,
73
+ extra_body,
74
+ **kwargs,
75
+ )
76
+
77
+ # Add the context to the function
78
+ ctx = contextvars.copy_context()
79
+ func_with_context = partial(ctx.run, func)
80
+ init_response = await loop.run_in_executor(None, func_with_context)
81
+
82
+ if asyncio.iscoroutine(init_response):
83
+ response = await init_response
84
+ else:
85
+ response = init_response
86
+
87
+ return response
88
+ except Exception as e:
89
+ raise e
90
+
91
+
92
+ @client
93
+ def create_batch(
94
+ completion_window: Literal["24h"],
95
+ endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
96
+ input_file_id: str,
97
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
98
+ metadata: Optional[Dict[str, str]] = None,
99
+ extra_headers: Optional[Dict[str, str]] = None,
100
+ extra_body: Optional[Dict[str, str]] = None,
101
+ **kwargs,
102
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
103
+ """
104
+ Creates and executes a batch from an uploaded file of request
105
+
106
+ LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
107
+ """
108
+ try:
109
+ optional_params = GenericLiteLLMParams(**kwargs)
110
+ litellm_call_id = kwargs.get("litellm_call_id", None)
111
+ proxy_server_request = kwargs.get("proxy_server_request", None)
112
+ model_info = kwargs.get("model_info", None)
113
+ _is_async = kwargs.pop("acreate_batch", False) is True
114
+ litellm_params = get_litellm_params(**kwargs)
115
+ litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
116
+ ### TIMEOUT LOGIC ###
117
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
118
+ litellm_logging_obj.update_environment_variables(
119
+ model=None,
120
+ user=None,
121
+ optional_params=optional_params.model_dump(),
122
+ litellm_params={
123
+ "litellm_call_id": litellm_call_id,
124
+ "proxy_server_request": proxy_server_request,
125
+ "model_info": model_info,
126
+ "metadata": metadata,
127
+ "preset_cache_key": None,
128
+ "stream_response": {},
129
+ **optional_params.model_dump(exclude_unset=True),
130
+ },
131
+ custom_llm_provider=custom_llm_provider,
132
+ )
133
+
134
+ if (
135
+ timeout is not None
136
+ and isinstance(timeout, httpx.Timeout)
137
+ and supports_httpx_timeout(custom_llm_provider) is False
138
+ ):
139
+ read_timeout = timeout.read or 600
140
+ timeout = read_timeout # default 10 min timeout
141
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
142
+ timeout = float(timeout) # type: ignore
143
+ elif timeout is None:
144
+ timeout = 600.0
145
+
146
+ _create_batch_request = CreateBatchRequest(
147
+ completion_window=completion_window,
148
+ endpoint=endpoint,
149
+ input_file_id=input_file_id,
150
+ metadata=metadata,
151
+ extra_headers=extra_headers,
152
+ extra_body=extra_body,
153
+ )
154
+ api_base: Optional[str] = None
155
+ if custom_llm_provider == "openai":
156
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
157
+ api_base = (
158
+ optional_params.api_base
159
+ or litellm.api_base
160
+ or os.getenv("OPENAI_BASE_URL")
161
+ or os.getenv("OPENAI_API_BASE")
162
+ or "https://api.openai.com/v1"
163
+ )
164
+ organization = (
165
+ optional_params.organization
166
+ or litellm.organization
167
+ or os.getenv("OPENAI_ORGANIZATION", None)
168
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
169
+ )
170
+ # set API KEY
171
+ api_key = (
172
+ optional_params.api_key
173
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
174
+ or litellm.openai_key
175
+ or os.getenv("OPENAI_API_KEY")
176
+ )
177
+
178
+ response = openai_batches_instance.create_batch(
179
+ api_base=api_base,
180
+ api_key=api_key,
181
+ organization=organization,
182
+ create_batch_data=_create_batch_request,
183
+ timeout=timeout,
184
+ max_retries=optional_params.max_retries,
185
+ _is_async=_is_async,
186
+ )
187
+ elif custom_llm_provider == "azure":
188
+ api_base = (
189
+ optional_params.api_base
190
+ or litellm.api_base
191
+ or get_secret_str("AZURE_API_BASE")
192
+ )
193
+ api_version = (
194
+ optional_params.api_version
195
+ or litellm.api_version
196
+ or get_secret_str("AZURE_API_VERSION")
197
+ )
198
+
199
+ api_key = (
200
+ optional_params.api_key
201
+ or litellm.api_key
202
+ or litellm.azure_key
203
+ or get_secret_str("AZURE_OPENAI_API_KEY")
204
+ or get_secret_str("AZURE_API_KEY")
205
+ )
206
+
207
+ extra_body = optional_params.get("extra_body", {})
208
+ if extra_body is not None:
209
+ extra_body.pop("azure_ad_token", None)
210
+ else:
211
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
212
+
213
+ response = azure_batches_instance.create_batch(
214
+ _is_async=_is_async,
215
+ api_base=api_base,
216
+ api_key=api_key,
217
+ api_version=api_version,
218
+ timeout=timeout,
219
+ max_retries=optional_params.max_retries,
220
+ create_batch_data=_create_batch_request,
221
+ litellm_params=litellm_params,
222
+ )
223
+ elif custom_llm_provider == "vertex_ai":
224
+ api_base = optional_params.api_base or ""
225
+ vertex_ai_project = (
226
+ optional_params.vertex_project
227
+ or litellm.vertex_project
228
+ or get_secret_str("VERTEXAI_PROJECT")
229
+ )
230
+ vertex_ai_location = (
231
+ optional_params.vertex_location
232
+ or litellm.vertex_location
233
+ or get_secret_str("VERTEXAI_LOCATION")
234
+ )
235
+ vertex_credentials = optional_params.vertex_credentials or get_secret_str(
236
+ "VERTEXAI_CREDENTIALS"
237
+ )
238
+
239
+ response = vertex_ai_batches_instance.create_batch(
240
+ _is_async=_is_async,
241
+ api_base=api_base,
242
+ vertex_project=vertex_ai_project,
243
+ vertex_location=vertex_ai_location,
244
+ vertex_credentials=vertex_credentials,
245
+ timeout=timeout,
246
+ max_retries=optional_params.max_retries,
247
+ create_batch_data=_create_batch_request,
248
+ )
249
+ else:
250
+ raise litellm.exceptions.BadRequestError(
251
+ message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
252
+ custom_llm_provider
253
+ ),
254
+ model="n/a",
255
+ llm_provider=custom_llm_provider,
256
+ response=httpx.Response(
257
+ status_code=400,
258
+ content="Unsupported provider",
259
+ request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
260
+ ),
261
+ )
262
+ return response
263
+ except Exception as e:
264
+ raise e
265
+
266
+
267
+ @client
268
+ async def aretrieve_batch(
269
+ batch_id: str,
270
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
271
+ metadata: Optional[Dict[str, str]] = None,
272
+ extra_headers: Optional[Dict[str, str]] = None,
273
+ extra_body: Optional[Dict[str, str]] = None,
274
+ **kwargs,
275
+ ) -> LiteLLMBatch:
276
+ """
277
+ Async: Retrieves a batch.
278
+
279
+ LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
280
+ """
281
+ try:
282
+ loop = asyncio.get_event_loop()
283
+ kwargs["aretrieve_batch"] = True
284
+
285
+ # Use a partial function to pass your keyword arguments
286
+ func = partial(
287
+ retrieve_batch,
288
+ batch_id,
289
+ custom_llm_provider,
290
+ metadata,
291
+ extra_headers,
292
+ extra_body,
293
+ **kwargs,
294
+ )
295
+ # Add the context to the function
296
+ ctx = contextvars.copy_context()
297
+ func_with_context = partial(ctx.run, func)
298
+ init_response = await loop.run_in_executor(None, func_with_context)
299
+ if asyncio.iscoroutine(init_response):
300
+ response = await init_response
301
+ else:
302
+ response = init_response # type: ignore
303
+
304
+ return response
305
+ except Exception as e:
306
+ raise e
307
+
308
+
309
+ @client
310
+ def retrieve_batch(
311
+ batch_id: str,
312
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
313
+ metadata: Optional[Dict[str, str]] = None,
314
+ extra_headers: Optional[Dict[str, str]] = None,
315
+ extra_body: Optional[Dict[str, str]] = None,
316
+ **kwargs,
317
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
318
+ """
319
+ Retrieves a batch.
320
+
321
+ LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
322
+ """
323
+ try:
324
+ optional_params = GenericLiteLLMParams(**kwargs)
325
+ litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
326
+ ### TIMEOUT LOGIC ###
327
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
328
+ litellm_params = get_litellm_params(
329
+ custom_llm_provider=custom_llm_provider,
330
+ **kwargs,
331
+ )
332
+ litellm_logging_obj.update_environment_variables(
333
+ model=None,
334
+ user=None,
335
+ optional_params=optional_params.model_dump(),
336
+ litellm_params=litellm_params,
337
+ custom_llm_provider=custom_llm_provider,
338
+ )
339
+
340
+ if (
341
+ timeout is not None
342
+ and isinstance(timeout, httpx.Timeout)
343
+ and supports_httpx_timeout(custom_llm_provider) is False
344
+ ):
345
+ read_timeout = timeout.read or 600
346
+ timeout = read_timeout # default 10 min timeout
347
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
348
+ timeout = float(timeout) # type: ignore
349
+ elif timeout is None:
350
+ timeout = 600.0
351
+
352
+ _retrieve_batch_request = RetrieveBatchRequest(
353
+ batch_id=batch_id,
354
+ extra_headers=extra_headers,
355
+ extra_body=extra_body,
356
+ )
357
+
358
+ _is_async = kwargs.pop("aretrieve_batch", False) is True
359
+ api_base: Optional[str] = None
360
+ if custom_llm_provider == "openai":
361
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
362
+ api_base = (
363
+ optional_params.api_base
364
+ or litellm.api_base
365
+ or os.getenv("OPENAI_BASE_URL")
366
+ or os.getenv("OPENAI_API_BASE")
367
+ or "https://api.openai.com/v1"
368
+ )
369
+ organization = (
370
+ optional_params.organization
371
+ or litellm.organization
372
+ or os.getenv("OPENAI_ORGANIZATION", None)
373
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
374
+ )
375
+ # set API KEY
376
+ api_key = (
377
+ optional_params.api_key
378
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
379
+ or litellm.openai_key
380
+ or os.getenv("OPENAI_API_KEY")
381
+ )
382
+
383
+ response = openai_batches_instance.retrieve_batch(
384
+ _is_async=_is_async,
385
+ retrieve_batch_data=_retrieve_batch_request,
386
+ api_base=api_base,
387
+ api_key=api_key,
388
+ organization=organization,
389
+ timeout=timeout,
390
+ max_retries=optional_params.max_retries,
391
+ )
392
+ elif custom_llm_provider == "azure":
393
+ api_base = (
394
+ optional_params.api_base
395
+ or litellm.api_base
396
+ or get_secret_str("AZURE_API_BASE")
397
+ )
398
+ api_version = (
399
+ optional_params.api_version
400
+ or litellm.api_version
401
+ or get_secret_str("AZURE_API_VERSION")
402
+ )
403
+
404
+ api_key = (
405
+ optional_params.api_key
406
+ or litellm.api_key
407
+ or litellm.azure_key
408
+ or get_secret_str("AZURE_OPENAI_API_KEY")
409
+ or get_secret_str("AZURE_API_KEY")
410
+ )
411
+
412
+ extra_body = optional_params.get("extra_body", {})
413
+ if extra_body is not None:
414
+ extra_body.pop("azure_ad_token", None)
415
+ else:
416
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
417
+
418
+ response = azure_batches_instance.retrieve_batch(
419
+ _is_async=_is_async,
420
+ api_base=api_base,
421
+ api_key=api_key,
422
+ api_version=api_version,
423
+ timeout=timeout,
424
+ max_retries=optional_params.max_retries,
425
+ retrieve_batch_data=_retrieve_batch_request,
426
+ litellm_params=litellm_params,
427
+ )
428
+ elif custom_llm_provider == "vertex_ai":
429
+ api_base = optional_params.api_base or ""
430
+ vertex_ai_project = (
431
+ optional_params.vertex_project
432
+ or litellm.vertex_project
433
+ or get_secret_str("VERTEXAI_PROJECT")
434
+ )
435
+ vertex_ai_location = (
436
+ optional_params.vertex_location
437
+ or litellm.vertex_location
438
+ or get_secret_str("VERTEXAI_LOCATION")
439
+ )
440
+ vertex_credentials = optional_params.vertex_credentials or get_secret_str(
441
+ "VERTEXAI_CREDENTIALS"
442
+ )
443
+
444
+ response = vertex_ai_batches_instance.retrieve_batch(
445
+ _is_async=_is_async,
446
+ batch_id=batch_id,
447
+ api_base=api_base,
448
+ vertex_project=vertex_ai_project,
449
+ vertex_location=vertex_ai_location,
450
+ vertex_credentials=vertex_credentials,
451
+ timeout=timeout,
452
+ max_retries=optional_params.max_retries,
453
+ )
454
+ else:
455
+ raise litellm.exceptions.BadRequestError(
456
+ message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
457
+ custom_llm_provider
458
+ ),
459
+ model="n/a",
460
+ llm_provider=custom_llm_provider,
461
+ response=httpx.Response(
462
+ status_code=400,
463
+ content="Unsupported provider",
464
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
465
+ ),
466
+ )
467
+ return response
468
+ except Exception as e:
469
+ raise e
470
+
471
+
472
+ async def alist_batches(
473
+ after: Optional[str] = None,
474
+ limit: Optional[int] = None,
475
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
476
+ metadata: Optional[Dict[str, str]] = None,
477
+ extra_headers: Optional[Dict[str, str]] = None,
478
+ extra_body: Optional[Dict[str, str]] = None,
479
+ **kwargs,
480
+ ):
481
+ """
482
+ Async: List your organization's batches.
483
+ """
484
+ try:
485
+ loop = asyncio.get_event_loop()
486
+ kwargs["alist_batches"] = True
487
+
488
+ # Use a partial function to pass your keyword arguments
489
+ func = partial(
490
+ list_batches,
491
+ after,
492
+ limit,
493
+ custom_llm_provider,
494
+ extra_headers,
495
+ extra_body,
496
+ **kwargs,
497
+ )
498
+
499
+ # Add the context to the function
500
+ ctx = contextvars.copy_context()
501
+ func_with_context = partial(ctx.run, func)
502
+ init_response = await loop.run_in_executor(None, func_with_context)
503
+ if asyncio.iscoroutine(init_response):
504
+ response = await init_response
505
+ else:
506
+ response = init_response # type: ignore
507
+
508
+ return response
509
+ except Exception as e:
510
+ raise e
511
+
512
+
513
+ def list_batches(
514
+ after: Optional[str] = None,
515
+ limit: Optional[int] = None,
516
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
517
+ extra_headers: Optional[Dict[str, str]] = None,
518
+ extra_body: Optional[Dict[str, str]] = None,
519
+ **kwargs,
520
+ ):
521
+ """
522
+ Lists batches
523
+
524
+ List your organization's batches.
525
+ """
526
+ try:
527
+ # set API KEY
528
+ optional_params = GenericLiteLLMParams(**kwargs)
529
+ litellm_params = get_litellm_params(
530
+ custom_llm_provider=custom_llm_provider,
531
+ **kwargs,
532
+ )
533
+ api_key = (
534
+ optional_params.api_key
535
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
536
+ or litellm.openai_key
537
+ or os.getenv("OPENAI_API_KEY")
538
+ )
539
+ ### TIMEOUT LOGIC ###
540
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
541
+ # set timeout for 10 minutes by default
542
+
543
+ if (
544
+ timeout is not None
545
+ and isinstance(timeout, httpx.Timeout)
546
+ and supports_httpx_timeout(custom_llm_provider) is False
547
+ ):
548
+ read_timeout = timeout.read or 600
549
+ timeout = read_timeout # default 10 min timeout
550
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
551
+ timeout = float(timeout) # type: ignore
552
+ elif timeout is None:
553
+ timeout = 600.0
554
+
555
+ _is_async = kwargs.pop("alist_batches", False) is True
556
+ if custom_llm_provider == "openai":
557
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
558
+ api_base = (
559
+ optional_params.api_base
560
+ or litellm.api_base
561
+ or os.getenv("OPENAI_BASE_URL")
562
+ or os.getenv("OPENAI_API_BASE")
563
+ or "https://api.openai.com/v1"
564
+ )
565
+ organization = (
566
+ optional_params.organization
567
+ or litellm.organization
568
+ or os.getenv("OPENAI_ORGANIZATION", None)
569
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
570
+ )
571
+
572
+ response = openai_batches_instance.list_batches(
573
+ _is_async=_is_async,
574
+ after=after,
575
+ limit=limit,
576
+ api_base=api_base,
577
+ api_key=api_key,
578
+ organization=organization,
579
+ timeout=timeout,
580
+ max_retries=optional_params.max_retries,
581
+ )
582
+ elif custom_llm_provider == "azure":
583
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
584
+ api_version = (
585
+ optional_params.api_version
586
+ or litellm.api_version
587
+ or get_secret_str("AZURE_API_VERSION")
588
+ )
589
+
590
+ api_key = (
591
+ optional_params.api_key
592
+ or litellm.api_key
593
+ or litellm.azure_key
594
+ or get_secret_str("AZURE_OPENAI_API_KEY")
595
+ or get_secret_str("AZURE_API_KEY")
596
+ )
597
+
598
+ extra_body = optional_params.get("extra_body", {})
599
+ if extra_body is not None:
600
+ extra_body.pop("azure_ad_token", None)
601
+ else:
602
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
603
+
604
+ response = azure_batches_instance.list_batches(
605
+ _is_async=_is_async,
606
+ api_base=api_base,
607
+ api_key=api_key,
608
+ api_version=api_version,
609
+ timeout=timeout,
610
+ max_retries=optional_params.max_retries,
611
+ litellm_params=litellm_params,
612
+ )
613
+ else:
614
+ raise litellm.exceptions.BadRequestError(
615
+ message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
616
+ custom_llm_provider
617
+ ),
618
+ model="n/a",
619
+ llm_provider=custom_llm_provider,
620
+ response=httpx.Response(
621
+ status_code=400,
622
+ content="Unsupported provider",
623
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
624
+ ),
625
+ )
626
+ return response
627
+ except Exception as e:
628
+ raise e
629
+
630
+
631
+ async def acancel_batch(
632
+ batch_id: str,
633
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
634
+ metadata: Optional[Dict[str, str]] = None,
635
+ extra_headers: Optional[Dict[str, str]] = None,
636
+ extra_body: Optional[Dict[str, str]] = None,
637
+ **kwargs,
638
+ ) -> Batch:
639
+ """
640
+ Async: Cancels a batch.
641
+
642
+ LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
643
+ """
644
+ try:
645
+ loop = asyncio.get_event_loop()
646
+ kwargs["acancel_batch"] = True
647
+
648
+ # Use a partial function to pass your keyword arguments
649
+ func = partial(
650
+ cancel_batch,
651
+ batch_id,
652
+ custom_llm_provider,
653
+ metadata,
654
+ extra_headers,
655
+ extra_body,
656
+ **kwargs,
657
+ )
658
+ # Add the context to the function
659
+ ctx = contextvars.copy_context()
660
+ func_with_context = partial(ctx.run, func)
661
+ init_response = await loop.run_in_executor(None, func_with_context)
662
+ if asyncio.iscoroutine(init_response):
663
+ response = await init_response
664
+ else:
665
+ response = init_response
666
+
667
+ return response
668
+ except Exception as e:
669
+ raise e
670
+
671
+
672
+ def cancel_batch(
673
+ batch_id: str,
674
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
675
+ metadata: Optional[Dict[str, str]] = None,
676
+ extra_headers: Optional[Dict[str, str]] = None,
677
+ extra_body: Optional[Dict[str, str]] = None,
678
+ **kwargs,
679
+ ) -> Union[Batch, Coroutine[Any, Any, Batch]]:
680
+ """
681
+ Cancels a batch.
682
+
683
+ LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
684
+ """
685
+ try:
686
+ optional_params = GenericLiteLLMParams(**kwargs)
687
+ litellm_params = get_litellm_params(
688
+ custom_llm_provider=custom_llm_provider,
689
+ **kwargs,
690
+ )
691
+ ### TIMEOUT LOGIC ###
692
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
693
+ # set timeout for 10 minutes by default
694
+
695
+ if (
696
+ timeout is not None
697
+ and isinstance(timeout, httpx.Timeout)
698
+ and supports_httpx_timeout(custom_llm_provider) is False
699
+ ):
700
+ read_timeout = timeout.read or 600
701
+ timeout = read_timeout # default 10 min timeout
702
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
703
+ timeout = float(timeout) # type: ignore
704
+ elif timeout is None:
705
+ timeout = 600.0
706
+
707
+ _cancel_batch_request = CancelBatchRequest(
708
+ batch_id=batch_id,
709
+ extra_headers=extra_headers,
710
+ extra_body=extra_body,
711
+ )
712
+
713
+ _is_async = kwargs.pop("acancel_batch", False) is True
714
+ api_base: Optional[str] = None
715
+ if custom_llm_provider == "openai":
716
+ api_base = (
717
+ optional_params.api_base
718
+ or litellm.api_base
719
+ or os.getenv("OPENAI_BASE_URL")
720
+ or os.getenv("OPENAI_API_BASE")
721
+ or "https://api.openai.com/v1"
722
+ )
723
+ organization = (
724
+ optional_params.organization
725
+ or litellm.organization
726
+ or os.getenv("OPENAI_ORGANIZATION", None)
727
+ or None
728
+ )
729
+ api_key = (
730
+ optional_params.api_key
731
+ or litellm.api_key
732
+ or litellm.openai_key
733
+ or os.getenv("OPENAI_API_KEY")
734
+ )
735
+
736
+ response = openai_batches_instance.cancel_batch(
737
+ _is_async=_is_async,
738
+ cancel_batch_data=_cancel_batch_request,
739
+ api_base=api_base,
740
+ api_key=api_key,
741
+ organization=organization,
742
+ timeout=timeout,
743
+ max_retries=optional_params.max_retries,
744
+ )
745
+ elif custom_llm_provider == "azure":
746
+ api_base = (
747
+ optional_params.api_base
748
+ or litellm.api_base
749
+ or get_secret_str("AZURE_API_BASE")
750
+ )
751
+ api_version = (
752
+ optional_params.api_version
753
+ or litellm.api_version
754
+ or get_secret_str("AZURE_API_VERSION")
755
+ )
756
+
757
+ api_key = (
758
+ optional_params.api_key
759
+ or litellm.api_key
760
+ or litellm.azure_key
761
+ or get_secret_str("AZURE_OPENAI_API_KEY")
762
+ or get_secret_str("AZURE_API_KEY")
763
+ )
764
+
765
+ extra_body = optional_params.get("extra_body", {})
766
+ if extra_body is not None:
767
+ extra_body.pop("azure_ad_token", None)
768
+ else:
769
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
770
+
771
+ response = azure_batches_instance.cancel_batch(
772
+ _is_async=_is_async,
773
+ api_base=api_base,
774
+ api_key=api_key,
775
+ api_version=api_version,
776
+ timeout=timeout,
777
+ max_retries=optional_params.max_retries,
778
+ cancel_batch_data=_cancel_batch_request,
779
+ litellm_params=litellm_params,
780
+ )
781
+ else:
782
+ raise litellm.exceptions.BadRequestError(
783
+ message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
784
+ custom_llm_provider
785
+ ),
786
+ model="n/a",
787
+ llm_provider=custom_llm_provider,
788
+ response=httpx.Response(
789
+ status_code=400,
790
+ content="Unsupported provider",
791
+ request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
792
+ ),
793
+ )
794
+ return response
795
+ except Exception as e:
796
+ raise e
litellm/budget_manager.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | NOT PROXY BUDGET MANAGER |
4
+ # | proxy budget manager is in proxy_server.py |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ import json
11
+ import os
12
+ import threading
13
+ import time
14
+ from typing import Literal, Optional
15
+
16
+ import litellm
17
+ from litellm.constants import (
18
+ DAYS_IN_A_MONTH,
19
+ DAYS_IN_A_WEEK,
20
+ DAYS_IN_A_YEAR,
21
+ HOURS_IN_A_DAY,
22
+ )
23
+ from litellm.utils import ModelResponse
24
+
25
+
26
+ class BudgetManager:
27
+ def __init__(
28
+ self,
29
+ project_name: str,
30
+ client_type: str = "local",
31
+ api_base: Optional[str] = None,
32
+ headers: Optional[dict] = None,
33
+ ):
34
+ self.client_type = client_type
35
+ self.project_name = project_name
36
+ self.api_base = api_base or "https://api.litellm.ai"
37
+ self.headers = headers or {"Content-Type": "application/json"}
38
+ ## load the data or init the initial dictionaries
39
+ self.load_data()
40
+
41
+ def print_verbose(self, print_statement):
42
+ try:
43
+ if litellm.set_verbose:
44
+ import logging
45
+
46
+ logging.info(print_statement)
47
+ except Exception:
48
+ pass
49
+
50
+ def load_data(self):
51
+ if self.client_type == "local":
52
+ # Check if user dict file exists
53
+ if os.path.isfile("user_cost.json"):
54
+ # Load the user dict
55
+ with open("user_cost.json", "r") as json_file:
56
+ self.user_dict = json.load(json_file)
57
+ else:
58
+ self.print_verbose("User Dictionary not found!")
59
+ self.user_dict = {}
60
+ self.print_verbose(f"user dict from local: {self.user_dict}")
61
+ elif self.client_type == "hosted":
62
+ # Load the user_dict from hosted db
63
+ url = self.api_base + "/get_budget"
64
+ data = {"project_name": self.project_name}
65
+ response = litellm.module_level_client.post(
66
+ url, headers=self.headers, json=data
67
+ )
68
+ response = response.json()
69
+ if response["status"] == "error":
70
+ self.user_dict = (
71
+ {}
72
+ ) # assume this means the user dict hasn't been stored yet
73
+ else:
74
+ self.user_dict = response["data"]
75
+
76
+ def create_budget(
77
+ self,
78
+ total_budget: float,
79
+ user: str,
80
+ duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
81
+ created_at: float = time.time(),
82
+ ):
83
+ self.user_dict[user] = {"total_budget": total_budget}
84
+ if duration is None:
85
+ return self.user_dict[user]
86
+
87
+ if duration == "daily":
88
+ duration_in_days = 1
89
+ elif duration == "weekly":
90
+ duration_in_days = DAYS_IN_A_WEEK
91
+ elif duration == "monthly":
92
+ duration_in_days = DAYS_IN_A_MONTH
93
+ elif duration == "yearly":
94
+ duration_in_days = DAYS_IN_A_YEAR
95
+ else:
96
+ raise ValueError(
97
+ """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
98
+ )
99
+ self.user_dict[user] = {
100
+ "total_budget": total_budget,
101
+ "duration": duration_in_days,
102
+ "created_at": created_at,
103
+ "last_updated_at": created_at,
104
+ }
105
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
106
+ return self.user_dict[user]
107
+
108
+ def projected_cost(self, model: str, messages: list, user: str):
109
+ text = "".join(message["content"] for message in messages)
110
+ prompt_tokens = litellm.token_counter(model=model, text=text)
111
+ prompt_cost, _ = litellm.cost_per_token(
112
+ model=model, prompt_tokens=prompt_tokens, completion_tokens=0
113
+ )
114
+ current_cost = self.user_dict[user].get("current_cost", 0)
115
+ projected_cost = prompt_cost + current_cost
116
+ return projected_cost
117
+
118
+ def get_total_budget(self, user: str):
119
+ return self.user_dict[user]["total_budget"]
120
+
121
+ def update_cost(
122
+ self,
123
+ user: str,
124
+ completion_obj: Optional[ModelResponse] = None,
125
+ model: Optional[str] = None,
126
+ input_text: Optional[str] = None,
127
+ output_text: Optional[str] = None,
128
+ ):
129
+ if model and input_text and output_text:
130
+ prompt_tokens = litellm.token_counter(
131
+ model=model, messages=[{"role": "user", "content": input_text}]
132
+ )
133
+ completion_tokens = litellm.token_counter(
134
+ model=model, messages=[{"role": "user", "content": output_text}]
135
+ )
136
+ (
137
+ prompt_tokens_cost_usd_dollar,
138
+ completion_tokens_cost_usd_dollar,
139
+ ) = litellm.cost_per_token(
140
+ model=model,
141
+ prompt_tokens=prompt_tokens,
142
+ completion_tokens=completion_tokens,
143
+ )
144
+ cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
145
+ elif completion_obj:
146
+ cost = litellm.completion_cost(completion_response=completion_obj)
147
+ model = completion_obj[
148
+ "model"
149
+ ] # if this throws an error try, model = completion_obj['model']
150
+ else:
151
+ raise ValueError(
152
+ "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
153
+ )
154
+
155
+ self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
156
+ "current_cost", 0
157
+ )
158
+ if "model_cost" in self.user_dict[user]:
159
+ self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
160
+ "model_cost"
161
+ ].get(model, 0)
162
+ else:
163
+ self.user_dict[user]["model_cost"] = {model: cost}
164
+
165
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
166
+ return {"user": self.user_dict[user]}
167
+
168
+ def get_current_cost(self, user):
169
+ return self.user_dict[user].get("current_cost", 0)
170
+
171
+ def get_model_cost(self, user):
172
+ return self.user_dict[user].get("model_cost", 0)
173
+
174
+ def is_valid_user(self, user: str) -> bool:
175
+ return user in self.user_dict
176
+
177
+ def get_users(self):
178
+ return list(self.user_dict.keys())
179
+
180
+ def reset_cost(self, user):
181
+ self.user_dict[user]["current_cost"] = 0
182
+ self.user_dict[user]["model_cost"] = {}
183
+ return {"user": self.user_dict[user]}
184
+
185
+ def reset_on_duration(self, user: str):
186
+ # Get current and creation time
187
+ last_updated_at = self.user_dict[user]["last_updated_at"]
188
+ current_time = time.time()
189
+
190
+ # Convert duration from days to seconds
191
+ duration_in_seconds = (
192
+ self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
193
+ )
194
+
195
+ # Check if duration has elapsed
196
+ if current_time - last_updated_at >= duration_in_seconds:
197
+ # Reset cost if duration has elapsed and update the creation time
198
+ self.reset_cost(user)
199
+ self.user_dict[user]["last_updated_at"] = current_time
200
+ self._save_data_thread() # Save the data
201
+
202
+ def update_budget_all_users(self):
203
+ for user in self.get_users():
204
+ if "duration" in self.user_dict[user]:
205
+ self.reset_on_duration(user)
206
+
207
+ def _save_data_thread(self):
208
+ thread = threading.Thread(
209
+ target=self.save_data
210
+ ) # [Non-Blocking]: saves data without blocking execution
211
+ thread.start()
212
+
213
+ def save_data(self):
214
+ if self.client_type == "local":
215
+ import json
216
+
217
+ # save the user dict
218
+ with open("user_cost.json", "w") as json_file:
219
+ json.dump(
220
+ self.user_dict, json_file, indent=4
221
+ ) # Indent for pretty formatting
222
+ return {"status": "success"}
223
+ elif self.client_type == "hosted":
224
+ url = self.api_base + "/set_budget"
225
+ data = {"project_name": self.project_name, "user_dict": self.user_dict}
226
+ response = litellm.module_level_client.post(
227
+ url, headers=self.headers, json=data
228
+ )
229
+ response = response.json()
230
+ return response
litellm/caching/Readme.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Caching on LiteLLM
2
+
3
+ LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
4
+
5
+ The following caching mechanisms are supported:
6
+
7
+ 1. **RedisCache**
8
+ 2. **RedisSemanticCache**
9
+ 3. **QdrantSemanticCache**
10
+ 4. **InMemoryCache**
11
+ 5. **DiskCache**
12
+ 6. **S3Cache**
13
+ 7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
14
+
15
+ ## Folder Structure
16
+
17
+ ```
18
+ litellm/caching/
19
+ ├── base_cache.py
20
+ ├── caching.py
21
+ ├── caching_handler.py
22
+ ├── disk_cache.py
23
+ ├── dual_cache.py
24
+ ├── in_memory_cache.py
25
+ ├── qdrant_semantic_cache.py
26
+ ├── redis_cache.py
27
+ ├── redis_semantic_cache.py
28
+ ├── s3_cache.py
29
+ ```
30
+
31
+ ## Documentation
32
+ - [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
33
+ - [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
34
+
35
+
36
+
37
+
38
+
39
+
40
+
litellm/caching/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .caching import Cache, LiteLLMCacheType
2
+ from .disk_cache import DiskCache
3
+ from .dual_cache import DualCache
4
+ from .in_memory_cache import InMemoryCache
5
+ from .qdrant_semantic_cache import QdrantSemanticCache
6
+ from .redis_cache import RedisCache
7
+ from .redis_cluster_cache import RedisClusterCache
8
+ from .redis_semantic_cache import RedisSemanticCache
9
+ from .s3_cache import S3Cache
litellm/caching/_internal_lru_cache.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Callable, Optional, TypeVar
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ def lru_cache_wrapper(
8
+ maxsize: Optional[int] = None,
9
+ ) -> Callable[[Callable[..., T]], Callable[..., T]]:
10
+ """
11
+ Wrapper for lru_cache that caches success and exceptions
12
+ """
13
+
14
+ def decorator(f: Callable[..., T]) -> Callable[..., T]:
15
+ @lru_cache(maxsize=maxsize)
16
+ def wrapper(*args, **kwargs):
17
+ try:
18
+ return ("success", f(*args, **kwargs))
19
+ except Exception as e:
20
+ return ("error", e)
21
+
22
+ def wrapped(*args, **kwargs):
23
+ result = wrapper(*args, **kwargs)
24
+ if result[0] == "error":
25
+ raise result[1]
26
+ return result[1]
27
+
28
+ return wrapped
29
+
30
+ return decorator
litellm/caching/base_cache.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base Cache implementation. All cache implementations should inherit from this class.
3
+
4
+ Has 4 methods:
5
+ - set_cache
6
+ - get_cache
7
+ - async_set_cache
8
+ - async_get_cache
9
+ """
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import TYPE_CHECKING, Any, Optional, Union
13
+
14
+ if TYPE_CHECKING:
15
+ from opentelemetry.trace import Span as _Span
16
+
17
+ Span = Union[_Span, Any]
18
+ else:
19
+ Span = Any
20
+
21
+
22
+ class BaseCache(ABC):
23
+ def __init__(self, default_ttl: int = 60):
24
+ self.default_ttl = default_ttl
25
+
26
+ def get_ttl(self, **kwargs) -> Optional[int]:
27
+ kwargs_ttl: Optional[int] = kwargs.get("ttl")
28
+ if kwargs_ttl is not None:
29
+ try:
30
+ return int(kwargs_ttl)
31
+ except ValueError:
32
+ return self.default_ttl
33
+ return self.default_ttl
34
+
35
+ def set_cache(self, key, value, **kwargs):
36
+ raise NotImplementedError
37
+
38
+ async def async_set_cache(self, key, value, **kwargs):
39
+ raise NotImplementedError
40
+
41
+ @abstractmethod
42
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
43
+ pass
44
+
45
+ def get_cache(self, key, **kwargs):
46
+ raise NotImplementedError
47
+
48
+ async def async_get_cache(self, key, **kwargs):
49
+ raise NotImplementedError
50
+
51
+ async def batch_cache_write(self, key, value, **kwargs):
52
+ raise NotImplementedError
53
+
54
+ async def disconnect(self):
55
+ raise NotImplementedError
litellm/caching/caching.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ import ast
11
+ import hashlib
12
+ import json
13
+ import time
14
+ import traceback
15
+ from enum import Enum
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ from pydantic import BaseModel
19
+
20
+ import litellm
21
+ from litellm._logging import verbose_logger
22
+ from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
23
+ from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
24
+ from litellm.types.caching import *
25
+ from litellm.types.utils import EmbeddingResponse, all_litellm_params
26
+
27
+ from .base_cache import BaseCache
28
+ from .disk_cache import DiskCache
29
+ from .dual_cache import DualCache # noqa
30
+ from .in_memory_cache import InMemoryCache
31
+ from .qdrant_semantic_cache import QdrantSemanticCache
32
+ from .redis_cache import RedisCache
33
+ from .redis_cluster_cache import RedisClusterCache
34
+ from .redis_semantic_cache import RedisSemanticCache
35
+ from .s3_cache import S3Cache
36
+
37
+
38
+ def print_verbose(print_statement):
39
+ try:
40
+ verbose_logger.debug(print_statement)
41
+ if litellm.set_verbose:
42
+ print(print_statement) # noqa
43
+ except Exception:
44
+ pass
45
+
46
+
47
+ class CacheMode(str, Enum):
48
+ default_on = "default_on"
49
+ default_off = "default_off"
50
+
51
+
52
+ #### LiteLLM.Completion / Embedding Cache ####
53
+ class Cache:
54
+ def __init__(
55
+ self,
56
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
57
+ mode: Optional[
58
+ CacheMode
59
+ ] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
60
+ host: Optional[str] = None,
61
+ port: Optional[str] = None,
62
+ password: Optional[str] = None,
63
+ namespace: Optional[str] = None,
64
+ ttl: Optional[float] = None,
65
+ default_in_memory_ttl: Optional[float] = None,
66
+ default_in_redis_ttl: Optional[float] = None,
67
+ similarity_threshold: Optional[float] = None,
68
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
69
+ "completion",
70
+ "acompletion",
71
+ "embedding",
72
+ "aembedding",
73
+ "atranscription",
74
+ "transcription",
75
+ "atext_completion",
76
+ "text_completion",
77
+ "arerank",
78
+ "rerank",
79
+ ],
80
+ # s3 Bucket, boto3 configuration
81
+ s3_bucket_name: Optional[str] = None,
82
+ s3_region_name: Optional[str] = None,
83
+ s3_api_version: Optional[str] = None,
84
+ s3_use_ssl: Optional[bool] = True,
85
+ s3_verify: Optional[Union[bool, str]] = None,
86
+ s3_endpoint_url: Optional[str] = None,
87
+ s3_aws_access_key_id: Optional[str] = None,
88
+ s3_aws_secret_access_key: Optional[str] = None,
89
+ s3_aws_session_token: Optional[str] = None,
90
+ s3_config: Optional[Any] = None,
91
+ s3_path: Optional[str] = None,
92
+ redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
93
+ redis_semantic_cache_index_name: Optional[str] = None,
94
+ redis_flush_size: Optional[int] = None,
95
+ redis_startup_nodes: Optional[List] = None,
96
+ disk_cache_dir: Optional[str] = None,
97
+ qdrant_api_base: Optional[str] = None,
98
+ qdrant_api_key: Optional[str] = None,
99
+ qdrant_collection_name: Optional[str] = None,
100
+ qdrant_quantization_config: Optional[str] = None,
101
+ qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
102
+ **kwargs,
103
+ ):
104
+ """
105
+ Initializes the cache based on the given type.
106
+
107
+ Args:
108
+ type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
109
+
110
+ # Redis Cache Args
111
+ host (str, optional): The host address for the Redis cache. Required if type is "redis".
112
+ port (int, optional): The port number for the Redis cache. Required if type is "redis".
113
+ password (str, optional): The password for the Redis cache. Required if type is "redis".
114
+ namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
115
+ ttl (float, optional): The ttl for the Redis cache
116
+ redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
117
+ redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
118
+
119
+ # Qdrant Cache Args
120
+ qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
121
+ qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
122
+ qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
123
+ similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
124
+
125
+ # Disk Cache Args
126
+ disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
127
+
128
+ # S3 Cache Args
129
+ s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
130
+ s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
131
+ s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
132
+ s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
133
+ s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
134
+ s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
135
+ s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
136
+ s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
137
+ s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
138
+ s3_config (dict, optional): The config for the s3 cache. Defaults to None.
139
+
140
+ # Common Cache Args
141
+ supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
142
+ **kwargs: Additional keyword arguments for redis.Redis() cache
143
+
144
+ Raises:
145
+ ValueError: If an invalid cache type is provided.
146
+
147
+ Returns:
148
+ None. Cache is set as a litellm param
149
+ """
150
+ if type == LiteLLMCacheType.REDIS:
151
+ if redis_startup_nodes:
152
+ self.cache: BaseCache = RedisClusterCache(
153
+ host=host,
154
+ port=port,
155
+ password=password,
156
+ redis_flush_size=redis_flush_size,
157
+ startup_nodes=redis_startup_nodes,
158
+ **kwargs,
159
+ )
160
+ else:
161
+ self.cache = RedisCache(
162
+ host=host,
163
+ port=port,
164
+ password=password,
165
+ redis_flush_size=redis_flush_size,
166
+ **kwargs,
167
+ )
168
+ elif type == LiteLLMCacheType.REDIS_SEMANTIC:
169
+ self.cache = RedisSemanticCache(
170
+ host=host,
171
+ port=port,
172
+ password=password,
173
+ similarity_threshold=similarity_threshold,
174
+ embedding_model=redis_semantic_cache_embedding_model,
175
+ index_name=redis_semantic_cache_index_name,
176
+ **kwargs,
177
+ )
178
+ elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
179
+ self.cache = QdrantSemanticCache(
180
+ qdrant_api_base=qdrant_api_base,
181
+ qdrant_api_key=qdrant_api_key,
182
+ collection_name=qdrant_collection_name,
183
+ similarity_threshold=similarity_threshold,
184
+ quantization_config=qdrant_quantization_config,
185
+ embedding_model=qdrant_semantic_cache_embedding_model,
186
+ )
187
+ elif type == LiteLLMCacheType.LOCAL:
188
+ self.cache = InMemoryCache()
189
+ elif type == LiteLLMCacheType.S3:
190
+ self.cache = S3Cache(
191
+ s3_bucket_name=s3_bucket_name,
192
+ s3_region_name=s3_region_name,
193
+ s3_api_version=s3_api_version,
194
+ s3_use_ssl=s3_use_ssl,
195
+ s3_verify=s3_verify,
196
+ s3_endpoint_url=s3_endpoint_url,
197
+ s3_aws_access_key_id=s3_aws_access_key_id,
198
+ s3_aws_secret_access_key=s3_aws_secret_access_key,
199
+ s3_aws_session_token=s3_aws_session_token,
200
+ s3_config=s3_config,
201
+ s3_path=s3_path,
202
+ **kwargs,
203
+ )
204
+ elif type == LiteLLMCacheType.DISK:
205
+ self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
206
+ if "cache" not in litellm.input_callback:
207
+ litellm.input_callback.append("cache")
208
+ if "cache" not in litellm.success_callback:
209
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
210
+ if "cache" not in litellm._async_success_callback:
211
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
212
+ self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
213
+ self.type = type
214
+ self.namespace = namespace
215
+ self.redis_flush_size = redis_flush_size
216
+ self.ttl = ttl
217
+ self.mode: CacheMode = mode or CacheMode.default_on
218
+
219
+ if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
220
+ self.ttl = default_in_memory_ttl
221
+
222
+ if (
223
+ self.type == LiteLLMCacheType.REDIS
224
+ or self.type == LiteLLMCacheType.REDIS_SEMANTIC
225
+ ) and default_in_redis_ttl is not None:
226
+ self.ttl = default_in_redis_ttl
227
+
228
+ if self.namespace is not None and isinstance(self.cache, RedisCache):
229
+ self.cache.namespace = self.namespace
230
+
231
+ def get_cache_key(self, **kwargs) -> str:
232
+ """
233
+ Get the cache key for the given arguments.
234
+
235
+ Args:
236
+ **kwargs: kwargs to litellm.completion() or embedding()
237
+
238
+ Returns:
239
+ str: The cache key generated from the arguments, or None if no cache key could be generated.
240
+ """
241
+ cache_key = ""
242
+ # verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
243
+
244
+ preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
245
+ if preset_cache_key is not None:
246
+ verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
247
+ return preset_cache_key
248
+
249
+ combined_kwargs = ModelParamHelper._get_all_llm_api_params()
250
+ litellm_param_kwargs = all_litellm_params
251
+ for param in kwargs:
252
+ if param in combined_kwargs:
253
+ param_value: Optional[str] = self._get_param_value(param, kwargs)
254
+ if param_value is not None:
255
+ cache_key += f"{str(param)}: {str(param_value)}"
256
+ elif (
257
+ param not in litellm_param_kwargs
258
+ ): # check if user passed in optional param - e.g. top_k
259
+ if (
260
+ litellm.enable_caching_on_provider_specific_optional_params is True
261
+ ): # feature flagged for now
262
+ if kwargs[param] is None:
263
+ continue # ignore None params
264
+ param_value = kwargs[param]
265
+ cache_key += f"{str(param)}: {str(param_value)}"
266
+
267
+ verbose_logger.debug("\nCreated cache key: %s", cache_key)
268
+ hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
269
+ hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
270
+ self._set_preset_cache_key_in_kwargs(
271
+ preset_cache_key=hashed_cache_key, **kwargs
272
+ )
273
+ return hashed_cache_key
274
+
275
+ def _get_param_value(
276
+ self,
277
+ param: str,
278
+ kwargs: dict,
279
+ ) -> Optional[str]:
280
+ """
281
+ Get the value for the given param from kwargs
282
+ """
283
+ if param == "model":
284
+ return self._get_model_param_value(kwargs)
285
+ elif param == "file":
286
+ return self._get_file_param_value(kwargs)
287
+ return kwargs[param]
288
+
289
+ def _get_model_param_value(self, kwargs: dict) -> str:
290
+ """
291
+ Handles getting the value for the 'model' param from kwargs
292
+
293
+ 1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
294
+ 2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
295
+ 3. Else use the `model` passed in kwargs
296
+ """
297
+ metadata: Dict = kwargs.get("metadata", {}) or {}
298
+ litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
299
+ metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
300
+ model_group: Optional[str] = metadata.get(
301
+ "model_group"
302
+ ) or metadata_in_litellm_params.get("model_group")
303
+ caching_group = self._get_caching_group(metadata, model_group)
304
+ return caching_group or model_group or kwargs["model"]
305
+
306
+ def _get_caching_group(
307
+ self, metadata: dict, model_group: Optional[str]
308
+ ) -> Optional[str]:
309
+ caching_groups: Optional[List] = metadata.get("caching_groups", [])
310
+ if caching_groups:
311
+ for group in caching_groups:
312
+ if model_group in group:
313
+ return str(group)
314
+ return None
315
+
316
+ def _get_file_param_value(self, kwargs: dict) -> str:
317
+ """
318
+ Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
319
+ """
320
+ file = kwargs.get("file")
321
+ metadata = kwargs.get("metadata", {})
322
+ litellm_params = kwargs.get("litellm_params", {})
323
+ return (
324
+ metadata.get("file_checksum")
325
+ or getattr(file, "name", None)
326
+ or metadata.get("file_name")
327
+ or litellm_params.get("file_name")
328
+ )
329
+
330
+ def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
331
+ """
332
+ Get the preset cache key from kwargs["litellm_params"]
333
+
334
+ We use _get_preset_cache_keys for two reasons
335
+
336
+ 1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
337
+ 2. avoid doing duplicate / repeated work
338
+ """
339
+ if kwargs:
340
+ if "litellm_params" in kwargs:
341
+ return kwargs["litellm_params"].get("preset_cache_key", None)
342
+ return None
343
+
344
+ def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
345
+ """
346
+ Set the calculated cache key in kwargs
347
+
348
+ This is used to avoid doing duplicate / repeated work
349
+
350
+ Placed in kwargs["litellm_params"]
351
+ """
352
+ if kwargs:
353
+ if "litellm_params" in kwargs:
354
+ kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
355
+
356
+ @staticmethod
357
+ def _get_hashed_cache_key(cache_key: str) -> str:
358
+ """
359
+ Get the hashed cache key for the given cache key.
360
+
361
+ Use hashlib to create a sha256 hash of the cache key
362
+
363
+ Args:
364
+ cache_key (str): The cache key to hash.
365
+
366
+ Returns:
367
+ str: The hashed cache key.
368
+ """
369
+ hash_object = hashlib.sha256(cache_key.encode())
370
+ # Hexadecimal representation of the hash
371
+ hash_hex = hash_object.hexdigest()
372
+ verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
373
+ return hash_hex
374
+
375
+ def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
376
+ """
377
+ If a redis namespace is provided, add it to the cache key
378
+
379
+ Args:
380
+ hash_hex (str): The hashed cache key.
381
+ **kwargs: Additional keyword arguments.
382
+
383
+ Returns:
384
+ str: The final hashed cache key with the redis namespace.
385
+ """
386
+ dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
387
+ namespace = (
388
+ dynamic_cache_control.get("namespace")
389
+ or kwargs.get("metadata", {}).get("redis_namespace")
390
+ or self.namespace
391
+ )
392
+ if namespace:
393
+ hash_hex = f"{namespace}:{hash_hex}"
394
+ verbose_logger.debug("Final hashed key: %s", hash_hex)
395
+ return hash_hex
396
+
397
+ def generate_streaming_content(self, content):
398
+ chunk_size = 5 # Adjust the chunk size as needed
399
+ for i in range(0, len(content), chunk_size):
400
+ yield {
401
+ "choices": [
402
+ {
403
+ "delta": {
404
+ "role": "assistant",
405
+ "content": content[i : i + chunk_size],
406
+ }
407
+ }
408
+ ]
409
+ }
410
+ time.sleep(CACHED_STREAMING_CHUNK_DELAY)
411
+
412
+ def _get_cache_logic(
413
+ self,
414
+ cached_result: Optional[Any],
415
+ max_age: Optional[float],
416
+ ):
417
+ """
418
+ Common get cache logic across sync + async implementations
419
+ """
420
+ # Check if a timestamp was stored with the cached response
421
+ if (
422
+ cached_result is not None
423
+ and isinstance(cached_result, dict)
424
+ and "timestamp" in cached_result
425
+ ):
426
+ timestamp = cached_result["timestamp"]
427
+ current_time = time.time()
428
+
429
+ # Calculate age of the cached response
430
+ response_age = current_time - timestamp
431
+
432
+ # Check if the cached response is older than the max-age
433
+ if max_age is not None and response_age > max_age:
434
+ return None # Cached response is too old
435
+
436
+ # If the response is fresh, or there's no max-age requirement, return the cached response
437
+ # cached_response is in `b{} convert it to ModelResponse
438
+ cached_response = cached_result.get("response")
439
+ try:
440
+ if isinstance(cached_response, dict):
441
+ pass
442
+ else:
443
+ cached_response = json.loads(
444
+ cached_response # type: ignore
445
+ ) # Convert string to dictionary
446
+ except Exception:
447
+ cached_response = ast.literal_eval(cached_response) # type: ignore
448
+ return cached_response
449
+ return cached_result
450
+
451
+ def get_cache(self, **kwargs):
452
+ """
453
+ Retrieves the cached result for the given arguments.
454
+
455
+ Args:
456
+ *args: args to litellm.completion() or embedding()
457
+ **kwargs: kwargs to litellm.completion() or embedding()
458
+
459
+ Returns:
460
+ The cached result if it exists, otherwise None.
461
+ """
462
+ try: # never block execution
463
+ if self.should_use_cache(**kwargs) is not True:
464
+ return
465
+ messages = kwargs.get("messages", [])
466
+ if "cache_key" in kwargs:
467
+ cache_key = kwargs["cache_key"]
468
+ else:
469
+ cache_key = self.get_cache_key(**kwargs)
470
+ if cache_key is not None:
471
+ cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
472
+ max_age = (
473
+ cache_control_args.get("s-maxage")
474
+ or cache_control_args.get("s-max-age")
475
+ or float("inf")
476
+ )
477
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
478
+ cached_result = self.cache.get_cache(cache_key, messages=messages)
479
+ return self._get_cache_logic(
480
+ cached_result=cached_result, max_age=max_age
481
+ )
482
+ except Exception:
483
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
484
+ return None
485
+
486
+ async def async_get_cache(self, **kwargs):
487
+ """
488
+ Async get cache implementation.
489
+
490
+ Used for embedding calls in async wrapper
491
+ """
492
+
493
+ try: # never block execution
494
+ if self.should_use_cache(**kwargs) is not True:
495
+ return
496
+
497
+ kwargs.get("messages", [])
498
+ if "cache_key" in kwargs:
499
+ cache_key = kwargs["cache_key"]
500
+ else:
501
+ cache_key = self.get_cache_key(**kwargs)
502
+ if cache_key is not None:
503
+ cache_control_args = kwargs.get("cache", {})
504
+ max_age = cache_control_args.get(
505
+ "s-max-age", cache_control_args.get("s-maxage", float("inf"))
506
+ )
507
+ cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
508
+ return self._get_cache_logic(
509
+ cached_result=cached_result, max_age=max_age
510
+ )
511
+ except Exception:
512
+ print_verbose(f"An exception occurred: {traceback.format_exc()}")
513
+ return None
514
+
515
+ def _add_cache_logic(self, result, **kwargs):
516
+ """
517
+ Common implementation across sync + async add_cache functions
518
+ """
519
+ try:
520
+ if "cache_key" in kwargs:
521
+ cache_key = kwargs["cache_key"]
522
+ else:
523
+ cache_key = self.get_cache_key(**kwargs)
524
+ if cache_key is not None:
525
+ if isinstance(result, BaseModel):
526
+ result = result.model_dump_json()
527
+
528
+ ## DEFAULT TTL ##
529
+ if self.ttl is not None:
530
+ kwargs["ttl"] = self.ttl
531
+ ## Get Cache-Controls ##
532
+ _cache_kwargs = kwargs.get("cache", None)
533
+ if isinstance(_cache_kwargs, dict):
534
+ for k, v in _cache_kwargs.items():
535
+ if k == "ttl":
536
+ kwargs["ttl"] = v
537
+
538
+ cached_data = {"timestamp": time.time(), "response": result}
539
+ return cache_key, cached_data, kwargs
540
+ else:
541
+ raise Exception("cache key is None")
542
+ except Exception as e:
543
+ raise e
544
+
545
+ def add_cache(self, result, **kwargs):
546
+ """
547
+ Adds a result to the cache.
548
+
549
+ Args:
550
+ *args: args to litellm.completion() or embedding()
551
+ **kwargs: kwargs to litellm.completion() or embedding()
552
+
553
+ Returns:
554
+ None
555
+ """
556
+ try:
557
+ if self.should_use_cache(**kwargs) is not True:
558
+ return
559
+ cache_key, cached_data, kwargs = self._add_cache_logic(
560
+ result=result, **kwargs
561
+ )
562
+ self.cache.set_cache(cache_key, cached_data, **kwargs)
563
+ except Exception as e:
564
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
565
+
566
+ async def async_add_cache(self, result, **kwargs):
567
+ """
568
+ Async implementation of add_cache
569
+ """
570
+ try:
571
+ if self.should_use_cache(**kwargs) is not True:
572
+ return
573
+ if self.type == "redis" and self.redis_flush_size is not None:
574
+ # high traffic - fill in results in memory and then flush
575
+ await self.batch_cache_write(result, **kwargs)
576
+ else:
577
+ cache_key, cached_data, kwargs = self._add_cache_logic(
578
+ result=result, **kwargs
579
+ )
580
+
581
+ await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
582
+ except Exception as e:
583
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
584
+
585
+ def add_embedding_response_to_cache(
586
+ self,
587
+ result: EmbeddingResponse,
588
+ input: str,
589
+ kwargs: dict,
590
+ idx_in_result_data: int = 0,
591
+ ) -> Tuple[str, dict, dict]:
592
+ preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
593
+ kwargs["cache_key"] = preset_cache_key
594
+ embedding_response = result.data[idx_in_result_data]
595
+ cache_key, cached_data, kwargs = self._add_cache_logic(
596
+ result=embedding_response,
597
+ **kwargs,
598
+ )
599
+ return cache_key, cached_data, kwargs
600
+
601
+ async def async_add_cache_pipeline(self, result, **kwargs):
602
+ """
603
+ Async implementation of add_cache for Embedding calls
604
+
605
+ Does a bulk write, to prevent using too many clients
606
+ """
607
+ try:
608
+ if self.should_use_cache(**kwargs) is not True:
609
+ return
610
+
611
+ # set default ttl if not set
612
+ if self.ttl is not None:
613
+ kwargs["ttl"] = self.ttl
614
+
615
+ cache_list = []
616
+ if isinstance(kwargs["input"], list):
617
+ for idx, i in enumerate(kwargs["input"]):
618
+ (
619
+ cache_key,
620
+ cached_data,
621
+ kwargs,
622
+ ) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
623
+ cache_list.append((cache_key, cached_data))
624
+ elif isinstance(kwargs["input"], str):
625
+ cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
626
+ result, kwargs["input"], kwargs
627
+ )
628
+ cache_list.append((cache_key, cached_data))
629
+
630
+ await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
631
+ # if async_set_cache_pipeline:
632
+ # await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
633
+ # else:
634
+ # tasks = []
635
+ # for val in cache_list:
636
+ # tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
637
+ # await asyncio.gather(*tasks)
638
+ except Exception as e:
639
+ verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
640
+
641
+ def should_use_cache(self, **kwargs):
642
+ """
643
+ Returns true if we should use the cache for LLM API calls
644
+
645
+ If cache is default_on then this is True
646
+ If cache is default_off then this is only true when user has opted in to use cache
647
+ """
648
+ if self.mode == CacheMode.default_on:
649
+ return True
650
+
651
+ # when mode == default_off -> Cache is opt in only
652
+ _cache = kwargs.get("cache", None)
653
+ verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
654
+ if _cache and isinstance(_cache, dict):
655
+ if _cache.get("use-cache", False) is True:
656
+ return True
657
+ return False
658
+
659
+ async def batch_cache_write(self, result, **kwargs):
660
+ cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
661
+ await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
662
+
663
+ async def ping(self):
664
+ cache_ping = getattr(self.cache, "ping")
665
+ if cache_ping:
666
+ return await cache_ping()
667
+ return None
668
+
669
+ async def delete_cache_keys(self, keys):
670
+ cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
671
+ if cache_delete_cache_keys:
672
+ return await cache_delete_cache_keys(keys)
673
+ return None
674
+
675
+ async def disconnect(self):
676
+ if hasattr(self.cache, "disconnect"):
677
+ await self.cache.disconnect()
678
+
679
+ def _supports_async(self) -> bool:
680
+ """
681
+ Internal method to check if the cache type supports async get/set operations
682
+
683
+ Only S3 Cache Does NOT support async operations
684
+
685
+ """
686
+ if self.type and self.type == LiteLLMCacheType.S3:
687
+ return False
688
+ return True
689
+
690
+
691
+ def enable_cache(
692
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
693
+ host: Optional[str] = None,
694
+ port: Optional[str] = None,
695
+ password: Optional[str] = None,
696
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
697
+ "completion",
698
+ "acompletion",
699
+ "embedding",
700
+ "aembedding",
701
+ "atranscription",
702
+ "transcription",
703
+ "atext_completion",
704
+ "text_completion",
705
+ "arerank",
706
+ "rerank",
707
+ ],
708
+ **kwargs,
709
+ ):
710
+ """
711
+ Enable cache with the specified configuration.
712
+
713
+ Args:
714
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
715
+ host (Optional[str]): The host address of the cache server. Defaults to None.
716
+ port (Optional[str]): The port number of the cache server. Defaults to None.
717
+ password (Optional[str]): The password for the cache server. Defaults to None.
718
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
719
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
720
+ **kwargs: Additional keyword arguments.
721
+
722
+ Returns:
723
+ None
724
+
725
+ Raises:
726
+ None
727
+ """
728
+ print_verbose("LiteLLM: Enabling Cache")
729
+ if "cache" not in litellm.input_callback:
730
+ litellm.input_callback.append("cache")
731
+ if "cache" not in litellm.success_callback:
732
+ litellm.logging_callback_manager.add_litellm_success_callback("cache")
733
+ if "cache" not in litellm._async_success_callback:
734
+ litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
735
+
736
+ if litellm.cache is None:
737
+ litellm.cache = Cache(
738
+ type=type,
739
+ host=host,
740
+ port=port,
741
+ password=password,
742
+ supported_call_types=supported_call_types,
743
+ **kwargs,
744
+ )
745
+ print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
746
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
747
+
748
+
749
+ def update_cache(
750
+ type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
751
+ host: Optional[str] = None,
752
+ port: Optional[str] = None,
753
+ password: Optional[str] = None,
754
+ supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
755
+ "completion",
756
+ "acompletion",
757
+ "embedding",
758
+ "aembedding",
759
+ "atranscription",
760
+ "transcription",
761
+ "atext_completion",
762
+ "text_completion",
763
+ "arerank",
764
+ "rerank",
765
+ ],
766
+ **kwargs,
767
+ ):
768
+ """
769
+ Update the cache for LiteLLM.
770
+
771
+ Args:
772
+ type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
773
+ host (Optional[str]): The host of the cache. Defaults to None.
774
+ port (Optional[str]): The port of the cache. Defaults to None.
775
+ password (Optional[str]): The password for the cache. Defaults to None.
776
+ supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
777
+ The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
778
+ **kwargs: Additional keyword arguments for the cache.
779
+
780
+ Returns:
781
+ None
782
+
783
+ """
784
+ print_verbose("LiteLLM: Updating Cache")
785
+ litellm.cache = Cache(
786
+ type=type,
787
+ host=host,
788
+ port=port,
789
+ password=password,
790
+ supported_call_types=supported_call_types,
791
+ **kwargs,
792
+ )
793
+ print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
794
+ print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
795
+
796
+
797
+ def disable_cache():
798
+ """
799
+ Disable the cache used by LiteLLM.
800
+
801
+ This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
802
+
803
+ Parameters:
804
+ None
805
+
806
+ Returns:
807
+ None
808
+ """
809
+ from contextlib import suppress
810
+
811
+ print_verbose("LiteLLM: Disabling Cache")
812
+ with suppress(ValueError):
813
+ litellm.input_callback.remove("cache")
814
+ litellm.success_callback.remove("cache")
815
+ litellm._async_success_callback.remove("cache")
816
+
817
+ litellm.cache = None
818
+ print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
litellm/caching/caching_handler.py ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This contains LLMCachingHandler
3
+
4
+ This exposes two methods:
5
+ - async_get_cache
6
+ - async_set_cache
7
+
8
+ This file is a wrapper around caching.py
9
+
10
+ This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
11
+
12
+ It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
13
+
14
+ In each method it will call the appropriate method from caching.py
15
+ """
16
+
17
+ import asyncio
18
+ import datetime
19
+ import inspect
20
+ import threading
21
+ from typing import (
22
+ TYPE_CHECKING,
23
+ Any,
24
+ AsyncGenerator,
25
+ Callable,
26
+ Dict,
27
+ Generator,
28
+ List,
29
+ Optional,
30
+ Tuple,
31
+ Union,
32
+ )
33
+
34
+ from pydantic import BaseModel
35
+
36
+ import litellm
37
+ from litellm._logging import print_verbose, verbose_logger
38
+ from litellm.caching.caching import S3Cache
39
+ from litellm.litellm_core_utils.logging_utils import (
40
+ _assemble_complete_response_from_streaming_chunks,
41
+ )
42
+ from litellm.types.rerank import RerankResponse
43
+ from litellm.types.utils import (
44
+ CallTypes,
45
+ Embedding,
46
+ EmbeddingResponse,
47
+ ModelResponse,
48
+ TextCompletionResponse,
49
+ TranscriptionResponse,
50
+ Usage,
51
+ )
52
+
53
+ if TYPE_CHECKING:
54
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
55
+ from litellm.utils import CustomStreamWrapper
56
+ else:
57
+ LiteLLMLoggingObj = Any
58
+ CustomStreamWrapper = Any
59
+
60
+
61
+ class CachingHandlerResponse(BaseModel):
62
+ """
63
+ This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
64
+
65
+ For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
66
+ """
67
+
68
+ cached_result: Optional[Any] = None
69
+ final_embedding_cached_response: Optional[EmbeddingResponse] = None
70
+ embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
71
+
72
+
73
+ class LLMCachingHandler:
74
+ def __init__(
75
+ self,
76
+ original_function: Callable,
77
+ request_kwargs: Dict[str, Any],
78
+ start_time: datetime.datetime,
79
+ ):
80
+ self.async_streaming_chunks: List[ModelResponse] = []
81
+ self.sync_streaming_chunks: List[ModelResponse] = []
82
+ self.request_kwargs = request_kwargs
83
+ self.original_function = original_function
84
+ self.start_time = start_time
85
+ pass
86
+
87
+ async def _async_get_cache(
88
+ self,
89
+ model: str,
90
+ original_function: Callable,
91
+ logging_obj: LiteLLMLoggingObj,
92
+ start_time: datetime.datetime,
93
+ call_type: str,
94
+ kwargs: Dict[str, Any],
95
+ args: Optional[Tuple[Any, ...]] = None,
96
+ ) -> CachingHandlerResponse:
97
+ """
98
+ Internal method to get from the cache.
99
+ Handles different call types (embeddings, chat/completions, text_completion, transcription)
100
+ and accordingly returns the cached response
101
+
102
+ Args:
103
+ model: str:
104
+ original_function: Callable:
105
+ logging_obj: LiteLLMLoggingObj:
106
+ start_time: datetime.datetime:
107
+ call_type: str:
108
+ kwargs: Dict[str, Any]:
109
+ args: Optional[Tuple[Any, ...]] = None:
110
+
111
+
112
+ Returns:
113
+ CachingHandlerResponse:
114
+ Raises:
115
+ None
116
+ """
117
+ from litellm.utils import CustomStreamWrapper
118
+
119
+ args = args or ()
120
+
121
+ final_embedding_cached_response: Optional[EmbeddingResponse] = None
122
+ embedding_all_elements_cache_hit: bool = False
123
+ cached_result: Optional[Any] = None
124
+ if (
125
+ (kwargs.get("caching", None) is None and litellm.cache is not None)
126
+ or kwargs.get("caching", False) is True
127
+ ) and (
128
+ kwargs.get("cache", {}).get("no-cache", False) is not True
129
+ ): # allow users to control returning cached responses from the completion function
130
+ if litellm.cache is not None and self._is_call_type_supported_by_cache(
131
+ original_function=original_function
132
+ ):
133
+ verbose_logger.debug("Checking Async Cache")
134
+ cached_result = await self._retrieve_from_cache(
135
+ call_type=call_type,
136
+ kwargs=kwargs,
137
+ args=args,
138
+ )
139
+
140
+ if cached_result is not None and not isinstance(cached_result, list):
141
+ verbose_logger.debug("Cache Hit!")
142
+ cache_hit = True
143
+ end_time = datetime.datetime.now()
144
+ model, _, _, _ = litellm.get_llm_provider(
145
+ model=model,
146
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
147
+ api_base=kwargs.get("api_base", None),
148
+ api_key=kwargs.get("api_key", None),
149
+ )
150
+ self._update_litellm_logging_obj_environment(
151
+ logging_obj=logging_obj,
152
+ model=model,
153
+ kwargs=kwargs,
154
+ cached_result=cached_result,
155
+ is_async=True,
156
+ )
157
+
158
+ call_type = original_function.__name__
159
+
160
+ cached_result = self._convert_cached_result_to_model_response(
161
+ cached_result=cached_result,
162
+ call_type=call_type,
163
+ kwargs=kwargs,
164
+ logging_obj=logging_obj,
165
+ model=model,
166
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
167
+ args=args,
168
+ )
169
+ if kwargs.get("stream", False) is False:
170
+ # LOG SUCCESS
171
+ self._async_log_cache_hit_on_callbacks(
172
+ logging_obj=logging_obj,
173
+ cached_result=cached_result,
174
+ start_time=start_time,
175
+ end_time=end_time,
176
+ cache_hit=cache_hit,
177
+ )
178
+ cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
179
+ **kwargs
180
+ )
181
+ if (
182
+ isinstance(cached_result, BaseModel)
183
+ or isinstance(cached_result, CustomStreamWrapper)
184
+ ) and hasattr(cached_result, "_hidden_params"):
185
+ cached_result._hidden_params["cache_key"] = cache_key # type: ignore
186
+ return CachingHandlerResponse(cached_result=cached_result)
187
+ elif (
188
+ call_type == CallTypes.aembedding.value
189
+ and cached_result is not None
190
+ and isinstance(cached_result, list)
191
+ and litellm.cache is not None
192
+ and not isinstance(
193
+ litellm.cache.cache, S3Cache
194
+ ) # s3 doesn't support bulk writing. Exclude.
195
+ ):
196
+ (
197
+ final_embedding_cached_response,
198
+ embedding_all_elements_cache_hit,
199
+ ) = self._process_async_embedding_cached_response(
200
+ final_embedding_cached_response=final_embedding_cached_response,
201
+ cached_result=cached_result,
202
+ kwargs=kwargs,
203
+ logging_obj=logging_obj,
204
+ start_time=start_time,
205
+ model=model,
206
+ )
207
+ return CachingHandlerResponse(
208
+ final_embedding_cached_response=final_embedding_cached_response,
209
+ embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
210
+ )
211
+ verbose_logger.debug(f"CACHE RESULT: {cached_result}")
212
+ return CachingHandlerResponse(
213
+ cached_result=cached_result,
214
+ final_embedding_cached_response=final_embedding_cached_response,
215
+ )
216
+
217
+ def _sync_get_cache(
218
+ self,
219
+ model: str,
220
+ original_function: Callable,
221
+ logging_obj: LiteLLMLoggingObj,
222
+ start_time: datetime.datetime,
223
+ call_type: str,
224
+ kwargs: Dict[str, Any],
225
+ args: Optional[Tuple[Any, ...]] = None,
226
+ ) -> CachingHandlerResponse:
227
+ from litellm.utils import CustomStreamWrapper
228
+
229
+ args = args or ()
230
+ new_kwargs = kwargs.copy()
231
+ new_kwargs.update(
232
+ convert_args_to_kwargs(
233
+ self.original_function,
234
+ args,
235
+ )
236
+ )
237
+ cached_result: Optional[Any] = None
238
+ if litellm.cache is not None and self._is_call_type_supported_by_cache(
239
+ original_function=original_function
240
+ ):
241
+ print_verbose("Checking Sync Cache")
242
+ cached_result = litellm.cache.get_cache(**new_kwargs)
243
+ if cached_result is not None:
244
+ if "detail" in cached_result:
245
+ # implies an error occurred
246
+ pass
247
+ else:
248
+ call_type = original_function.__name__
249
+ cached_result = self._convert_cached_result_to_model_response(
250
+ cached_result=cached_result,
251
+ call_type=call_type,
252
+ kwargs=kwargs,
253
+ logging_obj=logging_obj,
254
+ model=model,
255
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
256
+ args=args,
257
+ )
258
+
259
+ # LOG SUCCESS
260
+ cache_hit = True
261
+ end_time = datetime.datetime.now()
262
+ (
263
+ model,
264
+ custom_llm_provider,
265
+ dynamic_api_key,
266
+ api_base,
267
+ ) = litellm.get_llm_provider(
268
+ model=model or "",
269
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
270
+ api_base=kwargs.get("api_base", None),
271
+ api_key=kwargs.get("api_key", None),
272
+ )
273
+ self._update_litellm_logging_obj_environment(
274
+ logging_obj=logging_obj,
275
+ model=model,
276
+ kwargs=kwargs,
277
+ cached_result=cached_result,
278
+ is_async=False,
279
+ )
280
+
281
+ threading.Thread(
282
+ target=logging_obj.success_handler,
283
+ args=(cached_result, start_time, end_time, cache_hit),
284
+ ).start()
285
+ cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
286
+ **kwargs
287
+ )
288
+ if (
289
+ isinstance(cached_result, BaseModel)
290
+ or isinstance(cached_result, CustomStreamWrapper)
291
+ ) and hasattr(cached_result, "_hidden_params"):
292
+ cached_result._hidden_params["cache_key"] = cache_key # type: ignore
293
+ return CachingHandlerResponse(cached_result=cached_result)
294
+ return CachingHandlerResponse(cached_result=cached_result)
295
+
296
+ def _process_async_embedding_cached_response(
297
+ self,
298
+ final_embedding_cached_response: Optional[EmbeddingResponse],
299
+ cached_result: List[Optional[Dict[str, Any]]],
300
+ kwargs: Dict[str, Any],
301
+ logging_obj: LiteLLMLoggingObj,
302
+ start_time: datetime.datetime,
303
+ model: str,
304
+ ) -> Tuple[Optional[EmbeddingResponse], bool]:
305
+ """
306
+ Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
307
+
308
+ For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
309
+ This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
310
+
311
+ Args:
312
+ final_embedding_cached_response: Optional[EmbeddingResponse]:
313
+ cached_result: List[Optional[Dict[str, Any]]]:
314
+ kwargs: Dict[str, Any]:
315
+ logging_obj: LiteLLMLoggingObj:
316
+ start_time: datetime.datetime:
317
+ model: str:
318
+
319
+ Returns:
320
+ Tuple[Optional[EmbeddingResponse], bool]:
321
+ Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
322
+
323
+
324
+ """
325
+ embedding_all_elements_cache_hit: bool = False
326
+ remaining_list = []
327
+ non_null_list = []
328
+ for idx, cr in enumerate(cached_result):
329
+ if cr is None:
330
+ remaining_list.append(kwargs["input"][idx])
331
+ else:
332
+ non_null_list.append((idx, cr))
333
+ original_kwargs_input = kwargs["input"]
334
+ kwargs["input"] = remaining_list
335
+ if len(non_null_list) > 0:
336
+ print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
337
+ final_embedding_cached_response = EmbeddingResponse(
338
+ model=kwargs.get("model"),
339
+ data=[None] * len(original_kwargs_input),
340
+ )
341
+ final_embedding_cached_response._hidden_params["cache_hit"] = True
342
+
343
+ prompt_tokens = 0
344
+ for val in non_null_list:
345
+ idx, cr = val # (idx, cr) tuple
346
+ if cr is not None:
347
+ final_embedding_cached_response.data[idx] = Embedding(
348
+ embedding=cr["embedding"],
349
+ index=idx,
350
+ object="embedding",
351
+ )
352
+ if isinstance(original_kwargs_input[idx], str):
353
+ from litellm.utils import token_counter
354
+
355
+ prompt_tokens += token_counter(
356
+ text=original_kwargs_input[idx], count_response_tokens=True
357
+ )
358
+ ## USAGE
359
+ usage = Usage(
360
+ prompt_tokens=prompt_tokens,
361
+ completion_tokens=0,
362
+ total_tokens=prompt_tokens,
363
+ )
364
+ final_embedding_cached_response.usage = usage
365
+ if len(remaining_list) == 0:
366
+ # LOG SUCCESS
367
+ cache_hit = True
368
+ embedding_all_elements_cache_hit = True
369
+ end_time = datetime.datetime.now()
370
+ (
371
+ model,
372
+ custom_llm_provider,
373
+ dynamic_api_key,
374
+ api_base,
375
+ ) = litellm.get_llm_provider(
376
+ model=model,
377
+ custom_llm_provider=kwargs.get("custom_llm_provider", None),
378
+ api_base=kwargs.get("api_base", None),
379
+ api_key=kwargs.get("api_key", None),
380
+ )
381
+
382
+ self._update_litellm_logging_obj_environment(
383
+ logging_obj=logging_obj,
384
+ model=model,
385
+ kwargs=kwargs,
386
+ cached_result=final_embedding_cached_response,
387
+ is_async=True,
388
+ is_embedding=True,
389
+ )
390
+ self._async_log_cache_hit_on_callbacks(
391
+ logging_obj=logging_obj,
392
+ cached_result=final_embedding_cached_response,
393
+ start_time=start_time,
394
+ end_time=end_time,
395
+ cache_hit=cache_hit,
396
+ )
397
+ return final_embedding_cached_response, embedding_all_elements_cache_hit
398
+ return final_embedding_cached_response, embedding_all_elements_cache_hit
399
+
400
+ def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage:
401
+ return Usage(
402
+ prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
403
+ completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
404
+ total_tokens=usage1.total_tokens + usage2.total_tokens,
405
+ )
406
+
407
+ def _combine_cached_embedding_response_with_api_result(
408
+ self,
409
+ _caching_handler_response: CachingHandlerResponse,
410
+ embedding_response: EmbeddingResponse,
411
+ start_time: datetime.datetime,
412
+ end_time: datetime.datetime,
413
+ ) -> EmbeddingResponse:
414
+ """
415
+ Combines the cached embedding response with the API EmbeddingResponse
416
+
417
+ For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
418
+ This function combines the cached embedding response with the API EmbeddingResponse
419
+
420
+ Args:
421
+ caching_handler_response: CachingHandlerResponse:
422
+ embedding_response: EmbeddingResponse:
423
+
424
+ Returns:
425
+ EmbeddingResponse:
426
+ """
427
+ if _caching_handler_response.final_embedding_cached_response is None:
428
+ return embedding_response
429
+
430
+ idx = 0
431
+ final_data_list = []
432
+ for item in _caching_handler_response.final_embedding_cached_response.data:
433
+ if item is None and embedding_response.data is not None:
434
+ final_data_list.append(embedding_response.data[idx])
435
+ idx += 1
436
+ else:
437
+ final_data_list.append(item)
438
+
439
+ _caching_handler_response.final_embedding_cached_response.data = final_data_list
440
+ _caching_handler_response.final_embedding_cached_response._hidden_params[
441
+ "cache_hit"
442
+ ] = True
443
+ _caching_handler_response.final_embedding_cached_response._response_ms = (
444
+ end_time - start_time
445
+ ).total_seconds() * 1000
446
+
447
+ ## USAGE
448
+ if (
449
+ _caching_handler_response.final_embedding_cached_response.usage is not None
450
+ and embedding_response.usage is not None
451
+ ):
452
+ _caching_handler_response.final_embedding_cached_response.usage = self.combine_usage(
453
+ usage1=_caching_handler_response.final_embedding_cached_response.usage,
454
+ usage2=embedding_response.usage,
455
+ )
456
+
457
+ return _caching_handler_response.final_embedding_cached_response
458
+
459
+ def _async_log_cache_hit_on_callbacks(
460
+ self,
461
+ logging_obj: LiteLLMLoggingObj,
462
+ cached_result: Any,
463
+ start_time: datetime.datetime,
464
+ end_time: datetime.datetime,
465
+ cache_hit: bool,
466
+ ):
467
+ """
468
+ Helper function to log the success of a cached result on callbacks
469
+
470
+ Args:
471
+ logging_obj (LiteLLMLoggingObj): The logging object.
472
+ cached_result: The cached result.
473
+ start_time (datetime): The start time of the operation.
474
+ end_time (datetime): The end time of the operation.
475
+ cache_hit (bool): Whether it was a cache hit.
476
+ """
477
+ asyncio.create_task(
478
+ logging_obj.async_success_handler(
479
+ cached_result, start_time, end_time, cache_hit
480
+ )
481
+ )
482
+ threading.Thread(
483
+ target=logging_obj.success_handler,
484
+ args=(cached_result, start_time, end_time, cache_hit),
485
+ ).start()
486
+
487
+ async def _retrieve_from_cache(
488
+ self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
489
+ ) -> Optional[Any]:
490
+ """
491
+ Internal method to
492
+ - get cache key
493
+ - check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
494
+ - async get cache value
495
+ - return the cached value
496
+
497
+ Args:
498
+ call_type: str:
499
+ kwargs: Dict[str, Any]:
500
+ args: Optional[Tuple[Any, ...]] = None:
501
+
502
+ Returns:
503
+ Optional[Any]:
504
+ Raises:
505
+ None
506
+ """
507
+ if litellm.cache is None:
508
+ return None
509
+
510
+ new_kwargs = kwargs.copy()
511
+ new_kwargs.update(
512
+ convert_args_to_kwargs(
513
+ self.original_function,
514
+ args,
515
+ )
516
+ )
517
+ cached_result: Optional[Any] = None
518
+ if call_type == CallTypes.aembedding.value and isinstance(
519
+ new_kwargs["input"], list
520
+ ):
521
+ tasks = []
522
+ for idx, i in enumerate(new_kwargs["input"]):
523
+ preset_cache_key = litellm.cache.get_cache_key(
524
+ **{**new_kwargs, "input": i}
525
+ )
526
+ tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
527
+ cached_result = await asyncio.gather(*tasks)
528
+ ## check if cached result is None ##
529
+ if cached_result is not None and isinstance(cached_result, list):
530
+ # set cached_result to None if all elements are None
531
+ if all(result is None for result in cached_result):
532
+ cached_result = None
533
+ else:
534
+ if litellm.cache._supports_async() is True:
535
+ cached_result = await litellm.cache.async_get_cache(**new_kwargs)
536
+ else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
537
+ cached_result = litellm.cache.get_cache(**new_kwargs)
538
+ return cached_result
539
+
540
+ def _convert_cached_result_to_model_response(
541
+ self,
542
+ cached_result: Any,
543
+ call_type: str,
544
+ kwargs: Dict[str, Any],
545
+ logging_obj: LiteLLMLoggingObj,
546
+ model: str,
547
+ args: Tuple[Any, ...],
548
+ custom_llm_provider: Optional[str] = None,
549
+ ) -> Optional[
550
+ Union[
551
+ ModelResponse,
552
+ TextCompletionResponse,
553
+ EmbeddingResponse,
554
+ RerankResponse,
555
+ TranscriptionResponse,
556
+ CustomStreamWrapper,
557
+ ]
558
+ ]:
559
+ """
560
+ Internal method to process the cached result
561
+
562
+ Checks the call type and converts the cached result to the appropriate model response object
563
+ example if call type is text_completion -> returns TextCompletionResponse object
564
+
565
+ Args:
566
+ cached_result: Any:
567
+ call_type: str:
568
+ kwargs: Dict[str, Any]:
569
+ logging_obj: LiteLLMLoggingObj:
570
+ model: str:
571
+ custom_llm_provider: Optional[str] = None:
572
+ args: Optional[Tuple[Any, ...]] = None:
573
+
574
+ Returns:
575
+ Optional[Any]:
576
+ """
577
+ from litellm.utils import convert_to_model_response_object
578
+
579
+ if (
580
+ call_type == CallTypes.acompletion.value
581
+ or call_type == CallTypes.completion.value
582
+ ) and isinstance(cached_result, dict):
583
+ if kwargs.get("stream", False) is True:
584
+ cached_result = self._convert_cached_stream_response(
585
+ cached_result=cached_result,
586
+ call_type=call_type,
587
+ logging_obj=logging_obj,
588
+ model=model,
589
+ )
590
+ else:
591
+ cached_result = convert_to_model_response_object(
592
+ response_object=cached_result,
593
+ model_response_object=ModelResponse(),
594
+ )
595
+ if (
596
+ call_type == CallTypes.atext_completion.value
597
+ or call_type == CallTypes.text_completion.value
598
+ ) and isinstance(cached_result, dict):
599
+ if kwargs.get("stream", False) is True:
600
+ cached_result = self._convert_cached_stream_response(
601
+ cached_result=cached_result,
602
+ call_type=call_type,
603
+ logging_obj=logging_obj,
604
+ model=model,
605
+ )
606
+ else:
607
+ cached_result = TextCompletionResponse(**cached_result)
608
+ elif (
609
+ call_type == CallTypes.aembedding.value
610
+ or call_type == CallTypes.embedding.value
611
+ ) and isinstance(cached_result, dict):
612
+ cached_result = convert_to_model_response_object(
613
+ response_object=cached_result,
614
+ model_response_object=EmbeddingResponse(),
615
+ response_type="embedding",
616
+ )
617
+
618
+ elif (
619
+ call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
620
+ ) and isinstance(cached_result, dict):
621
+ cached_result = convert_to_model_response_object(
622
+ response_object=cached_result,
623
+ model_response_object=None,
624
+ response_type="rerank",
625
+ )
626
+ elif (
627
+ call_type == CallTypes.atranscription.value
628
+ or call_type == CallTypes.transcription.value
629
+ ) and isinstance(cached_result, dict):
630
+ hidden_params = {
631
+ "model": "whisper-1",
632
+ "custom_llm_provider": custom_llm_provider,
633
+ "cache_hit": True,
634
+ }
635
+ cached_result = convert_to_model_response_object(
636
+ response_object=cached_result,
637
+ model_response_object=TranscriptionResponse(),
638
+ response_type="audio_transcription",
639
+ hidden_params=hidden_params,
640
+ )
641
+
642
+ if (
643
+ hasattr(cached_result, "_hidden_params")
644
+ and cached_result._hidden_params is not None
645
+ and isinstance(cached_result._hidden_params, dict)
646
+ ):
647
+ cached_result._hidden_params["cache_hit"] = True
648
+ return cached_result
649
+
650
+ def _convert_cached_stream_response(
651
+ self,
652
+ cached_result: Any,
653
+ call_type: str,
654
+ logging_obj: LiteLLMLoggingObj,
655
+ model: str,
656
+ ) -> CustomStreamWrapper:
657
+ from litellm.utils import (
658
+ CustomStreamWrapper,
659
+ convert_to_streaming_response,
660
+ convert_to_streaming_response_async,
661
+ )
662
+
663
+ _stream_cached_result: Union[AsyncGenerator, Generator]
664
+ if (
665
+ call_type == CallTypes.acompletion.value
666
+ or call_type == CallTypes.atext_completion.value
667
+ ):
668
+ _stream_cached_result = convert_to_streaming_response_async(
669
+ response_object=cached_result,
670
+ )
671
+ else:
672
+ _stream_cached_result = convert_to_streaming_response(
673
+ response_object=cached_result,
674
+ )
675
+ return CustomStreamWrapper(
676
+ completion_stream=_stream_cached_result,
677
+ model=model,
678
+ custom_llm_provider="cached_response",
679
+ logging_obj=logging_obj,
680
+ )
681
+
682
+ async def async_set_cache(
683
+ self,
684
+ result: Any,
685
+ original_function: Callable,
686
+ kwargs: Dict[str, Any],
687
+ args: Optional[Tuple[Any, ...]] = None,
688
+ ):
689
+ """
690
+ Internal method to check the type of the result & cache used and adds the result to the cache accordingly
691
+
692
+ Args:
693
+ result: Any:
694
+ original_function: Callable:
695
+ kwargs: Dict[str, Any]:
696
+ args: Optional[Tuple[Any, ...]] = None:
697
+
698
+ Returns:
699
+ None
700
+ Raises:
701
+ None
702
+ """
703
+ if litellm.cache is None:
704
+ return
705
+
706
+ new_kwargs = kwargs.copy()
707
+ new_kwargs.update(
708
+ convert_args_to_kwargs(
709
+ original_function,
710
+ args,
711
+ )
712
+ )
713
+ # [OPTIONAL] ADD TO CACHE
714
+ if self._should_store_result_in_cache(
715
+ original_function=original_function, kwargs=new_kwargs
716
+ ):
717
+ if (
718
+ isinstance(result, litellm.ModelResponse)
719
+ or isinstance(result, litellm.EmbeddingResponse)
720
+ or isinstance(result, TranscriptionResponse)
721
+ or isinstance(result, RerankResponse)
722
+ ):
723
+ if (
724
+ isinstance(result, EmbeddingResponse)
725
+ and litellm.cache is not None
726
+ and not isinstance(
727
+ litellm.cache.cache, S3Cache
728
+ ) # s3 doesn't support bulk writing. Exclude.
729
+ ):
730
+ asyncio.create_task(
731
+ litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
732
+ )
733
+ elif isinstance(litellm.cache.cache, S3Cache):
734
+ threading.Thread(
735
+ target=litellm.cache.add_cache,
736
+ args=(result,),
737
+ kwargs=new_kwargs,
738
+ ).start()
739
+ else:
740
+ asyncio.create_task(
741
+ litellm.cache.async_add_cache(
742
+ result.model_dump_json(), **new_kwargs
743
+ )
744
+ )
745
+ else:
746
+ asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
747
+
748
+ def sync_set_cache(
749
+ self,
750
+ result: Any,
751
+ kwargs: Dict[str, Any],
752
+ args: Optional[Tuple[Any, ...]] = None,
753
+ ):
754
+ """
755
+ Sync internal method to add the result to the cache
756
+ """
757
+
758
+ new_kwargs = kwargs.copy()
759
+ new_kwargs.update(
760
+ convert_args_to_kwargs(
761
+ self.original_function,
762
+ args,
763
+ )
764
+ )
765
+ if litellm.cache is None:
766
+ return
767
+
768
+ if self._should_store_result_in_cache(
769
+ original_function=self.original_function, kwargs=new_kwargs
770
+ ):
771
+ litellm.cache.add_cache(result, **new_kwargs)
772
+
773
+ return
774
+
775
+ def _should_store_result_in_cache(
776
+ self, original_function: Callable, kwargs: Dict[str, Any]
777
+ ) -> bool:
778
+ """
779
+ Helper function to determine if the result should be stored in the cache.
780
+
781
+ Returns:
782
+ bool: True if the result should be stored in the cache, False otherwise.
783
+ """
784
+ return (
785
+ (litellm.cache is not None)
786
+ and litellm.cache.supported_call_types is not None
787
+ and (str(original_function.__name__) in litellm.cache.supported_call_types)
788
+ and (kwargs.get("cache", {}).get("no-store", False) is not True)
789
+ )
790
+
791
+ def _is_call_type_supported_by_cache(
792
+ self,
793
+ original_function: Callable,
794
+ ) -> bool:
795
+ """
796
+ Helper function to determine if the call type is supported by the cache.
797
+
798
+ call types are acompletion, aembedding, atext_completion, atranscription, arerank
799
+
800
+ Defined on `litellm.types.utils.CallTypes`
801
+
802
+ Returns:
803
+ bool: True if the call type is supported by the cache, False otherwise.
804
+ """
805
+ if (
806
+ litellm.cache is not None
807
+ and litellm.cache.supported_call_types is not None
808
+ and str(original_function.__name__) in litellm.cache.supported_call_types
809
+ ):
810
+ return True
811
+ return False
812
+
813
+ async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
814
+ """
815
+ Internal method to add the streaming response to the cache
816
+
817
+
818
+ - If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
819
+ - Else append the chunk to self.async_streaming_chunks
820
+
821
+ """
822
+
823
+ complete_streaming_response: Optional[
824
+ Union[ModelResponse, TextCompletionResponse]
825
+ ] = _assemble_complete_response_from_streaming_chunks(
826
+ result=processed_chunk,
827
+ start_time=self.start_time,
828
+ end_time=datetime.datetime.now(),
829
+ request_kwargs=self.request_kwargs,
830
+ streaming_chunks=self.async_streaming_chunks,
831
+ is_async=True,
832
+ )
833
+ # if a complete_streaming_response is assembled, add it to the cache
834
+ if complete_streaming_response is not None:
835
+ await self.async_set_cache(
836
+ result=complete_streaming_response,
837
+ original_function=self.original_function,
838
+ kwargs=self.request_kwargs,
839
+ )
840
+
841
+ def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
842
+ """
843
+ Sync internal method to add the streaming response to the cache
844
+ """
845
+ complete_streaming_response: Optional[
846
+ Union[ModelResponse, TextCompletionResponse]
847
+ ] = _assemble_complete_response_from_streaming_chunks(
848
+ result=processed_chunk,
849
+ start_time=self.start_time,
850
+ end_time=datetime.datetime.now(),
851
+ request_kwargs=self.request_kwargs,
852
+ streaming_chunks=self.sync_streaming_chunks,
853
+ is_async=False,
854
+ )
855
+
856
+ # if a complete_streaming_response is assembled, add it to the cache
857
+ if complete_streaming_response is not None:
858
+ self.sync_set_cache(
859
+ result=complete_streaming_response,
860
+ kwargs=self.request_kwargs,
861
+ )
862
+
863
+ def _update_litellm_logging_obj_environment(
864
+ self,
865
+ logging_obj: LiteLLMLoggingObj,
866
+ model: str,
867
+ kwargs: Dict[str, Any],
868
+ cached_result: Any,
869
+ is_async: bool,
870
+ is_embedding: bool = False,
871
+ ):
872
+ """
873
+ Helper function to update the LiteLLMLoggingObj environment variables.
874
+
875
+ Args:
876
+ logging_obj (LiteLLMLoggingObj): The logging object to update.
877
+ model (str): The model being used.
878
+ kwargs (Dict[str, Any]): The keyword arguments from the original function call.
879
+ cached_result (Any): The cached result to log.
880
+ is_async (bool): Whether the call is asynchronous or not.
881
+ is_embedding (bool): Whether the call is for embeddings or not.
882
+
883
+ Returns:
884
+ None
885
+ """
886
+ litellm_params = {
887
+ "logger_fn": kwargs.get("logger_fn", None),
888
+ "acompletion": is_async,
889
+ "api_base": kwargs.get("api_base", ""),
890
+ "metadata": kwargs.get("metadata", {}),
891
+ "model_info": kwargs.get("model_info", {}),
892
+ "proxy_server_request": kwargs.get("proxy_server_request", None),
893
+ "stream_response": kwargs.get("stream_response", {}),
894
+ }
895
+
896
+ if litellm.cache is not None:
897
+ litellm_params[
898
+ "preset_cache_key"
899
+ ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
900
+ else:
901
+ litellm_params["preset_cache_key"] = None
902
+
903
+ logging_obj.update_environment_variables(
904
+ model=model,
905
+ user=kwargs.get("user", None),
906
+ optional_params={},
907
+ litellm_params=litellm_params,
908
+ input=(
909
+ kwargs.get("messages", "")
910
+ if not is_embedding
911
+ else kwargs.get("input", "")
912
+ ),
913
+ api_key=kwargs.get("api_key", None),
914
+ original_response=str(cached_result),
915
+ additional_args=None,
916
+ stream=kwargs.get("stream", False),
917
+ )
918
+
919
+
920
+ def convert_args_to_kwargs(
921
+ original_function: Callable,
922
+ args: Optional[Tuple[Any, ...]] = None,
923
+ ) -> Dict[str, Any]:
924
+ # Get the signature of the original function
925
+ signature = inspect.signature(original_function)
926
+
927
+ # Get parameter names in the order they appear in the original function
928
+ param_names = list(signature.parameters.keys())
929
+
930
+ # Create a mapping of positional arguments to parameter names
931
+ args_to_kwargs = {}
932
+ if args:
933
+ for index, arg in enumerate(args):
934
+ if index < len(param_names):
935
+ param_name = param_names[index]
936
+ args_to_kwargs[param_name] = arg
937
+
938
+ return args_to_kwargs
litellm/caching/disk_cache.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
3
+
4
+ from .base_cache import BaseCache
5
+
6
+ if TYPE_CHECKING:
7
+ from opentelemetry.trace import Span as _Span
8
+
9
+ Span = Union[_Span, Any]
10
+ else:
11
+ Span = Any
12
+
13
+
14
+ class DiskCache(BaseCache):
15
+ def __init__(self, disk_cache_dir: Optional[str] = None):
16
+ import diskcache as dc
17
+
18
+ # if users don't provider one, use the default litellm cache
19
+ if disk_cache_dir is None:
20
+ self.disk_cache = dc.Cache(".litellm_cache")
21
+ else:
22
+ self.disk_cache = dc.Cache(disk_cache_dir)
23
+
24
+ def set_cache(self, key, value, **kwargs):
25
+ if "ttl" in kwargs:
26
+ self.disk_cache.set(key, value, expire=kwargs["ttl"])
27
+ else:
28
+ self.disk_cache.set(key, value)
29
+
30
+ async def async_set_cache(self, key, value, **kwargs):
31
+ self.set_cache(key=key, value=value, **kwargs)
32
+
33
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
34
+ for cache_key, cache_value in cache_list:
35
+ if "ttl" in kwargs:
36
+ self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
37
+ else:
38
+ self.set_cache(key=cache_key, value=cache_value)
39
+
40
+ def get_cache(self, key, **kwargs):
41
+ original_cached_response = self.disk_cache.get(key)
42
+ if original_cached_response:
43
+ try:
44
+ cached_response = json.loads(original_cached_response) # type: ignore
45
+ except Exception:
46
+ cached_response = original_cached_response
47
+ return cached_response
48
+ return None
49
+
50
+ def batch_get_cache(self, keys: list, **kwargs):
51
+ return_val = []
52
+ for k in keys:
53
+ val = self.get_cache(key=k, **kwargs)
54
+ return_val.append(val)
55
+ return return_val
56
+
57
+ def increment_cache(self, key, value: int, **kwargs) -> int:
58
+ # get the value
59
+ init_value = self.get_cache(key=key) or 0
60
+ value = init_value + value # type: ignore
61
+ self.set_cache(key, value, **kwargs)
62
+ return value
63
+
64
+ async def async_get_cache(self, key, **kwargs):
65
+ return self.get_cache(key=key, **kwargs)
66
+
67
+ async def async_batch_get_cache(self, keys: list, **kwargs):
68
+ return_val = []
69
+ for k in keys:
70
+ val = self.get_cache(key=k, **kwargs)
71
+ return_val.append(val)
72
+ return return_val
73
+
74
+ async def async_increment(self, key, value: int, **kwargs) -> int:
75
+ # get the value
76
+ init_value = await self.async_get_cache(key=key) or 0
77
+ value = init_value + value # type: ignore
78
+ await self.async_set_cache(key, value, **kwargs)
79
+ return value
80
+
81
+ def flush_cache(self):
82
+ self.disk_cache.clear()
83
+
84
+ async def disconnect(self):
85
+ pass
86
+
87
+ def delete_cache(self, key):
88
+ self.disk_cache.pop(key)
litellm/caching/dual_cache.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
3
+
4
+ Has 4 primary methods:
5
+ - set_cache
6
+ - get_cache
7
+ - async_set_cache
8
+ - async_get_cache
9
+ """
10
+
11
+ import asyncio
12
+ import time
13
+ import traceback
14
+ from concurrent.futures import ThreadPoolExecutor
15
+ from typing import TYPE_CHECKING, Any, List, Optional, Union
16
+
17
+ import litellm
18
+ from litellm._logging import print_verbose, verbose_logger
19
+
20
+ from .base_cache import BaseCache
21
+ from .in_memory_cache import InMemoryCache
22
+ from .redis_cache import RedisCache
23
+
24
+ if TYPE_CHECKING:
25
+ from opentelemetry.trace import Span as _Span
26
+
27
+ Span = Union[_Span, Any]
28
+ else:
29
+ Span = Any
30
+
31
+ from collections import OrderedDict
32
+
33
+
34
+ class LimitedSizeOrderedDict(OrderedDict):
35
+ def __init__(self, *args, max_size=100, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+ self.max_size = max_size
38
+
39
+ def __setitem__(self, key, value):
40
+ # If inserting a new key exceeds max size, remove the oldest item
41
+ if len(self) >= self.max_size:
42
+ self.popitem(last=False)
43
+ super().__setitem__(key, value)
44
+
45
+
46
+ class DualCache(BaseCache):
47
+ """
48
+ DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
49
+ When data is updated or inserted, it is written to both the in-memory cache + Redis.
50
+ This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_memory_cache: Optional[InMemoryCache] = None,
56
+ redis_cache: Optional[RedisCache] = None,
57
+ default_in_memory_ttl: Optional[float] = None,
58
+ default_redis_ttl: Optional[float] = None,
59
+ default_redis_batch_cache_expiry: Optional[float] = None,
60
+ default_max_redis_batch_cache_size: int = 100,
61
+ ) -> None:
62
+ super().__init__()
63
+ # If in_memory_cache is not provided, use the default InMemoryCache
64
+ self.in_memory_cache = in_memory_cache or InMemoryCache()
65
+ # If redis_cache is not provided, use the default RedisCache
66
+ self.redis_cache = redis_cache
67
+ self.last_redis_batch_access_time = LimitedSizeOrderedDict(
68
+ max_size=default_max_redis_batch_cache_size
69
+ )
70
+ self.redis_batch_cache_expiry = (
71
+ default_redis_batch_cache_expiry
72
+ or litellm.default_redis_batch_cache_expiry
73
+ or 10
74
+ )
75
+ self.default_in_memory_ttl = (
76
+ default_in_memory_ttl or litellm.default_in_memory_ttl
77
+ )
78
+ self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
79
+
80
+ def update_cache_ttl(
81
+ self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
82
+ ):
83
+ if default_in_memory_ttl is not None:
84
+ self.default_in_memory_ttl = default_in_memory_ttl
85
+
86
+ if default_redis_ttl is not None:
87
+ self.default_redis_ttl = default_redis_ttl
88
+
89
+ def set_cache(self, key, value, local_only: bool = False, **kwargs):
90
+ # Update both Redis and in-memory cache
91
+ try:
92
+ if self.in_memory_cache is not None:
93
+ if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
94
+ kwargs["ttl"] = self.default_in_memory_ttl
95
+
96
+ self.in_memory_cache.set_cache(key, value, **kwargs)
97
+
98
+ if self.redis_cache is not None and local_only is False:
99
+ self.redis_cache.set_cache(key, value, **kwargs)
100
+ except Exception as e:
101
+ print_verbose(e)
102
+
103
+ def increment_cache(
104
+ self, key, value: int, local_only: bool = False, **kwargs
105
+ ) -> int:
106
+ """
107
+ Key - the key in cache
108
+
109
+ Value - int - the value you want to increment by
110
+
111
+ Returns - int - the incremented value
112
+ """
113
+ try:
114
+ result: int = value
115
+ if self.in_memory_cache is not None:
116
+ result = self.in_memory_cache.increment_cache(key, value, **kwargs)
117
+
118
+ if self.redis_cache is not None and local_only is False:
119
+ result = self.redis_cache.increment_cache(key, value, **kwargs)
120
+
121
+ return result
122
+ except Exception as e:
123
+ verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
124
+ raise e
125
+
126
+ def get_cache(
127
+ self,
128
+ key,
129
+ parent_otel_span: Optional[Span] = None,
130
+ local_only: bool = False,
131
+ **kwargs,
132
+ ):
133
+ # Try to fetch from in-memory cache first
134
+ try:
135
+ result = None
136
+ if self.in_memory_cache is not None:
137
+ in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
138
+
139
+ if in_memory_result is not None:
140
+ result = in_memory_result
141
+
142
+ if result is None and self.redis_cache is not None and local_only is False:
143
+ # If not found in in-memory cache, try fetching from Redis
144
+ redis_result = self.redis_cache.get_cache(
145
+ key, parent_otel_span=parent_otel_span
146
+ )
147
+
148
+ if redis_result is not None:
149
+ # Update in-memory cache with the value from Redis
150
+ self.in_memory_cache.set_cache(key, redis_result, **kwargs)
151
+
152
+ result = redis_result
153
+
154
+ print_verbose(f"get cache: cache result: {result}")
155
+ return result
156
+ except Exception:
157
+ verbose_logger.error(traceback.format_exc())
158
+
159
+ def batch_get_cache(
160
+ self,
161
+ keys: list,
162
+ parent_otel_span: Optional[Span] = None,
163
+ local_only: bool = False,
164
+ **kwargs,
165
+ ):
166
+ received_args = locals()
167
+ received_args.pop("self")
168
+
169
+ def run_in_new_loop():
170
+ """Run the coroutine in a new event loop within this thread."""
171
+ new_loop = asyncio.new_event_loop()
172
+ try:
173
+ asyncio.set_event_loop(new_loop)
174
+ return new_loop.run_until_complete(
175
+ self.async_batch_get_cache(**received_args)
176
+ )
177
+ finally:
178
+ new_loop.close()
179
+ asyncio.set_event_loop(None)
180
+
181
+ try:
182
+ # First, try to get the current event loop
183
+ _ = asyncio.get_running_loop()
184
+ # If we're already in an event loop, run in a separate thread
185
+ # to avoid nested event loop issues
186
+ with ThreadPoolExecutor(max_workers=1) as executor:
187
+ future = executor.submit(run_in_new_loop)
188
+ return future.result()
189
+
190
+ except RuntimeError:
191
+ # No running event loop, we can safely run in this thread
192
+ return run_in_new_loop()
193
+
194
+ async def async_get_cache(
195
+ self,
196
+ key,
197
+ parent_otel_span: Optional[Span] = None,
198
+ local_only: bool = False,
199
+ **kwargs,
200
+ ):
201
+ # Try to fetch from in-memory cache first
202
+ try:
203
+ print_verbose(
204
+ f"async get cache: cache key: {key}; local_only: {local_only}"
205
+ )
206
+ result = None
207
+ if self.in_memory_cache is not None:
208
+ in_memory_result = await self.in_memory_cache.async_get_cache(
209
+ key, **kwargs
210
+ )
211
+
212
+ print_verbose(f"in_memory_result: {in_memory_result}")
213
+ if in_memory_result is not None:
214
+ result = in_memory_result
215
+
216
+ if result is None and self.redis_cache is not None and local_only is False:
217
+ # If not found in in-memory cache, try fetching from Redis
218
+ redis_result = await self.redis_cache.async_get_cache(
219
+ key, parent_otel_span=parent_otel_span
220
+ )
221
+
222
+ if redis_result is not None:
223
+ # Update in-memory cache with the value from Redis
224
+ await self.in_memory_cache.async_set_cache(
225
+ key, redis_result, **kwargs
226
+ )
227
+
228
+ result = redis_result
229
+
230
+ print_verbose(f"get cache: cache result: {result}")
231
+ return result
232
+ except Exception:
233
+ verbose_logger.error(traceback.format_exc())
234
+
235
+ def get_redis_batch_keys(
236
+ self,
237
+ current_time: float,
238
+ keys: List[str],
239
+ result: List[Any],
240
+ ) -> List[str]:
241
+ sublist_keys = []
242
+ for key, value in zip(keys, result):
243
+ if value is None:
244
+ if (
245
+ key not in self.last_redis_batch_access_time
246
+ or current_time - self.last_redis_batch_access_time[key]
247
+ >= self.redis_batch_cache_expiry
248
+ ):
249
+ sublist_keys.append(key)
250
+ return sublist_keys
251
+
252
+ async def async_batch_get_cache(
253
+ self,
254
+ keys: list,
255
+ parent_otel_span: Optional[Span] = None,
256
+ local_only: bool = False,
257
+ **kwargs,
258
+ ):
259
+ try:
260
+ result = [None for _ in range(len(keys))]
261
+ if self.in_memory_cache is not None:
262
+ in_memory_result = await self.in_memory_cache.async_batch_get_cache(
263
+ keys, **kwargs
264
+ )
265
+
266
+ if in_memory_result is not None:
267
+ result = in_memory_result
268
+
269
+ if None in result and self.redis_cache is not None and local_only is False:
270
+ """
271
+ - for the none values in the result
272
+ - check the redis cache
273
+ """
274
+ current_time = time.time()
275
+ sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
276
+
277
+ # Only hit Redis if the last access time was more than 5 seconds ago
278
+ if len(sublist_keys) > 0:
279
+ # If not found in in-memory cache, try fetching from Redis
280
+ redis_result = await self.redis_cache.async_batch_get_cache(
281
+ sublist_keys, parent_otel_span=parent_otel_span
282
+ )
283
+
284
+ if redis_result is not None:
285
+ # Update in-memory cache with the value from Redis
286
+ for key, value in redis_result.items():
287
+ if value is not None:
288
+ await self.in_memory_cache.async_set_cache(
289
+ key, redis_result[key], **kwargs
290
+ )
291
+ # Update the last access time for each key fetched from Redis
292
+ self.last_redis_batch_access_time[key] = current_time
293
+
294
+ for key, value in redis_result.items():
295
+ index = keys.index(key)
296
+ result[index] = value
297
+
298
+ return result
299
+ except Exception:
300
+ verbose_logger.error(traceback.format_exc())
301
+
302
+ async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
303
+ print_verbose(
304
+ f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
305
+ )
306
+ try:
307
+ if self.in_memory_cache is not None:
308
+ await self.in_memory_cache.async_set_cache(key, value, **kwargs)
309
+
310
+ if self.redis_cache is not None and local_only is False:
311
+ await self.redis_cache.async_set_cache(key, value, **kwargs)
312
+ except Exception as e:
313
+ verbose_logger.exception(
314
+ f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
315
+ )
316
+
317
+ # async_batch_set_cache
318
+ async def async_set_cache_pipeline(
319
+ self, cache_list: list, local_only: bool = False, **kwargs
320
+ ):
321
+ """
322
+ Batch write values to the cache
323
+ """
324
+ print_verbose(
325
+ f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
326
+ )
327
+ try:
328
+ if self.in_memory_cache is not None:
329
+ await self.in_memory_cache.async_set_cache_pipeline(
330
+ cache_list=cache_list, **kwargs
331
+ )
332
+
333
+ if self.redis_cache is not None and local_only is False:
334
+ await self.redis_cache.async_set_cache_pipeline(
335
+ cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
336
+ )
337
+ except Exception as e:
338
+ verbose_logger.exception(
339
+ f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
340
+ )
341
+
342
+ async def async_increment_cache(
343
+ self,
344
+ key,
345
+ value: float,
346
+ parent_otel_span: Optional[Span] = None,
347
+ local_only: bool = False,
348
+ **kwargs,
349
+ ) -> float:
350
+ """
351
+ Key - the key in cache
352
+
353
+ Value - float - the value you want to increment by
354
+
355
+ Returns - float - the incremented value
356
+ """
357
+ try:
358
+ result: float = value
359
+ if self.in_memory_cache is not None:
360
+ result = await self.in_memory_cache.async_increment(
361
+ key, value, **kwargs
362
+ )
363
+
364
+ if self.redis_cache is not None and local_only is False:
365
+ result = await self.redis_cache.async_increment(
366
+ key,
367
+ value,
368
+ parent_otel_span=parent_otel_span,
369
+ ttl=kwargs.get("ttl", None),
370
+ )
371
+
372
+ return result
373
+ except Exception as e:
374
+ raise e # don't log if exception is raised
375
+
376
+ async def async_set_cache_sadd(
377
+ self, key, value: List, local_only: bool = False, **kwargs
378
+ ) -> None:
379
+ """
380
+ Add value to a set
381
+
382
+ Key - the key in cache
383
+
384
+ Value - str - the value you want to add to the set
385
+
386
+ Returns - None
387
+ """
388
+ try:
389
+ if self.in_memory_cache is not None:
390
+ _ = await self.in_memory_cache.async_set_cache_sadd(
391
+ key, value, ttl=kwargs.get("ttl", None)
392
+ )
393
+
394
+ if self.redis_cache is not None and local_only is False:
395
+ _ = await self.redis_cache.async_set_cache_sadd(
396
+ key, value, ttl=kwargs.get("ttl", None)
397
+ )
398
+
399
+ return None
400
+ except Exception as e:
401
+ raise e # don't log, if exception is raised
402
+
403
+ def flush_cache(self):
404
+ if self.in_memory_cache is not None:
405
+ self.in_memory_cache.flush_cache()
406
+ if self.redis_cache is not None:
407
+ self.redis_cache.flush_cache()
408
+
409
+ def delete_cache(self, key):
410
+ """
411
+ Delete a key from the cache
412
+ """
413
+ if self.in_memory_cache is not None:
414
+ self.in_memory_cache.delete_cache(key)
415
+ if self.redis_cache is not None:
416
+ self.redis_cache.delete_cache(key)
417
+
418
+ async def async_delete_cache(self, key: str):
419
+ """
420
+ Delete a key from the cache
421
+ """
422
+ if self.in_memory_cache is not None:
423
+ self.in_memory_cache.delete_cache(key)
424
+ if self.redis_cache is not None:
425
+ await self.redis_cache.async_delete_cache(key)
426
+
427
+ async def async_get_ttl(self, key: str) -> Optional[int]:
428
+ """
429
+ Get the remaining TTL of a key in in-memory cache or redis
430
+ """
431
+ ttl = await self.in_memory_cache.async_get_ttl(key)
432
+ if ttl is None and self.redis_cache is not None:
433
+ ttl = await self.redis_cache.async_get_ttl(key)
434
+ return ttl
litellm/caching/in_memory_cache.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-Memory Cache implementation
3
+
4
+ Has 4 methods:
5
+ - set_cache
6
+ - get_cache
7
+ - async_set_cache
8
+ - async_get_cache
9
+ """
10
+
11
+ import json
12
+ import sys
13
+ import time
14
+ from typing import Any, List, Optional
15
+
16
+ from pydantic import BaseModel
17
+
18
+ from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
19
+
20
+ from .base_cache import BaseCache
21
+
22
+
23
+ class InMemoryCache(BaseCache):
24
+ def __init__(
25
+ self,
26
+ max_size_in_memory: Optional[int] = 200,
27
+ default_ttl: Optional[
28
+ int
29
+ ] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
30
+ max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
31
+ ):
32
+ """
33
+ max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
34
+ """
35
+ self.max_size_in_memory = (
36
+ max_size_in_memory or 200
37
+ ) # set an upper bound of 200 items in-memory
38
+ self.default_ttl = default_ttl or 600
39
+ self.max_size_per_item = (
40
+ max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
41
+ ) # 1MB = 1024KB
42
+
43
+ # in-memory cache
44
+ self.cache_dict: dict = {}
45
+ self.ttl_dict: dict = {}
46
+
47
+ def check_value_size(self, value: Any):
48
+ """
49
+ Check if value size exceeds max_size_per_item (1MB)
50
+ Returns True if value size is acceptable, False otherwise
51
+ """
52
+ try:
53
+ # Fast path for common primitive types that are typically small
54
+ if (
55
+ isinstance(value, (bool, int, float, str))
56
+ and len(str(value))
57
+ < self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
58
+ ): # Conservative estimate
59
+ return True
60
+
61
+ # Direct size check for bytes objects
62
+ if isinstance(value, bytes):
63
+ return sys.getsizeof(value) / 1024 <= self.max_size_per_item
64
+
65
+ # Handle special types without full conversion when possible
66
+ if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
67
+ size = value.__sizeof__() / 1024
68
+ return size <= self.max_size_per_item
69
+
70
+ # Fallback for complex types
71
+ if isinstance(value, BaseModel) and hasattr(
72
+ value, "model_dump"
73
+ ): # Pydantic v2
74
+ value = value.model_dump()
75
+ elif hasattr(value, "isoformat"): # datetime objects
76
+ return True # datetime strings are always small
77
+
78
+ # Only convert to JSON if absolutely necessary
79
+ if not isinstance(value, (str, bytes)):
80
+ value = json.dumps(value, default=str)
81
+
82
+ return sys.getsizeof(value) / 1024 <= self.max_size_per_item
83
+
84
+ except Exception:
85
+ return False
86
+
87
+ def evict_cache(self):
88
+ """
89
+ Eviction policy:
90
+ - check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
91
+
92
+
93
+ This guarantees the following:
94
+ - 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
95
+ - 2. When ttl is set: the item will remain in memory for at least that amount of time
96
+ - 3. the size of in-memory cache is bounded
97
+
98
+ """
99
+ for key in list(self.ttl_dict.keys()):
100
+ if time.time() > self.ttl_dict[key]:
101
+ self.cache_dict.pop(key, None)
102
+ self.ttl_dict.pop(key, None)
103
+
104
+ # de-reference the removed item
105
+ # https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
106
+ # One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
107
+ # This can occur when an object is referenced by another object, but the reference is never removed.
108
+
109
+ def set_cache(self, key, value, **kwargs):
110
+ if len(self.cache_dict) >= self.max_size_in_memory:
111
+ # only evict when cache is full
112
+ self.evict_cache()
113
+ if not self.check_value_size(value):
114
+ return
115
+
116
+ self.cache_dict[key] = value
117
+ if "ttl" in kwargs and kwargs["ttl"] is not None:
118
+ self.ttl_dict[key] = time.time() + kwargs["ttl"]
119
+ else:
120
+ self.ttl_dict[key] = time.time() + self.default_ttl
121
+
122
+ async def async_set_cache(self, key, value, **kwargs):
123
+ self.set_cache(key=key, value=value, **kwargs)
124
+
125
+ async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
126
+ for cache_key, cache_value in cache_list:
127
+ if ttl is not None:
128
+ self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
129
+ else:
130
+ self.set_cache(key=cache_key, value=cache_value)
131
+
132
+ async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
133
+ """
134
+ Add value to set
135
+ """
136
+ # get the value
137
+ init_value = self.get_cache(key=key) or set()
138
+ for val in value:
139
+ init_value.add(val)
140
+ self.set_cache(key, init_value, ttl=ttl)
141
+ return value
142
+
143
+ def get_cache(self, key, **kwargs):
144
+ if key in self.cache_dict:
145
+ if key in self.ttl_dict:
146
+ if time.time() > self.ttl_dict[key]:
147
+ self.cache_dict.pop(key, None)
148
+ return None
149
+ original_cached_response = self.cache_dict[key]
150
+ try:
151
+ cached_response = json.loads(original_cached_response)
152
+ except Exception:
153
+ cached_response = original_cached_response
154
+ return cached_response
155
+ return None
156
+
157
+ def batch_get_cache(self, keys: list, **kwargs):
158
+ return_val = []
159
+ for k in keys:
160
+ val = self.get_cache(key=k, **kwargs)
161
+ return_val.append(val)
162
+ return return_val
163
+
164
+ def increment_cache(self, key, value: int, **kwargs) -> int:
165
+ # get the value
166
+ init_value = self.get_cache(key=key) or 0
167
+ value = init_value + value
168
+ self.set_cache(key, value, **kwargs)
169
+ return value
170
+
171
+ async def async_get_cache(self, key, **kwargs):
172
+ return self.get_cache(key=key, **kwargs)
173
+
174
+ async def async_batch_get_cache(self, keys: list, **kwargs):
175
+ return_val = []
176
+ for k in keys:
177
+ val = self.get_cache(key=k, **kwargs)
178
+ return_val.append(val)
179
+ return return_val
180
+
181
+ async def async_increment(self, key, value: float, **kwargs) -> float:
182
+ # get the value
183
+ init_value = await self.async_get_cache(key=key) or 0
184
+ value = init_value + value
185
+ await self.async_set_cache(key, value, **kwargs)
186
+ return value
187
+
188
+ def flush_cache(self):
189
+ self.cache_dict.clear()
190
+ self.ttl_dict.clear()
191
+
192
+ async def disconnect(self):
193
+ pass
194
+
195
+ def delete_cache(self, key):
196
+ self.cache_dict.pop(key, None)
197
+ self.ttl_dict.pop(key, None)
198
+
199
+ async def async_get_ttl(self, key: str) -> Optional[int]:
200
+ """
201
+ Get the remaining TTL of a key in in-memory cache
202
+ """
203
+ return self.ttl_dict.get(key, None)
litellm/caching/llm_caching_handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Add the event loop to the cache key, to prevent event loop closed errors.
3
+ """
4
+
5
+ import asyncio
6
+
7
+ from .in_memory_cache import InMemoryCache
8
+
9
+
10
+ class LLMClientCache(InMemoryCache):
11
+ def update_cache_key_with_event_loop(self, key):
12
+ """
13
+ Add the event loop to the cache key, to prevent event loop closed errors.
14
+ If none, use the key as is.
15
+ """
16
+ try:
17
+ event_loop = asyncio.get_event_loop()
18
+ stringified_event_loop = str(id(event_loop))
19
+ return f"{key}-{stringified_event_loop}"
20
+ except Exception: # handle no current event loop
21
+ return key
22
+
23
+ def set_cache(self, key, value, **kwargs):
24
+ key = self.update_cache_key_with_event_loop(key)
25
+ return super().set_cache(key, value, **kwargs)
26
+
27
+ async def async_set_cache(self, key, value, **kwargs):
28
+ key = self.update_cache_key_with_event_loop(key)
29
+ return await super().async_set_cache(key, value, **kwargs)
30
+
31
+ def get_cache(self, key, **kwargs):
32
+ key = self.update_cache_key_with_event_loop(key)
33
+
34
+ return super().get_cache(key, **kwargs)
35
+
36
+ async def async_get_cache(self, key, **kwargs):
37
+ key = self.update_cache_key_with_event_loop(key)
38
+
39
+ return await super().async_get_cache(key, **kwargs)
litellm/caching/qdrant_semantic_cache.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Qdrant Semantic Cache implementation
3
+
4
+ Has 4 methods:
5
+ - set_cache
6
+ - get_cache
7
+ - async_set_cache
8
+ - async_get_cache
9
+ """
10
+
11
+ import ast
12
+ import asyncio
13
+ import json
14
+ from typing import Any, cast
15
+
16
+ import litellm
17
+ from litellm._logging import print_verbose
18
+ from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
19
+ from litellm.types.utils import EmbeddingResponse
20
+
21
+ from .base_cache import BaseCache
22
+
23
+
24
+ class QdrantSemanticCache(BaseCache):
25
+ def __init__( # noqa: PLR0915
26
+ self,
27
+ qdrant_api_base=None,
28
+ qdrant_api_key=None,
29
+ collection_name=None,
30
+ similarity_threshold=None,
31
+ quantization_config=None,
32
+ embedding_model="text-embedding-ada-002",
33
+ host_type=None,
34
+ ):
35
+ import os
36
+
37
+ from litellm.llms.custom_httpx.http_handler import (
38
+ _get_httpx_client,
39
+ get_async_httpx_client,
40
+ httpxSpecialProvider,
41
+ )
42
+ from litellm.secret_managers.main import get_secret_str
43
+
44
+ if collection_name is None:
45
+ raise Exception("collection_name must be provided, passed None")
46
+
47
+ self.collection_name = collection_name
48
+ print_verbose(
49
+ f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
50
+ )
51
+
52
+ if similarity_threshold is None:
53
+ raise Exception("similarity_threshold must be provided, passed None")
54
+ self.similarity_threshold = similarity_threshold
55
+ self.embedding_model = embedding_model
56
+ headers = {}
57
+
58
+ # check if defined as os.environ/ variable
59
+ if qdrant_api_base:
60
+ if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
61
+ "os.environ/"
62
+ ):
63
+ qdrant_api_base = get_secret_str(qdrant_api_base)
64
+ if qdrant_api_key:
65
+ if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
66
+ "os.environ/"
67
+ ):
68
+ qdrant_api_key = get_secret_str(qdrant_api_key)
69
+
70
+ qdrant_api_base = (
71
+ qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
72
+ )
73
+ qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
74
+ headers = {"Content-Type": "application/json"}
75
+ if qdrant_api_key:
76
+ headers["api-key"] = qdrant_api_key
77
+
78
+ if qdrant_api_base is None:
79
+ raise ValueError("Qdrant url must be provided")
80
+
81
+ self.qdrant_api_base = qdrant_api_base
82
+ self.qdrant_api_key = qdrant_api_key
83
+ print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
84
+
85
+ self.headers = headers
86
+
87
+ self.sync_client = _get_httpx_client()
88
+ self.async_client = get_async_httpx_client(
89
+ llm_provider=httpxSpecialProvider.Caching
90
+ )
91
+
92
+ if quantization_config is None:
93
+ print_verbose(
94
+ "Quantization config is not provided. Default binary quantization will be used."
95
+ )
96
+ collection_exists = self.sync_client.get(
97
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
98
+ headers=self.headers,
99
+ )
100
+ if collection_exists.status_code != 200:
101
+ raise ValueError(
102
+ f"Error from qdrant checking if /collections exist {collection_exists.text}"
103
+ )
104
+
105
+ if collection_exists.json()["result"]["exists"]:
106
+ collection_details = self.sync_client.get(
107
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
108
+ headers=self.headers,
109
+ )
110
+ self.collection_info = collection_details.json()
111
+ print_verbose(
112
+ f"Collection already exists.\nCollection details:{self.collection_info}"
113
+ )
114
+ else:
115
+ if quantization_config is None or quantization_config == "binary":
116
+ quantization_params = {
117
+ "binary": {
118
+ "always_ram": False,
119
+ }
120
+ }
121
+ elif quantization_config == "scalar":
122
+ quantization_params = {
123
+ "scalar": {
124
+ "type": "int8",
125
+ "quantile": QDRANT_SCALAR_QUANTILE,
126
+ "always_ram": False,
127
+ }
128
+ }
129
+ elif quantization_config == "product":
130
+ quantization_params = {
131
+ "product": {"compression": "x16", "always_ram": False}
132
+ }
133
+ else:
134
+ raise Exception(
135
+ "Quantization config must be one of 'scalar', 'binary' or 'product'"
136
+ )
137
+
138
+ new_collection_status = self.sync_client.put(
139
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
140
+ json={
141
+ "vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
142
+ "quantization_config": quantization_params,
143
+ },
144
+ headers=self.headers,
145
+ )
146
+ if new_collection_status.json()["result"]:
147
+ collection_details = self.sync_client.get(
148
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
149
+ headers=self.headers,
150
+ )
151
+ self.collection_info = collection_details.json()
152
+ print_verbose(
153
+ f"New collection created.\nCollection details:{self.collection_info}"
154
+ )
155
+ else:
156
+ raise Exception("Error while creating new collection")
157
+
158
+ def _get_cache_logic(self, cached_response: Any):
159
+ if cached_response is None:
160
+ return cached_response
161
+ try:
162
+ cached_response = json.loads(
163
+ cached_response
164
+ ) # Convert string to dictionary
165
+ except Exception:
166
+ cached_response = ast.literal_eval(cached_response)
167
+ return cached_response
168
+
169
+ def set_cache(self, key, value, **kwargs):
170
+ print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
171
+ import uuid
172
+
173
+ # get the prompt
174
+ messages = kwargs["messages"]
175
+ prompt = ""
176
+ for message in messages:
177
+ prompt += message["content"]
178
+
179
+ # create an embedding for prompt
180
+ embedding_response = cast(
181
+ EmbeddingResponse,
182
+ litellm.embedding(
183
+ model=self.embedding_model,
184
+ input=prompt,
185
+ cache={"no-store": True, "no-cache": True},
186
+ ),
187
+ )
188
+
189
+ # get the embedding
190
+ embedding = embedding_response["data"][0]["embedding"]
191
+
192
+ value = str(value)
193
+ assert isinstance(value, str)
194
+
195
+ data = {
196
+ "points": [
197
+ {
198
+ "id": str(uuid.uuid4()),
199
+ "vector": embedding,
200
+ "payload": {
201
+ "text": prompt,
202
+ "response": value,
203
+ },
204
+ },
205
+ ]
206
+ }
207
+ self.sync_client.put(
208
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
209
+ headers=self.headers,
210
+ json=data,
211
+ )
212
+ return
213
+
214
+ def get_cache(self, key, **kwargs):
215
+ print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
216
+
217
+ # get the messages
218
+ messages = kwargs["messages"]
219
+ prompt = ""
220
+ for message in messages:
221
+ prompt += message["content"]
222
+
223
+ # convert to embedding
224
+ embedding_response = cast(
225
+ EmbeddingResponse,
226
+ litellm.embedding(
227
+ model=self.embedding_model,
228
+ input=prompt,
229
+ cache={"no-store": True, "no-cache": True},
230
+ ),
231
+ )
232
+
233
+ # get the embedding
234
+ embedding = embedding_response["data"][0]["embedding"]
235
+
236
+ data = {
237
+ "vector": embedding,
238
+ "params": {
239
+ "quantization": {
240
+ "ignore": False,
241
+ "rescore": True,
242
+ "oversampling": 3.0,
243
+ }
244
+ },
245
+ "limit": 1,
246
+ "with_payload": True,
247
+ }
248
+
249
+ search_response = self.sync_client.post(
250
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
251
+ headers=self.headers,
252
+ json=data,
253
+ )
254
+ results = search_response.json()["result"]
255
+
256
+ if results is None:
257
+ return None
258
+ if isinstance(results, list):
259
+ if len(results) == 0:
260
+ return None
261
+
262
+ similarity = results[0]["score"]
263
+ cached_prompt = results[0]["payload"]["text"]
264
+
265
+ # check similarity, if more than self.similarity_threshold, return results
266
+ print_verbose(
267
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
268
+ )
269
+ if similarity >= self.similarity_threshold:
270
+ # cache hit !
271
+ cached_value = results[0]["payload"]["response"]
272
+ print_verbose(
273
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
274
+ )
275
+ return self._get_cache_logic(cached_response=cached_value)
276
+ else:
277
+ # cache miss !
278
+ return None
279
+ pass
280
+
281
+ async def async_set_cache(self, key, value, **kwargs):
282
+ import uuid
283
+
284
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
285
+
286
+ print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
287
+
288
+ # get the prompt
289
+ messages = kwargs["messages"]
290
+ prompt = ""
291
+ for message in messages:
292
+ prompt += message["content"]
293
+ # create an embedding for prompt
294
+ router_model_names = (
295
+ [m["model_name"] for m in llm_model_list]
296
+ if llm_model_list is not None
297
+ else []
298
+ )
299
+ if llm_router is not None and self.embedding_model in router_model_names:
300
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
301
+ embedding_response = await llm_router.aembedding(
302
+ model=self.embedding_model,
303
+ input=prompt,
304
+ cache={"no-store": True, "no-cache": True},
305
+ metadata={
306
+ "user_api_key": user_api_key,
307
+ "semantic-cache-embedding": True,
308
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
309
+ },
310
+ )
311
+ else:
312
+ # convert to embedding
313
+ embedding_response = await litellm.aembedding(
314
+ model=self.embedding_model,
315
+ input=prompt,
316
+ cache={"no-store": True, "no-cache": True},
317
+ )
318
+
319
+ # get the embedding
320
+ embedding = embedding_response["data"][0]["embedding"]
321
+
322
+ value = str(value)
323
+ assert isinstance(value, str)
324
+
325
+ data = {
326
+ "points": [
327
+ {
328
+ "id": str(uuid.uuid4()),
329
+ "vector": embedding,
330
+ "payload": {
331
+ "text": prompt,
332
+ "response": value,
333
+ },
334
+ },
335
+ ]
336
+ }
337
+
338
+ await self.async_client.put(
339
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
340
+ headers=self.headers,
341
+ json=data,
342
+ )
343
+ return
344
+
345
+ async def async_get_cache(self, key, **kwargs):
346
+ print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
347
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
348
+
349
+ # get the messages
350
+ messages = kwargs["messages"]
351
+ prompt = ""
352
+ for message in messages:
353
+ prompt += message["content"]
354
+
355
+ router_model_names = (
356
+ [m["model_name"] for m in llm_model_list]
357
+ if llm_model_list is not None
358
+ else []
359
+ )
360
+ if llm_router is not None and self.embedding_model in router_model_names:
361
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
362
+ embedding_response = await llm_router.aembedding(
363
+ model=self.embedding_model,
364
+ input=prompt,
365
+ cache={"no-store": True, "no-cache": True},
366
+ metadata={
367
+ "user_api_key": user_api_key,
368
+ "semantic-cache-embedding": True,
369
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
370
+ },
371
+ )
372
+ else:
373
+ # convert to embedding
374
+ embedding_response = await litellm.aembedding(
375
+ model=self.embedding_model,
376
+ input=prompt,
377
+ cache={"no-store": True, "no-cache": True},
378
+ )
379
+
380
+ # get the embedding
381
+ embedding = embedding_response["data"][0]["embedding"]
382
+
383
+ data = {
384
+ "vector": embedding,
385
+ "params": {
386
+ "quantization": {
387
+ "ignore": False,
388
+ "rescore": True,
389
+ "oversampling": 3.0,
390
+ }
391
+ },
392
+ "limit": 1,
393
+ "with_payload": True,
394
+ }
395
+
396
+ search_response = await self.async_client.post(
397
+ url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
398
+ headers=self.headers,
399
+ json=data,
400
+ )
401
+
402
+ results = search_response.json()["result"]
403
+
404
+ if results is None:
405
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
406
+ return None
407
+ if isinstance(results, list):
408
+ if len(results) == 0:
409
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
410
+ return None
411
+
412
+ similarity = results[0]["score"]
413
+ cached_prompt = results[0]["payload"]["text"]
414
+
415
+ # check similarity, if more than self.similarity_threshold, return results
416
+ print_verbose(
417
+ f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
418
+ )
419
+
420
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
421
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
422
+
423
+ if similarity >= self.similarity_threshold:
424
+ # cache hit !
425
+ cached_value = results[0]["payload"]["response"]
426
+ print_verbose(
427
+ f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
428
+ )
429
+ return self._get_cache_logic(cached_response=cached_value)
430
+ else:
431
+ # cache miss !
432
+ return None
433
+ pass
434
+
435
+ async def _collection_info(self):
436
+ return self.collection_info
437
+
438
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
439
+ tasks = []
440
+ for val in cache_list:
441
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
442
+ await asyncio.gather(*tasks)
litellm/caching/redis_cache.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Redis Cache implementation
3
+
4
+ Has 4 primary methods:
5
+ - set_cache
6
+ - get_cache
7
+ - async_set_cache
8
+ - async_get_cache
9
+ """
10
+
11
+ import ast
12
+ import asyncio
13
+ import inspect
14
+ import json
15
+ import time
16
+ from datetime import timedelta
17
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
18
+
19
+ import litellm
20
+ from litellm._logging import print_verbose, verbose_logger
21
+ from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
22
+ from litellm.types.caching import RedisPipelineIncrementOperation
23
+ from litellm.types.services import ServiceTypes
24
+
25
+ from .base_cache import BaseCache
26
+
27
+ if TYPE_CHECKING:
28
+ from opentelemetry.trace import Span as _Span
29
+ from redis.asyncio import Redis, RedisCluster
30
+ from redis.asyncio.client import Pipeline
31
+ from redis.asyncio.cluster import ClusterPipeline
32
+
33
+ pipeline = Pipeline
34
+ cluster_pipeline = ClusterPipeline
35
+ async_redis_client = Redis
36
+ async_redis_cluster_client = RedisCluster
37
+ Span = Union[_Span, Any]
38
+ else:
39
+ pipeline = Any
40
+ cluster_pipeline = Any
41
+ async_redis_client = Any
42
+ async_redis_cluster_client = Any
43
+ Span = Any
44
+
45
+
46
+ class RedisCache(BaseCache):
47
+ # if users don't provider one, use the default litellm cache
48
+
49
+ def __init__(
50
+ self,
51
+ host=None,
52
+ port=None,
53
+ password=None,
54
+ redis_flush_size: Optional[int] = 100,
55
+ namespace: Optional[str] = None,
56
+ startup_nodes: Optional[List] = None, # for redis-cluster
57
+ socket_timeout: Optional[float] = 5.0, # default 5 second timeout
58
+ **kwargs,
59
+ ):
60
+ from litellm._service_logger import ServiceLogging
61
+
62
+ from .._redis import get_redis_client, get_redis_connection_pool
63
+
64
+ redis_kwargs = {}
65
+ if host is not None:
66
+ redis_kwargs["host"] = host
67
+ if port is not None:
68
+ redis_kwargs["port"] = port
69
+ if password is not None:
70
+ redis_kwargs["password"] = password
71
+ if startup_nodes is not None:
72
+ redis_kwargs["startup_nodes"] = startup_nodes
73
+ if socket_timeout is not None:
74
+ redis_kwargs["socket_timeout"] = socket_timeout
75
+
76
+ ### HEALTH MONITORING OBJECT ###
77
+ if kwargs.get("service_logger_obj", None) is not None and isinstance(
78
+ kwargs["service_logger_obj"], ServiceLogging
79
+ ):
80
+ self.service_logger_obj = kwargs.pop("service_logger_obj")
81
+ else:
82
+ self.service_logger_obj = ServiceLogging()
83
+
84
+ redis_kwargs.update(kwargs)
85
+ self.redis_client = get_redis_client(**redis_kwargs)
86
+ self.redis_async_client: Optional[async_redis_client] = None
87
+ self.redis_kwargs = redis_kwargs
88
+ self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
89
+
90
+ # redis namespaces
91
+ self.namespace = namespace
92
+ # for high traffic, we store the redis results in memory and then batch write to redis
93
+ self.redis_batch_writing_buffer: list = []
94
+ if redis_flush_size is None:
95
+ self.redis_flush_size: int = 100
96
+ else:
97
+ self.redis_flush_size = redis_flush_size
98
+ self.redis_version = "Unknown"
99
+ try:
100
+ if not inspect.iscoroutinefunction(self.redis_client):
101
+ self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
102
+ except Exception:
103
+ pass
104
+
105
+ ### ASYNC HEALTH PING ###
106
+ try:
107
+ # asyncio.get_running_loop().create_task(self.ping())
108
+ _ = asyncio.get_running_loop().create_task(self.ping())
109
+ except Exception as e:
110
+ if "no running event loop" in str(e):
111
+ verbose_logger.debug(
112
+ "Ignoring async redis ping. No running event loop."
113
+ )
114
+ else:
115
+ verbose_logger.error(
116
+ "Error connecting to Async Redis client - {}".format(str(e)),
117
+ extra={"error": str(e)},
118
+ )
119
+
120
+ ### SYNC HEALTH PING ###
121
+ try:
122
+ if hasattr(self.redis_client, "ping"):
123
+ self.redis_client.ping() # type: ignore
124
+ except Exception as e:
125
+ verbose_logger.error(
126
+ "Error connecting to Sync Redis client", extra={"error": str(e)}
127
+ )
128
+
129
+ if litellm.default_redis_ttl is not None:
130
+ super().__init__(default_ttl=int(litellm.default_redis_ttl))
131
+ else:
132
+ super().__init__() # defaults to 60s
133
+
134
+ def init_async_client(
135
+ self,
136
+ ) -> Union[async_redis_client, async_redis_cluster_client]:
137
+ from .._redis import get_redis_async_client
138
+
139
+ if self.redis_async_client is None:
140
+ self.redis_async_client = get_redis_async_client(
141
+ connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
142
+ )
143
+ return self.redis_async_client
144
+
145
+ def check_and_fix_namespace(self, key: str) -> str:
146
+ """
147
+ Make sure each key starts with the given namespace
148
+ """
149
+ if self.namespace is not None and not key.startswith(self.namespace):
150
+ key = self.namespace + ":" + key
151
+
152
+ return key
153
+
154
+ def set_cache(self, key, value, **kwargs):
155
+ ttl = self.get_ttl(**kwargs)
156
+ print_verbose(
157
+ f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
158
+ )
159
+ key = self.check_and_fix_namespace(key=key)
160
+ try:
161
+ start_time = time.time()
162
+ self.redis_client.set(name=key, value=str(value), ex=ttl)
163
+ end_time = time.time()
164
+ _duration = end_time - start_time
165
+ self.service_logger_obj.service_success_hook(
166
+ service=ServiceTypes.REDIS,
167
+ duration=_duration,
168
+ call_type="set_cache",
169
+ start_time=start_time,
170
+ end_time=end_time,
171
+ )
172
+ except Exception as e:
173
+ # NON blocking - notify users Redis is throwing an exception
174
+ print_verbose(
175
+ f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}"
176
+ )
177
+
178
+ def increment_cache(
179
+ self, key, value: int, ttl: Optional[float] = None, **kwargs
180
+ ) -> int:
181
+ _redis_client = self.redis_client
182
+ start_time = time.time()
183
+ set_ttl = self.get_ttl(ttl=ttl)
184
+ try:
185
+ start_time = time.time()
186
+ result: int = _redis_client.incr(name=key, amount=value) # type: ignore
187
+ end_time = time.time()
188
+ _duration = end_time - start_time
189
+ self.service_logger_obj.service_success_hook(
190
+ service=ServiceTypes.REDIS,
191
+ duration=_duration,
192
+ call_type="increment_cache",
193
+ start_time=start_time,
194
+ end_time=end_time,
195
+ )
196
+
197
+ if set_ttl is not None:
198
+ # check if key already has ttl, if not -> set ttl
199
+ start_time = time.time()
200
+ current_ttl = _redis_client.ttl(key)
201
+ end_time = time.time()
202
+ _duration = end_time - start_time
203
+ self.service_logger_obj.service_success_hook(
204
+ service=ServiceTypes.REDIS,
205
+ duration=_duration,
206
+ call_type="increment_cache_ttl",
207
+ start_time=start_time,
208
+ end_time=end_time,
209
+ )
210
+ if current_ttl == -1:
211
+ # Key has no expiration
212
+ start_time = time.time()
213
+ _redis_client.expire(key, set_ttl) # type: ignore
214
+ end_time = time.time()
215
+ _duration = end_time - start_time
216
+ self.service_logger_obj.service_success_hook(
217
+ service=ServiceTypes.REDIS,
218
+ duration=_duration,
219
+ call_type="increment_cache_expire",
220
+ start_time=start_time,
221
+ end_time=end_time,
222
+ )
223
+ return result
224
+ except Exception as e:
225
+ ## LOGGING ##
226
+ end_time = time.time()
227
+ _duration = end_time - start_time
228
+ verbose_logger.error(
229
+ "LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
230
+ str(e),
231
+ value,
232
+ )
233
+ raise e
234
+
235
+ async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
236
+ start_time = time.time()
237
+ try:
238
+ keys = []
239
+ _redis_client = self.init_async_client()
240
+ if not hasattr(_redis_client, "scan_iter"):
241
+ verbose_logger.debug(
242
+ "Redis client does not support scan_iter, potentially using Redis Cluster. Returning empty list."
243
+ )
244
+ return []
245
+
246
+ async for key in _redis_client.scan_iter(match=pattern + "*", count=count): # type: ignore
247
+ keys.append(key)
248
+ if len(keys) >= count:
249
+ break
250
+
251
+ ## LOGGING ##
252
+ end_time = time.time()
253
+ _duration = end_time - start_time
254
+ asyncio.create_task(
255
+ self.service_logger_obj.async_service_success_hook(
256
+ service=ServiceTypes.REDIS,
257
+ duration=_duration,
258
+ call_type="async_scan_iter",
259
+ start_time=start_time,
260
+ end_time=end_time,
261
+ )
262
+ ) # DO NOT SLOW DOWN CALL B/C OF THIS
263
+ return keys
264
+ except Exception as e:
265
+ # NON blocking - notify users Redis is throwing an exception
266
+ ## LOGGING ##
267
+ end_time = time.time()
268
+ _duration = end_time - start_time
269
+ asyncio.create_task(
270
+ self.service_logger_obj.async_service_failure_hook(
271
+ service=ServiceTypes.REDIS,
272
+ duration=_duration,
273
+ error=e,
274
+ call_type="async_scan_iter",
275
+ start_time=start_time,
276
+ end_time=end_time,
277
+ )
278
+ )
279
+ raise e
280
+
281
+ async def async_set_cache(self, key, value, **kwargs):
282
+ from redis.asyncio import Redis
283
+
284
+ start_time = time.time()
285
+ try:
286
+ _redis_client: Redis = self.init_async_client() # type: ignore
287
+ except Exception as e:
288
+ end_time = time.time()
289
+ _duration = end_time - start_time
290
+ asyncio.create_task(
291
+ self.service_logger_obj.async_service_failure_hook(
292
+ service=ServiceTypes.REDIS,
293
+ duration=_duration,
294
+ error=e,
295
+ start_time=start_time,
296
+ end_time=end_time,
297
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
298
+ call_type="async_set_cache",
299
+ )
300
+ )
301
+ verbose_logger.error(
302
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
303
+ str(e),
304
+ value,
305
+ )
306
+ raise e
307
+
308
+ key = self.check_and_fix_namespace(key=key)
309
+ ttl = self.get_ttl(**kwargs)
310
+ nx = kwargs.get("nx", False)
311
+ print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
312
+
313
+ try:
314
+ if not hasattr(_redis_client, "set"):
315
+ raise Exception("Redis client cannot set cache. Attribute not found.")
316
+ result = await _redis_client.set(
317
+ name=key,
318
+ value=json.dumps(value),
319
+ nx=nx,
320
+ ex=ttl,
321
+ )
322
+ print_verbose(
323
+ f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
324
+ )
325
+ end_time = time.time()
326
+ _duration = end_time - start_time
327
+ asyncio.create_task(
328
+ self.service_logger_obj.async_service_success_hook(
329
+ service=ServiceTypes.REDIS,
330
+ duration=_duration,
331
+ call_type="async_set_cache",
332
+ start_time=start_time,
333
+ end_time=end_time,
334
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
335
+ event_metadata={"key": key},
336
+ )
337
+ )
338
+ return result
339
+ except Exception as e:
340
+ end_time = time.time()
341
+ _duration = end_time - start_time
342
+ asyncio.create_task(
343
+ self.service_logger_obj.async_service_failure_hook(
344
+ service=ServiceTypes.REDIS,
345
+ duration=_duration,
346
+ error=e,
347
+ call_type="async_set_cache",
348
+ start_time=start_time,
349
+ end_time=end_time,
350
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
351
+ event_metadata={"key": key},
352
+ )
353
+ )
354
+ verbose_logger.error(
355
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
356
+ str(e),
357
+ value,
358
+ )
359
+
360
+ async def _pipeline_helper(
361
+ self,
362
+ pipe: Union[pipeline, cluster_pipeline],
363
+ cache_list: List[Tuple[Any, Any]],
364
+ ttl: Optional[float],
365
+ ) -> List:
366
+ """
367
+ Helper function for executing a pipeline of set operations on Redis
368
+ """
369
+ ttl = self.get_ttl(ttl=ttl)
370
+ # Iterate through each key-value pair in the cache_list and set them in the pipeline.
371
+ for cache_key, cache_value in cache_list:
372
+ cache_key = self.check_and_fix_namespace(key=cache_key)
373
+ print_verbose(
374
+ f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
375
+ )
376
+ json_cache_value = json.dumps(cache_value)
377
+ # Set the value with a TTL if it's provided.
378
+ _td: Optional[timedelta] = None
379
+ if ttl is not None:
380
+ _td = timedelta(seconds=ttl)
381
+ pipe.set( # type: ignore
382
+ name=cache_key,
383
+ value=json_cache_value,
384
+ ex=_td,
385
+ )
386
+ # Execute the pipeline and return the results.
387
+ results = await pipe.execute()
388
+ return results
389
+
390
+ async def async_set_cache_pipeline(
391
+ self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
392
+ ):
393
+ """
394
+ Use Redis Pipelines for bulk write operations
395
+ """
396
+ # don't waste a network request if there's nothing to set
397
+ if len(cache_list) == 0:
398
+ return
399
+
400
+ _redis_client = self.init_async_client()
401
+ start_time = time.time()
402
+
403
+ print_verbose(
404
+ f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
405
+ )
406
+ cache_value: Any = None
407
+ try:
408
+ async with _redis_client.pipeline(transaction=False) as pipe:
409
+ results = await self._pipeline_helper(pipe, cache_list, ttl)
410
+
411
+ print_verbose(f"pipeline results: {results}")
412
+ # Optionally, you could process 'results' to make sure that all set operations were successful.
413
+ ## LOGGING ##
414
+ end_time = time.time()
415
+ _duration = end_time - start_time
416
+ asyncio.create_task(
417
+ self.service_logger_obj.async_service_success_hook(
418
+ service=ServiceTypes.REDIS,
419
+ duration=_duration,
420
+ call_type="async_set_cache_pipeline",
421
+ start_time=start_time,
422
+ end_time=end_time,
423
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
424
+ )
425
+ )
426
+ return None
427
+ except Exception as e:
428
+ ## LOGGING ##
429
+ end_time = time.time()
430
+ _duration = end_time - start_time
431
+ asyncio.create_task(
432
+ self.service_logger_obj.async_service_failure_hook(
433
+ service=ServiceTypes.REDIS,
434
+ duration=_duration,
435
+ error=e,
436
+ call_type="async_set_cache_pipeline",
437
+ start_time=start_time,
438
+ end_time=end_time,
439
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
440
+ )
441
+ )
442
+
443
+ verbose_logger.error(
444
+ "LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
445
+ str(e),
446
+ cache_value,
447
+ )
448
+
449
+ async def _set_cache_sadd_helper(
450
+ self,
451
+ redis_client: async_redis_client,
452
+ key: str,
453
+ value: List,
454
+ ttl: Optional[float],
455
+ ) -> None:
456
+ """Helper function for async_set_cache_sadd. Separated for testing."""
457
+ ttl = self.get_ttl(ttl=ttl)
458
+ try:
459
+ await redis_client.sadd(key, *value) # type: ignore
460
+ if ttl is not None:
461
+ _td = timedelta(seconds=ttl)
462
+ await redis_client.expire(key, _td)
463
+ except Exception:
464
+ raise
465
+
466
+ async def async_set_cache_sadd(
467
+ self, key, value: List, ttl: Optional[float], **kwargs
468
+ ):
469
+ from redis.asyncio import Redis
470
+
471
+ start_time = time.time()
472
+ try:
473
+ _redis_client: Redis = self.init_async_client() # type: ignore
474
+ except Exception as e:
475
+ end_time = time.time()
476
+ _duration = end_time - start_time
477
+ asyncio.create_task(
478
+ self.service_logger_obj.async_service_failure_hook(
479
+ service=ServiceTypes.REDIS,
480
+ duration=_duration,
481
+ error=e,
482
+ start_time=start_time,
483
+ end_time=end_time,
484
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
485
+ call_type="async_set_cache_sadd",
486
+ )
487
+ )
488
+ # NON blocking - notify users Redis is throwing an exception
489
+ verbose_logger.error(
490
+ "LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
491
+ str(e),
492
+ value,
493
+ )
494
+ raise e
495
+
496
+ key = self.check_and_fix_namespace(key=key)
497
+ print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
498
+ try:
499
+ await self._set_cache_sadd_helper(
500
+ redis_client=_redis_client, key=key, value=value, ttl=ttl
501
+ )
502
+ print_verbose(
503
+ f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
504
+ )
505
+ end_time = time.time()
506
+ _duration = end_time - start_time
507
+ asyncio.create_task(
508
+ self.service_logger_obj.async_service_success_hook(
509
+ service=ServiceTypes.REDIS,
510
+ duration=_duration,
511
+ call_type="async_set_cache_sadd",
512
+ start_time=start_time,
513
+ end_time=end_time,
514
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
515
+ )
516
+ )
517
+ except Exception as e:
518
+ end_time = time.time()
519
+ _duration = end_time - start_time
520
+ asyncio.create_task(
521
+ self.service_logger_obj.async_service_failure_hook(
522
+ service=ServiceTypes.REDIS,
523
+ duration=_duration,
524
+ error=e,
525
+ call_type="async_set_cache_sadd",
526
+ start_time=start_time,
527
+ end_time=end_time,
528
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
529
+ )
530
+ )
531
+ # NON blocking - notify users Redis is throwing an exception
532
+ verbose_logger.error(
533
+ "LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
534
+ str(e),
535
+ value,
536
+ )
537
+
538
+ async def batch_cache_write(self, key, value, **kwargs):
539
+ print_verbose(
540
+ f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
541
+ )
542
+ key = self.check_and_fix_namespace(key=key)
543
+ self.redis_batch_writing_buffer.append((key, value))
544
+ if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
545
+ await self.flush_cache_buffer() # logging done in here
546
+
547
+ async def async_increment(
548
+ self,
549
+ key,
550
+ value: float,
551
+ ttl: Optional[int] = None,
552
+ parent_otel_span: Optional[Span] = None,
553
+ ) -> float:
554
+ from redis.asyncio import Redis
555
+
556
+ _redis_client: Redis = self.init_async_client() # type: ignore
557
+ start_time = time.time()
558
+ _used_ttl = self.get_ttl(ttl=ttl)
559
+ key = self.check_and_fix_namespace(key=key)
560
+ try:
561
+ result = await _redis_client.incrbyfloat(name=key, amount=value)
562
+ if _used_ttl is not None:
563
+ # check if key already has ttl, if not -> set ttl
564
+ current_ttl = await _redis_client.ttl(key)
565
+ if current_ttl == -1:
566
+ # Key has no expiration
567
+ await _redis_client.expire(key, _used_ttl)
568
+
569
+ ## LOGGING ##
570
+ end_time = time.time()
571
+ _duration = end_time - start_time
572
+
573
+ asyncio.create_task(
574
+ self.service_logger_obj.async_service_success_hook(
575
+ service=ServiceTypes.REDIS,
576
+ duration=_duration,
577
+ call_type="async_increment",
578
+ start_time=start_time,
579
+ end_time=end_time,
580
+ parent_otel_span=parent_otel_span,
581
+ )
582
+ )
583
+ return result
584
+ except Exception as e:
585
+ ## LOGGING ##
586
+ end_time = time.time()
587
+ _duration = end_time - start_time
588
+ asyncio.create_task(
589
+ self.service_logger_obj.async_service_failure_hook(
590
+ service=ServiceTypes.REDIS,
591
+ duration=_duration,
592
+ error=e,
593
+ call_type="async_increment",
594
+ start_time=start_time,
595
+ end_time=end_time,
596
+ parent_otel_span=parent_otel_span,
597
+ )
598
+ )
599
+ verbose_logger.error(
600
+ "LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
601
+ str(e),
602
+ value,
603
+ )
604
+ raise e
605
+
606
+ async def flush_cache_buffer(self):
607
+ print_verbose(
608
+ f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
609
+ )
610
+ await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
611
+ self.redis_batch_writing_buffer = []
612
+
613
+ def _get_cache_logic(self, cached_response: Any):
614
+ """
615
+ Common 'get_cache_logic' across sync + async redis client implementations
616
+ """
617
+ if cached_response is None:
618
+ return cached_response
619
+ # cached_response is in `b{} convert it to ModelResponse
620
+ cached_response = cached_response.decode("utf-8") # Convert bytes to string
621
+ try:
622
+ cached_response = json.loads(
623
+ cached_response
624
+ ) # Convert string to dictionary
625
+ except Exception:
626
+ cached_response = ast.literal_eval(cached_response)
627
+ return cached_response
628
+
629
+ def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
630
+ try:
631
+ key = self.check_and_fix_namespace(key=key)
632
+ print_verbose(f"Get Redis Cache: key: {key}")
633
+ start_time = time.time()
634
+ cached_response = self.redis_client.get(key)
635
+ end_time = time.time()
636
+ _duration = end_time - start_time
637
+ self.service_logger_obj.service_success_hook(
638
+ service=ServiceTypes.REDIS,
639
+ duration=_duration,
640
+ call_type="get_cache",
641
+ start_time=start_time,
642
+ end_time=end_time,
643
+ parent_otel_span=parent_otel_span,
644
+ )
645
+ print_verbose(
646
+ f"Got Redis Cache: key: {key}, cached_response {cached_response}"
647
+ )
648
+ return self._get_cache_logic(cached_response=cached_response)
649
+ except Exception as e:
650
+ # NON blocking - notify users Redis is throwing an exception
651
+ verbose_logger.error(
652
+ "litellm.caching.caching: get() - Got exception from REDIS: ", e
653
+ )
654
+
655
+ def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
656
+ """
657
+ Wrapper to call `mget` on the redis client
658
+
659
+ We use a wrapper so RedisCluster can override this method
660
+ """
661
+ return self.redis_client.mget(keys=keys) # type: ignore
662
+
663
+ async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
664
+ """
665
+ Wrapper to call `mget` on the redis client
666
+
667
+ We use a wrapper so RedisCluster can override this method
668
+ """
669
+ async_redis_client = self.init_async_client()
670
+ return await async_redis_client.mget(keys=keys) # type: ignore
671
+
672
+ def batch_get_cache(
673
+ self,
674
+ key_list: Union[List[str], List[Optional[str]]],
675
+ parent_otel_span: Optional[Span] = None,
676
+ ) -> dict:
677
+ """
678
+ Use Redis for bulk read operations
679
+
680
+ Args:
681
+ key_list: List of keys to get from Redis
682
+ parent_otel_span: Optional parent OpenTelemetry span
683
+
684
+ Returns:
685
+ dict: A dictionary mapping keys to their cached values
686
+ """
687
+ key_value_dict = {}
688
+ _key_list = [key for key in key_list if key is not None]
689
+
690
+ try:
691
+ _keys = []
692
+ for cache_key in _key_list:
693
+ cache_key = self.check_and_fix_namespace(key=cache_key or "")
694
+ _keys.append(cache_key)
695
+ start_time = time.time()
696
+ results: List = self._run_redis_mget_operation(keys=_keys)
697
+ end_time = time.time()
698
+ _duration = end_time - start_time
699
+ self.service_logger_obj.service_success_hook(
700
+ service=ServiceTypes.REDIS,
701
+ duration=_duration,
702
+ call_type="batch_get_cache",
703
+ start_time=start_time,
704
+ end_time=end_time,
705
+ parent_otel_span=parent_otel_span,
706
+ )
707
+
708
+ # Associate the results back with their keys.
709
+ # 'results' is a list of values corresponding to the order of keys in '_key_list'.
710
+ key_value_dict = dict(zip(_key_list, results))
711
+
712
+ decoded_results = {}
713
+ for k, v in key_value_dict.items():
714
+ if isinstance(k, bytes):
715
+ k = k.decode("utf-8")
716
+ v = self._get_cache_logic(v)
717
+ decoded_results[k] = v
718
+
719
+ return decoded_results
720
+ except Exception as e:
721
+ verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
722
+ return key_value_dict
723
+
724
+ async def async_get_cache(
725
+ self, key, parent_otel_span: Optional[Span] = None, **kwargs
726
+ ):
727
+ from redis.asyncio import Redis
728
+
729
+ _redis_client: Redis = self.init_async_client() # type: ignore
730
+ key = self.check_and_fix_namespace(key=key)
731
+ start_time = time.time()
732
+
733
+ try:
734
+ print_verbose(f"Get Async Redis Cache: key: {key}")
735
+ cached_response = await _redis_client.get(key)
736
+ print_verbose(
737
+ f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
738
+ )
739
+ response = self._get_cache_logic(cached_response=cached_response)
740
+
741
+ end_time = time.time()
742
+ _duration = end_time - start_time
743
+ asyncio.create_task(
744
+ self.service_logger_obj.async_service_success_hook(
745
+ service=ServiceTypes.REDIS,
746
+ duration=_duration,
747
+ call_type="async_get_cache",
748
+ start_time=start_time,
749
+ end_time=end_time,
750
+ parent_otel_span=parent_otel_span,
751
+ event_metadata={"key": key},
752
+ )
753
+ )
754
+ return response
755
+ except Exception as e:
756
+ end_time = time.time()
757
+ _duration = end_time - start_time
758
+ asyncio.create_task(
759
+ self.service_logger_obj.async_service_failure_hook(
760
+ service=ServiceTypes.REDIS,
761
+ duration=_duration,
762
+ error=e,
763
+ call_type="async_get_cache",
764
+ start_time=start_time,
765
+ end_time=end_time,
766
+ parent_otel_span=parent_otel_span,
767
+ event_metadata={"key": key},
768
+ )
769
+ )
770
+ print_verbose(
771
+ f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
772
+ )
773
+
774
+ async def async_batch_get_cache(
775
+ self,
776
+ key_list: Union[List[str], List[Optional[str]]],
777
+ parent_otel_span: Optional[Span] = None,
778
+ ) -> dict:
779
+ """
780
+ Use Redis for bulk read operations
781
+
782
+ Args:
783
+ key_list: List of keys to get from Redis
784
+ parent_otel_span: Optional parent OpenTelemetry span
785
+
786
+ Returns:
787
+ dict: A dictionary mapping keys to their cached values
788
+
789
+ `.mget` does not support None keys. This will filter out None keys.
790
+ """
791
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
792
+ key_value_dict = {}
793
+ start_time = time.time()
794
+ _key_list = [key for key in key_list if key is not None]
795
+ try:
796
+ _keys = []
797
+ for cache_key in _key_list:
798
+ cache_key = self.check_and_fix_namespace(key=cache_key)
799
+ _keys.append(cache_key)
800
+ results = await self._async_run_redis_mget_operation(keys=_keys)
801
+ ## LOGGING ##
802
+ end_time = time.time()
803
+ _duration = end_time - start_time
804
+ asyncio.create_task(
805
+ self.service_logger_obj.async_service_success_hook(
806
+ service=ServiceTypes.REDIS,
807
+ duration=_duration,
808
+ call_type="async_batch_get_cache",
809
+ start_time=start_time,
810
+ end_time=end_time,
811
+ parent_otel_span=parent_otel_span,
812
+ )
813
+ )
814
+
815
+ # Associate the results back with their keys.
816
+ # 'results' is a list of values corresponding to the order of keys in 'key_list'.
817
+ key_value_dict = dict(zip(_key_list, results))
818
+
819
+ decoded_results = {}
820
+ for k, v in key_value_dict.items():
821
+ if isinstance(k, bytes):
822
+ k = k.decode("utf-8")
823
+ v = self._get_cache_logic(v)
824
+ decoded_results[k] = v
825
+
826
+ return decoded_results
827
+ except Exception as e:
828
+ ## LOGGING ##
829
+ end_time = time.time()
830
+ _duration = end_time - start_time
831
+ asyncio.create_task(
832
+ self.service_logger_obj.async_service_failure_hook(
833
+ service=ServiceTypes.REDIS,
834
+ duration=_duration,
835
+ error=e,
836
+ call_type="async_batch_get_cache",
837
+ start_time=start_time,
838
+ end_time=end_time,
839
+ parent_otel_span=parent_otel_span,
840
+ )
841
+ )
842
+ verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}")
843
+ return key_value_dict
844
+
845
+ def sync_ping(self) -> bool:
846
+ """
847
+ Tests if the sync redis client is correctly setup.
848
+ """
849
+ print_verbose("Pinging Sync Redis Cache")
850
+ start_time = time.time()
851
+ try:
852
+ response: bool = self.redis_client.ping() # type: ignore
853
+ print_verbose(f"Redis Cache PING: {response}")
854
+ ## LOGGING ##
855
+ end_time = time.time()
856
+ _duration = end_time - start_time
857
+ self.service_logger_obj.service_success_hook(
858
+ service=ServiceTypes.REDIS,
859
+ duration=_duration,
860
+ call_type="sync_ping",
861
+ start_time=start_time,
862
+ end_time=end_time,
863
+ )
864
+ return response
865
+ except Exception as e:
866
+ # NON blocking - notify users Redis is throwing an exception
867
+ ## LOGGING ##
868
+ end_time = time.time()
869
+ _duration = end_time - start_time
870
+ self.service_logger_obj.service_failure_hook(
871
+ service=ServiceTypes.REDIS,
872
+ duration=_duration,
873
+ error=e,
874
+ call_type="sync_ping",
875
+ )
876
+ verbose_logger.error(
877
+ f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
878
+ )
879
+ raise e
880
+
881
+ async def ping(self) -> bool:
882
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
883
+ _redis_client: Any = self.init_async_client()
884
+ start_time = time.time()
885
+ print_verbose("Pinging Async Redis Cache")
886
+ try:
887
+ response = await _redis_client.ping()
888
+ ## LOGGING ##
889
+ end_time = time.time()
890
+ _duration = end_time - start_time
891
+ asyncio.create_task(
892
+ self.service_logger_obj.async_service_success_hook(
893
+ service=ServiceTypes.REDIS,
894
+ duration=_duration,
895
+ call_type="async_ping",
896
+ )
897
+ )
898
+ return response
899
+ except Exception as e:
900
+ # NON blocking - notify users Redis is throwing an exception
901
+ ## LOGGING ##
902
+ end_time = time.time()
903
+ _duration = end_time - start_time
904
+ asyncio.create_task(
905
+ self.service_logger_obj.async_service_failure_hook(
906
+ service=ServiceTypes.REDIS,
907
+ duration=_duration,
908
+ error=e,
909
+ call_type="async_ping",
910
+ )
911
+ )
912
+ verbose_logger.error(
913
+ f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
914
+ )
915
+ raise e
916
+
917
+ async def delete_cache_keys(self, keys):
918
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
919
+ _redis_client: Any = self.init_async_client()
920
+ # keys is a list, unpack it so it gets passed as individual elements to delete
921
+ await _redis_client.delete(*keys)
922
+
923
+ def client_list(self) -> List:
924
+ client_list: List = self.redis_client.client_list() # type: ignore
925
+ return client_list
926
+
927
+ def info(self):
928
+ info = self.redis_client.info()
929
+ return info
930
+
931
+ def flush_cache(self):
932
+ self.redis_client.flushall()
933
+
934
+ def flushall(self):
935
+ self.redis_client.flushall()
936
+
937
+ async def disconnect(self):
938
+ await self.async_redis_conn_pool.disconnect(inuse_connections=True)
939
+
940
+ async def async_delete_cache(self, key: str):
941
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
942
+ _redis_client: Any = self.init_async_client()
943
+ # keys is str
944
+ return await _redis_client.delete(key)
945
+
946
+ def delete_cache(self, key):
947
+ self.redis_client.delete(key)
948
+
949
+ async def _pipeline_increment_helper(
950
+ self,
951
+ pipe: pipeline,
952
+ increment_list: List[RedisPipelineIncrementOperation],
953
+ ) -> Optional[List[float]]:
954
+ """Helper function for pipeline increment operations"""
955
+ # Iterate through each increment operation and add commands to pipeline
956
+ for increment_op in increment_list:
957
+ cache_key = self.check_and_fix_namespace(key=increment_op["key"])
958
+ print_verbose(
959
+ f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}"
960
+ )
961
+ pipe.incrbyfloat(cache_key, increment_op["increment_value"])
962
+ if increment_op["ttl"] is not None:
963
+ _td = timedelta(seconds=increment_op["ttl"])
964
+ pipe.expire(cache_key, _td)
965
+ # Execute the pipeline and return results
966
+ results = await pipe.execute()
967
+ print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}")
968
+ return results
969
+
970
+ async def async_increment_pipeline(
971
+ self, increment_list: List[RedisPipelineIncrementOperation], **kwargs
972
+ ) -> Optional[List[float]]:
973
+ """
974
+ Use Redis Pipelines for bulk increment operations
975
+ Args:
976
+ increment_list: List of RedisPipelineIncrementOperation dicts containing:
977
+ - key: str
978
+ - increment_value: float
979
+ - ttl_seconds: int
980
+ """
981
+ # don't waste a network request if there's nothing to increment
982
+ if len(increment_list) == 0:
983
+ return None
984
+
985
+ from redis.asyncio import Redis
986
+
987
+ _redis_client: Redis = self.init_async_client() # type: ignore
988
+ start_time = time.time()
989
+
990
+ print_verbose(
991
+ f"Increment Async Redis Cache Pipeline: increment list: {increment_list}"
992
+ )
993
+
994
+ try:
995
+ async with _redis_client.pipeline(transaction=False) as pipe:
996
+ results = await self._pipeline_increment_helper(pipe, increment_list)
997
+
998
+ print_verbose(f"pipeline increment results: {results}")
999
+
1000
+ ## LOGGING ##
1001
+ end_time = time.time()
1002
+ _duration = end_time - start_time
1003
+ asyncio.create_task(
1004
+ self.service_logger_obj.async_service_success_hook(
1005
+ service=ServiceTypes.REDIS,
1006
+ duration=_duration,
1007
+ call_type="async_increment_pipeline",
1008
+ start_time=start_time,
1009
+ end_time=end_time,
1010
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
1011
+ )
1012
+ )
1013
+ return results
1014
+ except Exception as e:
1015
+ ## LOGGING ##
1016
+ end_time = time.time()
1017
+ _duration = end_time - start_time
1018
+ asyncio.create_task(
1019
+ self.service_logger_obj.async_service_failure_hook(
1020
+ service=ServiceTypes.REDIS,
1021
+ duration=_duration,
1022
+ error=e,
1023
+ call_type="async_increment_pipeline",
1024
+ start_time=start_time,
1025
+ end_time=end_time,
1026
+ parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
1027
+ )
1028
+ )
1029
+ verbose_logger.error(
1030
+ "LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s",
1031
+ str(e),
1032
+ )
1033
+ raise e
1034
+
1035
+ async def async_get_ttl(self, key: str) -> Optional[int]:
1036
+ """
1037
+ Get the remaining TTL of a key in Redis
1038
+
1039
+ Args:
1040
+ key (str): The key to get TTL for
1041
+
1042
+ Returns:
1043
+ Optional[int]: The remaining TTL in seconds, or None if key doesn't exist
1044
+
1045
+ Redis ref: https://redis.io/docs/latest/commands/ttl/
1046
+ """
1047
+ try:
1048
+ # typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
1049
+ _redis_client: Any = self.init_async_client()
1050
+ ttl = await _redis_client.ttl(key)
1051
+ if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
1052
+ return None
1053
+ return ttl
1054
+ except Exception as e:
1055
+ verbose_logger.debug(f"Redis TTL Error: {e}")
1056
+ return None
1057
+
1058
+ async def async_rpush(
1059
+ self,
1060
+ key: str,
1061
+ values: List[Any],
1062
+ parent_otel_span: Optional[Span] = None,
1063
+ **kwargs,
1064
+ ) -> int:
1065
+ """
1066
+ Append one or multiple values to a list stored at key
1067
+
1068
+ Args:
1069
+ key: The Redis key of the list
1070
+ values: One or more values to append to the list
1071
+ parent_otel_span: Optional parent OpenTelemetry span
1072
+
1073
+ Returns:
1074
+ int: The length of the list after the push operation
1075
+ """
1076
+ _redis_client: Any = self.init_async_client()
1077
+ start_time = time.time()
1078
+ try:
1079
+ response = await _redis_client.rpush(key, *values)
1080
+ ## LOGGING ##
1081
+ end_time = time.time()
1082
+ _duration = end_time - start_time
1083
+ asyncio.create_task(
1084
+ self.service_logger_obj.async_service_success_hook(
1085
+ service=ServiceTypes.REDIS,
1086
+ duration=_duration,
1087
+ call_type="async_rpush",
1088
+ )
1089
+ )
1090
+ return response
1091
+ except Exception as e:
1092
+ # NON blocking - notify users Redis is throwing an exception
1093
+ ## LOGGING ##
1094
+ end_time = time.time()
1095
+ _duration = end_time - start_time
1096
+ asyncio.create_task(
1097
+ self.service_logger_obj.async_service_failure_hook(
1098
+ service=ServiceTypes.REDIS,
1099
+ duration=_duration,
1100
+ error=e,
1101
+ call_type="async_rpush",
1102
+ )
1103
+ )
1104
+ verbose_logger.error(
1105
+ f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {str(e)}"
1106
+ )
1107
+ raise e
1108
+
1109
+ async def async_lpop(
1110
+ self,
1111
+ key: str,
1112
+ count: Optional[int] = None,
1113
+ parent_otel_span: Optional[Span] = None,
1114
+ **kwargs,
1115
+ ) -> Union[Any, List[Any]]:
1116
+ _redis_client: Any = self.init_async_client()
1117
+ start_time = time.time()
1118
+ print_verbose(f"LPOP from Redis list: key: {key}, count: {count}")
1119
+ try:
1120
+ result = await _redis_client.lpop(key, count)
1121
+ ## LOGGING ##
1122
+ end_time = time.time()
1123
+ _duration = end_time - start_time
1124
+ asyncio.create_task(
1125
+ self.service_logger_obj.async_service_success_hook(
1126
+ service=ServiceTypes.REDIS,
1127
+ duration=_duration,
1128
+ call_type="async_lpop",
1129
+ )
1130
+ )
1131
+
1132
+ # Handle result parsing if needed
1133
+ if isinstance(result, bytes):
1134
+ try:
1135
+ return result.decode("utf-8")
1136
+ except Exception:
1137
+ return result
1138
+ elif isinstance(result, list) and all(
1139
+ isinstance(item, bytes) for item in result
1140
+ ):
1141
+ try:
1142
+ return [item.decode("utf-8") for item in result]
1143
+ except Exception:
1144
+ return result
1145
+ return result
1146
+ except Exception as e:
1147
+ # NON blocking - notify users Redis is throwing an exception
1148
+ ## LOGGING ##
1149
+ end_time = time.time()
1150
+ _duration = end_time - start_time
1151
+ asyncio.create_task(
1152
+ self.service_logger_obj.async_service_failure_hook(
1153
+ service=ServiceTypes.REDIS,
1154
+ duration=_duration,
1155
+ error=e,
1156
+ call_type="async_lpop",
1157
+ )
1158
+ )
1159
+ verbose_logger.error(
1160
+ f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}"
1161
+ )
1162
+ raise e
litellm/caching/redis_cluster_cache.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Redis Cluster Cache implementation
3
+
4
+ Key differences:
5
+ - RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
6
+ """
7
+
8
+ from typing import TYPE_CHECKING, Any, List, Optional, Union
9
+
10
+ from litellm.caching.redis_cache import RedisCache
11
+
12
+ if TYPE_CHECKING:
13
+ from opentelemetry.trace import Span as _Span
14
+ from redis.asyncio import Redis, RedisCluster
15
+ from redis.asyncio.client import Pipeline
16
+
17
+ pipeline = Pipeline
18
+ async_redis_client = Redis
19
+ Span = Union[_Span, Any]
20
+ else:
21
+ pipeline = Any
22
+ async_redis_client = Any
23
+ Span = Any
24
+
25
+
26
+ class RedisClusterCache(RedisCache):
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
30
+ self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
31
+
32
+ def init_async_client(self):
33
+ from redis.asyncio import RedisCluster
34
+
35
+ from .._redis import get_redis_async_client
36
+
37
+ if self.redis_async_redis_cluster_client:
38
+ return self.redis_async_redis_cluster_client
39
+
40
+ _redis_client = get_redis_async_client(
41
+ connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
42
+ )
43
+ if isinstance(_redis_client, RedisCluster):
44
+ self.redis_async_redis_cluster_client = _redis_client
45
+
46
+ return _redis_client
47
+
48
+ def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
49
+ """
50
+ Overrides `_run_redis_mget_operation` in redis_cache.py
51
+ """
52
+ return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
53
+
54
+ async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
55
+ """
56
+ Overrides `_async_run_redis_mget_operation` in redis_cache.py
57
+ """
58
+ async_redis_cluster_client = self.init_async_client()
59
+ return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
litellm/caching/redis_semantic_cache.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Redis Semantic Cache implementation for LiteLLM
3
+
4
+ The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
5
+ This cache stores responses based on the semantic similarity of prompts rather than
6
+ exact matching, allowing for more flexible caching of LLM responses.
7
+
8
+ This implementation uses RedisVL's SemanticCache to find semantically similar prompts
9
+ and their cached responses.
10
+ """
11
+
12
+ import ast
13
+ import asyncio
14
+ import json
15
+ import os
16
+ from typing import Any, Dict, List, Optional, Tuple, cast
17
+
18
+ import litellm
19
+ from litellm._logging import print_verbose
20
+ from litellm.litellm_core_utils.prompt_templates.common_utils import (
21
+ get_str_from_messages,
22
+ )
23
+ from litellm.types.utils import EmbeddingResponse
24
+
25
+ from .base_cache import BaseCache
26
+
27
+
28
+ class RedisSemanticCache(BaseCache):
29
+ """
30
+ Redis-backed semantic cache for LLM responses.
31
+
32
+ This cache uses vector similarity to find semantically similar prompts that have been
33
+ previously sent to the LLM, allowing for cache hits even when prompts are not identical
34
+ but carry similar meaning.
35
+ """
36
+
37
+ DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
38
+
39
+ def __init__(
40
+ self,
41
+ host: Optional[str] = None,
42
+ port: Optional[str] = None,
43
+ password: Optional[str] = None,
44
+ redis_url: Optional[str] = None,
45
+ similarity_threshold: Optional[float] = None,
46
+ embedding_model: str = "text-embedding-ada-002",
47
+ index_name: Optional[str] = None,
48
+ **kwargs,
49
+ ):
50
+ """
51
+ Initialize the Redis Semantic Cache.
52
+
53
+ Args:
54
+ host: Redis host address
55
+ port: Redis port
56
+ password: Redis password
57
+ redis_url: Full Redis URL (alternative to separate host/port/password)
58
+ similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
59
+ where 1.0 requires exact matches and 0.0 accepts any match
60
+ embedding_model: Model to use for generating embeddings
61
+ index_name: Name for the Redis index
62
+ ttl: Default time-to-live for cache entries in seconds
63
+ **kwargs: Additional arguments passed to the Redis client
64
+
65
+ Raises:
66
+ Exception: If similarity_threshold is not provided or required Redis
67
+ connection information is missing
68
+ """
69
+ from redisvl.extensions.llmcache import SemanticCache
70
+ from redisvl.utils.vectorize import CustomTextVectorizer
71
+
72
+ if index_name is None:
73
+ index_name = self.DEFAULT_REDIS_INDEX_NAME
74
+
75
+ print_verbose(f"Redis semantic-cache initializing index - {index_name}")
76
+
77
+ # Validate similarity threshold
78
+ if similarity_threshold is None:
79
+ raise ValueError("similarity_threshold must be provided, passed None")
80
+
81
+ # Store configuration
82
+ self.similarity_threshold = similarity_threshold
83
+
84
+ # Convert similarity threshold [0,1] to distance threshold [0,2]
85
+ # For cosine distance: 0 = most similar, 2 = least similar
86
+ # While similarity: 1 = most similar, 0 = least similar
87
+ self.distance_threshold = 1 - similarity_threshold
88
+ self.embedding_model = embedding_model
89
+
90
+ # Set up Redis connection
91
+ if redis_url is None:
92
+ try:
93
+ # Attempt to use provided parameters or fallback to environment variables
94
+ host = host or os.environ["REDIS_HOST"]
95
+ port = port or os.environ["REDIS_PORT"]
96
+ password = password or os.environ["REDIS_PASSWORD"]
97
+ except KeyError as e:
98
+ # Raise a more informative exception if any of the required keys are missing
99
+ missing_var = e.args[0]
100
+ raise ValueError(
101
+ f"Missing required Redis configuration: {missing_var}. "
102
+ f"Provide {missing_var} or redis_url."
103
+ ) from e
104
+
105
+ redis_url = f"redis://:{password}@{host}:{port}"
106
+
107
+ print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
108
+
109
+ # Initialize the Redis vectorizer and cache
110
+ cache_vectorizer = CustomTextVectorizer(self._get_embedding)
111
+
112
+ self.llmcache = SemanticCache(
113
+ name=index_name,
114
+ redis_url=redis_url,
115
+ vectorizer=cache_vectorizer,
116
+ distance_threshold=self.distance_threshold,
117
+ overwrite=False,
118
+ )
119
+
120
+ def _get_ttl(self, **kwargs) -> Optional[int]:
121
+ """
122
+ Get the TTL (time-to-live) value for cache entries.
123
+
124
+ Args:
125
+ **kwargs: Keyword arguments that may contain a custom TTL
126
+
127
+ Returns:
128
+ Optional[int]: The TTL value in seconds, or None if no TTL should be applied
129
+ """
130
+ ttl = kwargs.get("ttl")
131
+ if ttl is not None:
132
+ ttl = int(ttl)
133
+ return ttl
134
+
135
+ def _get_embedding(self, prompt: str) -> List[float]:
136
+ """
137
+ Generate an embedding vector for the given prompt using the configured embedding model.
138
+
139
+ Args:
140
+ prompt: The text to generate an embedding for
141
+
142
+ Returns:
143
+ List[float]: The embedding vector
144
+ """
145
+ # Create an embedding from prompt
146
+ embedding_response = cast(
147
+ EmbeddingResponse,
148
+ litellm.embedding(
149
+ model=self.embedding_model,
150
+ input=prompt,
151
+ cache={"no-store": True, "no-cache": True},
152
+ ),
153
+ )
154
+ embedding = embedding_response["data"][0]["embedding"]
155
+ return embedding
156
+
157
+ def _get_cache_logic(self, cached_response: Any) -> Any:
158
+ """
159
+ Process the cached response to prepare it for use.
160
+
161
+ Args:
162
+ cached_response: The raw cached response
163
+
164
+ Returns:
165
+ The processed cache response, or None if input was None
166
+ """
167
+ if cached_response is None:
168
+ return cached_response
169
+
170
+ # Convert bytes to string if needed
171
+ if isinstance(cached_response, bytes):
172
+ cached_response = cached_response.decode("utf-8")
173
+
174
+ # Convert string representation to Python object
175
+ try:
176
+ cached_response = json.loads(cached_response)
177
+ except json.JSONDecodeError:
178
+ try:
179
+ cached_response = ast.literal_eval(cached_response)
180
+ except (ValueError, SyntaxError) as e:
181
+ print_verbose(f"Error parsing cached response: {str(e)}")
182
+ return None
183
+
184
+ return cached_response
185
+
186
+ def set_cache(self, key: str, value: Any, **kwargs) -> None:
187
+ """
188
+ Store a value in the semantic cache.
189
+
190
+ Args:
191
+ key: The cache key (not directly used in semantic caching)
192
+ value: The response value to cache
193
+ **kwargs: Additional arguments including 'messages' for the prompt
194
+ and optional 'ttl' for time-to-live
195
+ """
196
+ print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
197
+
198
+ value_str: Optional[str] = None
199
+ try:
200
+ # Extract the prompt from messages
201
+ messages = kwargs.get("messages", [])
202
+ if not messages:
203
+ print_verbose("No messages provided for semantic caching")
204
+ return
205
+
206
+ prompt = get_str_from_messages(messages)
207
+ value_str = str(value)
208
+
209
+ # Get TTL and store in Redis semantic cache
210
+ ttl = self._get_ttl(**kwargs)
211
+ if ttl is not None:
212
+ self.llmcache.store(prompt, value_str, ttl=int(ttl))
213
+ else:
214
+ self.llmcache.store(prompt, value_str)
215
+ except Exception as e:
216
+ print_verbose(
217
+ f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
218
+ )
219
+
220
+ def get_cache(self, key: str, **kwargs) -> Any:
221
+ """
222
+ Retrieve a semantically similar cached response.
223
+
224
+ Args:
225
+ key: The cache key (not directly used in semantic caching)
226
+ **kwargs: Additional arguments including 'messages' for the prompt
227
+
228
+ Returns:
229
+ The cached response if a semantically similar prompt is found, else None
230
+ """
231
+ print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
232
+
233
+ try:
234
+ # Extract the prompt from messages
235
+ messages = kwargs.get("messages", [])
236
+ if not messages:
237
+ print_verbose("No messages provided for semantic cache lookup")
238
+ return None
239
+
240
+ prompt = get_str_from_messages(messages)
241
+ # Check the cache for semantically similar prompts
242
+ results = self.llmcache.check(prompt=prompt)
243
+
244
+ # Return None if no similar prompts found
245
+ if not results:
246
+ return None
247
+
248
+ # Process the best matching result
249
+ cache_hit = results[0]
250
+ vector_distance = float(cache_hit["vector_distance"])
251
+
252
+ # Convert vector distance back to similarity score
253
+ # For cosine distance: 0 = most similar, 2 = least similar
254
+ # While similarity: 1 = most similar, 0 = least similar
255
+ similarity = 1 - vector_distance
256
+
257
+ cached_prompt = cache_hit["prompt"]
258
+ cached_response = cache_hit["response"]
259
+
260
+ print_verbose(
261
+ f"Cache hit: similarity threshold: {self.similarity_threshold}, "
262
+ f"actual similarity: {similarity}, "
263
+ f"current prompt: {prompt}, "
264
+ f"cached prompt: {cached_prompt}"
265
+ )
266
+
267
+ return self._get_cache_logic(cached_response=cached_response)
268
+ except Exception as e:
269
+ print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
270
+
271
+ async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
272
+ """
273
+ Asynchronously generate an embedding for the given prompt.
274
+
275
+ Args:
276
+ prompt: The text to generate an embedding for
277
+ **kwargs: Additional arguments that may contain metadata
278
+
279
+ Returns:
280
+ List[float]: The embedding vector
281
+ """
282
+ from litellm.proxy.proxy_server import llm_model_list, llm_router
283
+
284
+ # Route the embedding request through the proxy if appropriate
285
+ router_model_names = (
286
+ [m["model_name"] for m in llm_model_list]
287
+ if llm_model_list is not None
288
+ else []
289
+ )
290
+
291
+ try:
292
+ if llm_router is not None and self.embedding_model in router_model_names:
293
+ # Use the router for embedding generation
294
+ user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
295
+ embedding_response = await llm_router.aembedding(
296
+ model=self.embedding_model,
297
+ input=prompt,
298
+ cache={"no-store": True, "no-cache": True},
299
+ metadata={
300
+ "user_api_key": user_api_key,
301
+ "semantic-cache-embedding": True,
302
+ "trace_id": kwargs.get("metadata", {}).get("trace_id", None),
303
+ },
304
+ )
305
+ else:
306
+ # Generate embedding directly
307
+ embedding_response = await litellm.aembedding(
308
+ model=self.embedding_model,
309
+ input=prompt,
310
+ cache={"no-store": True, "no-cache": True},
311
+ )
312
+
313
+ # Extract and return the embedding vector
314
+ return embedding_response["data"][0]["embedding"]
315
+ except Exception as e:
316
+ print_verbose(f"Error generating async embedding: {str(e)}")
317
+ raise ValueError(f"Failed to generate embedding: {str(e)}") from e
318
+
319
+ async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
320
+ """
321
+ Asynchronously store a value in the semantic cache.
322
+
323
+ Args:
324
+ key: The cache key (not directly used in semantic caching)
325
+ value: The response value to cache
326
+ **kwargs: Additional arguments including 'messages' for the prompt
327
+ and optional 'ttl' for time-to-live
328
+ """
329
+ print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
330
+
331
+ try:
332
+ # Extract the prompt from messages
333
+ messages = kwargs.get("messages", [])
334
+ if not messages:
335
+ print_verbose("No messages provided for semantic caching")
336
+ return
337
+
338
+ prompt = get_str_from_messages(messages)
339
+ value_str = str(value)
340
+
341
+ # Generate embedding for the value (response) to cache
342
+ prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
343
+
344
+ # Get TTL and store in Redis semantic cache
345
+ ttl = self._get_ttl(**kwargs)
346
+ if ttl is not None:
347
+ await self.llmcache.astore(
348
+ prompt,
349
+ value_str,
350
+ vector=prompt_embedding, # Pass through custom embedding
351
+ ttl=ttl,
352
+ )
353
+ else:
354
+ await self.llmcache.astore(
355
+ prompt,
356
+ value_str,
357
+ vector=prompt_embedding, # Pass through custom embedding
358
+ )
359
+ except Exception as e:
360
+ print_verbose(f"Error in async_set_cache: {str(e)}")
361
+
362
+ async def async_get_cache(self, key: str, **kwargs) -> Any:
363
+ """
364
+ Asynchronously retrieve a semantically similar cached response.
365
+
366
+ Args:
367
+ key: The cache key (not directly used in semantic caching)
368
+ **kwargs: Additional arguments including 'messages' for the prompt
369
+
370
+ Returns:
371
+ The cached response if a semantically similar prompt is found, else None
372
+ """
373
+ print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
374
+
375
+ try:
376
+ # Extract the prompt from messages
377
+ messages = kwargs.get("messages", [])
378
+ if not messages:
379
+ print_verbose("No messages provided for semantic cache lookup")
380
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
381
+ return None
382
+
383
+ prompt = get_str_from_messages(messages)
384
+
385
+ # Generate embedding for the prompt
386
+ prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
387
+
388
+ # Check the cache for semantically similar prompts
389
+ results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
390
+
391
+ # handle results / cache hit
392
+ if not results:
393
+ kwargs.setdefault("metadata", {})[
394
+ "semantic-similarity"
395
+ ] = 0.0 # TODO why here but not above??
396
+ return None
397
+
398
+ cache_hit = results[0]
399
+ vector_distance = float(cache_hit["vector_distance"])
400
+
401
+ # Convert vector distance back to similarity
402
+ # For cosine distance: 0 = most similar, 2 = least similar
403
+ # While similarity: 1 = most similar, 0 = least similar
404
+ similarity = 1 - vector_distance
405
+
406
+ cached_prompt = cache_hit["prompt"]
407
+ cached_response = cache_hit["response"]
408
+
409
+ # update kwargs["metadata"] with similarity, don't rewrite the original metadata
410
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
411
+
412
+ print_verbose(
413
+ f"Cache hit: similarity threshold: {self.similarity_threshold}, "
414
+ f"actual similarity: {similarity}, "
415
+ f"current prompt: {prompt}, "
416
+ f"cached prompt: {cached_prompt}"
417
+ )
418
+
419
+ return self._get_cache_logic(cached_response=cached_response)
420
+ except Exception as e:
421
+ print_verbose(f"Error in async_get_cache: {str(e)}")
422
+ kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
423
+
424
+ async def _index_info(self) -> Dict[str, Any]:
425
+ """
426
+ Get information about the Redis index.
427
+
428
+ Returns:
429
+ Dict[str, Any]: Information about the Redis index
430
+ """
431
+ aindex = await self.llmcache._get_async_index()
432
+ return await aindex.info()
433
+
434
+ async def async_set_cache_pipeline(
435
+ self, cache_list: List[Tuple[str, Any]], **kwargs
436
+ ) -> None:
437
+ """
438
+ Asynchronously store multiple values in the semantic cache.
439
+
440
+ Args:
441
+ cache_list: List of (key, value) tuples to cache
442
+ **kwargs: Additional arguments
443
+ """
444
+ try:
445
+ tasks = []
446
+ for val in cache_list:
447
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
448
+ await asyncio.gather(*tasks)
449
+ except Exception as e:
450
+ print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
litellm/caching/s3_cache.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ S3 Cache implementation
3
+ WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
4
+
5
+ Has 4 methods:
6
+ - set_cache
7
+ - get_cache
8
+ - async_set_cache
9
+ - async_get_cache
10
+ """
11
+
12
+ import ast
13
+ import asyncio
14
+ import json
15
+ from typing import Optional
16
+
17
+ from litellm._logging import print_verbose, verbose_logger
18
+
19
+ from .base_cache import BaseCache
20
+
21
+
22
+ class S3Cache(BaseCache):
23
+ def __init__(
24
+ self,
25
+ s3_bucket_name,
26
+ s3_region_name=None,
27
+ s3_api_version=None,
28
+ s3_use_ssl: Optional[bool] = True,
29
+ s3_verify=None,
30
+ s3_endpoint_url=None,
31
+ s3_aws_access_key_id=None,
32
+ s3_aws_secret_access_key=None,
33
+ s3_aws_session_token=None,
34
+ s3_config=None,
35
+ s3_path=None,
36
+ **kwargs,
37
+ ):
38
+ import boto3
39
+
40
+ self.bucket_name = s3_bucket_name
41
+ self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
42
+ # Create an S3 client with custom endpoint URL
43
+
44
+ self.s3_client = boto3.client(
45
+ "s3",
46
+ region_name=s3_region_name,
47
+ endpoint_url=s3_endpoint_url,
48
+ api_version=s3_api_version,
49
+ use_ssl=s3_use_ssl,
50
+ verify=s3_verify,
51
+ aws_access_key_id=s3_aws_access_key_id,
52
+ aws_secret_access_key=s3_aws_secret_access_key,
53
+ aws_session_token=s3_aws_session_token,
54
+ config=s3_config,
55
+ **kwargs,
56
+ )
57
+
58
+ def set_cache(self, key, value, **kwargs):
59
+ try:
60
+ print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
61
+ ttl = kwargs.get("ttl", None)
62
+ # Convert value to JSON before storing in S3
63
+ serialized_value = json.dumps(value)
64
+ key = self.key_prefix + key
65
+
66
+ if ttl is not None:
67
+ cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
68
+ import datetime
69
+
70
+ # Calculate expiration time
71
+ expiration_time = datetime.datetime.now() + ttl
72
+
73
+ # Upload the data to S3 with the calculated expiration time
74
+ self.s3_client.put_object(
75
+ Bucket=self.bucket_name,
76
+ Key=key,
77
+ Body=serialized_value,
78
+ Expires=expiration_time,
79
+ CacheControl=cache_control,
80
+ ContentType="application/json",
81
+ ContentLanguage="en",
82
+ ContentDisposition=f'inline; filename="{key}.json"',
83
+ )
84
+ else:
85
+ cache_control = "immutable, max-age=31536000, s-maxage=31536000"
86
+ # Upload the data to S3 without specifying Expires
87
+ self.s3_client.put_object(
88
+ Bucket=self.bucket_name,
89
+ Key=key,
90
+ Body=serialized_value,
91
+ CacheControl=cache_control,
92
+ ContentType="application/json",
93
+ ContentLanguage="en",
94
+ ContentDisposition=f'inline; filename="{key}.json"',
95
+ )
96
+ except Exception as e:
97
+ # NON blocking - notify users S3 is throwing an exception
98
+ print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
99
+
100
+ async def async_set_cache(self, key, value, **kwargs):
101
+ self.set_cache(key=key, value=value, **kwargs)
102
+
103
+ def get_cache(self, key, **kwargs):
104
+ import botocore
105
+
106
+ try:
107
+ key = self.key_prefix + key
108
+
109
+ print_verbose(f"Get S3 Cache: key: {key}")
110
+ # Download the data from S3
111
+ cached_response = self.s3_client.get_object(
112
+ Bucket=self.bucket_name, Key=key
113
+ )
114
+
115
+ if cached_response is not None:
116
+ # cached_response is in `b{} convert it to ModelResponse
117
+ cached_response = (
118
+ cached_response["Body"].read().decode("utf-8")
119
+ ) # Convert bytes to string
120
+ try:
121
+ cached_response = json.loads(
122
+ cached_response
123
+ ) # Convert string to dictionary
124
+ except Exception:
125
+ cached_response = ast.literal_eval(cached_response)
126
+ if not isinstance(cached_response, dict):
127
+ cached_response = dict(cached_response)
128
+ verbose_logger.debug(
129
+ f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
130
+ )
131
+
132
+ return cached_response
133
+ except botocore.exceptions.ClientError as e: # type: ignore
134
+ if e.response["Error"]["Code"] == "NoSuchKey":
135
+ verbose_logger.debug(
136
+ f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
137
+ )
138
+ return None
139
+
140
+ except Exception as e:
141
+ # NON blocking - notify users S3 is throwing an exception
142
+ verbose_logger.error(
143
+ f"S3 Caching: get_cache() - Got exception from S3: {e}"
144
+ )
145
+
146
+ async def async_get_cache(self, key, **kwargs):
147
+ return self.get_cache(key=key, **kwargs)
148
+
149
+ def flush_cache(self):
150
+ pass
151
+
152
+ async def disconnect(self):
153
+ pass
154
+
155
+ async def async_set_cache_pipeline(self, cache_list, **kwargs):
156
+ tasks = []
157
+ for val in cache_list:
158
+ tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
159
+ await asyncio.gather(*tasks)
litellm/constants.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Literal
2
+
3
+ ROUTER_MAX_FALLBACKS = 5
4
+ DEFAULT_BATCH_SIZE = 512
5
+ DEFAULT_FLUSH_INTERVAL_SECONDS = 5
6
+ DEFAULT_MAX_RETRIES = 2
7
+ DEFAULT_MAX_RECURSE_DEPTH = 10
8
+ DEFAULT_FAILURE_THRESHOLD_PERCENT = (
9
+ 0.5 # default cooldown a deployment if 50% of requests fail in a given minute
10
+ )
11
+ DEFAULT_MAX_TOKENS = 4096
12
+ DEFAULT_ALLOWED_FAILS = 3
13
+ DEFAULT_REDIS_SYNC_INTERVAL = 1
14
+ DEFAULT_COOLDOWN_TIME_SECONDS = 5
15
+ DEFAULT_REPLICATE_POLLING_RETRIES = 5
16
+ DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
17
+ DEFAULT_IMAGE_TOKEN_COUNT = 250
18
+ DEFAULT_IMAGE_WIDTH = 300
19
+ DEFAULT_IMAGE_HEIGHT = 300
20
+ DEFAULT_MAX_TOKENS = 256 # used when providers need a default
21
+ MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
22
+ SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
23
+
24
+ DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET = 1024
25
+ DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET = 2048
26
+ DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET = 4096
27
+
28
+ ########## Networking constants ##############################################################
29
+ _DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
30
+
31
+ ########### v2 Architecture constants for managing writing updates to the database ###########
32
+ REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
33
+ REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
34
+ REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
35
+ REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
36
+ MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
37
+ MAX_SIZE_IN_MEMORY_QUEUE = 10000
38
+ MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
39
+ ###############################################################################################
40
+ MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
41
+ 1024 # minimum number of tokens to cache a prompt by Anthropic
42
+ )
43
+ DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
44
+ HOURS_IN_A_DAY = 24
45
+ DAYS_IN_A_WEEK = 7
46
+ DAYS_IN_A_MONTH = 28
47
+ DAYS_IN_A_YEAR = 365
48
+ REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
49
+ #### TOKEN COUNTING ####
50
+ FUNCTION_DEFINITION_TOKEN_COUNT = 9
51
+ SYSTEM_MESSAGE_TOKEN_COUNT = 4
52
+ TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
53
+ DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
54
+ DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
55
+ MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
56
+ MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
57
+ MAX_TILE_WIDTH = 512
58
+ MAX_TILE_HEIGHT = 512
59
+ OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
60
+ MIN_NON_ZERO_TEMPERATURE = 0.0001
61
+ #### RELIABILITY ####
62
+ REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
63
+ DEFAULT_MAX_LRU_CACHE_SIZE = 16
64
+ INITIAL_RETRY_DELAY = 0.5
65
+ MAX_RETRY_DELAY = 8.0
66
+ JITTER = 0.75
67
+ DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
68
+ DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
69
+ AZURE_OPERATION_POLLING_TIMEOUT = 120
70
+ REDIS_SOCKET_TIMEOUT = 0.1
71
+ REDIS_CONNECTION_POOL_TIMEOUT = 5
72
+ NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
73
+ MAX_EXCEPTION_MESSAGE_LENGTH = 2000
74
+ BEDROCK_MAX_POLICY_SIZE = 75
75
+ REPLICATE_POLLING_DELAY_SECONDS = 0.5
76
+ DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
77
+ TOGETHER_AI_4_B = 4
78
+ TOGETHER_AI_8_B = 8
79
+ TOGETHER_AI_21_B = 21
80
+ TOGETHER_AI_41_B = 41
81
+ TOGETHER_AI_80_B = 80
82
+ TOGETHER_AI_110_B = 110
83
+ TOGETHER_AI_EMBEDDING_150_M = 150
84
+ TOGETHER_AI_EMBEDDING_350_M = 350
85
+ QDRANT_SCALAR_QUANTILE = 0.99
86
+ QDRANT_VECTOR_SIZE = 1536
87
+ CACHED_STREAMING_CHUNK_DELAY = 0.02
88
+ MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
89
+ DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
90
+ #### Networking settings ####
91
+ request_timeout: float = 6000 # time in seconds
92
+ STREAM_SSE_DONE_STRING: str = "[DONE]"
93
+ ### SPEND TRACKING ###
94
+ DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
95
+ FIREWORKS_AI_56_B_MOE = 56
96
+ FIREWORKS_AI_176_B_MOE = 176
97
+ FIREWORKS_AI_4_B = 4
98
+ FIREWORKS_AI_16_B = 16
99
+ FIREWORKS_AI_80_B = 80
100
+
101
+ LITELLM_CHAT_PROVIDERS = [
102
+ "openai",
103
+ "openai_like",
104
+ "xai",
105
+ "custom_openai",
106
+ "text-completion-openai",
107
+ "cohere",
108
+ "cohere_chat",
109
+ "clarifai",
110
+ "anthropic",
111
+ "anthropic_text",
112
+ "replicate",
113
+ "huggingface",
114
+ "together_ai",
115
+ "openrouter",
116
+ "vertex_ai",
117
+ "vertex_ai_beta",
118
+ "gemini",
119
+ "ai21",
120
+ "baseten",
121
+ "azure",
122
+ "azure_text",
123
+ "azure_ai",
124
+ "sagemaker",
125
+ "sagemaker_chat",
126
+ "bedrock",
127
+ "vllm",
128
+ "nlp_cloud",
129
+ "petals",
130
+ "oobabooga",
131
+ "ollama",
132
+ "ollama_chat",
133
+ "deepinfra",
134
+ "perplexity",
135
+ "mistral",
136
+ "groq",
137
+ "nvidia_nim",
138
+ "cerebras",
139
+ "ai21_chat",
140
+ "volcengine",
141
+ "codestral",
142
+ "text-completion-codestral",
143
+ "deepseek",
144
+ "sambanova",
145
+ "maritalk",
146
+ "cloudflare",
147
+ "fireworks_ai",
148
+ "friendliai",
149
+ "watsonx",
150
+ "watsonx_text",
151
+ "triton",
152
+ "predibase",
153
+ "databricks",
154
+ "empower",
155
+ "github",
156
+ "custom",
157
+ "litellm_proxy",
158
+ "hosted_vllm",
159
+ "llamafile",
160
+ "lm_studio",
161
+ "galadriel",
162
+ ]
163
+
164
+
165
+ OPENAI_CHAT_COMPLETION_PARAMS = [
166
+ "functions",
167
+ "function_call",
168
+ "temperature",
169
+ "temperature",
170
+ "top_p",
171
+ "n",
172
+ "stream",
173
+ "stream_options",
174
+ "stop",
175
+ "max_completion_tokens",
176
+ "modalities",
177
+ "prediction",
178
+ "audio",
179
+ "max_tokens",
180
+ "presence_penalty",
181
+ "frequency_penalty",
182
+ "logit_bias",
183
+ "user",
184
+ "request_timeout",
185
+ "api_base",
186
+ "api_version",
187
+ "api_key",
188
+ "deployment_id",
189
+ "organization",
190
+ "base_url",
191
+ "default_headers",
192
+ "timeout",
193
+ "response_format",
194
+ "seed",
195
+ "tools",
196
+ "tool_choice",
197
+ "max_retries",
198
+ "parallel_tool_calls",
199
+ "logprobs",
200
+ "top_logprobs",
201
+ "reasoning_effort",
202
+ "extra_headers",
203
+ "thinking",
204
+ ]
205
+
206
+ openai_compatible_endpoints: List = [
207
+ "api.perplexity.ai",
208
+ "api.endpoints.anyscale.com/v1",
209
+ "api.deepinfra.com/v1/openai",
210
+ "api.mistral.ai/v1",
211
+ "codestral.mistral.ai/v1/chat/completions",
212
+ "codestral.mistral.ai/v1/fim/completions",
213
+ "api.groq.com/openai/v1",
214
+ "https://integrate.api.nvidia.com/v1",
215
+ "api.deepseek.com/v1",
216
+ "api.together.xyz/v1",
217
+ "app.empower.dev/api/v1",
218
+ "https://api.friendli.ai/serverless/v1",
219
+ "api.sambanova.ai/v1",
220
+ "api.x.ai/v1",
221
+ "api.galadriel.ai/v1",
222
+ ]
223
+
224
+
225
+ openai_compatible_providers: List = [
226
+ "anyscale",
227
+ "mistral",
228
+ "groq",
229
+ "nvidia_nim",
230
+ "cerebras",
231
+ "sambanova",
232
+ "ai21_chat",
233
+ "ai21",
234
+ "volcengine",
235
+ "codestral",
236
+ "deepseek",
237
+ "deepinfra",
238
+ "perplexity",
239
+ "xinference",
240
+ "xai",
241
+ "together_ai",
242
+ "fireworks_ai",
243
+ "empower",
244
+ "friendliai",
245
+ "azure_ai",
246
+ "github",
247
+ "litellm_proxy",
248
+ "hosted_vllm",
249
+ "llamafile",
250
+ "lm_studio",
251
+ "galadriel",
252
+ ]
253
+ openai_text_completion_compatible_providers: List = (
254
+ [ # providers that support `/v1/completions`
255
+ "together_ai",
256
+ "fireworks_ai",
257
+ "hosted_vllm",
258
+ "llamafile",
259
+ ]
260
+ )
261
+ _openai_like_providers: List = [
262
+ "predibase",
263
+ "databricks",
264
+ "watsonx",
265
+ ] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
266
+ # well supported replicate llms
267
+ replicate_models: List = [
268
+ # llama replicate supported LLMs
269
+ "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
270
+ "a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
271
+ "meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db",
272
+ # Vicuna
273
+ "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
274
+ "joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
275
+ # Flan T-5
276
+ "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
277
+ # Others
278
+ "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
279
+ "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
280
+ ]
281
+
282
+ clarifai_models: List = [
283
+ "clarifai/meta.Llama-3.Llama-3-8B-Instruct",
284
+ "clarifai/gcp.generate.gemma-1_1-7b-it",
285
+ "clarifai/mistralai.completion.mixtral-8x22B",
286
+ "clarifai/cohere.generate.command-r-plus",
287
+ "clarifai/databricks.drbx.dbrx-instruct",
288
+ "clarifai/mistralai.completion.mistral-large",
289
+ "clarifai/mistralai.completion.mistral-medium",
290
+ "clarifai/mistralai.completion.mistral-small",
291
+ "clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
292
+ "clarifai/gcp.generate.gemma-2b-it",
293
+ "clarifai/gcp.generate.gemma-7b-it",
294
+ "clarifai/deci.decilm.deciLM-7B-instruct",
295
+ "clarifai/mistralai.completion.mistral-7B-Instruct",
296
+ "clarifai/gcp.generate.gemini-pro",
297
+ "clarifai/anthropic.completion.claude-v1",
298
+ "clarifai/anthropic.completion.claude-instant-1_2",
299
+ "clarifai/anthropic.completion.claude-instant",
300
+ "clarifai/anthropic.completion.claude-v2",
301
+ "clarifai/anthropic.completion.claude-2_1",
302
+ "clarifai/meta.Llama-2.codeLlama-70b-Python",
303
+ "clarifai/meta.Llama-2.codeLlama-70b-Instruct",
304
+ "clarifai/openai.completion.gpt-3_5-turbo-instruct",
305
+ "clarifai/meta.Llama-2.llama2-7b-chat",
306
+ "clarifai/meta.Llama-2.llama2-13b-chat",
307
+ "clarifai/meta.Llama-2.llama2-70b-chat",
308
+ "clarifai/openai.chat-completion.gpt-4-turbo",
309
+ "clarifai/microsoft.text-generation.phi-2",
310
+ "clarifai/meta.Llama-2.llama2-7b-chat-vllm",
311
+ "clarifai/upstage.solar.solar-10_7b-instruct",
312
+ "clarifai/openchat.openchat.openchat-3_5-1210",
313
+ "clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
314
+ "clarifai/gcp.generate.text-bison",
315
+ "clarifai/meta.Llama-2.llamaGuard-7b",
316
+ "clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
317
+ "clarifai/openai.chat-completion.GPT-4",
318
+ "clarifai/openai.chat-completion.GPT-3_5-turbo",
319
+ "clarifai/ai21.complete.Jurassic2-Grande",
320
+ "clarifai/ai21.complete.Jurassic2-Grande-Instruct",
321
+ "clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
322
+ "clarifai/ai21.complete.Jurassic2-Jumbo",
323
+ "clarifai/ai21.complete.Jurassic2-Large",
324
+ "clarifai/cohere.generate.cohere-generate-command",
325
+ "clarifai/wizardlm.generate.wizardCoder-Python-34B",
326
+ "clarifai/wizardlm.generate.wizardLM-70B",
327
+ "clarifai/tiiuae.falcon.falcon-40b-instruct",
328
+ "clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
329
+ "clarifai/gcp.generate.code-gecko",
330
+ "clarifai/gcp.generate.code-bison",
331
+ "clarifai/mistralai.completion.mistral-7B-OpenOrca",
332
+ "clarifai/mistralai.completion.openHermes-2-mistral-7B",
333
+ "clarifai/wizardlm.generate.wizardLM-13B",
334
+ "clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
335
+ "clarifai/wizardlm.generate.wizardCoder-15B",
336
+ "clarifai/microsoft.text-generation.phi-1_5",
337
+ "clarifai/databricks.Dolly-v2.dolly-v2-12b",
338
+ "clarifai/bigcode.code.StarCoder",
339
+ "clarifai/salesforce.xgen.xgen-7b-8k-instruct",
340
+ "clarifai/mosaicml.mpt.mpt-7b-instruct",
341
+ "clarifai/anthropic.completion.claude-3-opus",
342
+ "clarifai/anthropic.completion.claude-3-sonnet",
343
+ "clarifai/gcp.generate.gemini-1_5-pro",
344
+ "clarifai/gcp.generate.imagen-2",
345
+ "clarifai/salesforce.blip.general-english-image-caption-blip-2",
346
+ ]
347
+
348
+
349
+ huggingface_models: List = [
350
+ "meta-llama/Llama-2-7b-hf",
351
+ "meta-llama/Llama-2-7b-chat-hf",
352
+ "meta-llama/Llama-2-13b-hf",
353
+ "meta-llama/Llama-2-13b-chat-hf",
354
+ "meta-llama/Llama-2-70b-hf",
355
+ "meta-llama/Llama-2-70b-chat-hf",
356
+ "meta-llama/Llama-2-7b",
357
+ "meta-llama/Llama-2-7b-chat",
358
+ "meta-llama/Llama-2-13b",
359
+ "meta-llama/Llama-2-13b-chat",
360
+ "meta-llama/Llama-2-70b",
361
+ "meta-llama/Llama-2-70b-chat",
362
+ ] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers
363
+ empower_models = [
364
+ "empower/empower-functions",
365
+ "empower/empower-functions-small",
366
+ ]
367
+
368
+ together_ai_models: List = [
369
+ # llama llms - chat
370
+ "togethercomputer/llama-2-70b-chat",
371
+ # llama llms - language / instruct
372
+ "togethercomputer/llama-2-70b",
373
+ "togethercomputer/LLaMA-2-7B-32K",
374
+ "togethercomputer/Llama-2-7B-32K-Instruct",
375
+ "togethercomputer/llama-2-7b",
376
+ # falcon llms
377
+ "togethercomputer/falcon-40b-instruct",
378
+ "togethercomputer/falcon-7b-instruct",
379
+ # alpaca
380
+ "togethercomputer/alpaca-7b",
381
+ # chat llms
382
+ "HuggingFaceH4/starchat-alpha",
383
+ # code llms
384
+ "togethercomputer/CodeLlama-34b",
385
+ "togethercomputer/CodeLlama-34b-Instruct",
386
+ "togethercomputer/CodeLlama-34b-Python",
387
+ "defog/sqlcoder",
388
+ "NumbersStation/nsql-llama-2-7B",
389
+ "WizardLM/WizardCoder-15B-V1.0",
390
+ "WizardLM/WizardCoder-Python-34B-V1.0",
391
+ # language llms
392
+ "NousResearch/Nous-Hermes-Llama2-13b",
393
+ "Austism/chronos-hermes-13b",
394
+ "upstage/SOLAR-0-70b-16bit",
395
+ "WizardLM/WizardLM-70B-V1.0",
396
+ ] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
397
+
398
+
399
+ baseten_models: List = [
400
+ "qvv0xeq",
401
+ "q841o8w",
402
+ "31dxrj3",
403
+ ] # FALCON 7B # WizardLM # Mosaic ML
404
+
405
+ BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
406
+ "cohere",
407
+ "anthropic",
408
+ "mistral",
409
+ "amazon",
410
+ "meta",
411
+ "llama",
412
+ "ai21",
413
+ "nova",
414
+ "deepseek_r1",
415
+ ]
416
+
417
+ open_ai_embedding_models: List = ["text-embedding-ada-002"]
418
+ cohere_embedding_models: List = [
419
+ "embed-english-v3.0",
420
+ "embed-english-light-v3.0",
421
+ "embed-multilingual-v3.0",
422
+ "embed-english-v2.0",
423
+ "embed-english-light-v2.0",
424
+ "embed-multilingual-v2.0",
425
+ ]
426
+ bedrock_embedding_models: List = [
427
+ "amazon.titan-embed-text-v1",
428
+ "cohere.embed-english-v3",
429
+ "cohere.embed-multilingual-v3",
430
+ ]
431
+
432
+ known_tokenizer_config = {
433
+ "mistralai/Mistral-7B-Instruct-v0.1": {
434
+ "tokenizer": {
435
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
436
+ "bos_token": "<s>",
437
+ "eos_token": "</s>",
438
+ },
439
+ "status": "success",
440
+ },
441
+ "meta-llama/Meta-Llama-3-8B-Instruct": {
442
+ "tokenizer": {
443
+ "chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
444
+ "bos_token": "<|begin_of_text|>",
445
+ "eos_token": "",
446
+ },
447
+ "status": "success",
448
+ },
449
+ "deepseek-r1/deepseek-r1-7b-instruct": {
450
+ "tokenizer": {
451
+ "add_bos_token": True,
452
+ "add_eos_token": False,
453
+ "bos_token": {
454
+ "__type": "AddedToken",
455
+ "content": "<|begin▁of▁sentence|>",
456
+ "lstrip": False,
457
+ "normalized": True,
458
+ "rstrip": False,
459
+ "single_word": False,
460
+ },
461
+ "clean_up_tokenization_spaces": False,
462
+ "eos_token": {
463
+ "__type": "AddedToken",
464
+ "content": "<|end▁of▁sentence|>",
465
+ "lstrip": False,
466
+ "normalized": True,
467
+ "rstrip": False,
468
+ "single_word": False,
469
+ },
470
+ "legacy": True,
471
+ "model_max_length": 16384,
472
+ "pad_token": {
473
+ "__type": "AddedToken",
474
+ "content": "<|end▁of▁sentence|>",
475
+ "lstrip": False,
476
+ "normalized": True,
477
+ "rstrip": False,
478
+ "single_word": False,
479
+ },
480
+ "sp_model_kwargs": {},
481
+ "unk_token": None,
482
+ "tokenizer_class": "LlamaTokenizerFast",
483
+ "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
484
+ },
485
+ "status": "success",
486
+ },
487
+ }
488
+
489
+
490
+ OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
491
+ HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
492
+ RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
493
+
494
+ ########################### Logging Callback Constants ###########################
495
+ AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
496
+ PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
497
+ MCP_TOOL_NAME_PREFIX = "mcp_tool"
498
+
499
+ ########################### LiteLLM Proxy Specific Constants ###########################
500
+ ########################################################################################
501
+ MAX_SPENDLOG_ROWS_TO_QUERY = (
502
+ 1_000_000 # if spendLogs has more than 1M rows, do not query the DB
503
+ )
504
+ DEFAULT_SOFT_BUDGET = (
505
+ 50.0 # by default all litellm proxy keys have a soft budget of 50.0
506
+ )
507
+ # makes it clear this is a rate limit error for a litellm virtual key
508
+ RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
509
+
510
+ # pass through route constansts
511
+ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
512
+ "agents/",
513
+ "knowledgebases/",
514
+ "flows/",
515
+ "retrieveAndGenerate/",
516
+ "rerank/",
517
+ "generateQuery/",
518
+ "optimize-prompt/",
519
+ ]
520
+
521
+ BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
522
+ BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
523
+
524
+ HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds
525
+
526
+ UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"
527
+ LITELLM_PROXY_ADMIN_NAME = "default_user_id"
528
+
529
+ ########################### DB CRON JOB NAMES ###########################
530
+ DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
531
+ PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
532
+ DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
533
+ PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
534
+ PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
535
+ PROXY_BATCH_WRITE_AT = 10 # in seconds
536
+ DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
537
+ PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
538
+ DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
539
+ DEFAULT_SLACK_ALERTING_THRESHOLD = 300
540
+ MAX_TEAM_LIST_LIMIT = 20
541
+ DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
542
+ LENGTH_OF_LITELLM_GENERATED_KEY = 16
543
+ SECRET_MANAGER_REFRESH_INTERVAL = 86400
litellm/cost.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "gpt-3.5-turbo-0613": 0.00015000000000000001,
3
+ "claude-2": 0.00016454,
4
+ "gpt-4-0613": 0.015408
5
+ }
litellm/cost_calculator.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What is this?
2
+ ## File for 'response_cost' calculation in Logging
3
+ import time
4
+ from functools import lru_cache
5
+ from typing import Any, List, Literal, Optional, Tuple, Union, cast
6
+
7
+ from pydantic import BaseModel
8
+
9
+ import litellm
10
+ import litellm._logging
11
+ from litellm import verbose_logger
12
+ from litellm.constants import (
13
+ DEFAULT_MAX_LRU_CACHE_SIZE,
14
+ DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND,
15
+ )
16
+ from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
17
+ StandardBuiltInToolCostTracking,
18
+ )
19
+ from litellm.litellm_core_utils.llm_cost_calc.utils import (
20
+ _generic_cost_per_character,
21
+ generic_cost_per_token,
22
+ select_cost_metric_for_model,
23
+ )
24
+ from litellm.llms.anthropic.cost_calculation import (
25
+ cost_per_token as anthropic_cost_per_token,
26
+ )
27
+ from litellm.llms.azure.cost_calculation import (
28
+ cost_per_token as azure_openai_cost_per_token,
29
+ )
30
+ from litellm.llms.bedrock.image.cost_calculator import (
31
+ cost_calculator as bedrock_image_cost_calculator,
32
+ )
33
+ from litellm.llms.databricks.cost_calculator import (
34
+ cost_per_token as databricks_cost_per_token,
35
+ )
36
+ from litellm.llms.deepseek.cost_calculator import (
37
+ cost_per_token as deepseek_cost_per_token,
38
+ )
39
+ from litellm.llms.fireworks_ai.cost_calculator import (
40
+ cost_per_token as fireworks_ai_cost_per_token,
41
+ )
42
+ from litellm.llms.gemini.cost_calculator import cost_per_token as gemini_cost_per_token
43
+ from litellm.llms.openai.cost_calculation import (
44
+ cost_per_second as openai_cost_per_second,
45
+ )
46
+ from litellm.llms.openai.cost_calculation import cost_per_token as openai_cost_per_token
47
+ from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
48
+ from litellm.llms.vertex_ai.cost_calculator import (
49
+ cost_per_character as google_cost_per_character,
50
+ )
51
+ from litellm.llms.vertex_ai.cost_calculator import (
52
+ cost_per_token as google_cost_per_token,
53
+ )
54
+ from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_router
55
+ from litellm.llms.vertex_ai.image_generation.cost_calculator import (
56
+ cost_calculator as vertex_ai_image_cost_calculator,
57
+ )
58
+ from litellm.responses.utils import ResponseAPILoggingUtils
59
+ from litellm.types.llms.openai import (
60
+ HttpxBinaryResponseContent,
61
+ ImageGenerationRequestQuality,
62
+ OpenAIModerationResponse,
63
+ OpenAIRealtimeStreamList,
64
+ OpenAIRealtimeStreamResponseBaseObject,
65
+ OpenAIRealtimeStreamSessionEvents,
66
+ ResponseAPIUsage,
67
+ ResponsesAPIResponse,
68
+ )
69
+ from litellm.types.rerank import RerankBilledUnits, RerankResponse
70
+ from litellm.types.utils import (
71
+ CallTypesLiteral,
72
+ LiteLLMRealtimeStreamLoggingObject,
73
+ LlmProviders,
74
+ LlmProvidersSet,
75
+ ModelInfo,
76
+ PassthroughCallTypes,
77
+ StandardBuiltInToolsParams,
78
+ Usage,
79
+ )
80
+ from litellm.utils import (
81
+ CallTypes,
82
+ CostPerToken,
83
+ EmbeddingResponse,
84
+ ImageResponse,
85
+ ModelResponse,
86
+ ProviderConfigManager,
87
+ TextCompletionResponse,
88
+ TranscriptionResponse,
89
+ _cached_get_model_info_helper,
90
+ token_counter,
91
+ )
92
+
93
+
94
+ def _cost_per_token_custom_pricing_helper(
95
+ prompt_tokens: float = 0,
96
+ completion_tokens: float = 0,
97
+ response_time_ms: Optional[float] = 0.0,
98
+ ### CUSTOM PRICING ###
99
+ custom_cost_per_token: Optional[CostPerToken] = None,
100
+ custom_cost_per_second: Optional[float] = None,
101
+ ) -> Optional[Tuple[float, float]]:
102
+ """Internal helper function for calculating cost, if custom pricing given"""
103
+ if custom_cost_per_token is None and custom_cost_per_second is None:
104
+ return None
105
+
106
+ if custom_cost_per_token is not None:
107
+ input_cost = custom_cost_per_token["input_cost_per_token"] * prompt_tokens
108
+ output_cost = custom_cost_per_token["output_cost_per_token"] * completion_tokens
109
+ return input_cost, output_cost
110
+ elif custom_cost_per_second is not None:
111
+ output_cost = custom_cost_per_second * response_time_ms / 1000 # type: ignore
112
+ return 0, output_cost
113
+
114
+ return None
115
+
116
+
117
+ def cost_per_token( # noqa: PLR0915
118
+ model: str = "",
119
+ prompt_tokens: int = 0,
120
+ completion_tokens: int = 0,
121
+ response_time_ms: Optional[float] = 0.0,
122
+ custom_llm_provider: Optional[str] = None,
123
+ region_name=None,
124
+ ### CHARACTER PRICING ###
125
+ prompt_characters: Optional[int] = None,
126
+ completion_characters: Optional[int] = None,
127
+ ### PROMPT CACHING PRICING ### - used for anthropic
128
+ cache_creation_input_tokens: Optional[int] = 0,
129
+ cache_read_input_tokens: Optional[int] = 0,
130
+ ### CUSTOM PRICING ###
131
+ custom_cost_per_token: Optional[CostPerToken] = None,
132
+ custom_cost_per_second: Optional[float] = None,
133
+ ### NUMBER OF QUERIES ###
134
+ number_of_queries: Optional[int] = None,
135
+ ### USAGE OBJECT ###
136
+ usage_object: Optional[Usage] = None, # just read the usage object if provided
137
+ ### BILLED UNITS ###
138
+ rerank_billed_units: Optional[RerankBilledUnits] = None,
139
+ ### CALL TYPE ###
140
+ call_type: CallTypesLiteral = "completion",
141
+ audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
142
+ ) -> Tuple[float, float]: # type: ignore
143
+ """
144
+ Calculates the cost per token for a given model, prompt tokens, and completion tokens.
145
+
146
+ Parameters:
147
+ model (str): The name of the model to use. Default is ""
148
+ prompt_tokens (int): The number of tokens in the prompt.
149
+ completion_tokens (int): The number of tokens in the completion.
150
+ response_time (float): The amount of time, in milliseconds, it took the call to complete.
151
+ prompt_characters (float): The number of characters in the prompt. Used for vertex ai cost calculation.
152
+ completion_characters (float): The number of characters in the completion response. Used for vertex ai cost calculation.
153
+ custom_llm_provider (str): The llm provider to whom the call was made (see init.py for full list)
154
+ custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
155
+ custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
156
+ call_type: Optional[str]: the call type
157
+
158
+ Returns:
159
+ tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
160
+ """
161
+ if model is None:
162
+ raise Exception("Invalid arg. Model cannot be none.")
163
+
164
+ ## RECONSTRUCT USAGE BLOCK ##
165
+ if usage_object is not None:
166
+ usage_block = usage_object
167
+ else:
168
+ usage_block = Usage(
169
+ prompt_tokens=prompt_tokens,
170
+ completion_tokens=completion_tokens,
171
+ total_tokens=prompt_tokens + completion_tokens,
172
+ cache_creation_input_tokens=cache_creation_input_tokens,
173
+ cache_read_input_tokens=cache_read_input_tokens,
174
+ )
175
+
176
+ ## CUSTOM PRICING ##
177
+ response_cost = _cost_per_token_custom_pricing_helper(
178
+ prompt_tokens=prompt_tokens,
179
+ completion_tokens=completion_tokens,
180
+ response_time_ms=response_time_ms,
181
+ custom_cost_per_second=custom_cost_per_second,
182
+ custom_cost_per_token=custom_cost_per_token,
183
+ )
184
+
185
+ if response_cost is not None:
186
+ return response_cost[0], response_cost[1]
187
+
188
+ # given
189
+ prompt_tokens_cost_usd_dollar: float = 0
190
+ completion_tokens_cost_usd_dollar: float = 0
191
+ model_cost_ref = litellm.model_cost
192
+ model_with_provider = model
193
+ if custom_llm_provider is not None:
194
+ model_with_provider = custom_llm_provider + "/" + model
195
+ if region_name is not None:
196
+ model_with_provider_and_region = (
197
+ f"{custom_llm_provider}/{region_name}/{model}"
198
+ )
199
+ if (
200
+ model_with_provider_and_region in model_cost_ref
201
+ ): # use region based pricing, if it's available
202
+ model_with_provider = model_with_provider_and_region
203
+ else:
204
+ _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
205
+ model_without_prefix = model
206
+ model_parts = model.split("/", 1)
207
+ if len(model_parts) > 1:
208
+ model_without_prefix = model_parts[1]
209
+ else:
210
+ model_without_prefix = model
211
+ """
212
+ Code block that formats model to lookup in litellm.model_cost
213
+ Option1. model = "bedrock/ap-northeast-1/anthropic.claude-instant-v1". This is the most accurate since it is region based. Should always be option 1
214
+ Option2. model = "openai/gpt-4" - model = provider/model
215
+ Option3. model = "anthropic.claude-3" - model = model
216
+ """
217
+ if (
218
+ model_with_provider in model_cost_ref
219
+ ): # Option 2. use model with provider, model = "openai/gpt-4"
220
+ model = model_with_provider
221
+ elif model in model_cost_ref: # Option 1. use model passed, model="gpt-4"
222
+ model = model
223
+ elif (
224
+ model_without_prefix in model_cost_ref
225
+ ): # Option 3. if user passed model="bedrock/anthropic.claude-3", use model="anthropic.claude-3"
226
+ model = model_without_prefix
227
+
228
+ # see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
229
+ if call_type == "speech" or call_type == "aspeech":
230
+ speech_model_info = litellm.get_model_info(
231
+ model=model_without_prefix, custom_llm_provider=custom_llm_provider
232
+ )
233
+ cost_metric = select_cost_metric_for_model(speech_model_info)
234
+ prompt_cost: float = 0.0
235
+ completion_cost: float = 0.0
236
+ if cost_metric == "cost_per_character":
237
+ if prompt_characters is None:
238
+ raise ValueError(
239
+ "prompt_characters must be provided for tts calls. prompt_characters={}, model={}, custom_llm_provider={}, call_type={}".format(
240
+ prompt_characters,
241
+ model,
242
+ custom_llm_provider,
243
+ call_type,
244
+ )
245
+ )
246
+ _prompt_cost, _completion_cost = _generic_cost_per_character(
247
+ model=model_without_prefix,
248
+ custom_llm_provider=custom_llm_provider,
249
+ prompt_characters=prompt_characters,
250
+ completion_characters=0,
251
+ custom_prompt_cost=None,
252
+ custom_completion_cost=0,
253
+ )
254
+ if _prompt_cost is None or _completion_cost is None:
255
+ raise ValueError(
256
+ "cost for tts call is None. prompt_cost={}, completion_cost={}, model={}, custom_llm_provider={}, prompt_characters={}, completion_characters={}".format(
257
+ _prompt_cost,
258
+ _completion_cost,
259
+ model_without_prefix,
260
+ custom_llm_provider,
261
+ prompt_characters,
262
+ completion_characters,
263
+ )
264
+ )
265
+ prompt_cost = _prompt_cost
266
+ completion_cost = _completion_cost
267
+ elif cost_metric == "cost_per_token":
268
+ prompt_cost, completion_cost = generic_cost_per_token(
269
+ model=model_without_prefix,
270
+ usage=usage_block,
271
+ custom_llm_provider=custom_llm_provider,
272
+ )
273
+
274
+ return prompt_cost, completion_cost
275
+ elif call_type == "arerank" or call_type == "rerank":
276
+ return rerank_cost(
277
+ model=model,
278
+ custom_llm_provider=custom_llm_provider,
279
+ billed_units=rerank_billed_units,
280
+ )
281
+ elif (
282
+ call_type == "aretrieve_batch"
283
+ or call_type == "retrieve_batch"
284
+ or call_type == CallTypes.aretrieve_batch
285
+ or call_type == CallTypes.retrieve_batch
286
+ ):
287
+ return batch_cost_calculator(
288
+ usage=usage_block, model=model, custom_llm_provider=custom_llm_provider
289
+ )
290
+ elif call_type == "atranscription" or call_type == "transcription":
291
+ return openai_cost_per_second(
292
+ model=model,
293
+ custom_llm_provider=custom_llm_provider,
294
+ duration=audio_transcription_file_duration,
295
+ )
296
+ elif custom_llm_provider == "vertex_ai":
297
+ cost_router = google_cost_router(
298
+ model=model_without_prefix,
299
+ custom_llm_provider=custom_llm_provider,
300
+ call_type=call_type,
301
+ )
302
+ if cost_router == "cost_per_character":
303
+ return google_cost_per_character(
304
+ model=model_without_prefix,
305
+ custom_llm_provider=custom_llm_provider,
306
+ prompt_characters=prompt_characters,
307
+ completion_characters=completion_characters,
308
+ usage=usage_block,
309
+ )
310
+ elif cost_router == "cost_per_token":
311
+ return google_cost_per_token(
312
+ model=model_without_prefix,
313
+ custom_llm_provider=custom_llm_provider,
314
+ usage=usage_block,
315
+ )
316
+ elif custom_llm_provider == "anthropic":
317
+ return anthropic_cost_per_token(model=model, usage=usage_block)
318
+ elif custom_llm_provider == "openai":
319
+ return openai_cost_per_token(model=model, usage=usage_block)
320
+ elif custom_llm_provider == "databricks":
321
+ return databricks_cost_per_token(model=model, usage=usage_block)
322
+ elif custom_llm_provider == "fireworks_ai":
323
+ return fireworks_ai_cost_per_token(model=model, usage=usage_block)
324
+ elif custom_llm_provider == "azure":
325
+ return azure_openai_cost_per_token(
326
+ model=model, usage=usage_block, response_time_ms=response_time_ms
327
+ )
328
+ elif custom_llm_provider == "gemini":
329
+ return gemini_cost_per_token(model=model, usage=usage_block)
330
+ elif custom_llm_provider == "deepseek":
331
+ return deepseek_cost_per_token(model=model, usage=usage_block)
332
+ else:
333
+ model_info = _cached_get_model_info_helper(
334
+ model=model, custom_llm_provider=custom_llm_provider
335
+ )
336
+
337
+ if model_info["input_cost_per_token"] > 0:
338
+ ## COST PER TOKEN ##
339
+ prompt_tokens_cost_usd_dollar = (
340
+ model_info["input_cost_per_token"] * prompt_tokens
341
+ )
342
+ elif (
343
+ model_info.get("input_cost_per_second", None) is not None
344
+ and response_time_ms is not None
345
+ ):
346
+ verbose_logger.debug(
347
+ "For model=%s - input_cost_per_second: %s; response time: %s",
348
+ model,
349
+ model_info.get("input_cost_per_second", None),
350
+ response_time_ms,
351
+ )
352
+ ## COST PER SECOND ##
353
+ prompt_tokens_cost_usd_dollar = (
354
+ model_info["input_cost_per_second"] * response_time_ms / 1000 # type: ignore
355
+ )
356
+
357
+ if model_info["output_cost_per_token"] > 0:
358
+ completion_tokens_cost_usd_dollar = (
359
+ model_info["output_cost_per_token"] * completion_tokens
360
+ )
361
+ elif (
362
+ model_info.get("output_cost_per_second", None) is not None
363
+ and response_time_ms is not None
364
+ ):
365
+ verbose_logger.debug(
366
+ "For model=%s - output_cost_per_second: %s; response time: %s",
367
+ model,
368
+ model_info.get("output_cost_per_second", None),
369
+ response_time_ms,
370
+ )
371
+ ## COST PER SECOND ##
372
+ completion_tokens_cost_usd_dollar = (
373
+ model_info["output_cost_per_second"] * response_time_ms / 1000 # type: ignore
374
+ )
375
+
376
+ verbose_logger.debug(
377
+ "Returned custom cost for model=%s - prompt_tokens_cost_usd_dollar: %s, completion_tokens_cost_usd_dollar: %s",
378
+ model,
379
+ prompt_tokens_cost_usd_dollar,
380
+ completion_tokens_cost_usd_dollar,
381
+ )
382
+ return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
383
+
384
+
385
+ def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
386
+ # see https://replicate.com/pricing
387
+ # for all litellm currently supported LLMs, almost all requests go to a100_80gb
388
+ a100_80gb_price_per_second_public = DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND # assume all calls sent to A100 80GB for now
389
+ if total_time == 0.0: # total time is in ms
390
+ start_time = completion_response.get("created", time.time())
391
+ end_time = getattr(completion_response, "ended", time.time())
392
+ total_time = end_time - start_time
393
+
394
+ return a100_80gb_price_per_second_public * total_time / 1000
395
+
396
+
397
+ def has_hidden_params(obj: Any) -> bool:
398
+ return hasattr(obj, "_hidden_params")
399
+
400
+
401
+ def _get_provider_for_cost_calc(
402
+ model: Optional[str],
403
+ custom_llm_provider: Optional[str] = None,
404
+ ) -> Optional[str]:
405
+ if custom_llm_provider is not None:
406
+ return custom_llm_provider
407
+ if model is None:
408
+ return None
409
+ try:
410
+ _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
411
+ except Exception as e:
412
+ verbose_logger.debug(
413
+ f"litellm.cost_calculator.py::_get_provider_for_cost_calc() - Error inferring custom_llm_provider - {str(e)}"
414
+ )
415
+ return None
416
+
417
+ return custom_llm_provider
418
+
419
+
420
+ def _select_model_name_for_cost_calc(
421
+ model: Optional[str],
422
+ completion_response: Optional[Any],
423
+ base_model: Optional[str] = None,
424
+ custom_pricing: Optional[bool] = None,
425
+ custom_llm_provider: Optional[str] = None,
426
+ router_model_id: Optional[str] = None,
427
+ ) -> Optional[str]:
428
+ """
429
+ 1. If custom pricing is true, return received model name
430
+ 2. If base_model is set (e.g. for azure models), return that
431
+ 3. If completion response has model set return that
432
+ 4. Check if model is passed in return that
433
+ """
434
+
435
+ return_model: Optional[str] = None
436
+ region_name: Optional[str] = None
437
+ custom_llm_provider = _get_provider_for_cost_calc(
438
+ model=model, custom_llm_provider=custom_llm_provider
439
+ )
440
+
441
+ completion_response_model: Optional[str] = None
442
+ if completion_response is not None:
443
+ if isinstance(completion_response, BaseModel):
444
+ completion_response_model = getattr(completion_response, "model", None)
445
+ elif isinstance(completion_response, dict):
446
+ completion_response_model = completion_response.get("model", None)
447
+ hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
448
+
449
+ if custom_pricing is True:
450
+ if router_model_id is not None and router_model_id in litellm.model_cost:
451
+ return_model = router_model_id
452
+ else:
453
+ return_model = model
454
+
455
+ if base_model is not None:
456
+ return_model = base_model
457
+
458
+ if completion_response_model is None and hidden_params is not None:
459
+ if (
460
+ hidden_params.get("model", None) is not None
461
+ and len(hidden_params["model"]) > 0
462
+ ):
463
+ return_model = hidden_params.get("model", model)
464
+ if hidden_params is not None and hidden_params.get("region_name", None) is not None:
465
+ region_name = hidden_params.get("region_name", None)
466
+
467
+ if return_model is None and completion_response_model is not None:
468
+ return_model = completion_response_model
469
+
470
+ if return_model is None and model is not None:
471
+ return_model = model
472
+
473
+ if (
474
+ return_model is not None
475
+ and custom_llm_provider is not None
476
+ and not _model_contains_known_llm_provider(return_model)
477
+ ): # add provider prefix if not already present, to match model_cost
478
+ if region_name is not None:
479
+ return_model = f"{custom_llm_provider}/{region_name}/{return_model}"
480
+ else:
481
+ return_model = f"{custom_llm_provider}/{return_model}"
482
+
483
+ return return_model
484
+
485
+
486
+ @lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
487
+ def _model_contains_known_llm_provider(model: str) -> bool:
488
+ """
489
+ Check if the model contains a known llm provider
490
+ """
491
+ _provider_prefix = model.split("/")[0]
492
+ return _provider_prefix in LlmProvidersSet
493
+
494
+
495
+ def _get_usage_object(
496
+ completion_response: Any,
497
+ ) -> Optional[Usage]:
498
+ usage_obj = cast(
499
+ Union[Usage, ResponseAPIUsage, dict, BaseModel],
500
+ (
501
+ completion_response.get("usage")
502
+ if isinstance(completion_response, dict)
503
+ else getattr(completion_response, "get", lambda x: None)("usage")
504
+ ),
505
+ )
506
+
507
+ if usage_obj is None:
508
+ return None
509
+ if isinstance(usage_obj, Usage):
510
+ return usage_obj
511
+ elif (
512
+ usage_obj is not None
513
+ and (isinstance(usage_obj, dict) or isinstance(usage_obj, ResponseAPIUsage))
514
+ and ResponseAPILoggingUtils._is_response_api_usage(usage_obj)
515
+ ):
516
+ return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
517
+ usage_obj
518
+ )
519
+ elif isinstance(usage_obj, dict):
520
+ return Usage(**usage_obj)
521
+ elif isinstance(usage_obj, BaseModel):
522
+ return Usage(**usage_obj.model_dump())
523
+ else:
524
+ verbose_logger.debug(
525
+ f"Unknown usage object type: {type(usage_obj)}, usage_obj: {usage_obj}"
526
+ )
527
+ return None
528
+
529
+
530
+ def _is_known_usage_objects(usage_obj):
531
+ """Returns True if the usage obj is a known Usage type"""
532
+ return isinstance(usage_obj, litellm.Usage) or isinstance(
533
+ usage_obj, ResponseAPIUsage
534
+ )
535
+
536
+
537
+ def _infer_call_type(
538
+ call_type: Optional[CallTypesLiteral], completion_response: Any
539
+ ) -> Optional[CallTypesLiteral]:
540
+ if call_type is not None:
541
+ return call_type
542
+
543
+ if completion_response is None:
544
+ return None
545
+
546
+ if isinstance(completion_response, ModelResponse):
547
+ return "completion"
548
+ elif isinstance(completion_response, EmbeddingResponse):
549
+ return "embedding"
550
+ elif isinstance(completion_response, TranscriptionResponse):
551
+ return "transcription"
552
+ elif isinstance(completion_response, HttpxBinaryResponseContent):
553
+ return "speech"
554
+ elif isinstance(completion_response, RerankResponse):
555
+ return "rerank"
556
+ elif isinstance(completion_response, ImageResponse):
557
+ return "image_generation"
558
+ elif isinstance(completion_response, TextCompletionResponse):
559
+ return "text_completion"
560
+
561
+ return call_type
562
+
563
+
564
+ def completion_cost( # noqa: PLR0915
565
+ completion_response=None,
566
+ model: Optional[str] = None,
567
+ prompt="",
568
+ messages: List = [],
569
+ completion="",
570
+ total_time: Optional[float] = 0.0, # used for replicate, sagemaker
571
+ call_type: Optional[CallTypesLiteral] = None,
572
+ ### REGION ###
573
+ custom_llm_provider=None,
574
+ region_name=None, # used for bedrock pricing
575
+ ### IMAGE GEN ###
576
+ size: Optional[str] = None,
577
+ quality: Optional[str] = None,
578
+ n: Optional[int] = None, # number of images
579
+ ### CUSTOM PRICING ###
580
+ custom_cost_per_token: Optional[CostPerToken] = None,
581
+ custom_cost_per_second: Optional[float] = None,
582
+ optional_params: Optional[dict] = None,
583
+ custom_pricing: Optional[bool] = None,
584
+ base_model: Optional[str] = None,
585
+ standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
586
+ litellm_model_name: Optional[str] = None,
587
+ router_model_id: Optional[str] = None,
588
+ ) -> float:
589
+ """
590
+ Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
591
+
592
+ Parameters:
593
+ completion_response (litellm.ModelResponses): [Required] The response received from a LiteLLM completion request.
594
+
595
+ [OPTIONAL PARAMS]
596
+ model (str): Optional. The name of the language model used in the completion calls
597
+ prompt (str): Optional. The input prompt passed to the llm
598
+ completion (str): Optional. The output completion text from the llm
599
+ total_time (float, int): Optional. (Only used for Replicate LLMs) The total time used for the request in seconds
600
+ custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
601
+ custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
602
+
603
+ Returns:
604
+ float: The cost in USD dollars for the completion based on the provided parameters.
605
+
606
+ Exceptions:
607
+ Raises exception if model not in the litellm model cost map. Register model, via custom pricing or PR - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
608
+
609
+
610
+ Note:
611
+ - If completion_response is provided, the function extracts token information and the model name from it.
612
+ - If completion_response is not provided, the function calculates token counts based on the model and input text.
613
+ - The cost is calculated based on the model, prompt tokens, and completion tokens.
614
+ - For certain models containing "togethercomputer" in the name, prices are based on the model size.
615
+ - For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
616
+ """
617
+ try:
618
+ call_type = _infer_call_type(call_type, completion_response) or "completion"
619
+
620
+ if (
621
+ (call_type == "aimage_generation" or call_type == "image_generation")
622
+ and model is not None
623
+ and isinstance(model, str)
624
+ and len(model) == 0
625
+ and custom_llm_provider == "azure"
626
+ ):
627
+ model = "dall-e-2" # for dall-e-2, azure expects an empty model name
628
+ # Handle Inputs to completion_cost
629
+ prompt_tokens = 0
630
+ prompt_characters: Optional[int] = None
631
+ completion_tokens = 0
632
+ completion_characters: Optional[int] = None
633
+ cache_creation_input_tokens: Optional[int] = None
634
+ cache_read_input_tokens: Optional[int] = None
635
+ audio_transcription_file_duration: float = 0.0
636
+ cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
637
+ completion_response=completion_response
638
+ )
639
+ rerank_billed_units: Optional[RerankBilledUnits] = None
640
+
641
+ selected_model = _select_model_name_for_cost_calc(
642
+ model=model,
643
+ completion_response=completion_response,
644
+ custom_llm_provider=custom_llm_provider,
645
+ custom_pricing=custom_pricing,
646
+ base_model=base_model,
647
+ router_model_id=router_model_id,
648
+ )
649
+
650
+ potential_model_names = [selected_model]
651
+ if model is not None:
652
+ potential_model_names.append(model)
653
+ for idx, model in enumerate(potential_model_names):
654
+ try:
655
+ verbose_logger.info(
656
+ f"selected model name for cost calculation: {model}"
657
+ )
658
+
659
+ if completion_response is not None and (
660
+ isinstance(completion_response, BaseModel)
661
+ or isinstance(completion_response, dict)
662
+ ): # tts returns a custom class
663
+ if isinstance(completion_response, dict):
664
+ usage_obj: Optional[
665
+ Union[dict, Usage]
666
+ ] = completion_response.get("usage", {})
667
+ else:
668
+ usage_obj = getattr(completion_response, "usage", {})
669
+ if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
670
+ usage_obj=usage_obj
671
+ ):
672
+ setattr(
673
+ completion_response,
674
+ "usage",
675
+ litellm.Usage(**usage_obj.model_dump()),
676
+ )
677
+ if usage_obj is None:
678
+ _usage = {}
679
+ elif isinstance(usage_obj, BaseModel):
680
+ _usage = usage_obj.model_dump()
681
+ else:
682
+ _usage = usage_obj
683
+
684
+ if ResponseAPILoggingUtils._is_response_api_usage(_usage):
685
+ _usage = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
686
+ _usage
687
+ ).model_dump()
688
+
689
+ # get input/output tokens from completion_response
690
+ prompt_tokens = _usage.get("prompt_tokens", 0)
691
+ completion_tokens = _usage.get("completion_tokens", 0)
692
+ cache_creation_input_tokens = _usage.get(
693
+ "cache_creation_input_tokens", 0
694
+ )
695
+ cache_read_input_tokens = _usage.get("cache_read_input_tokens", 0)
696
+ if (
697
+ "prompt_tokens_details" in _usage
698
+ and _usage["prompt_tokens_details"] != {}
699
+ and _usage["prompt_tokens_details"]
700
+ ):
701
+ prompt_tokens_details = _usage.get("prompt_tokens_details", {})
702
+ cache_read_input_tokens = prompt_tokens_details.get(
703
+ "cached_tokens", 0
704
+ )
705
+
706
+ total_time = getattr(completion_response, "_response_ms", 0)
707
+
708
+ hidden_params = getattr(completion_response, "_hidden_params", None)
709
+ if hidden_params is not None:
710
+ custom_llm_provider = hidden_params.get(
711
+ "custom_llm_provider", custom_llm_provider or None
712
+ )
713
+ region_name = hidden_params.get("region_name", region_name)
714
+ size = hidden_params.get("optional_params", {}).get(
715
+ "size", "1024-x-1024"
716
+ ) # openai default
717
+ quality = hidden_params.get("optional_params", {}).get(
718
+ "quality", "standard"
719
+ ) # openai default
720
+ n = hidden_params.get("optional_params", {}).get(
721
+ "n", 1
722
+ ) # openai default
723
+ else:
724
+ if model is None:
725
+ raise ValueError(
726
+ f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
727
+ )
728
+ if len(messages) > 0:
729
+ prompt_tokens = token_counter(model=model, messages=messages)
730
+ elif len(prompt) > 0:
731
+ prompt_tokens = token_counter(model=model, text=prompt)
732
+ completion_tokens = token_counter(model=model, text=completion)
733
+
734
+ if model is None:
735
+ raise ValueError(
736
+ f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
737
+ )
738
+ if custom_llm_provider is None:
739
+ try:
740
+ model, custom_llm_provider, _, _ = litellm.get_llm_provider(
741
+ model=model
742
+ ) # strip the llm provider from the model name -> for image gen cost calculation
743
+ except Exception as e:
744
+ verbose_logger.debug(
745
+ "litellm.cost_calculator.py::completion_cost() - Error inferring custom_llm_provider - {}".format(
746
+ str(e)
747
+ )
748
+ )
749
+ if (
750
+ call_type == CallTypes.image_generation.value
751
+ or call_type == CallTypes.aimage_generation.value
752
+ or call_type
753
+ == PassthroughCallTypes.passthrough_image_generation.value
754
+ ):
755
+ ### IMAGE GENERATION COST CALCULATION ###
756
+ if custom_llm_provider == "vertex_ai":
757
+ if isinstance(completion_response, ImageResponse):
758
+ return vertex_ai_image_cost_calculator(
759
+ model=model,
760
+ image_response=completion_response,
761
+ )
762
+ elif custom_llm_provider == "bedrock":
763
+ if isinstance(completion_response, ImageResponse):
764
+ return bedrock_image_cost_calculator(
765
+ model=model,
766
+ size=size,
767
+ image_response=completion_response,
768
+ optional_params=optional_params,
769
+ )
770
+ raise TypeError(
771
+ "completion_response must be of type ImageResponse for bedrock image cost calculation"
772
+ )
773
+ else:
774
+ return default_image_cost_calculator(
775
+ model=model,
776
+ quality=quality,
777
+ custom_llm_provider=custom_llm_provider,
778
+ n=n,
779
+ size=size,
780
+ optional_params=optional_params,
781
+ )
782
+ elif (
783
+ call_type == CallTypes.speech.value
784
+ or call_type == CallTypes.aspeech.value
785
+ ):
786
+ prompt_characters = litellm.utils._count_characters(text=prompt)
787
+ elif (
788
+ call_type == CallTypes.atranscription.value
789
+ or call_type == CallTypes.transcription.value
790
+ ):
791
+ audio_transcription_file_duration = getattr(
792
+ completion_response, "duration", 0.0
793
+ )
794
+ elif (
795
+ call_type == CallTypes.rerank.value
796
+ or call_type == CallTypes.arerank.value
797
+ ):
798
+ if completion_response is not None and isinstance(
799
+ completion_response, RerankResponse
800
+ ):
801
+ meta_obj = completion_response.meta
802
+ if meta_obj is not None:
803
+ billed_units = meta_obj.get("billed_units", {}) or {}
804
+ else:
805
+ billed_units = {}
806
+
807
+ rerank_billed_units = RerankBilledUnits(
808
+ search_units=billed_units.get("search_units"),
809
+ total_tokens=billed_units.get("total_tokens"),
810
+ )
811
+
812
+ search_units = (
813
+ billed_units.get("search_units") or 1
814
+ ) # cohere charges per request by default.
815
+ completion_tokens = search_units
816
+ elif call_type == CallTypes.arealtime.value and isinstance(
817
+ completion_response, LiteLLMRealtimeStreamLoggingObject
818
+ ):
819
+ if (
820
+ cost_per_token_usage_object is None
821
+ or custom_llm_provider is None
822
+ ):
823
+ raise ValueError(
824
+ "usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
825
+ cost_per_token_usage_object,
826
+ custom_llm_provider,
827
+ )
828
+ )
829
+ return handle_realtime_stream_cost_calculation(
830
+ results=completion_response.results,
831
+ combined_usage_object=cost_per_token_usage_object,
832
+ custom_llm_provider=custom_llm_provider,
833
+ litellm_model_name=model,
834
+ )
835
+ # Calculate cost based on prompt_tokens, completion_tokens
836
+ if (
837
+ "togethercomputer" in model
838
+ or "together_ai" in model
839
+ or custom_llm_provider == "together_ai"
840
+ ):
841
+ # together ai prices based on size of llm
842
+ # get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
843
+
844
+ model = get_model_params_and_category(
845
+ model, call_type=CallTypes(call_type)
846
+ )
847
+
848
+ # replicate llms are calculate based on time for request running
849
+ # see https://replicate.com/pricing
850
+ elif (
851
+ model in litellm.replicate_models or "replicate" in model
852
+ ) and model not in litellm.model_cost:
853
+ # for unmapped replicate model, default to replicate's time tracking logic
854
+ return get_replicate_completion_pricing(completion_response, total_time) # type: ignore
855
+
856
+ if model is None:
857
+ raise ValueError(
858
+ f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
859
+ )
860
+
861
+ if (
862
+ custom_llm_provider is not None
863
+ and custom_llm_provider == "vertex_ai"
864
+ ):
865
+ # Calculate the prompt characters + response characters
866
+ if len(messages) > 0:
867
+ prompt_string = litellm.utils.get_formatted_prompt(
868
+ data={"messages": messages}, call_type="completion"
869
+ )
870
+
871
+ prompt_characters = litellm.utils._count_characters(
872
+ text=prompt_string
873
+ )
874
+ if completion_response is not None and isinstance(
875
+ completion_response, ModelResponse
876
+ ):
877
+ completion_string = litellm.utils.get_response_string(
878
+ response_obj=completion_response
879
+ )
880
+ completion_characters = litellm.utils._count_characters(
881
+ text=completion_string
882
+ )
883
+
884
+ (
885
+ prompt_tokens_cost_usd_dollar,
886
+ completion_tokens_cost_usd_dollar,
887
+ ) = cost_per_token(
888
+ model=model,
889
+ prompt_tokens=prompt_tokens,
890
+ completion_tokens=completion_tokens,
891
+ custom_llm_provider=custom_llm_provider,
892
+ response_time_ms=total_time,
893
+ region_name=region_name,
894
+ custom_cost_per_second=custom_cost_per_second,
895
+ custom_cost_per_token=custom_cost_per_token,
896
+ prompt_characters=prompt_characters,
897
+ completion_characters=completion_characters,
898
+ cache_creation_input_tokens=cache_creation_input_tokens,
899
+ cache_read_input_tokens=cache_read_input_tokens,
900
+ usage_object=cost_per_token_usage_object,
901
+ call_type=cast(CallTypesLiteral, call_type),
902
+ audio_transcription_file_duration=audio_transcription_file_duration,
903
+ rerank_billed_units=rerank_billed_units,
904
+ )
905
+ _final_cost = (
906
+ prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
907
+ )
908
+ _final_cost += (
909
+ StandardBuiltInToolCostTracking.get_cost_for_built_in_tools(
910
+ model=model,
911
+ response_object=completion_response,
912
+ standard_built_in_tools_params=standard_built_in_tools_params,
913
+ custom_llm_provider=custom_llm_provider,
914
+ )
915
+ )
916
+ return _final_cost
917
+ except Exception as e:
918
+ verbose_logger.debug(
919
+ "litellm.cost_calculator.py::completion_cost() - Error calculating cost for model={} - {}".format(
920
+ model, str(e)
921
+ )
922
+ )
923
+ if idx == len(potential_model_names) - 1:
924
+ raise e
925
+ raise Exception(
926
+ "Unable to calculat cost for received potential model names - {}".format(
927
+ potential_model_names
928
+ )
929
+ )
930
+ except Exception as e:
931
+ raise e
932
+
933
+
934
+ def get_response_cost_from_hidden_params(
935
+ hidden_params: Union[dict, BaseModel],
936
+ ) -> Optional[float]:
937
+ if isinstance(hidden_params, BaseModel):
938
+ _hidden_params_dict = hidden_params.model_dump()
939
+ else:
940
+ _hidden_params_dict = hidden_params
941
+
942
+ additional_headers = _hidden_params_dict.get("additional_headers", {})
943
+ if (
944
+ additional_headers
945
+ and "llm_provider-x-litellm-response-cost" in additional_headers
946
+ ):
947
+ response_cost = additional_headers["llm_provider-x-litellm-response-cost"]
948
+ if response_cost is None:
949
+ return None
950
+ return float(additional_headers["llm_provider-x-litellm-response-cost"])
951
+ return None
952
+
953
+
954
+ def response_cost_calculator(
955
+ response_object: Union[
956
+ ModelResponse,
957
+ EmbeddingResponse,
958
+ ImageResponse,
959
+ TranscriptionResponse,
960
+ TextCompletionResponse,
961
+ HttpxBinaryResponseContent,
962
+ RerankResponse,
963
+ ResponsesAPIResponse,
964
+ LiteLLMRealtimeStreamLoggingObject,
965
+ OpenAIModerationResponse,
966
+ ],
967
+ model: str,
968
+ custom_llm_provider: Optional[str],
969
+ call_type: Literal[
970
+ "embedding",
971
+ "aembedding",
972
+ "completion",
973
+ "acompletion",
974
+ "atext_completion",
975
+ "text_completion",
976
+ "image_generation",
977
+ "aimage_generation",
978
+ "moderation",
979
+ "amoderation",
980
+ "atranscription",
981
+ "transcription",
982
+ "aspeech",
983
+ "speech",
984
+ "rerank",
985
+ "arerank",
986
+ ],
987
+ optional_params: dict,
988
+ cache_hit: Optional[bool] = None,
989
+ base_model: Optional[str] = None,
990
+ custom_pricing: Optional[bool] = None,
991
+ prompt: str = "",
992
+ standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
993
+ litellm_model_name: Optional[str] = None,
994
+ router_model_id: Optional[str] = None,
995
+ ) -> float:
996
+ """
997
+ Returns
998
+ - float or None: cost of response
999
+ """
1000
+ try:
1001
+ response_cost: float = 0.0
1002
+ if cache_hit is not None and cache_hit is True:
1003
+ response_cost = 0.0
1004
+ else:
1005
+ if isinstance(response_object, BaseModel):
1006
+ response_object._hidden_params["optional_params"] = optional_params
1007
+
1008
+ if hasattr(response_object, "_hidden_params"):
1009
+ provider_response_cost = get_response_cost_from_hidden_params(
1010
+ response_object._hidden_params
1011
+ )
1012
+ if provider_response_cost is not None:
1013
+ return provider_response_cost
1014
+
1015
+ response_cost = completion_cost(
1016
+ completion_response=response_object,
1017
+ model=model,
1018
+ call_type=call_type,
1019
+ custom_llm_provider=custom_llm_provider,
1020
+ optional_params=optional_params,
1021
+ custom_pricing=custom_pricing,
1022
+ base_model=base_model,
1023
+ prompt=prompt,
1024
+ standard_built_in_tools_params=standard_built_in_tools_params,
1025
+ litellm_model_name=litellm_model_name,
1026
+ router_model_id=router_model_id,
1027
+ )
1028
+ return response_cost
1029
+ except Exception as e:
1030
+ raise e
1031
+
1032
+
1033
+ def rerank_cost(
1034
+ model: str,
1035
+ custom_llm_provider: Optional[str],
1036
+ billed_units: Optional[RerankBilledUnits] = None,
1037
+ ) -> Tuple[float, float]:
1038
+ """
1039
+ Returns
1040
+ - float or None: cost of response OR none if error.
1041
+ """
1042
+ _, custom_llm_provider, _, _ = litellm.get_llm_provider(
1043
+ model=model, custom_llm_provider=custom_llm_provider
1044
+ )
1045
+
1046
+ try:
1047
+ config = ProviderConfigManager.get_provider_rerank_config(
1048
+ model=model,
1049
+ api_base=None,
1050
+ present_version_params=[],
1051
+ provider=LlmProviders(custom_llm_provider),
1052
+ )
1053
+
1054
+ try:
1055
+ model_info: Optional[ModelInfo] = litellm.get_model_info(
1056
+ model=model, custom_llm_provider=custom_llm_provider
1057
+ )
1058
+ except Exception:
1059
+ model_info = None
1060
+
1061
+ return config.calculate_rerank_cost(
1062
+ model=model,
1063
+ custom_llm_provider=custom_llm_provider,
1064
+ billed_units=billed_units,
1065
+ model_info=model_info,
1066
+ )
1067
+ except Exception as e:
1068
+ raise e
1069
+
1070
+
1071
+ def transcription_cost(
1072
+ model: str, custom_llm_provider: Optional[str], duration: float
1073
+ ) -> Tuple[float, float]:
1074
+ return openai_cost_per_second(
1075
+ model=model, custom_llm_provider=custom_llm_provider, duration=duration
1076
+ )
1077
+
1078
+
1079
+ def default_image_cost_calculator(
1080
+ model: str,
1081
+ custom_llm_provider: Optional[str] = None,
1082
+ quality: Optional[str] = None,
1083
+ n: Optional[int] = 1, # Default to 1 image
1084
+ size: Optional[str] = "1024-x-1024", # OpenAI default
1085
+ optional_params: Optional[dict] = None,
1086
+ ) -> float:
1087
+ """
1088
+ Default image cost calculator for image generation
1089
+
1090
+ Args:
1091
+ model (str): Model name
1092
+ image_response (ImageResponse): Response from image generation
1093
+ quality (Optional[str]): Image quality setting
1094
+ n (Optional[int]): Number of images generated
1095
+ size (Optional[str]): Image size (e.g. "1024x1024" or "1024-x-1024")
1096
+
1097
+ Returns:
1098
+ float: Cost in USD for the image generation
1099
+
1100
+ Raises:
1101
+ Exception: If model pricing not found in cost map
1102
+ """
1103
+ # Standardize size format to use "-x-"
1104
+ size_str: str = size or "1024-x-1024"
1105
+ size_str = (
1106
+ size_str.replace("x", "-x-")
1107
+ if "x" in size_str and "-x-" not in size_str
1108
+ else size_str
1109
+ )
1110
+
1111
+ # Parse dimensions
1112
+ height, width = map(int, size_str.split("-x-"))
1113
+
1114
+ # Build model names for cost lookup
1115
+ base_model_name = f"{size_str}/{model}"
1116
+ if custom_llm_provider and model.startswith(custom_llm_provider):
1117
+ base_model_name = (
1118
+ f"{custom_llm_provider}/{size_str}/{model.replace(custom_llm_provider, '')}"
1119
+ )
1120
+ model_name_with_quality = (
1121
+ f"{quality}/{base_model_name}" if quality else base_model_name
1122
+ )
1123
+
1124
+ # gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
1125
+ model_name_with_v2_quality = (
1126
+ f"{ImageGenerationRequestQuality.MEDIUM.value}/{base_model_name}"
1127
+ )
1128
+
1129
+ verbose_logger.debug(
1130
+ f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
1131
+ )
1132
+
1133
+ model_without_provider = f"{size_str}/{model.split('/')[-1]}"
1134
+ model_with_quality_without_provider = (
1135
+ f"{quality}/{model_without_provider}" if quality else model_without_provider
1136
+ )
1137
+
1138
+ # Try model with quality first, fall back to base model name
1139
+ cost_info: Optional[dict] = None
1140
+ models_to_check = [
1141
+ model_name_with_quality,
1142
+ base_model_name,
1143
+ model_name_with_v2_quality,
1144
+ model_with_quality_without_provider,
1145
+ model_without_provider,
1146
+ model,
1147
+ ]
1148
+ for model in models_to_check:
1149
+ if model in litellm.model_cost:
1150
+ cost_info = litellm.model_cost[model]
1151
+ break
1152
+ if cost_info is None:
1153
+ raise Exception(
1154
+ f"Model not found in cost map. Tried checking {models_to_check}"
1155
+ )
1156
+
1157
+ return cost_info["input_cost_per_pixel"] * height * width * n
1158
+
1159
+
1160
+ def batch_cost_calculator(
1161
+ usage: Usage,
1162
+ model: str,
1163
+ custom_llm_provider: Optional[str] = None,
1164
+ ) -> Tuple[float, float]:
1165
+ """
1166
+ Calculate the cost of a batch job
1167
+ """
1168
+
1169
+ _, custom_llm_provider, _, _ = litellm.get_llm_provider(
1170
+ model=model, custom_llm_provider=custom_llm_provider
1171
+ )
1172
+
1173
+ verbose_logger.info(
1174
+ "Calculating batch cost per token. model=%s, custom_llm_provider=%s",
1175
+ model,
1176
+ custom_llm_provider,
1177
+ )
1178
+
1179
+ try:
1180
+ model_info: Optional[ModelInfo] = litellm.get_model_info(
1181
+ model=model, custom_llm_provider=custom_llm_provider
1182
+ )
1183
+ except Exception:
1184
+ model_info = None
1185
+
1186
+ if not model_info:
1187
+ return 0.0, 0.0
1188
+
1189
+ input_cost_per_token_batches = model_info.get("input_cost_per_token_batches")
1190
+ input_cost_per_token = model_info.get("input_cost_per_token")
1191
+ output_cost_per_token_batches = model_info.get("output_cost_per_token_batches")
1192
+ output_cost_per_token = model_info.get("output_cost_per_token")
1193
+ total_prompt_cost = 0.0
1194
+ total_completion_cost = 0.0
1195
+ if input_cost_per_token_batches:
1196
+ total_prompt_cost = usage.prompt_tokens * input_cost_per_token_batches
1197
+ elif input_cost_per_token:
1198
+ total_prompt_cost = (
1199
+ usage.prompt_tokens * (input_cost_per_token) / 2
1200
+ ) # batch cost is usually half of the regular token cost
1201
+ if output_cost_per_token_batches:
1202
+ total_completion_cost = usage.completion_tokens * output_cost_per_token_batches
1203
+ elif output_cost_per_token:
1204
+ total_completion_cost = (
1205
+ usage.completion_tokens * (output_cost_per_token) / 2
1206
+ ) # batch cost is usually half of the regular token cost
1207
+
1208
+ return total_prompt_cost, total_completion_cost
1209
+
1210
+
1211
+ class RealtimeAPITokenUsageProcessor:
1212
+ @staticmethod
1213
+ def collect_usage_from_realtime_stream_results(
1214
+ results: OpenAIRealtimeStreamList,
1215
+ ) -> List[Usage]:
1216
+ """
1217
+ Collect usage from realtime stream results
1218
+ """
1219
+ response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
1220
+ List[OpenAIRealtimeStreamResponseBaseObject],
1221
+ [result for result in results if result["type"] == "response.done"],
1222
+ )
1223
+ usage_objects: List[Usage] = []
1224
+ for result in response_done_events:
1225
+ usage_object = (
1226
+ ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
1227
+ result["response"].get("usage", {})
1228
+ )
1229
+ )
1230
+ usage_objects.append(usage_object)
1231
+ return usage_objects
1232
+
1233
+ @staticmethod
1234
+ def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
1235
+ """
1236
+ Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
1237
+ """
1238
+ from litellm.types.utils import (
1239
+ CompletionTokensDetails,
1240
+ PromptTokensDetailsWrapper,
1241
+ Usage,
1242
+ )
1243
+
1244
+ combined = Usage()
1245
+
1246
+ # Sum basic token counts
1247
+ for usage in usage_objects:
1248
+ # Handle direct attributes by checking what exists in the model
1249
+ for attr in dir(usage):
1250
+ if not attr.startswith("_") and not callable(getattr(usage, attr)):
1251
+ current_val = getattr(combined, attr, 0)
1252
+ new_val = getattr(usage, attr, 0)
1253
+ if (
1254
+ new_val is not None
1255
+ and isinstance(new_val, (int, float))
1256
+ and isinstance(current_val, (int, float))
1257
+ ):
1258
+ setattr(combined, attr, current_val + new_val)
1259
+ # Handle nested prompt_tokens_details
1260
+ if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
1261
+ if (
1262
+ not hasattr(combined, "prompt_tokens_details")
1263
+ or not combined.prompt_tokens_details
1264
+ ):
1265
+ combined.prompt_tokens_details = PromptTokensDetailsWrapper()
1266
+
1267
+ # Check what keys exist in the model's prompt_tokens_details
1268
+ for attr in dir(usage.prompt_tokens_details):
1269
+ if not attr.startswith("_") and not callable(
1270
+ getattr(usage.prompt_tokens_details, attr)
1271
+ ):
1272
+ current_val = getattr(combined.prompt_tokens_details, attr, 0)
1273
+ new_val = getattr(usage.prompt_tokens_details, attr, 0)
1274
+ if new_val is not None:
1275
+ setattr(
1276
+ combined.prompt_tokens_details,
1277
+ attr,
1278
+ current_val + new_val,
1279
+ )
1280
+
1281
+ # Handle nested completion_tokens_details
1282
+ if (
1283
+ hasattr(usage, "completion_tokens_details")
1284
+ and usage.completion_tokens_details
1285
+ ):
1286
+ if (
1287
+ not hasattr(combined, "completion_tokens_details")
1288
+ or not combined.completion_tokens_details
1289
+ ):
1290
+ combined.completion_tokens_details = CompletionTokensDetails()
1291
+
1292
+ # Check what keys exist in the model's completion_tokens_details
1293
+ for attr in dir(usage.completion_tokens_details):
1294
+ if not attr.startswith("_") and not callable(
1295
+ getattr(usage.completion_tokens_details, attr)
1296
+ ):
1297
+ current_val = getattr(
1298
+ combined.completion_tokens_details, attr, 0
1299
+ )
1300
+ new_val = getattr(usage.completion_tokens_details, attr, 0)
1301
+ if new_val is not None:
1302
+ setattr(
1303
+ combined.completion_tokens_details,
1304
+ attr,
1305
+ current_val + new_val,
1306
+ )
1307
+
1308
+ return combined
1309
+
1310
+ @staticmethod
1311
+ def collect_and_combine_usage_from_realtime_stream_results(
1312
+ results: OpenAIRealtimeStreamList,
1313
+ ) -> Usage:
1314
+ """
1315
+ Collect and combine usage from realtime stream results
1316
+ """
1317
+ collected_usage_objects = (
1318
+ RealtimeAPITokenUsageProcessor.collect_usage_from_realtime_stream_results(
1319
+ results
1320
+ )
1321
+ )
1322
+ combined_usage_object = RealtimeAPITokenUsageProcessor.combine_usage_objects(
1323
+ collected_usage_objects
1324
+ )
1325
+ return combined_usage_object
1326
+
1327
+ @staticmethod
1328
+ def create_logging_realtime_object(
1329
+ usage: Usage, results: OpenAIRealtimeStreamList
1330
+ ) -> LiteLLMRealtimeStreamLoggingObject:
1331
+ return LiteLLMRealtimeStreamLoggingObject(
1332
+ usage=usage,
1333
+ results=results,
1334
+ )
1335
+
1336
+
1337
+ def handle_realtime_stream_cost_calculation(
1338
+ results: OpenAIRealtimeStreamList,
1339
+ combined_usage_object: Usage,
1340
+ custom_llm_provider: str,
1341
+ litellm_model_name: str,
1342
+ ) -> float:
1343
+ """
1344
+ Handles the cost calculation for realtime stream responses.
1345
+
1346
+ Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
1347
+
1348
+ Args:
1349
+ results: A list of OpenAIRealtimeStreamBaseObject objects
1350
+ """
1351
+ received_model = None
1352
+ potential_model_names = []
1353
+ for result in results:
1354
+ if result["type"] == "session.created":
1355
+ received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
1356
+ "model"
1357
+ ]
1358
+ potential_model_names.append(received_model)
1359
+
1360
+ potential_model_names.append(litellm_model_name)
1361
+ input_cost_per_token = 0.0
1362
+ output_cost_per_token = 0.0
1363
+
1364
+ for model_name in potential_model_names:
1365
+ try:
1366
+ _input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
1367
+ model=model_name,
1368
+ usage=combined_usage_object,
1369
+ custom_llm_provider=custom_llm_provider,
1370
+ )
1371
+ except Exception:
1372
+ continue
1373
+ input_cost_per_token += _input_cost_per_token
1374
+ output_cost_per_token += _output_cost_per_token
1375
+ break # exit if we find a valid model
1376
+ total_cost = input_cost_per_token + output_cost_per_token
1377
+
1378
+ return total_cost
litellm/exceptions.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # +-----------------------------------------------+
2
+ # | |
3
+ # | Give Feedback / Get Help |
4
+ # | https://github.com/BerriAI/litellm/issues/new |
5
+ # | |
6
+ # +-----------------------------------------------+
7
+ #
8
+ # Thank you users! We ❤️ you! - Krrish & Ishaan
9
+
10
+ ## LiteLLM versions of the OpenAI Exception Types
11
+
12
+ from typing import Optional
13
+
14
+ import httpx
15
+ import openai
16
+
17
+ from litellm.types.utils import LiteLLMCommonStrings
18
+
19
+
20
+ class AuthenticationError(openai.AuthenticationError): # type: ignore
21
+ def __init__(
22
+ self,
23
+ message,
24
+ llm_provider,
25
+ model,
26
+ response: Optional[httpx.Response] = None,
27
+ litellm_debug_info: Optional[str] = None,
28
+ max_retries: Optional[int] = None,
29
+ num_retries: Optional[int] = None,
30
+ ):
31
+ self.status_code = 401
32
+ self.message = "litellm.AuthenticationError: {}".format(message)
33
+ self.llm_provider = llm_provider
34
+ self.model = model
35
+ self.litellm_debug_info = litellm_debug_info
36
+ self.max_retries = max_retries
37
+ self.num_retries = num_retries
38
+ self.response = response or httpx.Response(
39
+ status_code=self.status_code,
40
+ request=httpx.Request(
41
+ method="GET", url="https://litellm.ai"
42
+ ), # mock request object
43
+ )
44
+ super().__init__(
45
+ self.message, response=self.response, body=None
46
+ ) # Call the base class constructor with the parameters it needs
47
+
48
+ def __str__(self):
49
+ _message = self.message
50
+ if self.num_retries:
51
+ _message += f" LiteLLM Retried: {self.num_retries} times"
52
+ if self.max_retries:
53
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
54
+ return _message
55
+
56
+ def __repr__(self):
57
+ _message = self.message
58
+ if self.num_retries:
59
+ _message += f" LiteLLM Retried: {self.num_retries} times"
60
+ if self.max_retries:
61
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
62
+ return _message
63
+
64
+
65
+ # raise when invalid models passed, example gpt-8
66
+ class NotFoundError(openai.NotFoundError): # type: ignore
67
+ def __init__(
68
+ self,
69
+ message,
70
+ model,
71
+ llm_provider,
72
+ response: Optional[httpx.Response] = None,
73
+ litellm_debug_info: Optional[str] = None,
74
+ max_retries: Optional[int] = None,
75
+ num_retries: Optional[int] = None,
76
+ ):
77
+ self.status_code = 404
78
+ self.message = "litellm.NotFoundError: {}".format(message)
79
+ self.model = model
80
+ self.llm_provider = llm_provider
81
+ self.litellm_debug_info = litellm_debug_info
82
+ self.max_retries = max_retries
83
+ self.num_retries = num_retries
84
+ self.response = response or httpx.Response(
85
+ status_code=self.status_code,
86
+ request=httpx.Request(
87
+ method="GET", url="https://litellm.ai"
88
+ ), # mock request object
89
+ )
90
+ super().__init__(
91
+ self.message, response=self.response, body=None
92
+ ) # Call the base class constructor with the parameters it needs
93
+
94
+ def __str__(self):
95
+ _message = self.message
96
+ if self.num_retries:
97
+ _message += f" LiteLLM Retried: {self.num_retries} times"
98
+ if self.max_retries:
99
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
100
+ return _message
101
+
102
+ def __repr__(self):
103
+ _message = self.message
104
+ if self.num_retries:
105
+ _message += f" LiteLLM Retried: {self.num_retries} times"
106
+ if self.max_retries:
107
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
108
+ return _message
109
+
110
+
111
+ class BadRequestError(openai.BadRequestError): # type: ignore
112
+ def __init__(
113
+ self,
114
+ message,
115
+ model,
116
+ llm_provider,
117
+ response: Optional[httpx.Response] = None,
118
+ litellm_debug_info: Optional[str] = None,
119
+ max_retries: Optional[int] = None,
120
+ num_retries: Optional[int] = None,
121
+ body: Optional[dict] = None,
122
+ ):
123
+ self.status_code = 400
124
+ self.message = "litellm.BadRequestError: {}".format(message)
125
+ self.model = model
126
+ self.llm_provider = llm_provider
127
+ self.litellm_debug_info = litellm_debug_info
128
+ response = httpx.Response(
129
+ status_code=self.status_code,
130
+ request=httpx.Request(
131
+ method="GET", url="https://litellm.ai"
132
+ ), # mock request object
133
+ )
134
+ self.max_retries = max_retries
135
+ self.num_retries = num_retries
136
+ super().__init__(
137
+ self.message, response=response, body=body
138
+ ) # Call the base class constructor with the parameters it needs
139
+
140
+ def __str__(self):
141
+ _message = self.message
142
+ if self.num_retries:
143
+ _message += f" LiteLLM Retried: {self.num_retries} times"
144
+ if self.max_retries:
145
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
146
+ return _message
147
+
148
+ def __repr__(self):
149
+ _message = self.message
150
+ if self.num_retries:
151
+ _message += f" LiteLLM Retried: {self.num_retries} times"
152
+ if self.max_retries:
153
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
154
+ return _message
155
+
156
+
157
+ class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
158
+ def __init__(
159
+ self,
160
+ message,
161
+ model,
162
+ llm_provider,
163
+ response: httpx.Response,
164
+ litellm_debug_info: Optional[str] = None,
165
+ max_retries: Optional[int] = None,
166
+ num_retries: Optional[int] = None,
167
+ ):
168
+ self.status_code = 422
169
+ self.message = "litellm.UnprocessableEntityError: {}".format(message)
170
+ self.model = model
171
+ self.llm_provider = llm_provider
172
+ self.litellm_debug_info = litellm_debug_info
173
+ self.max_retries = max_retries
174
+ self.num_retries = num_retries
175
+ super().__init__(
176
+ self.message, response=response, body=None
177
+ ) # Call the base class constructor with the parameters it needs
178
+
179
+ def __str__(self):
180
+ _message = self.message
181
+ if self.num_retries:
182
+ _message += f" LiteLLM Retried: {self.num_retries} times"
183
+ if self.max_retries:
184
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
185
+ return _message
186
+
187
+ def __repr__(self):
188
+ _message = self.message
189
+ if self.num_retries:
190
+ _message += f" LiteLLM Retried: {self.num_retries} times"
191
+ if self.max_retries:
192
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
193
+ return _message
194
+
195
+
196
+ class Timeout(openai.APITimeoutError): # type: ignore
197
+ def __init__(
198
+ self,
199
+ message,
200
+ model,
201
+ llm_provider,
202
+ litellm_debug_info: Optional[str] = None,
203
+ max_retries: Optional[int] = None,
204
+ num_retries: Optional[int] = None,
205
+ headers: Optional[dict] = None,
206
+ exception_status_code: Optional[int] = None,
207
+ ):
208
+ request = httpx.Request(
209
+ method="POST",
210
+ url="https://api.openai.com/v1",
211
+ )
212
+ super().__init__(
213
+ request=request
214
+ ) # Call the base class constructor with the parameters it needs
215
+ self.status_code = exception_status_code or 408
216
+ self.message = "litellm.Timeout: {}".format(message)
217
+ self.model = model
218
+ self.llm_provider = llm_provider
219
+ self.litellm_debug_info = litellm_debug_info
220
+ self.max_retries = max_retries
221
+ self.num_retries = num_retries
222
+ self.headers = headers
223
+
224
+ # custom function to convert to str
225
+ def __str__(self):
226
+ _message = self.message
227
+ if self.num_retries:
228
+ _message += f" LiteLLM Retried: {self.num_retries} times"
229
+ if self.max_retries:
230
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
231
+ return _message
232
+
233
+ def __repr__(self):
234
+ _message = self.message
235
+ if self.num_retries:
236
+ _message += f" LiteLLM Retried: {self.num_retries} times"
237
+ if self.max_retries:
238
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
239
+ return _message
240
+
241
+
242
+ class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
243
+ def __init__(
244
+ self,
245
+ message,
246
+ llm_provider,
247
+ model,
248
+ response: httpx.Response,
249
+ litellm_debug_info: Optional[str] = None,
250
+ max_retries: Optional[int] = None,
251
+ num_retries: Optional[int] = None,
252
+ ):
253
+ self.status_code = 403
254
+ self.message = "litellm.PermissionDeniedError: {}".format(message)
255
+ self.llm_provider = llm_provider
256
+ self.model = model
257
+ self.litellm_debug_info = litellm_debug_info
258
+ self.max_retries = max_retries
259
+ self.num_retries = num_retries
260
+ super().__init__(
261
+ self.message, response=response, body=None
262
+ ) # Call the base class constructor with the parameters it needs
263
+
264
+ def __str__(self):
265
+ _message = self.message
266
+ if self.num_retries:
267
+ _message += f" LiteLLM Retried: {self.num_retries} times"
268
+ if self.max_retries:
269
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
270
+ return _message
271
+
272
+ def __repr__(self):
273
+ _message = self.message
274
+ if self.num_retries:
275
+ _message += f" LiteLLM Retried: {self.num_retries} times"
276
+ if self.max_retries:
277
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
278
+ return _message
279
+
280
+
281
+ class RateLimitError(openai.RateLimitError): # type: ignore
282
+ def __init__(
283
+ self,
284
+ message,
285
+ llm_provider,
286
+ model,
287
+ response: Optional[httpx.Response] = None,
288
+ litellm_debug_info: Optional[str] = None,
289
+ max_retries: Optional[int] = None,
290
+ num_retries: Optional[int] = None,
291
+ ):
292
+ self.status_code = 429
293
+ self.message = "litellm.RateLimitError: {}".format(message)
294
+ self.llm_provider = llm_provider
295
+ self.model = model
296
+ self.litellm_debug_info = litellm_debug_info
297
+ self.max_retries = max_retries
298
+ self.num_retries = num_retries
299
+ _response_headers = (
300
+ getattr(response, "headers", None) if response is not None else None
301
+ )
302
+ self.response = httpx.Response(
303
+ status_code=429,
304
+ headers=_response_headers,
305
+ request=httpx.Request(
306
+ method="POST",
307
+ url=" https://cloud.google.com/vertex-ai/",
308
+ ),
309
+ )
310
+ super().__init__(
311
+ self.message, response=self.response, body=None
312
+ ) # Call the base class constructor with the parameters it needs
313
+ self.code = "429"
314
+ self.type = "throttling_error"
315
+
316
+ def __str__(self):
317
+ _message = self.message
318
+ if self.num_retries:
319
+ _message += f" LiteLLM Retried: {self.num_retries} times"
320
+ if self.max_retries:
321
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
322
+ return _message
323
+
324
+ def __repr__(self):
325
+ _message = self.message
326
+ if self.num_retries:
327
+ _message += f" LiteLLM Retried: {self.num_retries} times"
328
+ if self.max_retries:
329
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
330
+ return _message
331
+
332
+
333
+ # sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
334
+ class ContextWindowExceededError(BadRequestError): # type: ignore
335
+ def __init__(
336
+ self,
337
+ message,
338
+ model,
339
+ llm_provider,
340
+ response: Optional[httpx.Response] = None,
341
+ litellm_debug_info: Optional[str] = None,
342
+ ):
343
+ self.status_code = 400
344
+ self.model = model
345
+ self.llm_provider = llm_provider
346
+ self.litellm_debug_info = litellm_debug_info
347
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
348
+ self.response = httpx.Response(status_code=400, request=request)
349
+ super().__init__(
350
+ message=message,
351
+ model=self.model, # type: ignore
352
+ llm_provider=self.llm_provider, # type: ignore
353
+ response=self.response,
354
+ litellm_debug_info=self.litellm_debug_info,
355
+ ) # Call the base class constructor with the parameters it needs
356
+
357
+ # set after, to make it clear the raised error is a context window exceeded error
358
+ self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
359
+
360
+ def __str__(self):
361
+ _message = self.message
362
+ if self.num_retries:
363
+ _message += f" LiteLLM Retried: {self.num_retries} times"
364
+ if self.max_retries:
365
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
366
+ return _message
367
+
368
+ def __repr__(self):
369
+ _message = self.message
370
+ if self.num_retries:
371
+ _message += f" LiteLLM Retried: {self.num_retries} times"
372
+ if self.max_retries:
373
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
374
+ return _message
375
+
376
+
377
+ # sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
378
+ class RejectedRequestError(BadRequestError): # type: ignore
379
+ def __init__(
380
+ self,
381
+ message,
382
+ model,
383
+ llm_provider,
384
+ request_data: dict,
385
+ litellm_debug_info: Optional[str] = None,
386
+ ):
387
+ self.status_code = 400
388
+ self.message = "litellm.RejectedRequestError: {}".format(message)
389
+ self.model = model
390
+ self.llm_provider = llm_provider
391
+ self.litellm_debug_info = litellm_debug_info
392
+ self.request_data = request_data
393
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
394
+ response = httpx.Response(status_code=400, request=request)
395
+ super().__init__(
396
+ message=self.message,
397
+ model=self.model, # type: ignore
398
+ llm_provider=self.llm_provider, # type: ignore
399
+ response=response,
400
+ litellm_debug_info=self.litellm_debug_info,
401
+ ) # Call the base class constructor with the parameters it needs
402
+
403
+ def __str__(self):
404
+ _message = self.message
405
+ if self.num_retries:
406
+ _message += f" LiteLLM Retried: {self.num_retries} times"
407
+ if self.max_retries:
408
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
409
+ return _message
410
+
411
+ def __repr__(self):
412
+ _message = self.message
413
+ if self.num_retries:
414
+ _message += f" LiteLLM Retried: {self.num_retries} times"
415
+ if self.max_retries:
416
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
417
+ return _message
418
+
419
+
420
+ class ContentPolicyViolationError(BadRequestError): # type: ignore
421
+ # Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
422
+ def __init__(
423
+ self,
424
+ message,
425
+ model,
426
+ llm_provider,
427
+ response: Optional[httpx.Response] = None,
428
+ litellm_debug_info: Optional[str] = None,
429
+ ):
430
+ self.status_code = 400
431
+ self.message = "litellm.ContentPolicyViolationError: {}".format(message)
432
+ self.model = model
433
+ self.llm_provider = llm_provider
434
+ self.litellm_debug_info = litellm_debug_info
435
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
436
+ self.response = httpx.Response(status_code=400, request=request)
437
+ super().__init__(
438
+ message=self.message,
439
+ model=self.model, # type: ignore
440
+ llm_provider=self.llm_provider, # type: ignore
441
+ response=self.response,
442
+ litellm_debug_info=self.litellm_debug_info,
443
+ ) # Call the base class constructor with the parameters it needs
444
+
445
+ def __str__(self):
446
+ _message = self.message
447
+ if self.num_retries:
448
+ _message += f" LiteLLM Retried: {self.num_retries} times"
449
+ if self.max_retries:
450
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
451
+ return _message
452
+
453
+ def __repr__(self):
454
+ _message = self.message
455
+ if self.num_retries:
456
+ _message += f" LiteLLM Retried: {self.num_retries} times"
457
+ if self.max_retries:
458
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
459
+ return _message
460
+
461
+
462
+ class ServiceUnavailableError(openai.APIStatusError): # type: ignore
463
+ def __init__(
464
+ self,
465
+ message,
466
+ llm_provider,
467
+ model,
468
+ response: Optional[httpx.Response] = None,
469
+ litellm_debug_info: Optional[str] = None,
470
+ max_retries: Optional[int] = None,
471
+ num_retries: Optional[int] = None,
472
+ ):
473
+ self.status_code = 503
474
+ self.message = "litellm.ServiceUnavailableError: {}".format(message)
475
+ self.llm_provider = llm_provider
476
+ self.model = model
477
+ self.litellm_debug_info = litellm_debug_info
478
+ self.max_retries = max_retries
479
+ self.num_retries = num_retries
480
+ self.response = httpx.Response(
481
+ status_code=self.status_code,
482
+ request=httpx.Request(
483
+ method="POST",
484
+ url=" https://cloud.google.com/vertex-ai/",
485
+ ),
486
+ )
487
+ super().__init__(
488
+ self.message, response=self.response, body=None
489
+ ) # Call the base class constructor with the parameters it needs
490
+
491
+ def __str__(self):
492
+ _message = self.message
493
+ if self.num_retries:
494
+ _message += f" LiteLLM Retried: {self.num_retries} times"
495
+ if self.max_retries:
496
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
497
+ return _message
498
+
499
+ def __repr__(self):
500
+ _message = self.message
501
+ if self.num_retries:
502
+ _message += f" LiteLLM Retried: {self.num_retries} times"
503
+ if self.max_retries:
504
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
505
+ return _message
506
+
507
+
508
+ class InternalServerError(openai.InternalServerError): # type: ignore
509
+ def __init__(
510
+ self,
511
+ message,
512
+ llm_provider,
513
+ model,
514
+ response: Optional[httpx.Response] = None,
515
+ litellm_debug_info: Optional[str] = None,
516
+ max_retries: Optional[int] = None,
517
+ num_retries: Optional[int] = None,
518
+ ):
519
+ self.status_code = 500
520
+ self.message = "litellm.InternalServerError: {}".format(message)
521
+ self.llm_provider = llm_provider
522
+ self.model = model
523
+ self.litellm_debug_info = litellm_debug_info
524
+ self.max_retries = max_retries
525
+ self.num_retries = num_retries
526
+ self.response = httpx.Response(
527
+ status_code=self.status_code,
528
+ request=httpx.Request(
529
+ method="POST",
530
+ url=" https://cloud.google.com/vertex-ai/",
531
+ ),
532
+ )
533
+ super().__init__(
534
+ self.message, response=self.response, body=None
535
+ ) # Call the base class constructor with the parameters it needs
536
+
537
+ def __str__(self):
538
+ _message = self.message
539
+ if self.num_retries:
540
+ _message += f" LiteLLM Retried: {self.num_retries} times"
541
+ if self.max_retries:
542
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
543
+ return _message
544
+
545
+ def __repr__(self):
546
+ _message = self.message
547
+ if self.num_retries:
548
+ _message += f" LiteLLM Retried: {self.num_retries} times"
549
+ if self.max_retries:
550
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
551
+ return _message
552
+
553
+
554
+ # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
555
+ class APIError(openai.APIError): # type: ignore
556
+ def __init__(
557
+ self,
558
+ status_code: int,
559
+ message,
560
+ llm_provider,
561
+ model,
562
+ request: Optional[httpx.Request] = None,
563
+ litellm_debug_info: Optional[str] = None,
564
+ max_retries: Optional[int] = None,
565
+ num_retries: Optional[int] = None,
566
+ ):
567
+ self.status_code = status_code
568
+ self.message = "litellm.APIError: {}".format(message)
569
+ self.llm_provider = llm_provider
570
+ self.model = model
571
+ self.litellm_debug_info = litellm_debug_info
572
+ self.max_retries = max_retries
573
+ self.num_retries = num_retries
574
+ if request is None:
575
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
576
+ super().__init__(self.message, request=request, body=None) # type: ignore
577
+
578
+ def __str__(self):
579
+ _message = self.message
580
+ if self.num_retries:
581
+ _message += f" LiteLLM Retried: {self.num_retries} times"
582
+ if self.max_retries:
583
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
584
+ return _message
585
+
586
+ def __repr__(self):
587
+ _message = self.message
588
+ if self.num_retries:
589
+ _message += f" LiteLLM Retried: {self.num_retries} times"
590
+ if self.max_retries:
591
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
592
+ return _message
593
+
594
+
595
+ # raised if an invalid request (not get, delete, put, post) is made
596
+ class APIConnectionError(openai.APIConnectionError): # type: ignore
597
+ def __init__(
598
+ self,
599
+ message,
600
+ llm_provider,
601
+ model,
602
+ request: Optional[httpx.Request] = None,
603
+ litellm_debug_info: Optional[str] = None,
604
+ max_retries: Optional[int] = None,
605
+ num_retries: Optional[int] = None,
606
+ ):
607
+ self.message = "litellm.APIConnectionError: {}".format(message)
608
+ self.llm_provider = llm_provider
609
+ self.model = model
610
+ self.status_code = 500
611
+ self.litellm_debug_info = litellm_debug_info
612
+ self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
613
+ self.max_retries = max_retries
614
+ self.num_retries = num_retries
615
+ super().__init__(message=self.message, request=self.request)
616
+
617
+ def __str__(self):
618
+ _message = self.message
619
+ if self.num_retries:
620
+ _message += f" LiteLLM Retried: {self.num_retries} times"
621
+ if self.max_retries:
622
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
623
+ return _message
624
+
625
+ def __repr__(self):
626
+ _message = self.message
627
+ if self.num_retries:
628
+ _message += f" LiteLLM Retried: {self.num_retries} times"
629
+ if self.max_retries:
630
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
631
+ return _message
632
+
633
+
634
+ # raised if an invalid request (not get, delete, put, post) is made
635
+ class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
636
+ def __init__(
637
+ self,
638
+ message,
639
+ llm_provider,
640
+ model,
641
+ litellm_debug_info: Optional[str] = None,
642
+ max_retries: Optional[int] = None,
643
+ num_retries: Optional[int] = None,
644
+ ):
645
+ self.message = "litellm.APIResponseValidationError: {}".format(message)
646
+ self.llm_provider = llm_provider
647
+ self.model = model
648
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
649
+ response = httpx.Response(status_code=500, request=request)
650
+ self.litellm_debug_info = litellm_debug_info
651
+ self.max_retries = max_retries
652
+ self.num_retries = num_retries
653
+ super().__init__(response=response, body=None, message=message)
654
+
655
+ def __str__(self):
656
+ _message = self.message
657
+ if self.num_retries:
658
+ _message += f" LiteLLM Retried: {self.num_retries} times"
659
+ if self.max_retries:
660
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
661
+ return _message
662
+
663
+ def __repr__(self):
664
+ _message = self.message
665
+ if self.num_retries:
666
+ _message += f" LiteLLM Retried: {self.num_retries} times"
667
+ if self.max_retries:
668
+ _message += f", LiteLLM Max Retries: {self.max_retries}"
669
+ return _message
670
+
671
+
672
+ class JSONSchemaValidationError(APIResponseValidationError):
673
+ def __init__(
674
+ self, model: str, llm_provider: str, raw_response: str, schema: str
675
+ ) -> None:
676
+ self.raw_response = raw_response
677
+ self.schema = schema
678
+ self.model = model
679
+ message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
680
+ model, raw_response, schema
681
+ )
682
+ self.message = message
683
+ super().__init__(model=model, message=message, llm_provider=llm_provider)
684
+
685
+
686
+ class OpenAIError(openai.OpenAIError): # type: ignore
687
+ def __init__(self, original_exception=None):
688
+ super().__init__()
689
+ self.llm_provider = "openai"
690
+
691
+
692
+ class UnsupportedParamsError(BadRequestError):
693
+ def __init__(
694
+ self,
695
+ message,
696
+ llm_provider: Optional[str] = None,
697
+ model: Optional[str] = None,
698
+ status_code: int = 400,
699
+ response: Optional[httpx.Response] = None,
700
+ litellm_debug_info: Optional[str] = None,
701
+ max_retries: Optional[int] = None,
702
+ num_retries: Optional[int] = None,
703
+ ):
704
+ self.status_code = 400
705
+ self.message = "litellm.UnsupportedParamsError: {}".format(message)
706
+ self.model = model
707
+ self.llm_provider = llm_provider
708
+ self.litellm_debug_info = litellm_debug_info
709
+ response = response or httpx.Response(
710
+ status_code=self.status_code,
711
+ request=httpx.Request(
712
+ method="GET", url="https://litellm.ai"
713
+ ), # mock request object
714
+ )
715
+ self.max_retries = max_retries
716
+ self.num_retries = num_retries
717
+
718
+
719
+ LITELLM_EXCEPTION_TYPES = [
720
+ AuthenticationError,
721
+ NotFoundError,
722
+ BadRequestError,
723
+ UnprocessableEntityError,
724
+ UnsupportedParamsError,
725
+ Timeout,
726
+ PermissionDeniedError,
727
+ RateLimitError,
728
+ ContextWindowExceededError,
729
+ RejectedRequestError,
730
+ ContentPolicyViolationError,
731
+ InternalServerError,
732
+ ServiceUnavailableError,
733
+ APIError,
734
+ APIConnectionError,
735
+ APIResponseValidationError,
736
+ OpenAIError,
737
+ InternalServerError,
738
+ JSONSchemaValidationError,
739
+ ]
740
+
741
+
742
+ class BudgetExceededError(Exception):
743
+ def __init__(
744
+ self, current_cost: float, max_budget: float, message: Optional[str] = None
745
+ ):
746
+ self.current_cost = current_cost
747
+ self.max_budget = max_budget
748
+ message = (
749
+ message
750
+ or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
751
+ )
752
+ self.message = message
753
+ super().__init__(message)
754
+
755
+
756
+ ## DEPRECATED ##
757
+ class InvalidRequestError(openai.BadRequestError): # type: ignore
758
+ def __init__(self, message, model, llm_provider):
759
+ self.status_code = 400
760
+ self.message = message
761
+ self.model = model
762
+ self.llm_provider = llm_provider
763
+ self.response = httpx.Response(
764
+ status_code=400,
765
+ request=httpx.Request(
766
+ method="GET", url="https://litellm.ai"
767
+ ), # mock request object
768
+ )
769
+ super().__init__(
770
+ message=self.message, response=self.response, body=None
771
+ ) # Call the base class constructor with the parameters it needs
772
+
773
+
774
+ class MockException(openai.APIError):
775
+ # used for testing
776
+ def __init__(
777
+ self,
778
+ status_code: int,
779
+ message,
780
+ llm_provider,
781
+ model,
782
+ request: Optional[httpx.Request] = None,
783
+ litellm_debug_info: Optional[str] = None,
784
+ max_retries: Optional[int] = None,
785
+ num_retries: Optional[int] = None,
786
+ ):
787
+ self.status_code = status_code
788
+ self.message = "litellm.MockException: {}".format(message)
789
+ self.llm_provider = llm_provider
790
+ self.model = model
791
+ self.litellm_debug_info = litellm_debug_info
792
+ self.max_retries = max_retries
793
+ self.num_retries = num_retries
794
+ if request is None:
795
+ request = httpx.Request(method="POST", url="https://api.openai.com/v1")
796
+ super().__init__(self.message, request=request, body=None) # type: ignore
797
+
798
+
799
+ class LiteLLMUnknownProvider(BadRequestError):
800
+ def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
801
+ self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
802
+ model=model, custom_llm_provider=custom_llm_provider
803
+ )
804
+ super().__init__(
805
+ self.message, model=model, llm_provider=custom_llm_provider, response=None
806
+ )
807
+
808
+ def __str__(self):
809
+ return self.message
litellm/experimental_mcp_client/Readme.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # LiteLLM MCP Client
2
+
3
+ LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
4
+
5
+
6
+
litellm/experimental_mcp_client/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .tools import call_openai_tool, load_mcp_tools
2
+
3
+ __all__ = ["load_mcp_tools", "call_openai_tool"]
litellm/experimental_mcp_client/client.py ADDED
File without changes
litellm/experimental_mcp_client/tools.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Dict, List, Literal, Union
3
+
4
+ from mcp import ClientSession
5
+ from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
6
+ from mcp.types import CallToolResult as MCPCallToolResult
7
+ from mcp.types import Tool as MCPTool
8
+ from openai.types.chat import ChatCompletionToolParam
9
+ from openai.types.shared_params.function_definition import FunctionDefinition
10
+
11
+ from litellm.types.utils import ChatCompletionMessageToolCall
12
+
13
+
14
+ ########################################################
15
+ # List MCP Tool functions
16
+ ########################################################
17
+ def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
18
+ """Convert an MCP tool to an OpenAI tool."""
19
+ return ChatCompletionToolParam(
20
+ type="function",
21
+ function=FunctionDefinition(
22
+ name=mcp_tool.name,
23
+ description=mcp_tool.description or "",
24
+ parameters=mcp_tool.inputSchema,
25
+ strict=False,
26
+ ),
27
+ )
28
+
29
+
30
+ async def load_mcp_tools(
31
+ session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
32
+ ) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
33
+ """
34
+ Load all available MCP tools
35
+
36
+ Args:
37
+ session: The MCP session to use
38
+ format: The format to convert the tools to
39
+ By default, the tools are returned in MCP format.
40
+
41
+ If format is set to "openai", the tools are converted to OpenAI API compatible tools.
42
+ """
43
+ tools = await session.list_tools()
44
+ if format == "openai":
45
+ return [
46
+ transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
47
+ ]
48
+ return tools.tools
49
+
50
+
51
+ ########################################################
52
+ # Call MCP Tool functions
53
+ ########################################################
54
+
55
+
56
+ async def call_mcp_tool(
57
+ session: ClientSession,
58
+ call_tool_request_params: MCPCallToolRequestParams,
59
+ ) -> MCPCallToolResult:
60
+ """Call an MCP tool."""
61
+ tool_result = await session.call_tool(
62
+ name=call_tool_request_params.name,
63
+ arguments=call_tool_request_params.arguments,
64
+ )
65
+ return tool_result
66
+
67
+
68
+ def _get_function_arguments(function: FunctionDefinition) -> dict:
69
+ """Helper to safely get and parse function arguments."""
70
+ arguments = function.get("arguments", {})
71
+ if isinstance(arguments, str):
72
+ try:
73
+ arguments = json.loads(arguments)
74
+ except json.JSONDecodeError:
75
+ arguments = {}
76
+ return arguments if isinstance(arguments, dict) else {}
77
+
78
+
79
+ def transform_openai_tool_call_request_to_mcp_tool_call_request(
80
+ openai_tool: Union[ChatCompletionMessageToolCall, Dict],
81
+ ) -> MCPCallToolRequestParams:
82
+ """Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
83
+ function = openai_tool["function"]
84
+ return MCPCallToolRequestParams(
85
+ name=function["name"],
86
+ arguments=_get_function_arguments(function),
87
+ )
88
+
89
+
90
+ async def call_openai_tool(
91
+ session: ClientSession,
92
+ openai_tool: ChatCompletionMessageToolCall,
93
+ ) -> MCPCallToolResult:
94
+ """
95
+ Call an OpenAI tool using MCP client.
96
+
97
+ Args:
98
+ session: The MCP session to use
99
+ openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
100
+ Returns:
101
+ The result of the MCP tool call.
102
+ """
103
+ mcp_tool_call_request_params = (
104
+ transform_openai_tool_call_request_to_mcp_tool_call_request(
105
+ openai_tool=openai_tool,
106
+ )
107
+ )
108
+ return await call_mcp_tool(
109
+ session=session,
110
+ call_tool_request_params=mcp_tool_call_request_params,
111
+ )
litellm/files/main.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main File for Files API implementation
3
+
4
+ https://platform.openai.com/docs/api-reference/files
5
+
6
+ """
7
+
8
+ import asyncio
9
+ import contextvars
10
+ import os
11
+ from functools import partial
12
+ from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
13
+
14
+ import httpx
15
+
16
+ import litellm
17
+ from litellm import get_secret_str
18
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
19
+ from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
20
+ from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
21
+ from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
22
+ from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
23
+ from litellm.types.llms.openai import (
24
+ CreateFileRequest,
25
+ FileContentRequest,
26
+ FileTypes,
27
+ HttpxBinaryResponseContent,
28
+ OpenAIFileObject,
29
+ )
30
+ from litellm.types.router import *
31
+ from litellm.types.utils import LlmProviders
32
+ from litellm.utils import (
33
+ ProviderConfigManager,
34
+ client,
35
+ get_litellm_params,
36
+ supports_httpx_timeout,
37
+ )
38
+
39
+ base_llm_http_handler = BaseLLMHTTPHandler()
40
+
41
+ ####### ENVIRONMENT VARIABLES ###################
42
+ openai_files_instance = OpenAIFilesAPI()
43
+ azure_files_instance = AzureOpenAIFilesAPI()
44
+ vertex_ai_files_instance = VertexAIFilesHandler()
45
+ #################################################
46
+
47
+
48
+ @client
49
+ async def acreate_file(
50
+ file: FileTypes,
51
+ purpose: Literal["assistants", "batch", "fine-tune"],
52
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
53
+ extra_headers: Optional[Dict[str, str]] = None,
54
+ extra_body: Optional[Dict[str, str]] = None,
55
+ **kwargs,
56
+ ) -> OpenAIFileObject:
57
+ """
58
+ Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
59
+
60
+ LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
61
+ """
62
+ try:
63
+ loop = asyncio.get_event_loop()
64
+ kwargs["acreate_file"] = True
65
+
66
+ call_args = {
67
+ "file": file,
68
+ "purpose": purpose,
69
+ "custom_llm_provider": custom_llm_provider,
70
+ "extra_headers": extra_headers,
71
+ "extra_body": extra_body,
72
+ **kwargs,
73
+ }
74
+
75
+ # Use a partial function to pass your keyword arguments
76
+ func = partial(create_file, **call_args)
77
+
78
+ # Add the context to the function
79
+ ctx = contextvars.copy_context()
80
+ func_with_context = partial(ctx.run, func)
81
+ init_response = await loop.run_in_executor(None, func_with_context)
82
+ if asyncio.iscoroutine(init_response):
83
+ response = await init_response
84
+ else:
85
+ response = init_response # type: ignore
86
+
87
+ return response
88
+ except Exception as e:
89
+ raise e
90
+
91
+
92
+ @client
93
+ def create_file(
94
+ file: FileTypes,
95
+ purpose: Literal["assistants", "batch", "fine-tune"],
96
+ custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
97
+ extra_headers: Optional[Dict[str, str]] = None,
98
+ extra_body: Optional[Dict[str, str]] = None,
99
+ **kwargs,
100
+ ) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
101
+ """
102
+ Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
103
+
104
+ LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
105
+
106
+ Specify either provider_list or custom_llm_provider.
107
+ """
108
+ try:
109
+ _is_async = kwargs.pop("acreate_file", False) is True
110
+ optional_params = GenericLiteLLMParams(**kwargs)
111
+ litellm_params_dict = get_litellm_params(**kwargs)
112
+ logging_obj = cast(
113
+ Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
114
+ )
115
+ if logging_obj is None:
116
+ raise ValueError("logging_obj is required")
117
+ client = kwargs.get("client")
118
+
119
+ ### TIMEOUT LOGIC ###
120
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
121
+ # set timeout for 10 minutes by default
122
+
123
+ if (
124
+ timeout is not None
125
+ and isinstance(timeout, httpx.Timeout)
126
+ and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
127
+ ):
128
+ read_timeout = timeout.read or 600
129
+ timeout = read_timeout # default 10 min timeout
130
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
131
+ timeout = float(timeout) # type: ignore
132
+ elif timeout is None:
133
+ timeout = 600.0
134
+
135
+ _create_file_request = CreateFileRequest(
136
+ file=file,
137
+ purpose=purpose,
138
+ extra_headers=extra_headers,
139
+ extra_body=extra_body,
140
+ )
141
+
142
+ provider_config = ProviderConfigManager.get_provider_files_config(
143
+ model="",
144
+ provider=LlmProviders(custom_llm_provider),
145
+ )
146
+ if provider_config is not None:
147
+ response = base_llm_http_handler.create_file(
148
+ provider_config=provider_config,
149
+ litellm_params=litellm_params_dict,
150
+ create_file_data=_create_file_request,
151
+ headers=extra_headers or {},
152
+ api_base=optional_params.api_base,
153
+ api_key=optional_params.api_key,
154
+ logging_obj=logging_obj,
155
+ _is_async=_is_async,
156
+ client=client
157
+ if client is not None
158
+ and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
159
+ else None,
160
+ timeout=timeout,
161
+ )
162
+ elif custom_llm_provider == "openai":
163
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
164
+ api_base = (
165
+ optional_params.api_base
166
+ or litellm.api_base
167
+ or os.getenv("OPENAI_BASE_URL")
168
+ or os.getenv("OPENAI_API_BASE")
169
+ or "https://api.openai.com/v1"
170
+ )
171
+ organization = (
172
+ optional_params.organization
173
+ or litellm.organization
174
+ or os.getenv("OPENAI_ORGANIZATION", None)
175
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
176
+ )
177
+ # set API KEY
178
+ api_key = (
179
+ optional_params.api_key
180
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
181
+ or litellm.openai_key
182
+ or os.getenv("OPENAI_API_KEY")
183
+ )
184
+
185
+ response = openai_files_instance.create_file(
186
+ _is_async=_is_async,
187
+ api_base=api_base,
188
+ api_key=api_key,
189
+ timeout=timeout,
190
+ max_retries=optional_params.max_retries,
191
+ organization=organization,
192
+ create_file_data=_create_file_request,
193
+ )
194
+ elif custom_llm_provider == "azure":
195
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
196
+ api_version = (
197
+ optional_params.api_version
198
+ or litellm.api_version
199
+ or get_secret_str("AZURE_API_VERSION")
200
+ ) # type: ignore
201
+
202
+ api_key = (
203
+ optional_params.api_key
204
+ or litellm.api_key
205
+ or litellm.azure_key
206
+ or get_secret_str("AZURE_OPENAI_API_KEY")
207
+ or get_secret_str("AZURE_API_KEY")
208
+ ) # type: ignore
209
+
210
+ extra_body = optional_params.get("extra_body", {})
211
+ if extra_body is not None:
212
+ extra_body.pop("azure_ad_token", None)
213
+ else:
214
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
215
+
216
+ response = azure_files_instance.create_file(
217
+ _is_async=_is_async,
218
+ api_base=api_base,
219
+ api_key=api_key,
220
+ api_version=api_version,
221
+ timeout=timeout,
222
+ max_retries=optional_params.max_retries,
223
+ create_file_data=_create_file_request,
224
+ litellm_params=litellm_params_dict,
225
+ )
226
+ elif custom_llm_provider == "vertex_ai":
227
+ api_base = optional_params.api_base or ""
228
+ vertex_ai_project = (
229
+ optional_params.vertex_project
230
+ or litellm.vertex_project
231
+ or get_secret_str("VERTEXAI_PROJECT")
232
+ )
233
+ vertex_ai_location = (
234
+ optional_params.vertex_location
235
+ or litellm.vertex_location
236
+ or get_secret_str("VERTEXAI_LOCATION")
237
+ )
238
+ vertex_credentials = optional_params.vertex_credentials or get_secret_str(
239
+ "VERTEXAI_CREDENTIALS"
240
+ )
241
+
242
+ response = vertex_ai_files_instance.create_file(
243
+ _is_async=_is_async,
244
+ api_base=api_base,
245
+ vertex_project=vertex_ai_project,
246
+ vertex_location=vertex_ai_location,
247
+ vertex_credentials=vertex_credentials,
248
+ timeout=timeout,
249
+ max_retries=optional_params.max_retries,
250
+ create_file_data=_create_file_request,
251
+ )
252
+ else:
253
+ raise litellm.exceptions.BadRequestError(
254
+ message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai'] are supported.".format(
255
+ custom_llm_provider
256
+ ),
257
+ model="n/a",
258
+ llm_provider=custom_llm_provider,
259
+ response=httpx.Response(
260
+ status_code=400,
261
+ content="Unsupported provider",
262
+ request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
263
+ ),
264
+ )
265
+ return response
266
+ except Exception as e:
267
+ raise e
268
+
269
+
270
+ async def afile_retrieve(
271
+ file_id: str,
272
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
273
+ extra_headers: Optional[Dict[str, str]] = None,
274
+ extra_body: Optional[Dict[str, str]] = None,
275
+ **kwargs,
276
+ ):
277
+ """
278
+ Async: Get file contents
279
+
280
+ LiteLLM Equivalent of GET https://api.openai.com/v1/files
281
+ """
282
+ try:
283
+ loop = asyncio.get_event_loop()
284
+ kwargs["is_async"] = True
285
+
286
+ # Use a partial function to pass your keyword arguments
287
+ func = partial(
288
+ file_retrieve,
289
+ file_id,
290
+ custom_llm_provider,
291
+ extra_headers,
292
+ extra_body,
293
+ **kwargs,
294
+ )
295
+
296
+ # Add the context to the function
297
+ ctx = contextvars.copy_context()
298
+ func_with_context = partial(ctx.run, func)
299
+ init_response = await loop.run_in_executor(None, func_with_context)
300
+ if asyncio.iscoroutine(init_response):
301
+ response = await init_response
302
+ else:
303
+ response = init_response
304
+
305
+ return response
306
+ except Exception as e:
307
+ raise e
308
+
309
+
310
+ def file_retrieve(
311
+ file_id: str,
312
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
313
+ extra_headers: Optional[Dict[str, str]] = None,
314
+ extra_body: Optional[Dict[str, str]] = None,
315
+ **kwargs,
316
+ ) -> FileObject:
317
+ """
318
+ Returns the contents of the specified file.
319
+
320
+ LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
321
+ """
322
+ try:
323
+ optional_params = GenericLiteLLMParams(**kwargs)
324
+ ### TIMEOUT LOGIC ###
325
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
326
+ # set timeout for 10 minutes by default
327
+
328
+ if (
329
+ timeout is not None
330
+ and isinstance(timeout, httpx.Timeout)
331
+ and supports_httpx_timeout(custom_llm_provider) is False
332
+ ):
333
+ read_timeout = timeout.read or 600
334
+ timeout = read_timeout # default 10 min timeout
335
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
336
+ timeout = float(timeout) # type: ignore
337
+ elif timeout is None:
338
+ timeout = 600.0
339
+
340
+ _is_async = kwargs.pop("is_async", False) is True
341
+
342
+ if custom_llm_provider == "openai":
343
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
344
+ api_base = (
345
+ optional_params.api_base
346
+ or litellm.api_base
347
+ or os.getenv("OPENAI_BASE_URL")
348
+ or os.getenv("OPENAI_API_BASE")
349
+ or "https://api.openai.com/v1"
350
+ )
351
+ organization = (
352
+ optional_params.organization
353
+ or litellm.organization
354
+ or os.getenv("OPENAI_ORGANIZATION", None)
355
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
356
+ )
357
+ # set API KEY
358
+ api_key = (
359
+ optional_params.api_key
360
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
361
+ or litellm.openai_key
362
+ or os.getenv("OPENAI_API_KEY")
363
+ )
364
+
365
+ response = openai_files_instance.retrieve_file(
366
+ file_id=file_id,
367
+ _is_async=_is_async,
368
+ api_base=api_base,
369
+ api_key=api_key,
370
+ timeout=timeout,
371
+ max_retries=optional_params.max_retries,
372
+ organization=organization,
373
+ )
374
+ elif custom_llm_provider == "azure":
375
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
376
+ api_version = (
377
+ optional_params.api_version
378
+ or litellm.api_version
379
+ or get_secret_str("AZURE_API_VERSION")
380
+ ) # type: ignore
381
+
382
+ api_key = (
383
+ optional_params.api_key
384
+ or litellm.api_key
385
+ or litellm.azure_key
386
+ or get_secret_str("AZURE_OPENAI_API_KEY")
387
+ or get_secret_str("AZURE_API_KEY")
388
+ ) # type: ignore
389
+
390
+ extra_body = optional_params.get("extra_body", {})
391
+ if extra_body is not None:
392
+ extra_body.pop("azure_ad_token", None)
393
+ else:
394
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
395
+
396
+ response = azure_files_instance.retrieve_file(
397
+ _is_async=_is_async,
398
+ api_base=api_base,
399
+ api_key=api_key,
400
+ api_version=api_version,
401
+ timeout=timeout,
402
+ max_retries=optional_params.max_retries,
403
+ file_id=file_id,
404
+ )
405
+ else:
406
+ raise litellm.exceptions.BadRequestError(
407
+ message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai' and 'azure' are supported.".format(
408
+ custom_llm_provider
409
+ ),
410
+ model="n/a",
411
+ llm_provider=custom_llm_provider,
412
+ response=httpx.Response(
413
+ status_code=400,
414
+ content="Unsupported provider",
415
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
416
+ ),
417
+ )
418
+ return cast(FileObject, response)
419
+ except Exception as e:
420
+ raise e
421
+
422
+
423
+ # Delete file
424
+ async def afile_delete(
425
+ file_id: str,
426
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
427
+ extra_headers: Optional[Dict[str, str]] = None,
428
+ extra_body: Optional[Dict[str, str]] = None,
429
+ **kwargs,
430
+ ) -> Coroutine[Any, Any, FileObject]:
431
+ """
432
+ Async: Delete file
433
+
434
+ LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
435
+ """
436
+ try:
437
+ loop = asyncio.get_event_loop()
438
+ kwargs["is_async"] = True
439
+
440
+ # Use a partial function to pass your keyword arguments
441
+ func = partial(
442
+ file_delete,
443
+ file_id,
444
+ custom_llm_provider,
445
+ extra_headers,
446
+ extra_body,
447
+ **kwargs,
448
+ )
449
+
450
+ # Add the context to the function
451
+ ctx = contextvars.copy_context()
452
+ func_with_context = partial(ctx.run, func)
453
+ init_response = await loop.run_in_executor(None, func_with_context)
454
+ if asyncio.iscoroutine(init_response):
455
+ response = await init_response
456
+ else:
457
+ response = init_response # type: ignore
458
+
459
+ return cast(FileDeleted, response) # type: ignore
460
+ except Exception as e:
461
+ raise e
462
+
463
+
464
+ def file_delete(
465
+ file_id: str,
466
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
467
+ extra_headers: Optional[Dict[str, str]] = None,
468
+ extra_body: Optional[Dict[str, str]] = None,
469
+ **kwargs,
470
+ ) -> FileDeleted:
471
+ """
472
+ Delete file
473
+
474
+ LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
475
+ """
476
+ try:
477
+ optional_params = GenericLiteLLMParams(**kwargs)
478
+ litellm_params_dict = get_litellm_params(**kwargs)
479
+ ### TIMEOUT LOGIC ###
480
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
481
+ # set timeout for 10 minutes by default
482
+ client = kwargs.get("client")
483
+
484
+ if (
485
+ timeout is not None
486
+ and isinstance(timeout, httpx.Timeout)
487
+ and supports_httpx_timeout(custom_llm_provider) is False
488
+ ):
489
+ read_timeout = timeout.read or 600
490
+ timeout = read_timeout # default 10 min timeout
491
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
492
+ timeout = float(timeout) # type: ignore
493
+ elif timeout is None:
494
+ timeout = 600.0
495
+ _is_async = kwargs.pop("is_async", False) is True
496
+ if custom_llm_provider == "openai":
497
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
498
+ api_base = (
499
+ optional_params.api_base
500
+ or litellm.api_base
501
+ or os.getenv("OPENAI_BASE_URL")
502
+ or os.getenv("OPENAI_API_BASE")
503
+ or "https://api.openai.com/v1"
504
+ )
505
+ organization = (
506
+ optional_params.organization
507
+ or litellm.organization
508
+ or os.getenv("OPENAI_ORGANIZATION", None)
509
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
510
+ )
511
+ # set API KEY
512
+ api_key = (
513
+ optional_params.api_key
514
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
515
+ or litellm.openai_key
516
+ or os.getenv("OPENAI_API_KEY")
517
+ )
518
+ response = openai_files_instance.delete_file(
519
+ file_id=file_id,
520
+ _is_async=_is_async,
521
+ api_base=api_base,
522
+ api_key=api_key,
523
+ timeout=timeout,
524
+ max_retries=optional_params.max_retries,
525
+ organization=organization,
526
+ )
527
+ elif custom_llm_provider == "azure":
528
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
529
+ api_version = (
530
+ optional_params.api_version
531
+ or litellm.api_version
532
+ or get_secret_str("AZURE_API_VERSION")
533
+ ) # type: ignore
534
+
535
+ api_key = (
536
+ optional_params.api_key
537
+ or litellm.api_key
538
+ or litellm.azure_key
539
+ or get_secret_str("AZURE_OPENAI_API_KEY")
540
+ or get_secret_str("AZURE_API_KEY")
541
+ ) # type: ignore
542
+
543
+ extra_body = optional_params.get("extra_body", {})
544
+ if extra_body is not None:
545
+ extra_body.pop("azure_ad_token", None)
546
+ else:
547
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
548
+
549
+ response = azure_files_instance.delete_file(
550
+ _is_async=_is_async,
551
+ api_base=api_base,
552
+ api_key=api_key,
553
+ api_version=api_version,
554
+ timeout=timeout,
555
+ max_retries=optional_params.max_retries,
556
+ file_id=file_id,
557
+ client=client,
558
+ litellm_params=litellm_params_dict,
559
+ )
560
+ else:
561
+ raise litellm.exceptions.BadRequestError(
562
+ message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
563
+ custom_llm_provider
564
+ ),
565
+ model="n/a",
566
+ llm_provider=custom_llm_provider,
567
+ response=httpx.Response(
568
+ status_code=400,
569
+ content="Unsupported provider",
570
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
571
+ ),
572
+ )
573
+ return cast(FileDeleted, response)
574
+ except Exception as e:
575
+ raise e
576
+
577
+
578
+ # List files
579
+ async def afile_list(
580
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
581
+ purpose: Optional[str] = None,
582
+ extra_headers: Optional[Dict[str, str]] = None,
583
+ extra_body: Optional[Dict[str, str]] = None,
584
+ **kwargs,
585
+ ):
586
+ """
587
+ Async: List files
588
+
589
+ LiteLLM Equivalent of GET https://api.openai.com/v1/files
590
+ """
591
+ try:
592
+ loop = asyncio.get_event_loop()
593
+ kwargs["is_async"] = True
594
+
595
+ # Use a partial function to pass your keyword arguments
596
+ func = partial(
597
+ file_list,
598
+ custom_llm_provider,
599
+ purpose,
600
+ extra_headers,
601
+ extra_body,
602
+ **kwargs,
603
+ )
604
+
605
+ # Add the context to the function
606
+ ctx = contextvars.copy_context()
607
+ func_with_context = partial(ctx.run, func)
608
+ init_response = await loop.run_in_executor(None, func_with_context)
609
+ if asyncio.iscoroutine(init_response):
610
+ response = await init_response
611
+ else:
612
+ response = init_response # type: ignore
613
+
614
+ return response
615
+ except Exception as e:
616
+ raise e
617
+
618
+
619
+ def file_list(
620
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
621
+ purpose: Optional[str] = None,
622
+ extra_headers: Optional[Dict[str, str]] = None,
623
+ extra_body: Optional[Dict[str, str]] = None,
624
+ **kwargs,
625
+ ):
626
+ """
627
+ List files
628
+
629
+ LiteLLM Equivalent of GET https://api.openai.com/v1/files
630
+ """
631
+ try:
632
+ optional_params = GenericLiteLLMParams(**kwargs)
633
+ ### TIMEOUT LOGIC ###
634
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
635
+ # set timeout for 10 minutes by default
636
+
637
+ if (
638
+ timeout is not None
639
+ and isinstance(timeout, httpx.Timeout)
640
+ and supports_httpx_timeout(custom_llm_provider) is False
641
+ ):
642
+ read_timeout = timeout.read or 600
643
+ timeout = read_timeout # default 10 min timeout
644
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
645
+ timeout = float(timeout) # type: ignore
646
+ elif timeout is None:
647
+ timeout = 600.0
648
+
649
+ _is_async = kwargs.pop("is_async", False) is True
650
+ if custom_llm_provider == "openai":
651
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
652
+ api_base = (
653
+ optional_params.api_base
654
+ or litellm.api_base
655
+ or os.getenv("OPENAI_BASE_URL")
656
+ or os.getenv("OPENAI_API_BASE")
657
+ or "https://api.openai.com/v1"
658
+ )
659
+ organization = (
660
+ optional_params.organization
661
+ or litellm.organization
662
+ or os.getenv("OPENAI_ORGANIZATION", None)
663
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
664
+ )
665
+ # set API KEY
666
+ api_key = (
667
+ optional_params.api_key
668
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
669
+ or litellm.openai_key
670
+ or os.getenv("OPENAI_API_KEY")
671
+ )
672
+
673
+ response = openai_files_instance.list_files(
674
+ purpose=purpose,
675
+ _is_async=_is_async,
676
+ api_base=api_base,
677
+ api_key=api_key,
678
+ timeout=timeout,
679
+ max_retries=optional_params.max_retries,
680
+ organization=organization,
681
+ )
682
+ elif custom_llm_provider == "azure":
683
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
684
+ api_version = (
685
+ optional_params.api_version
686
+ or litellm.api_version
687
+ or get_secret_str("AZURE_API_VERSION")
688
+ ) # type: ignore
689
+
690
+ api_key = (
691
+ optional_params.api_key
692
+ or litellm.api_key
693
+ or litellm.azure_key
694
+ or get_secret_str("AZURE_OPENAI_API_KEY")
695
+ or get_secret_str("AZURE_API_KEY")
696
+ ) # type: ignore
697
+
698
+ extra_body = optional_params.get("extra_body", {})
699
+ if extra_body is not None:
700
+ extra_body.pop("azure_ad_token", None)
701
+ else:
702
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
703
+
704
+ response = azure_files_instance.list_files(
705
+ _is_async=_is_async,
706
+ api_base=api_base,
707
+ api_key=api_key,
708
+ api_version=api_version,
709
+ timeout=timeout,
710
+ max_retries=optional_params.max_retries,
711
+ purpose=purpose,
712
+ )
713
+ else:
714
+ raise litellm.exceptions.BadRequestError(
715
+ message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' and 'azure' are supported.".format(
716
+ custom_llm_provider
717
+ ),
718
+ model="n/a",
719
+ llm_provider=custom_llm_provider,
720
+ response=httpx.Response(
721
+ status_code=400,
722
+ content="Unsupported provider",
723
+ request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
724
+ ),
725
+ )
726
+ return response
727
+ except Exception as e:
728
+ raise e
729
+
730
+
731
+ async def afile_content(
732
+ file_id: str,
733
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
734
+ extra_headers: Optional[Dict[str, str]] = None,
735
+ extra_body: Optional[Dict[str, str]] = None,
736
+ **kwargs,
737
+ ) -> HttpxBinaryResponseContent:
738
+ """
739
+ Async: Get file contents
740
+
741
+ LiteLLM Equivalent of GET https://api.openai.com/v1/files
742
+ """
743
+ try:
744
+ loop = asyncio.get_event_loop()
745
+ kwargs["afile_content"] = True
746
+
747
+ # Use a partial function to pass your keyword arguments
748
+ func = partial(
749
+ file_content,
750
+ file_id,
751
+ custom_llm_provider,
752
+ extra_headers,
753
+ extra_body,
754
+ **kwargs,
755
+ )
756
+
757
+ # Add the context to the function
758
+ ctx = contextvars.copy_context()
759
+ func_with_context = partial(ctx.run, func)
760
+ init_response = await loop.run_in_executor(None, func_with_context)
761
+ if asyncio.iscoroutine(init_response):
762
+ response = await init_response
763
+ else:
764
+ response = init_response # type: ignore
765
+
766
+ return response
767
+ except Exception as e:
768
+ raise e
769
+
770
+
771
+ def file_content(
772
+ file_id: str,
773
+ custom_llm_provider: Literal["openai", "azure"] = "openai",
774
+ extra_headers: Optional[Dict[str, str]] = None,
775
+ extra_body: Optional[Dict[str, str]] = None,
776
+ **kwargs,
777
+ ) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
778
+ """
779
+ Returns the contents of the specified file.
780
+
781
+ LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
782
+ """
783
+ try:
784
+ optional_params = GenericLiteLLMParams(**kwargs)
785
+ litellm_params_dict = get_litellm_params(**kwargs)
786
+ ### TIMEOUT LOGIC ###
787
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
788
+ client = kwargs.get("client")
789
+ # set timeout for 10 minutes by default
790
+
791
+ if (
792
+ timeout is not None
793
+ and isinstance(timeout, httpx.Timeout)
794
+ and supports_httpx_timeout(custom_llm_provider) is False
795
+ ):
796
+ read_timeout = timeout.read or 600
797
+ timeout = read_timeout # default 10 min timeout
798
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
799
+ timeout = float(timeout) # type: ignore
800
+ elif timeout is None:
801
+ timeout = 600.0
802
+
803
+ _file_content_request = FileContentRequest(
804
+ file_id=file_id,
805
+ extra_headers=extra_headers,
806
+ extra_body=extra_body,
807
+ )
808
+
809
+ _is_async = kwargs.pop("afile_content", False) is True
810
+
811
+ if custom_llm_provider == "openai":
812
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
813
+ api_base = (
814
+ optional_params.api_base
815
+ or litellm.api_base
816
+ or os.getenv("OPENAI_BASE_URL")
817
+ or os.getenv("OPENAI_API_BASE")
818
+ or "https://api.openai.com/v1"
819
+ )
820
+ organization = (
821
+ optional_params.organization
822
+ or litellm.organization
823
+ or os.getenv("OPENAI_ORGANIZATION", None)
824
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
825
+ )
826
+ # set API KEY
827
+ api_key = (
828
+ optional_params.api_key
829
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
830
+ or litellm.openai_key
831
+ or os.getenv("OPENAI_API_KEY")
832
+ )
833
+
834
+ response = openai_files_instance.file_content(
835
+ _is_async=_is_async,
836
+ file_content_request=_file_content_request,
837
+ api_base=api_base,
838
+ api_key=api_key,
839
+ timeout=timeout,
840
+ max_retries=optional_params.max_retries,
841
+ organization=organization,
842
+ )
843
+ elif custom_llm_provider == "azure":
844
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
845
+ api_version = (
846
+ optional_params.api_version
847
+ or litellm.api_version
848
+ or get_secret_str("AZURE_API_VERSION")
849
+ ) # type: ignore
850
+
851
+ api_key = (
852
+ optional_params.api_key
853
+ or litellm.api_key
854
+ or litellm.azure_key
855
+ or get_secret_str("AZURE_OPENAI_API_KEY")
856
+ or get_secret_str("AZURE_API_KEY")
857
+ ) # type: ignore
858
+
859
+ extra_body = optional_params.get("extra_body", {})
860
+ if extra_body is not None:
861
+ extra_body.pop("azure_ad_token", None)
862
+ else:
863
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
864
+
865
+ response = azure_files_instance.file_content(
866
+ _is_async=_is_async,
867
+ api_base=api_base,
868
+ api_key=api_key,
869
+ api_version=api_version,
870
+ timeout=timeout,
871
+ max_retries=optional_params.max_retries,
872
+ file_content_request=_file_content_request,
873
+ client=client,
874
+ litellm_params=litellm_params_dict,
875
+ )
876
+ else:
877
+ raise litellm.exceptions.BadRequestError(
878
+ message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
879
+ custom_llm_provider
880
+ ),
881
+ model="n/a",
882
+ llm_provider=custom_llm_provider,
883
+ response=httpx.Response(
884
+ status_code=400,
885
+ content="Unsupported provider",
886
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
887
+ ),
888
+ )
889
+ return response
890
+ except Exception as e:
891
+ raise e
litellm/fine_tuning/main.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main File for Fine Tuning API implementation
3
+
4
+ https://platform.openai.com/docs/api-reference/fine-tuning
5
+
6
+ - fine_tuning.jobs.create()
7
+ - fine_tuning.jobs.list()
8
+ - client.fine_tuning.jobs.list_events()
9
+ """
10
+
11
+ import asyncio
12
+ import contextvars
13
+ import os
14
+ from functools import partial
15
+ from typing import Any, Coroutine, Dict, Literal, Optional, Union
16
+
17
+ import httpx
18
+
19
+ import litellm
20
+ from litellm._logging import verbose_logger
21
+ from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
22
+ from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
23
+ from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
24
+ from litellm.secret_managers.main import get_secret_str
25
+ from litellm.types.llms.openai import (
26
+ FineTuningJob,
27
+ FineTuningJobCreate,
28
+ Hyperparameters,
29
+ )
30
+ from litellm.types.router import *
31
+ from litellm.utils import client, supports_httpx_timeout
32
+
33
+ ####### ENVIRONMENT VARIABLES ###################
34
+ openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
35
+ azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
36
+ vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
37
+ #################################################
38
+
39
+
40
+ @client
41
+ async def acreate_fine_tuning_job(
42
+ model: str,
43
+ training_file: str,
44
+ hyperparameters: Optional[dict] = {},
45
+ suffix: Optional[str] = None,
46
+ validation_file: Optional[str] = None,
47
+ integrations: Optional[List[str]] = None,
48
+ seed: Optional[int] = None,
49
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
50
+ extra_headers: Optional[Dict[str, str]] = None,
51
+ extra_body: Optional[Dict[str, str]] = None,
52
+ **kwargs,
53
+ ) -> FineTuningJob:
54
+ """
55
+ Async: Creates and executes a batch from an uploaded file of request
56
+
57
+ """
58
+ verbose_logger.debug(
59
+ "inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
60
+ )
61
+ try:
62
+ loop = asyncio.get_event_loop()
63
+ kwargs["acreate_fine_tuning_job"] = True
64
+
65
+ # Use a partial function to pass your keyword arguments
66
+ func = partial(
67
+ create_fine_tuning_job,
68
+ model,
69
+ training_file,
70
+ hyperparameters,
71
+ suffix,
72
+ validation_file,
73
+ integrations,
74
+ seed,
75
+ custom_llm_provider,
76
+ extra_headers,
77
+ extra_body,
78
+ **kwargs,
79
+ )
80
+
81
+ # Add the context to the function
82
+ ctx = contextvars.copy_context()
83
+ func_with_context = partial(ctx.run, func)
84
+ init_response = await loop.run_in_executor(None, func_with_context)
85
+ if asyncio.iscoroutine(init_response):
86
+ response = await init_response
87
+ else:
88
+ response = init_response # type: ignore
89
+ return response
90
+ except Exception as e:
91
+ raise e
92
+
93
+
94
+ @client
95
+ def create_fine_tuning_job(
96
+ model: str,
97
+ training_file: str,
98
+ hyperparameters: Optional[dict] = {},
99
+ suffix: Optional[str] = None,
100
+ validation_file: Optional[str] = None,
101
+ integrations: Optional[List[str]] = None,
102
+ seed: Optional[int] = None,
103
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
104
+ extra_headers: Optional[Dict[str, str]] = None,
105
+ extra_body: Optional[Dict[str, str]] = None,
106
+ **kwargs,
107
+ ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
108
+ """
109
+ Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
110
+
111
+ Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
112
+
113
+ """
114
+ try:
115
+ _is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
116
+ optional_params = GenericLiteLLMParams(**kwargs)
117
+
118
+ # handle hyperparameters
119
+ hyperparameters = hyperparameters or {} # original hyperparameters
120
+ _oai_hyperparameters: Hyperparameters = Hyperparameters(
121
+ **hyperparameters
122
+ ) # Typed Hyperparameters for OpenAI Spec
123
+ ### TIMEOUT LOGIC ###
124
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
125
+ # set timeout for 10 minutes by default
126
+
127
+ if (
128
+ timeout is not None
129
+ and isinstance(timeout, httpx.Timeout)
130
+ and supports_httpx_timeout(custom_llm_provider) is False
131
+ ):
132
+ read_timeout = timeout.read or 600
133
+ timeout = read_timeout # default 10 min timeout
134
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
135
+ timeout = float(timeout) # type: ignore
136
+ elif timeout is None:
137
+ timeout = 600.0
138
+
139
+ # OpenAI
140
+ if custom_llm_provider == "openai":
141
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
142
+ api_base = (
143
+ optional_params.api_base
144
+ or litellm.api_base
145
+ or os.getenv("OPENAI_BASE_URL")
146
+ or os.getenv("OPENAI_API_BASE")
147
+ or "https://api.openai.com/v1"
148
+ )
149
+ organization = (
150
+ optional_params.organization
151
+ or litellm.organization
152
+ or os.getenv("OPENAI_ORGANIZATION", None)
153
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
154
+ )
155
+ # set API KEY
156
+ api_key = (
157
+ optional_params.api_key
158
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
159
+ or litellm.openai_key
160
+ or os.getenv("OPENAI_API_KEY")
161
+ )
162
+
163
+ create_fine_tuning_job_data = FineTuningJobCreate(
164
+ model=model,
165
+ training_file=training_file,
166
+ hyperparameters=_oai_hyperparameters,
167
+ suffix=suffix,
168
+ validation_file=validation_file,
169
+ integrations=integrations,
170
+ seed=seed,
171
+ )
172
+
173
+ create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
174
+ exclude_none=True
175
+ )
176
+
177
+ response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
178
+ api_base=api_base,
179
+ api_key=api_key,
180
+ api_version=optional_params.api_version,
181
+ organization=organization,
182
+ create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
183
+ timeout=timeout,
184
+ max_retries=optional_params.max_retries,
185
+ _is_async=_is_async,
186
+ client=kwargs.get(
187
+ "client", None
188
+ ), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
189
+ )
190
+ # Azure OpenAI
191
+ elif custom_llm_provider == "azure":
192
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
193
+
194
+ api_version = (
195
+ optional_params.api_version
196
+ or litellm.api_version
197
+ or get_secret_str("AZURE_API_VERSION")
198
+ ) # type: ignore
199
+
200
+ api_key = (
201
+ optional_params.api_key
202
+ or litellm.api_key
203
+ or litellm.azure_key
204
+ or get_secret_str("AZURE_OPENAI_API_KEY")
205
+ or get_secret_str("AZURE_API_KEY")
206
+ ) # type: ignore
207
+
208
+ extra_body = optional_params.get("extra_body", {})
209
+ if extra_body is not None:
210
+ extra_body.pop("azure_ad_token", None)
211
+ else:
212
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
213
+ create_fine_tuning_job_data = FineTuningJobCreate(
214
+ model=model,
215
+ training_file=training_file,
216
+ hyperparameters=_oai_hyperparameters,
217
+ suffix=suffix,
218
+ validation_file=validation_file,
219
+ integrations=integrations,
220
+ seed=seed,
221
+ )
222
+
223
+ create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
224
+ exclude_none=True
225
+ )
226
+
227
+ response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
228
+ api_base=api_base,
229
+ api_key=api_key,
230
+ api_version=api_version,
231
+ create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
232
+ timeout=timeout,
233
+ max_retries=optional_params.max_retries,
234
+ _is_async=_is_async,
235
+ organization=optional_params.organization,
236
+ )
237
+ elif custom_llm_provider == "vertex_ai":
238
+ api_base = optional_params.api_base or ""
239
+ vertex_ai_project = (
240
+ optional_params.vertex_project
241
+ or litellm.vertex_project
242
+ or get_secret_str("VERTEXAI_PROJECT")
243
+ )
244
+ vertex_ai_location = (
245
+ optional_params.vertex_location
246
+ or litellm.vertex_location
247
+ or get_secret_str("VERTEXAI_LOCATION")
248
+ )
249
+ vertex_credentials = optional_params.vertex_credentials or get_secret_str(
250
+ "VERTEXAI_CREDENTIALS"
251
+ )
252
+ create_fine_tuning_job_data = FineTuningJobCreate(
253
+ model=model,
254
+ training_file=training_file,
255
+ hyperparameters=_oai_hyperparameters,
256
+ suffix=suffix,
257
+ validation_file=validation_file,
258
+ integrations=integrations,
259
+ seed=seed,
260
+ )
261
+ response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
262
+ _is_async=_is_async,
263
+ create_fine_tuning_job_data=create_fine_tuning_job_data,
264
+ vertex_credentials=vertex_credentials,
265
+ vertex_project=vertex_ai_project,
266
+ vertex_location=vertex_ai_location,
267
+ timeout=timeout,
268
+ api_base=api_base,
269
+ kwargs=kwargs,
270
+ original_hyperparameters=hyperparameters,
271
+ )
272
+ else:
273
+ raise litellm.exceptions.BadRequestError(
274
+ message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
275
+ custom_llm_provider
276
+ ),
277
+ model="n/a",
278
+ llm_provider=custom_llm_provider,
279
+ response=httpx.Response(
280
+ status_code=400,
281
+ content="Unsupported provider",
282
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
283
+ ),
284
+ )
285
+ return response
286
+ except Exception as e:
287
+ verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
288
+ raise e
289
+
290
+
291
+ async def acancel_fine_tuning_job(
292
+ fine_tuning_job_id: str,
293
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
294
+ extra_headers: Optional[Dict[str, str]] = None,
295
+ extra_body: Optional[Dict[str, str]] = None,
296
+ **kwargs,
297
+ ) -> FineTuningJob:
298
+ """
299
+ Async: Immediately cancel a fine-tune job.
300
+ """
301
+ try:
302
+ loop = asyncio.get_event_loop()
303
+ kwargs["acancel_fine_tuning_job"] = True
304
+
305
+ # Use a partial function to pass your keyword arguments
306
+ func = partial(
307
+ cancel_fine_tuning_job,
308
+ fine_tuning_job_id,
309
+ custom_llm_provider,
310
+ extra_headers,
311
+ extra_body,
312
+ **kwargs,
313
+ )
314
+
315
+ # Add the context to the function
316
+ ctx = contextvars.copy_context()
317
+ func_with_context = partial(ctx.run, func)
318
+ init_response = await loop.run_in_executor(None, func_with_context)
319
+ if asyncio.iscoroutine(init_response):
320
+ response = await init_response
321
+ else:
322
+ response = init_response # type: ignore
323
+ return response
324
+ except Exception as e:
325
+ raise e
326
+
327
+
328
+ def cancel_fine_tuning_job(
329
+ fine_tuning_job_id: str,
330
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
331
+ extra_headers: Optional[Dict[str, str]] = None,
332
+ extra_body: Optional[Dict[str, str]] = None,
333
+ **kwargs,
334
+ ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
335
+ """
336
+ Immediately cancel a fine-tune job.
337
+
338
+ Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
339
+
340
+ """
341
+ try:
342
+ optional_params = GenericLiteLLMParams(**kwargs)
343
+ ### TIMEOUT LOGIC ###
344
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
345
+ # set timeout for 10 minutes by default
346
+
347
+ if (
348
+ timeout is not None
349
+ and isinstance(timeout, httpx.Timeout)
350
+ and supports_httpx_timeout(custom_llm_provider) is False
351
+ ):
352
+ read_timeout = timeout.read or 600
353
+ timeout = read_timeout # default 10 min timeout
354
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
355
+ timeout = float(timeout) # type: ignore
356
+ elif timeout is None:
357
+ timeout = 600.0
358
+
359
+ _is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
360
+
361
+ # OpenAI
362
+ if custom_llm_provider == "openai":
363
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
364
+ api_base = (
365
+ optional_params.api_base
366
+ or litellm.api_base
367
+ or os.getenv("OPENAI_BASE_URL")
368
+ or os.getenv("OPENAI_API_BASE")
369
+ or "https://api.openai.com/v1"
370
+ )
371
+ organization = (
372
+ optional_params.organization
373
+ or litellm.organization
374
+ or os.getenv("OPENAI_ORGANIZATION", None)
375
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
376
+ )
377
+ # set API KEY
378
+ api_key = (
379
+ optional_params.api_key
380
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
381
+ or litellm.openai_key
382
+ or os.getenv("OPENAI_API_KEY")
383
+ )
384
+
385
+ response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
386
+ api_base=api_base,
387
+ api_key=api_key,
388
+ api_version=optional_params.api_version,
389
+ organization=organization,
390
+ fine_tuning_job_id=fine_tuning_job_id,
391
+ timeout=timeout,
392
+ max_retries=optional_params.max_retries,
393
+ _is_async=_is_async,
394
+ client=kwargs.get("client", None),
395
+ )
396
+ # Azure OpenAI
397
+ elif custom_llm_provider == "azure":
398
+ api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
399
+
400
+ api_version = (
401
+ optional_params.api_version
402
+ or litellm.api_version
403
+ or get_secret_str("AZURE_API_VERSION")
404
+ ) # type: ignore
405
+
406
+ api_key = (
407
+ optional_params.api_key
408
+ or litellm.api_key
409
+ or litellm.azure_key
410
+ or get_secret_str("AZURE_OPENAI_API_KEY")
411
+ or get_secret_str("AZURE_API_KEY")
412
+ ) # type: ignore
413
+
414
+ extra_body = optional_params.get("extra_body", {})
415
+ if extra_body is not None:
416
+ extra_body.pop("azure_ad_token", None)
417
+ else:
418
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
419
+
420
+ response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
421
+ api_base=api_base,
422
+ api_key=api_key,
423
+ api_version=api_version,
424
+ fine_tuning_job_id=fine_tuning_job_id,
425
+ timeout=timeout,
426
+ max_retries=optional_params.max_retries,
427
+ _is_async=_is_async,
428
+ organization=optional_params.organization,
429
+ )
430
+ else:
431
+ raise litellm.exceptions.BadRequestError(
432
+ message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
433
+ custom_llm_provider
434
+ ),
435
+ model="n/a",
436
+ llm_provider=custom_llm_provider,
437
+ response=httpx.Response(
438
+ status_code=400,
439
+ content="Unsupported provider",
440
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
441
+ ),
442
+ )
443
+ return response
444
+ except Exception as e:
445
+ raise e
446
+
447
+
448
+ async def alist_fine_tuning_jobs(
449
+ after: Optional[str] = None,
450
+ limit: Optional[int] = None,
451
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
452
+ extra_headers: Optional[Dict[str, str]] = None,
453
+ extra_body: Optional[Dict[str, str]] = None,
454
+ **kwargs,
455
+ ):
456
+ """
457
+ Async: List your organization's fine-tuning jobs
458
+ """
459
+ try:
460
+ loop = asyncio.get_event_loop()
461
+ kwargs["alist_fine_tuning_jobs"] = True
462
+
463
+ # Use a partial function to pass your keyword arguments
464
+ func = partial(
465
+ list_fine_tuning_jobs,
466
+ after,
467
+ limit,
468
+ custom_llm_provider,
469
+ extra_headers,
470
+ extra_body,
471
+ **kwargs,
472
+ )
473
+
474
+ # Add the context to the function
475
+ ctx = contextvars.copy_context()
476
+ func_with_context = partial(ctx.run, func)
477
+ init_response = await loop.run_in_executor(None, func_with_context)
478
+ if asyncio.iscoroutine(init_response):
479
+ response = await init_response
480
+ else:
481
+ response = init_response # type: ignore
482
+ return response
483
+ except Exception as e:
484
+ raise e
485
+
486
+
487
+ def list_fine_tuning_jobs(
488
+ after: Optional[str] = None,
489
+ limit: Optional[int] = None,
490
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
491
+ extra_headers: Optional[Dict[str, str]] = None,
492
+ extra_body: Optional[Dict[str, str]] = None,
493
+ **kwargs,
494
+ ):
495
+ """
496
+ List your organization's fine-tuning jobs
497
+
498
+ Params:
499
+
500
+ - after: Optional[str] = None, Identifier for the last job from the previous pagination request.
501
+ - limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
502
+ """
503
+ try:
504
+ optional_params = GenericLiteLLMParams(**kwargs)
505
+ ### TIMEOUT LOGIC ###
506
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
507
+ # set timeout for 10 minutes by default
508
+
509
+ if (
510
+ timeout is not None
511
+ and isinstance(timeout, httpx.Timeout)
512
+ and supports_httpx_timeout(custom_llm_provider) is False
513
+ ):
514
+ read_timeout = timeout.read or 600
515
+ timeout = read_timeout # default 10 min timeout
516
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
517
+ timeout = float(timeout) # type: ignore
518
+ elif timeout is None:
519
+ timeout = 600.0
520
+
521
+ _is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
522
+
523
+ # OpenAI
524
+ if custom_llm_provider == "openai":
525
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
526
+ api_base = (
527
+ optional_params.api_base
528
+ or litellm.api_base
529
+ or os.getenv("OPENAI_BASE_URL")
530
+ or os.getenv("OPENAI_API_BASE")
531
+ or "https://api.openai.com/v1"
532
+ )
533
+ organization = (
534
+ optional_params.organization
535
+ or litellm.organization
536
+ or os.getenv("OPENAI_ORGANIZATION", None)
537
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
538
+ )
539
+ # set API KEY
540
+ api_key = (
541
+ optional_params.api_key
542
+ or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
543
+ or litellm.openai_key
544
+ or os.getenv("OPENAI_API_KEY")
545
+ )
546
+
547
+ response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
548
+ api_base=api_base,
549
+ api_key=api_key,
550
+ api_version=optional_params.api_version,
551
+ organization=organization,
552
+ after=after,
553
+ limit=limit,
554
+ timeout=timeout,
555
+ max_retries=optional_params.max_retries,
556
+ _is_async=_is_async,
557
+ client=kwargs.get("client", None),
558
+ )
559
+ # Azure OpenAI
560
+ elif custom_llm_provider == "azure":
561
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
562
+
563
+ api_version = (
564
+ optional_params.api_version
565
+ or litellm.api_version
566
+ or get_secret_str("AZURE_API_VERSION")
567
+ ) # type: ignore
568
+
569
+ api_key = (
570
+ optional_params.api_key
571
+ or litellm.api_key
572
+ or litellm.azure_key
573
+ or get_secret_str("AZURE_OPENAI_API_KEY")
574
+ or get_secret_str("AZURE_API_KEY")
575
+ ) # type: ignore
576
+
577
+ extra_body = optional_params.get("extra_body", {})
578
+ if extra_body is not None:
579
+ extra_body.pop("azure_ad_token", None)
580
+ else:
581
+ get_secret("AZURE_AD_TOKEN") # type: ignore
582
+
583
+ response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
584
+ api_base=api_base,
585
+ api_key=api_key,
586
+ api_version=api_version,
587
+ after=after,
588
+ limit=limit,
589
+ timeout=timeout,
590
+ max_retries=optional_params.max_retries,
591
+ _is_async=_is_async,
592
+ organization=optional_params.organization,
593
+ )
594
+ else:
595
+ raise litellm.exceptions.BadRequestError(
596
+ message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
597
+ custom_llm_provider
598
+ ),
599
+ model="n/a",
600
+ llm_provider=custom_llm_provider,
601
+ response=httpx.Response(
602
+ status_code=400,
603
+ content="Unsupported provider",
604
+ request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
605
+ ),
606
+ )
607
+ return response
608
+ except Exception as e:
609
+ raise e
610
+
611
+
612
+ async def aretrieve_fine_tuning_job(
613
+ fine_tuning_job_id: str,
614
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
615
+ extra_headers: Optional[Dict[str, str]] = None,
616
+ extra_body: Optional[Dict[str, str]] = None,
617
+ **kwargs,
618
+ ) -> FineTuningJob:
619
+ """
620
+ Async: Get info about a fine-tuning job.
621
+ """
622
+ try:
623
+ loop = asyncio.get_event_loop()
624
+ kwargs["aretrieve_fine_tuning_job"] = True
625
+
626
+ # Use a partial function to pass your keyword arguments
627
+ func = partial(
628
+ retrieve_fine_tuning_job,
629
+ fine_tuning_job_id,
630
+ custom_llm_provider,
631
+ extra_headers,
632
+ extra_body,
633
+ **kwargs,
634
+ )
635
+
636
+ # Add the context to the function
637
+ ctx = contextvars.copy_context()
638
+ func_with_context = partial(ctx.run, func)
639
+ init_response = await loop.run_in_executor(None, func_with_context)
640
+ if asyncio.iscoroutine(init_response):
641
+ response = await init_response
642
+ else:
643
+ response = init_response # type: ignore
644
+ return response
645
+ except Exception as e:
646
+ raise e
647
+
648
+
649
+ def retrieve_fine_tuning_job(
650
+ fine_tuning_job_id: str,
651
+ custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
652
+ extra_headers: Optional[Dict[str, str]] = None,
653
+ extra_body: Optional[Dict[str, str]] = None,
654
+ **kwargs,
655
+ ) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
656
+ """
657
+ Get info about a fine-tuning job.
658
+ """
659
+ try:
660
+ optional_params = GenericLiteLLMParams(**kwargs)
661
+ ### TIMEOUT LOGIC ###
662
+ timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
663
+ # set timeout for 10 minutes by default
664
+
665
+ if (
666
+ timeout is not None
667
+ and isinstance(timeout, httpx.Timeout)
668
+ and supports_httpx_timeout(custom_llm_provider) is False
669
+ ):
670
+ read_timeout = timeout.read or 600
671
+ timeout = read_timeout # default 10 min timeout
672
+ elif timeout is not None and not isinstance(timeout, httpx.Timeout):
673
+ timeout = float(timeout) # type: ignore
674
+ elif timeout is None:
675
+ timeout = 600.0
676
+
677
+ _is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
678
+
679
+ # OpenAI
680
+ if custom_llm_provider == "openai":
681
+ api_base = (
682
+ optional_params.api_base
683
+ or litellm.api_base
684
+ or os.getenv("OPENAI_BASE_URL")
685
+ or os.getenv("OPENAI_API_BASE")
686
+ or "https://api.openai.com/v1"
687
+ )
688
+ organization = (
689
+ optional_params.organization
690
+ or litellm.organization
691
+ or os.getenv("OPENAI_ORGANIZATION", None)
692
+ or None
693
+ )
694
+ api_key = (
695
+ optional_params.api_key
696
+ or litellm.api_key
697
+ or litellm.openai_key
698
+ or os.getenv("OPENAI_API_KEY")
699
+ )
700
+
701
+ response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
702
+ api_base=api_base,
703
+ api_key=api_key,
704
+ api_version=optional_params.api_version,
705
+ organization=organization,
706
+ fine_tuning_job_id=fine_tuning_job_id,
707
+ timeout=timeout,
708
+ max_retries=optional_params.max_retries,
709
+ _is_async=_is_async,
710
+ client=kwargs.get("client", None),
711
+ )
712
+ # Azure OpenAI
713
+ elif custom_llm_provider == "azure":
714
+ api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
715
+
716
+ api_version = (
717
+ optional_params.api_version
718
+ or litellm.api_version
719
+ or get_secret_str("AZURE_API_VERSION")
720
+ ) # type: ignore
721
+
722
+ api_key = (
723
+ optional_params.api_key
724
+ or litellm.api_key
725
+ or litellm.azure_key
726
+ or get_secret_str("AZURE_OPENAI_API_KEY")
727
+ or get_secret_str("AZURE_API_KEY")
728
+ ) # type: ignore
729
+
730
+ extra_body = optional_params.get("extra_body", {})
731
+ if extra_body is not None:
732
+ extra_body.pop("azure_ad_token", None)
733
+ else:
734
+ get_secret_str("AZURE_AD_TOKEN") # type: ignore
735
+
736
+ response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
737
+ api_base=api_base,
738
+ api_key=api_key,
739
+ api_version=api_version,
740
+ fine_tuning_job_id=fine_tuning_job_id,
741
+ timeout=timeout,
742
+ max_retries=optional_params.max_retries,
743
+ _is_async=_is_async,
744
+ organization=optional_params.organization,
745
+ )
746
+ else:
747
+ raise litellm.exceptions.BadRequestError(
748
+ message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
749
+ custom_llm_provider
750
+ ),
751
+ model="n/a",
752
+ llm_provider=custom_llm_provider,
753
+ response=httpx.Response(
754
+ status_code=400,
755
+ content="Unsupported provider",
756
+ request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
757
+ ),
758
+ )
759
+ return response
760
+ except Exception as e:
761
+ raise e
litellm/integrations/Readme.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Integrations
2
+
3
+ This folder contains logging integrations for litellm
4
+
5
+ eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
litellm/integrations/SlackAlerting/Readme.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Slack Alerting on LiteLLM Gateway
2
+
3
+ This folder contains the Slack Alerting integration for LiteLLM Gateway.
4
+
5
+ ## Folder Structure
6
+
7
+ - `slack_alerting.py`: This is the main file that handles sending different types of alerts
8
+ - `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
9
+ - `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
10
+ - `utils.py`: This file contains common utils used specifically for slack alerting
11
+
12
+ ## Further Reading
13
+ - [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
litellm/integrations/SlackAlerting/batching_handler.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Handles Batching + sending Httpx Post requests to slack
3
+
4
+ Slack alerts are sent every 10s or when events are greater than X events
5
+
6
+ see custom_batch_logger.py for more details / defaults
7
+ """
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ from litellm._logging import verbose_proxy_logger
12
+
13
+ if TYPE_CHECKING:
14
+ from .slack_alerting import SlackAlerting as _SlackAlerting
15
+
16
+ SlackAlertingType = _SlackAlerting
17
+ else:
18
+ SlackAlertingType = Any
19
+
20
+
21
+ def squash_payloads(queue):
22
+ squashed = {}
23
+ if len(queue) == 0:
24
+ return squashed
25
+ if len(queue) == 1:
26
+ return {"key": {"item": queue[0], "count": 1}}
27
+
28
+ for item in queue:
29
+ url = item["url"]
30
+ alert_type = item["alert_type"]
31
+ _key = (url, alert_type)
32
+
33
+ if _key in squashed:
34
+ squashed[_key]["count"] += 1
35
+ # Merge the payloads
36
+
37
+ else:
38
+ squashed[_key] = {"item": item, "count": 1}
39
+
40
+ return squashed
41
+
42
+
43
+ def _print_alerting_payload_warning(
44
+ payload: dict, slackAlertingInstance: SlackAlertingType
45
+ ):
46
+ """
47
+ Print the payload to the console when
48
+ slackAlertingInstance.alerting_args.log_to_console is True
49
+
50
+ Relevant issue: https://github.com/BerriAI/litellm/issues/7372
51
+ """
52
+ if slackAlertingInstance.alerting_args.log_to_console is True:
53
+ verbose_proxy_logger.warning(payload)
54
+
55
+
56
+ async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
57
+ """
58
+ Send a single slack alert to the webhook
59
+ """
60
+ import json
61
+
62
+ payload = item.get("payload", {})
63
+ try:
64
+ if count > 1:
65
+ payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
66
+
67
+ response = await slackAlertingInstance.async_http_handler.post(
68
+ url=item["url"],
69
+ headers=item["headers"],
70
+ data=json.dumps(payload),
71
+ )
72
+ if response.status_code != 200:
73
+ verbose_proxy_logger.debug(
74
+ f"Error sending slack alert to url={item['url']}. Error={response.text}"
75
+ )
76
+ except Exception as e:
77
+ verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
78
+ finally:
79
+ _print_alerting_payload_warning(
80
+ payload, slackAlertingInstance=slackAlertingInstance
81
+ )
litellm/integrations/SlackAlerting/slack_alerting.py ADDED
@@ -0,0 +1,1825 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #### What this does ####
2
+ # Class for sending Slack Alerts #
3
+ import asyncio
4
+ import datetime
5
+ import os
6
+ import random
7
+ import time
8
+ from datetime import timedelta
9
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
10
+
11
+ from openai import APIError
12
+
13
+ import litellm
14
+ import litellm.litellm_core_utils
15
+ import litellm.litellm_core_utils.litellm_logging
16
+ import litellm.types
17
+ from litellm._logging import verbose_logger, verbose_proxy_logger
18
+ from litellm.caching.caching import DualCache
19
+ from litellm.constants import HOURS_IN_A_DAY
20
+ from litellm.integrations.custom_batch_logger import CustomBatchLogger
21
+ from litellm.litellm_core_utils.duration_parser import duration_in_seconds
22
+ from litellm.litellm_core_utils.exception_mapping_utils import (
23
+ _add_key_name_and_team_to_alert,
24
+ )
25
+ from litellm.llms.custom_httpx.http_handler import (
26
+ get_async_httpx_client,
27
+ httpxSpecialProvider,
28
+ )
29
+ from litellm.proxy._types import AlertType, CallInfo, VirtualKeyEvent, WebhookEvent
30
+ from litellm.types.integrations.slack_alerting import *
31
+
32
+ from ..email_templates.templates import *
33
+ from .batching_handler import send_to_webhook, squash_payloads
34
+ from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables
35
+
36
+ if TYPE_CHECKING:
37
+ from litellm.router import Router as _Router
38
+
39
+ Router = _Router
40
+ else:
41
+ Router = Any
42
+
43
+
44
+ class SlackAlerting(CustomBatchLogger):
45
+ """
46
+ Class for sending Slack Alerts
47
+ """
48
+
49
+ # Class variables or attributes
50
+ def __init__(
51
+ self,
52
+ internal_usage_cache: Optional[DualCache] = None,
53
+ alerting_threshold: Optional[
54
+ float
55
+ ] = None, # threshold for slow / hanging llm responses (in seconds)
56
+ alerting: Optional[List] = [],
57
+ alert_types: List[AlertType] = DEFAULT_ALERT_TYPES,
58
+ alert_to_webhook_url: Optional[
59
+ Dict[AlertType, Union[List[str], str]]
60
+ ] = None, # if user wants to separate alerts to diff channels
61
+ alerting_args={},
62
+ default_webhook_url: Optional[str] = None,
63
+ **kwargs,
64
+ ):
65
+ if alerting_threshold is None:
66
+ alerting_threshold = 300
67
+ self.alerting_threshold = alerting_threshold
68
+ self.alerting = alerting
69
+ self.alert_types = alert_types
70
+ self.internal_usage_cache = internal_usage_cache or DualCache()
71
+ self.async_http_handler = get_async_httpx_client(
72
+ llm_provider=httpxSpecialProvider.LoggingCallback
73
+ )
74
+ self.alert_to_webhook_url = process_slack_alerting_variables(
75
+ alert_to_webhook_url=alert_to_webhook_url
76
+ )
77
+ self.is_running = False
78
+ self.alerting_args = SlackAlertingArgs(**alerting_args)
79
+ self.default_webhook_url = default_webhook_url
80
+ self.flush_lock = asyncio.Lock()
81
+ super().__init__(**kwargs, flush_lock=self.flush_lock)
82
+
83
+ def update_values(
84
+ self,
85
+ alerting: Optional[List] = None,
86
+ alerting_threshold: Optional[float] = None,
87
+ alert_types: Optional[List[AlertType]] = None,
88
+ alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
89
+ alerting_args: Optional[Dict] = None,
90
+ llm_router: Optional[Router] = None,
91
+ ):
92
+ if alerting is not None:
93
+ self.alerting = alerting
94
+ asyncio.create_task(self.periodic_flush())
95
+ if alerting_threshold is not None:
96
+ self.alerting_threshold = alerting_threshold
97
+ if alert_types is not None:
98
+ self.alert_types = alert_types
99
+ if alerting_args is not None:
100
+ self.alerting_args = SlackAlertingArgs(**alerting_args)
101
+ if alert_to_webhook_url is not None:
102
+ # update the dict
103
+ if self.alert_to_webhook_url is None:
104
+ self.alert_to_webhook_url = process_slack_alerting_variables(
105
+ alert_to_webhook_url=alert_to_webhook_url
106
+ )
107
+ else:
108
+ _new_values = (
109
+ process_slack_alerting_variables(
110
+ alert_to_webhook_url=alert_to_webhook_url
111
+ )
112
+ or {}
113
+ )
114
+ self.alert_to_webhook_url.update(_new_values)
115
+ if llm_router is not None:
116
+ self.llm_router = llm_router
117
+
118
+ async def deployment_in_cooldown(self):
119
+ pass
120
+
121
+ async def deployment_removed_from_cooldown(self):
122
+ pass
123
+
124
+ def _all_possible_alert_types(self):
125
+ # used by the UI to show all supported alert types
126
+ # Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
127
+ # return list of all values AlertType enum
128
+ return list(AlertType)
129
+
130
+ def _response_taking_too_long_callback_helper(
131
+ self,
132
+ kwargs, # kwargs to completion
133
+ start_time,
134
+ end_time, # start/end time
135
+ ):
136
+ try:
137
+ time_difference = end_time - start_time
138
+ # Convert the timedelta to float (in seconds)
139
+ time_difference_float = time_difference.total_seconds()
140
+ litellm_params = kwargs.get("litellm_params", {})
141
+ model = kwargs.get("model", "")
142
+ api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
143
+ messages = kwargs.get("messages", None)
144
+ # if messages does not exist fallback to "input"
145
+ if messages is None:
146
+ messages = kwargs.get("input", None)
147
+
148
+ # only use first 100 chars for alerting
149
+ _messages = str(messages)[:100]
150
+
151
+ return time_difference_float, model, api_base, _messages
152
+ except Exception as e:
153
+ raise e
154
+
155
+ def _get_deployment_latencies_to_alert(self, metadata=None):
156
+ if metadata is None:
157
+ return None
158
+
159
+ if "_latency_per_deployment" in metadata:
160
+ # Translate model_id to -> api_base
161
+ # _latency_per_deployment is a dictionary that looks like this:
162
+ """
163
+ _latency_per_deployment: {
164
+ api_base: 0.01336697916666667
165
+ }
166
+ """
167
+ _message_to_send = ""
168
+ _deployment_latencies = metadata["_latency_per_deployment"]
169
+ if len(_deployment_latencies) == 0:
170
+ return None
171
+ _deployment_latency_map: Optional[dict] = None
172
+ try:
173
+ # try sorting deployments by latency
174
+ _deployment_latencies = sorted(
175
+ _deployment_latencies.items(), key=lambda x: x[1]
176
+ )
177
+ _deployment_latency_map = dict(_deployment_latencies)
178
+ except Exception:
179
+ pass
180
+
181
+ if _deployment_latency_map is None:
182
+ return
183
+
184
+ for api_base, latency in _deployment_latency_map.items():
185
+ _message_to_send += f"\n{api_base}: {round(latency,2)}s"
186
+ _message_to_send = "```" + _message_to_send + "```"
187
+ return _message_to_send
188
+
189
+ async def response_taking_too_long_callback(
190
+ self,
191
+ kwargs, # kwargs to completion
192
+ completion_response, # response from completion
193
+ start_time,
194
+ end_time, # start/end time
195
+ ):
196
+ if self.alerting is None or self.alert_types is None:
197
+ return
198
+
199
+ (
200
+ time_difference_float,
201
+ model,
202
+ api_base,
203
+ messages,
204
+ ) = self._response_taking_too_long_callback_helper(
205
+ kwargs=kwargs,
206
+ start_time=start_time,
207
+ end_time=end_time,
208
+ )
209
+ if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
210
+ messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
211
+ request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
212
+ slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
213
+ alerting_metadata: dict = {}
214
+ if time_difference_float > self.alerting_threshold:
215
+ # add deployment latencies to alert
216
+ if (
217
+ kwargs is not None
218
+ and "litellm_params" in kwargs
219
+ and "metadata" in kwargs["litellm_params"]
220
+ ):
221
+ _metadata: dict = kwargs["litellm_params"]["metadata"]
222
+ request_info = _add_key_name_and_team_to_alert(
223
+ request_info=request_info, metadata=_metadata
224
+ )
225
+
226
+ _deployment_latency_map = self._get_deployment_latencies_to_alert(
227
+ metadata=_metadata
228
+ )
229
+ if _deployment_latency_map is not None:
230
+ request_info += (
231
+ f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
232
+ )
233
+
234
+ if "alerting_metadata" in _metadata:
235
+ alerting_metadata = _metadata["alerting_metadata"]
236
+ await self.send_alert(
237
+ message=slow_message + request_info,
238
+ level="Low",
239
+ alert_type=AlertType.llm_too_slow,
240
+ alerting_metadata=alerting_metadata,
241
+ )
242
+
243
+ async def async_update_daily_reports(
244
+ self, deployment_metrics: DeploymentMetrics
245
+ ) -> int:
246
+ """
247
+ Store the perf by deployment in cache
248
+ - Number of failed requests per deployment
249
+ - Latency / output tokens per deployment
250
+
251
+ 'deployment_id:daily_metrics:failed_requests'
252
+ 'deployment_id:daily_metrics:latency_per_output_token'
253
+
254
+ Returns
255
+ int - count of metrics set (1 - if just latency, 2 - if failed + latency)
256
+ """
257
+
258
+ return_val = 0
259
+ try:
260
+ ## FAILED REQUESTS ##
261
+ if deployment_metrics.failed_request:
262
+ await self.internal_usage_cache.async_increment_cache(
263
+ key="{}:{}".format(
264
+ deployment_metrics.id,
265
+ SlackAlertingCacheKeys.failed_requests_key.value,
266
+ ),
267
+ value=1,
268
+ parent_otel_span=None, # no attached request, this is a background operation
269
+ )
270
+
271
+ return_val += 1
272
+
273
+ ## LATENCY ##
274
+ if deployment_metrics.latency_per_output_token is not None:
275
+ await self.internal_usage_cache.async_increment_cache(
276
+ key="{}:{}".format(
277
+ deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
278
+ ),
279
+ value=deployment_metrics.latency_per_output_token,
280
+ parent_otel_span=None, # no attached request, this is a background operation
281
+ )
282
+
283
+ return_val += 1
284
+
285
+ return return_val
286
+ except Exception:
287
+ return 0
288
+
289
+ async def send_daily_reports(self, router) -> bool: # noqa: PLR0915
290
+ """
291
+ Send a daily report on:
292
+ - Top 5 deployments with most failed requests
293
+ - Top 5 slowest deployments (normalized by latency/output tokens)
294
+
295
+ Get the value from redis cache (if available) or in-memory and send it
296
+
297
+ Cleanup:
298
+ - reset values in cache -> prevent memory leak
299
+
300
+ Returns:
301
+ True -> if successfuly sent
302
+ False -> if not sent
303
+ """
304
+
305
+ ids = router.get_model_ids()
306
+
307
+ # get keys
308
+ failed_request_keys = [
309
+ "{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
310
+ for id in ids
311
+ ]
312
+ latency_keys = [
313
+ "{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
314
+ ]
315
+
316
+ combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
317
+
318
+ combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
319
+ keys=combined_metrics_keys
320
+ ) # [1, 2, None, ..]
321
+
322
+ if combined_metrics_values is None:
323
+ return False
324
+
325
+ all_none = True
326
+ for val in combined_metrics_values:
327
+ if val is not None and val > 0:
328
+ all_none = False
329
+ break
330
+
331
+ if all_none:
332
+ return False
333
+
334
+ failed_request_values = combined_metrics_values[
335
+ : len(failed_request_keys)
336
+ ] # # [1, 2, None, ..]
337
+ latency_values = combined_metrics_values[len(failed_request_keys) :]
338
+
339
+ # find top 5 failed
340
+ ## Replace None values with a placeholder value (-1 in this case)
341
+ placeholder_value = 0
342
+ replaced_failed_values = [
343
+ value if value is not None else placeholder_value
344
+ for value in failed_request_values
345
+ ]
346
+
347
+ ## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values)
348
+ top_5_failed = sorted(
349
+ range(len(replaced_failed_values)),
350
+ key=lambda i: replaced_failed_values[i],
351
+ reverse=True,
352
+ )[:5]
353
+ top_5_failed = [
354
+ index for index in top_5_failed if replaced_failed_values[index] > 0
355
+ ]
356
+
357
+ # find top 5 slowest
358
+ # Replace None values with a placeholder value (-1 in this case)
359
+ placeholder_value = 0
360
+ replaced_slowest_values = [
361
+ value if value is not None else placeholder_value
362
+ for value in latency_values
363
+ ]
364
+
365
+ # Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values)
366
+ top_5_slowest = sorted(
367
+ range(len(replaced_slowest_values)),
368
+ key=lambda i: replaced_slowest_values[i],
369
+ reverse=True,
370
+ )[:5]
371
+ top_5_slowest = [
372
+ index for index in top_5_slowest if replaced_slowest_values[index] > 0
373
+ ]
374
+
375
+ # format alert -> return the litellm model name + api base
376
+ message = f"\n\nTime: `{time.time()}`s\nHere are today's key metrics 📈: \n\n"
377
+
378
+ message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
379
+ if not top_5_failed:
380
+ message += "\tNone\n"
381
+ for i in range(len(top_5_failed)):
382
+ key = failed_request_keys[top_5_failed[i]].split(":")[0]
383
+ _deployment = router.get_model_info(key)
384
+ if isinstance(_deployment, dict):
385
+ deployment_name = _deployment["litellm_params"].get("model", "")
386
+ else:
387
+ return False
388
+
389
+ api_base = litellm.get_api_base(
390
+ model=deployment_name,
391
+ optional_params=(
392
+ _deployment["litellm_params"] if _deployment is not None else {}
393
+ ),
394
+ )
395
+ if api_base is None:
396
+ api_base = ""
397
+ value = replaced_failed_values[top_5_failed[i]]
398
+ message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
399
+
400
+ message += "\n\n*😅 Top Slowest Deployments:*\n\n"
401
+ if not top_5_slowest:
402
+ message += "\tNone\n"
403
+ for i in range(len(top_5_slowest)):
404
+ key = latency_keys[top_5_slowest[i]].split(":")[0]
405
+ _deployment = router.get_model_info(key)
406
+ if _deployment is not None:
407
+ deployment_name = _deployment["litellm_params"].get("model", "")
408
+ else:
409
+ deployment_name = ""
410
+ api_base = litellm.get_api_base(
411
+ model=deployment_name,
412
+ optional_params=(
413
+ _deployment["litellm_params"] if _deployment is not None else {}
414
+ ),
415
+ )
416
+ value = round(replaced_slowest_values[top_5_slowest[i]], 3)
417
+ message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
418
+
419
+ # cache cleanup -> reset values to 0
420
+ latency_cache_keys = [(key, 0) for key in latency_keys]
421
+ failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
422
+ combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
423
+ await self.internal_usage_cache.async_set_cache_pipeline(
424
+ cache_list=combined_metrics_cache_keys
425
+ )
426
+
427
+ message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
428
+
429
+ # send alert
430
+ await self.send_alert(
431
+ message=message,
432
+ level="Low",
433
+ alert_type=AlertType.daily_reports,
434
+ alerting_metadata={},
435
+ )
436
+
437
+ return True
438
+
439
+ async def response_taking_too_long(
440
+ self,
441
+ start_time: Optional[datetime.datetime] = None,
442
+ end_time: Optional[datetime.datetime] = None,
443
+ type: Literal["hanging_request", "slow_response"] = "hanging_request",
444
+ request_data: Optional[dict] = None,
445
+ ):
446
+ if self.alerting is None or self.alert_types is None:
447
+ return
448
+ model: str = ""
449
+ if request_data is not None:
450
+ model = request_data.get("model", "")
451
+ messages = request_data.get("messages", None)
452
+ if messages is None:
453
+ # if messages does not exist fallback to "input"
454
+ messages = request_data.get("input", None)
455
+
456
+ # try casting messages to str and get the first 100 characters, else mark as None
457
+ try:
458
+ messages = str(messages)
459
+ messages = messages[:100]
460
+ except Exception:
461
+ messages = ""
462
+
463
+ if (
464
+ litellm.turn_off_message_logging
465
+ or litellm.redact_messages_in_exceptions
466
+ ):
467
+ messages = (
468
+ "Message not logged. litellm.redact_messages_in_exceptions=True"
469
+ )
470
+ request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
471
+ else:
472
+ request_info = ""
473
+
474
+ if type == "hanging_request":
475
+ await asyncio.sleep(
476
+ self.alerting_threshold
477
+ ) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
478
+ alerting_metadata: dict = {}
479
+ if await self._request_is_completed(request_data=request_data) is True:
480
+ return
481
+
482
+ if request_data is not None:
483
+ if request_data.get("deployment", None) is not None and isinstance(
484
+ request_data["deployment"], dict
485
+ ):
486
+ _api_base = litellm.get_api_base(
487
+ model=model,
488
+ optional_params=request_data["deployment"].get(
489
+ "litellm_params", {}
490
+ ),
491
+ )
492
+
493
+ if _api_base is None:
494
+ _api_base = ""
495
+
496
+ request_info += f"\nAPI Base: {_api_base}"
497
+ elif request_data.get("metadata", None) is not None and isinstance(
498
+ request_data["metadata"], dict
499
+ ):
500
+ # In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
501
+ # in that case we fallback to the api base set in the request metadata
502
+ _metadata: dict = request_data["metadata"]
503
+ _api_base = _metadata.get("api_base", "")
504
+
505
+ request_info = _add_key_name_and_team_to_alert(
506
+ request_info=request_info, metadata=_metadata
507
+ )
508
+
509
+ if _api_base is None:
510
+ _api_base = ""
511
+
512
+ if "alerting_metadata" in _metadata:
513
+ alerting_metadata = _metadata["alerting_metadata"]
514
+ request_info += f"\nAPI Base: `{_api_base}`"
515
+ # only alert hanging responses if they have not been marked as success
516
+ alerting_message = (
517
+ f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
518
+ )
519
+
520
+ if "langfuse" in litellm.success_callback:
521
+ langfuse_url = await _add_langfuse_trace_id_to_alert(
522
+ request_data=request_data,
523
+ )
524
+
525
+ if langfuse_url is not None:
526
+ request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
527
+
528
+ # add deployment latencies to alert
529
+ _deployment_latency_map = self._get_deployment_latencies_to_alert(
530
+ metadata=request_data.get("metadata", {})
531
+ )
532
+ if _deployment_latency_map is not None:
533
+ request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
534
+
535
+ await self.send_alert(
536
+ message=alerting_message + request_info,
537
+ level="Medium",
538
+ alert_type=AlertType.llm_requests_hanging,
539
+ alerting_metadata=alerting_metadata,
540
+ )
541
+
542
+ async def failed_tracking_alert(self, error_message: str, failing_model: str):
543
+ """
544
+ Raise alert when tracking failed for specific model
545
+
546
+ Args:
547
+ error_message (str): Error message
548
+ failing_model (str): Model that failed tracking
549
+ """
550
+ if self.alerting is None or self.alert_types is None:
551
+ # do nothing if alerting is not switched on
552
+ return
553
+ if "failed_tracking_spend" not in self.alert_types:
554
+ return
555
+
556
+ _cache: DualCache = self.internal_usage_cache
557
+ message = "Failed Tracking Cost for " + error_message
558
+ _cache_key = "budget_alerts:failed_tracking:{}".format(failing_model)
559
+ result = await _cache.async_get_cache(key=_cache_key)
560
+ if result is None:
561
+ await self.send_alert(
562
+ message=message,
563
+ level="High",
564
+ alert_type=AlertType.failed_tracking_spend,
565
+ alerting_metadata={},
566
+ )
567
+ await _cache.async_set_cache(
568
+ key=_cache_key,
569
+ value="SENT",
570
+ ttl=self.alerting_args.budget_alert_ttl,
571
+ )
572
+
573
+ async def budget_alerts( # noqa: PLR0915
574
+ self,
575
+ type: Literal[
576
+ "token_budget",
577
+ "soft_budget",
578
+ "user_budget",
579
+ "team_budget",
580
+ "proxy_budget",
581
+ "projected_limit_exceeded",
582
+ ],
583
+ user_info: CallInfo,
584
+ ):
585
+ ## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
586
+ # - Alert once within 24hr period
587
+ # - Cache this information
588
+ # - Don't re-alert, if alert already sent
589
+ _cache: DualCache = self.internal_usage_cache
590
+
591
+ if self.alerting is None or self.alert_types is None:
592
+ # do nothing if alerting is not switched on
593
+ return
594
+ if "budget_alerts" not in self.alert_types:
595
+ return
596
+ _id: Optional[str] = "default_id" # used for caching
597
+ user_info_json = user_info.model_dump(exclude_none=True)
598
+ user_info_str = self._get_user_info_str(user_info)
599
+ event: Optional[
600
+ Literal[
601
+ "budget_crossed",
602
+ "threshold_crossed",
603
+ "projected_limit_exceeded",
604
+ "soft_budget_crossed",
605
+ ]
606
+ ] = None
607
+ event_group: Optional[
608
+ Literal["internal_user", "team", "key", "proxy", "customer"]
609
+ ] = None
610
+ event_message: str = ""
611
+ webhook_event: Optional[WebhookEvent] = None
612
+ if type == "proxy_budget":
613
+ event_group = "proxy"
614
+ event_message += "Proxy Budget: "
615
+ elif type == "soft_budget":
616
+ event_group = "proxy"
617
+ event_message += "Soft Budget Crossed: "
618
+ elif type == "user_budget":
619
+ event_group = "internal_user"
620
+ event_message += "User Budget: "
621
+ _id = user_info.user_id or _id
622
+ elif type == "team_budget":
623
+ event_group = "team"
624
+ event_message += "Team Budget: "
625
+ _id = user_info.team_id or _id
626
+ elif type == "token_budget":
627
+ event_group = "key"
628
+ event_message += "Key Budget: "
629
+ _id = user_info.token
630
+ elif type == "projected_limit_exceeded":
631
+ event_group = "key"
632
+ event_message += "Key Budget: Projected Limit Exceeded"
633
+ event = "projected_limit_exceeded"
634
+ _id = user_info.token
635
+
636
+ # percent of max_budget left to spend
637
+ if user_info.max_budget is None and user_info.soft_budget is None:
638
+ return
639
+ percent_left: float = 0
640
+ if user_info.max_budget is not None:
641
+ if user_info.max_budget > 0:
642
+ percent_left = (
643
+ user_info.max_budget - user_info.spend
644
+ ) / user_info.max_budget
645
+
646
+ # check if crossed budget
647
+ if user_info.max_budget is not None:
648
+ if user_info.spend >= user_info.max_budget:
649
+ event = "budget_crossed"
650
+ event_message += (
651
+ f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
652
+ )
653
+ elif percent_left <= SLACK_ALERTING_THRESHOLD_5_PERCENT:
654
+ event = "threshold_crossed"
655
+ event_message += "5% Threshold Crossed "
656
+ elif percent_left <= SLACK_ALERTING_THRESHOLD_15_PERCENT:
657
+ event = "threshold_crossed"
658
+ event_message += "15% Threshold Crossed"
659
+ elif user_info.soft_budget is not None:
660
+ if user_info.spend >= user_info.soft_budget:
661
+ event = "soft_budget_crossed"
662
+ if event is not None and event_group is not None:
663
+ _cache_key = "budget_alerts:{}:{}".format(event, _id)
664
+ result = await _cache.async_get_cache(key=_cache_key)
665
+ if result is None:
666
+ webhook_event = WebhookEvent(
667
+ event=event,
668
+ event_group=event_group,
669
+ event_message=event_message,
670
+ **user_info_json,
671
+ )
672
+ await self.send_alert(
673
+ message=event_message + "\n\n" + user_info_str,
674
+ level="High",
675
+ alert_type=AlertType.budget_alerts,
676
+ user_info=webhook_event,
677
+ alerting_metadata={},
678
+ )
679
+ await _cache.async_set_cache(
680
+ key=_cache_key,
681
+ value="SENT",
682
+ ttl=self.alerting_args.budget_alert_ttl,
683
+ )
684
+
685
+ return
686
+ return
687
+
688
+ def _get_user_info_str(self, user_info: CallInfo) -> str:
689
+ """
690
+ Create a standard message for a budget alert
691
+ """
692
+ _all_fields_as_dict = user_info.model_dump(exclude_none=True)
693
+ _all_fields_as_dict.pop("token")
694
+ msg = ""
695
+ for k, v in _all_fields_as_dict.items():
696
+ msg += f"*{k}:* `{v}`\n"
697
+
698
+ return msg
699
+
700
+ async def customer_spend_alert(
701
+ self,
702
+ token: Optional[str],
703
+ key_alias: Optional[str],
704
+ end_user_id: Optional[str],
705
+ response_cost: Optional[float],
706
+ max_budget: Optional[float],
707
+ ):
708
+ if (
709
+ self.alerting is not None
710
+ and "webhook" in self.alerting
711
+ and end_user_id is not None
712
+ and token is not None
713
+ and response_cost is not None
714
+ ):
715
+ # log customer spend
716
+ event = WebhookEvent(
717
+ spend=response_cost,
718
+ max_budget=max_budget,
719
+ token=token,
720
+ customer_id=end_user_id,
721
+ user_id=None,
722
+ team_id=None,
723
+ user_email=None,
724
+ key_alias=key_alias,
725
+ projected_exceeded_date=None,
726
+ projected_spend=None,
727
+ event="spend_tracked",
728
+ event_group="customer",
729
+ event_message="Customer spend tracked. Customer={}, spend={}".format(
730
+ end_user_id, response_cost
731
+ ),
732
+ )
733
+
734
+ await self.send_webhook_alert(webhook_event=event)
735
+
736
+ def _count_outage_alerts(self, alerts: List[int]) -> str:
737
+ """
738
+ Parameters:
739
+ - alerts: List[int] -> list of error codes (either 408 or 500+)
740
+
741
+ Returns:
742
+ - str -> formatted string. This is an alert message, giving a human-friendly description of the errors.
743
+ """
744
+ error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0}
745
+ for alert in alerts:
746
+ if alert == 408:
747
+ error_breakdown["Timeout Errors"] += 1
748
+ elif alert >= 500:
749
+ error_breakdown["API Errors"] += 1
750
+ else:
751
+ error_breakdown["Unknown Errors"] += 1
752
+
753
+ error_msg = ""
754
+ for key, value in error_breakdown.items():
755
+ if value > 0:
756
+ error_msg += "\n{}: {}\n".format(key, value)
757
+
758
+ return error_msg
759
+
760
+ def _outage_alert_msg_factory(
761
+ self,
762
+ alert_type: Literal["Major", "Minor"],
763
+ key: Literal["Model", "Region"],
764
+ key_val: str,
765
+ provider: str,
766
+ api_base: Optional[str],
767
+ outage_value: BaseOutageModel,
768
+ ) -> str:
769
+ """Format an alert message for slack"""
770
+ headers = {f"{key} Name": key_val, "Provider": provider}
771
+ if api_base is not None:
772
+ headers["API Base"] = api_base # type: ignore
773
+
774
+ headers_str = "\n"
775
+ for k, v in headers.items():
776
+ headers_str += f"*{k}:* `{v}`\n"
777
+ return f"""\n\n
778
+ *⚠️ {alert_type} Service Outage*
779
+
780
+ {headers_str}
781
+
782
+ *Errors:*
783
+ {self._count_outage_alerts(alerts=outage_value["alerts"])}
784
+
785
+ *Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
786
+ """
787
+
788
+ async def region_outage_alerts(
789
+ self,
790
+ exception: APIError,
791
+ deployment_id: str,
792
+ ) -> None:
793
+ """
794
+ Send slack alert if specific provider region is having an outage.
795
+
796
+ Track for 408 (Timeout) and >=500 Error codes
797
+ """
798
+ ## CREATE (PROVIDER+REGION) ID ##
799
+ if self.llm_router is None:
800
+ return
801
+
802
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
803
+
804
+ if deployment is None:
805
+ return
806
+
807
+ model = deployment.litellm_params.model
808
+ ### GET PROVIDER ###
809
+ provider = deployment.litellm_params.custom_llm_provider
810
+ if provider is None:
811
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
812
+
813
+ ### GET REGION ###
814
+ region_name = deployment.litellm_params.region_name
815
+ if region_name is None:
816
+ region_name = litellm.utils._get_model_region(
817
+ custom_llm_provider=provider, litellm_params=deployment.litellm_params
818
+ )
819
+
820
+ if region_name is None:
821
+ return
822
+
823
+ ### UNIQUE CACHE KEY ###
824
+ cache_key = provider + region_name
825
+
826
+ outage_value: Optional[
827
+ ProviderRegionOutageModel
828
+ ] = await self.internal_usage_cache.async_get_cache(key=cache_key)
829
+
830
+ if (
831
+ getattr(exception, "status_code", None) is None
832
+ or (
833
+ exception.status_code != 408 # type: ignore
834
+ and exception.status_code < 500 # type: ignore
835
+ )
836
+ or self.llm_router is None
837
+ ):
838
+ return
839
+
840
+ if outage_value is None:
841
+ _deployment_set = set()
842
+ _deployment_set.add(deployment_id)
843
+ outage_value = ProviderRegionOutageModel(
844
+ provider_region_id=cache_key,
845
+ alerts=[exception.status_code], # type: ignore
846
+ minor_alert_sent=False,
847
+ major_alert_sent=False,
848
+ last_updated_at=time.time(),
849
+ deployment_ids=_deployment_set,
850
+ )
851
+
852
+ ## add to cache ##
853
+ await self.internal_usage_cache.async_set_cache(
854
+ key=cache_key,
855
+ value=outage_value,
856
+ ttl=self.alerting_args.region_outage_alert_ttl,
857
+ )
858
+ return
859
+
860
+ if len(outage_value["alerts"]) < self.alerting_args.max_outage_alert_list_size:
861
+ outage_value["alerts"].append(exception.status_code) # type: ignore
862
+ else: # prevent memory leaks
863
+ pass
864
+ _deployment_set = outage_value["deployment_ids"]
865
+ _deployment_set.add(deployment_id)
866
+ outage_value["deployment_ids"] = _deployment_set
867
+ outage_value["last_updated_at"] = time.time()
868
+
869
+ ## MINOR OUTAGE ALERT SENT ##
870
+ if (
871
+ outage_value["minor_alert_sent"] is False
872
+ and len(outage_value["alerts"])
873
+ >= self.alerting_args.minor_outage_alert_threshold
874
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
875
+ ):
876
+ msg = self._outage_alert_msg_factory(
877
+ alert_type="Minor",
878
+ key="Region",
879
+ key_val=region_name,
880
+ api_base=None,
881
+ outage_value=outage_value,
882
+ provider=provider,
883
+ )
884
+ # send minor alert
885
+ await self.send_alert(
886
+ message=msg,
887
+ level="Medium",
888
+ alert_type=AlertType.outage_alerts,
889
+ alerting_metadata={},
890
+ )
891
+ # set to true
892
+ outage_value["minor_alert_sent"] = True
893
+
894
+ ## MAJOR OUTAGE ALERT SENT ##
895
+ elif (
896
+ outage_value["major_alert_sent"] is False
897
+ and len(outage_value["alerts"])
898
+ >= self.alerting_args.major_outage_alert_threshold
899
+ and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
900
+ ):
901
+ msg = self._outage_alert_msg_factory(
902
+ alert_type="Major",
903
+ key="Region",
904
+ key_val=region_name,
905
+ api_base=None,
906
+ outage_value=outage_value,
907
+ provider=provider,
908
+ )
909
+
910
+ # send minor alert
911
+ await self.send_alert(
912
+ message=msg,
913
+ level="High",
914
+ alert_type=AlertType.outage_alerts,
915
+ alerting_metadata={},
916
+ )
917
+ # set to true
918
+ outage_value["major_alert_sent"] = True
919
+
920
+ ## update cache ##
921
+ await self.internal_usage_cache.async_set_cache(
922
+ key=cache_key, value=outage_value
923
+ )
924
+
925
+ async def outage_alerts(
926
+ self,
927
+ exception: APIError,
928
+ deployment_id: str,
929
+ ) -> None:
930
+ """
931
+ Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500).
932
+
933
+ key = model_id
934
+
935
+ value = {
936
+ - model_id
937
+ - threshold
938
+ - alerts []
939
+ }
940
+
941
+ ttl = 1hr
942
+ max_alerts_size = 10
943
+ """
944
+ try:
945
+ outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore
946
+ if (
947
+ getattr(exception, "status_code", None) is None
948
+ or (
949
+ exception.status_code != 408 # type: ignore
950
+ and exception.status_code < 500 # type: ignore
951
+ )
952
+ or self.llm_router is None
953
+ ):
954
+ return
955
+
956
+ ### EXTRACT MODEL DETAILS ###
957
+ deployment = self.llm_router.get_deployment(model_id=deployment_id)
958
+ if deployment is None:
959
+ return
960
+
961
+ model = deployment.litellm_params.model
962
+ provider = deployment.litellm_params.custom_llm_provider
963
+ if provider is None:
964
+ try:
965
+ model, provider, _, _ = litellm.get_llm_provider(model=model)
966
+ except Exception:
967
+ provider = ""
968
+ api_base = litellm.get_api_base(
969
+ model=model, optional_params=deployment.litellm_params
970
+ )
971
+
972
+ if outage_value is None:
973
+ outage_value = OutageModel(
974
+ model_id=deployment_id,
975
+ alerts=[exception.status_code], # type: ignore
976
+ minor_alert_sent=False,
977
+ major_alert_sent=False,
978
+ last_updated_at=time.time(),
979
+ )
980
+
981
+ ## add to cache ##
982
+ await self.internal_usage_cache.async_set_cache(
983
+ key=deployment_id,
984
+ value=outage_value,
985
+ ttl=self.alerting_args.outage_alert_ttl,
986
+ )
987
+ return
988
+
989
+ if (
990
+ len(outage_value["alerts"])
991
+ < self.alerting_args.max_outage_alert_list_size
992
+ ):
993
+ outage_value["alerts"].append(exception.status_code) # type: ignore
994
+ else: # prevent memory leaks
995
+ pass
996
+
997
+ outage_value["last_updated_at"] = time.time()
998
+
999
+ ## MINOR OUTAGE ALERT SENT ##
1000
+ if (
1001
+ outage_value["minor_alert_sent"] is False
1002
+ and len(outage_value["alerts"])
1003
+ >= self.alerting_args.minor_outage_alert_threshold
1004
+ ):
1005
+ msg = self._outage_alert_msg_factory(
1006
+ alert_type="Minor",
1007
+ key="Model",
1008
+ key_val=model,
1009
+ api_base=api_base,
1010
+ outage_value=outage_value,
1011
+ provider=provider,
1012
+ )
1013
+ # send minor alert
1014
+ await self.send_alert(
1015
+ message=msg,
1016
+ level="Medium",
1017
+ alert_type=AlertType.outage_alerts,
1018
+ alerting_metadata={},
1019
+ )
1020
+ # set to true
1021
+ outage_value["minor_alert_sent"] = True
1022
+ elif (
1023
+ outage_value["major_alert_sent"] is False
1024
+ and len(outage_value["alerts"])
1025
+ >= self.alerting_args.major_outage_alert_threshold
1026
+ ):
1027
+ msg = self._outage_alert_msg_factory(
1028
+ alert_type="Major",
1029
+ key="Model",
1030
+ key_val=model,
1031
+ api_base=api_base,
1032
+ outage_value=outage_value,
1033
+ provider=provider,
1034
+ )
1035
+ # send minor alert
1036
+ await self.send_alert(
1037
+ message=msg,
1038
+ level="High",
1039
+ alert_type=AlertType.outage_alerts,
1040
+ alerting_metadata={},
1041
+ )
1042
+ # set to true
1043
+ outage_value["major_alert_sent"] = True
1044
+
1045
+ ## update cache ##
1046
+ await self.internal_usage_cache.async_set_cache(
1047
+ key=deployment_id, value=outage_value
1048
+ )
1049
+ except Exception:
1050
+ pass
1051
+
1052
+ async def model_added_alert(
1053
+ self, model_name: str, litellm_model_name: str, passed_model_info: Any
1054
+ ):
1055
+ base_model_from_user = getattr(passed_model_info, "base_model", None)
1056
+ model_info = {}
1057
+ base_model = ""
1058
+ if base_model_from_user is not None:
1059
+ model_info = litellm.model_cost.get(base_model_from_user, {})
1060
+ base_model = f"Base Model: `{base_model_from_user}`\n"
1061
+ else:
1062
+ model_info = litellm.model_cost.get(litellm_model_name, {})
1063
+ model_info_str = ""
1064
+ for k, v in model_info.items():
1065
+ if k == "input_cost_per_token" or k == "output_cost_per_token":
1066
+ # when converting to string it should not be 1.63e-06
1067
+ v = "{:.8f}".format(v)
1068
+
1069
+ model_info_str += f"{k}: {v}\n"
1070
+
1071
+ message = f"""
1072
+ *🚅 New Model Added*
1073
+ Model Name: `{model_name}`
1074
+ {base_model}
1075
+
1076
+ Usage OpenAI Python SDK:
1077
+ ```
1078
+ import openai
1079
+ client = openai.OpenAI(
1080
+ api_key="your_api_key",
1081
+ base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
1082
+ )
1083
+
1084
+ response = client.chat.completions.create(
1085
+ model="{model_name}", # model to send to the proxy
1086
+ messages = [
1087
+ {{
1088
+ "role": "user",
1089
+ "content": "this is a test request, write a short poem"
1090
+ }}
1091
+ ]
1092
+ )
1093
+ ```
1094
+
1095
+ Model Info:
1096
+ ```
1097
+ {model_info_str}
1098
+ ```
1099
+ """
1100
+
1101
+ alert_val = self.send_alert(
1102
+ message=message,
1103
+ level="Low",
1104
+ alert_type=AlertType.new_model_added,
1105
+ alerting_metadata={},
1106
+ )
1107
+
1108
+ if alert_val is not None and asyncio.iscoroutine(alert_val):
1109
+ await alert_val
1110
+
1111
+ async def model_removed_alert(self, model_name: str):
1112
+ pass
1113
+
1114
+ async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
1115
+ """
1116
+ Sends structured alert to webhook, if set.
1117
+
1118
+ Currently only implemented for budget alerts
1119
+
1120
+ Returns -> True if sent, False if not.
1121
+
1122
+ Raises Exception
1123
+ - if WEBHOOK_URL is not set
1124
+ """
1125
+
1126
+ webhook_url = os.getenv("WEBHOOK_URL", None)
1127
+ if webhook_url is None:
1128
+ raise Exception("Missing webhook_url from environment")
1129
+
1130
+ payload = webhook_event.model_dump_json()
1131
+ headers = {"Content-type": "application/json"}
1132
+
1133
+ response = await self.async_http_handler.post(
1134
+ url=webhook_url,
1135
+ headers=headers,
1136
+ data=payload,
1137
+ )
1138
+ if response.status_code == 200:
1139
+ return True
1140
+ else:
1141
+ print("Error sending webhook alert. Error=", response.text) # noqa
1142
+
1143
+ return False
1144
+
1145
+ async def _check_if_using_premium_email_feature(
1146
+ self,
1147
+ premium_user: bool,
1148
+ email_logo_url: Optional[str] = None,
1149
+ email_support_contact: Optional[str] = None,
1150
+ ):
1151
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
1152
+
1153
+ if premium_user is not True:
1154
+ if email_logo_url is not None or email_support_contact is not None:
1155
+ raise ValueError(
1156
+ f"Trying to Customize Email Alerting\n {CommonProxyErrors.not_premium_user.value}"
1157
+ )
1158
+ return
1159
+
1160
+ async def send_key_created_or_user_invited_email(
1161
+ self, webhook_event: WebhookEvent
1162
+ ) -> bool:
1163
+ try:
1164
+ from litellm.proxy.utils import send_email
1165
+
1166
+ if self.alerting is None or "email" not in self.alerting:
1167
+ # do nothing if user does not want email alerts
1168
+ verbose_proxy_logger.error(
1169
+ "Error sending email alert - 'email' not in self.alerting %s",
1170
+ self.alerting,
1171
+ )
1172
+ return False
1173
+ from litellm.proxy.proxy_server import premium_user, prisma_client
1174
+
1175
+ email_logo_url = os.getenv(
1176
+ "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
1177
+ )
1178
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
1179
+ await self._check_if_using_premium_email_feature(
1180
+ premium_user, email_logo_url, email_support_contact
1181
+ )
1182
+ if email_logo_url is None:
1183
+ email_logo_url = LITELLM_LOGO_URL
1184
+ if email_support_contact is None:
1185
+ email_support_contact = LITELLM_SUPPORT_CONTACT
1186
+
1187
+ event_name = webhook_event.event_message
1188
+ recipient_email = webhook_event.user_email
1189
+ recipient_user_id = webhook_event.user_id
1190
+ if (
1191
+ recipient_email is None
1192
+ and recipient_user_id is not None
1193
+ and prisma_client is not None
1194
+ ):
1195
+ user_row = await prisma_client.db.litellm_usertable.find_unique(
1196
+ where={"user_id": recipient_user_id}
1197
+ )
1198
+
1199
+ if user_row is not None:
1200
+ recipient_email = user_row.user_email
1201
+
1202
+ key_token = webhook_event.token
1203
+ key_budget = webhook_event.max_budget
1204
+ base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
1205
+
1206
+ email_html_content = "Alert from LiteLLM Server"
1207
+ if recipient_email is None:
1208
+ verbose_proxy_logger.error(
1209
+ "Trying to send email alert to no recipient",
1210
+ extra=webhook_event.dict(),
1211
+ )
1212
+
1213
+ if webhook_event.event == "key_created":
1214
+ email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
1215
+ email_logo_url=email_logo_url,
1216
+ recipient_email=recipient_email,
1217
+ key_budget=key_budget,
1218
+ key_token=key_token,
1219
+ base_url=base_url,
1220
+ email_support_contact=email_support_contact,
1221
+ )
1222
+ elif webhook_event.event == "internal_user_created":
1223
+ # GET TEAM NAME
1224
+ team_id = webhook_event.team_id
1225
+ team_name = "Default Team"
1226
+ if team_id is not None and prisma_client is not None:
1227
+ team_row = await prisma_client.db.litellm_teamtable.find_unique(
1228
+ where={"team_id": team_id}
1229
+ )
1230
+ if team_row is not None:
1231
+ team_name = team_row.team_alias or "-"
1232
+ email_html_content = USER_INVITED_EMAIL_TEMPLATE.format(
1233
+ email_logo_url=email_logo_url,
1234
+ recipient_email=recipient_email,
1235
+ team_name=team_name,
1236
+ base_url=base_url,
1237
+ email_support_contact=email_support_contact,
1238
+ )
1239
+ else:
1240
+ verbose_proxy_logger.error(
1241
+ "Trying to send email alert on unknown webhook event",
1242
+ extra=webhook_event.model_dump(),
1243
+ )
1244
+
1245
+ webhook_event.model_dump_json()
1246
+ email_event = {
1247
+ "to": recipient_email,
1248
+ "subject": f"LiteLLM: {event_name}",
1249
+ "html": email_html_content,
1250
+ }
1251
+
1252
+ await send_email(
1253
+ receiver_email=email_event["to"],
1254
+ subject=email_event["subject"],
1255
+ html=email_event["html"],
1256
+ )
1257
+
1258
+ return True
1259
+
1260
+ except Exception as e:
1261
+ verbose_proxy_logger.error("Error sending email alert %s", str(e))
1262
+ return False
1263
+
1264
+ async def send_email_alert_using_smtp(
1265
+ self, webhook_event: WebhookEvent, alert_type: str
1266
+ ) -> bool:
1267
+ """
1268
+ Sends structured Email alert to an SMTP server
1269
+
1270
+ Currently only implemented for budget alerts
1271
+
1272
+ Returns -> True if sent, False if not.
1273
+ """
1274
+ from litellm.proxy.proxy_server import premium_user
1275
+ from litellm.proxy.utils import send_email
1276
+
1277
+ email_logo_url = os.getenv(
1278
+ "SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
1279
+ )
1280
+ email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
1281
+ await self._check_if_using_premium_email_feature(
1282
+ premium_user, email_logo_url, email_support_contact
1283
+ )
1284
+
1285
+ if email_logo_url is None:
1286
+ email_logo_url = LITELLM_LOGO_URL
1287
+ if email_support_contact is None:
1288
+ email_support_contact = LITELLM_SUPPORT_CONTACT
1289
+
1290
+ event_name = webhook_event.event_message
1291
+ recipient_email = webhook_event.user_email
1292
+ user_name = webhook_event.user_id
1293
+ max_budget = webhook_event.max_budget
1294
+ email_html_content = "Alert from LiteLLM Server"
1295
+ if recipient_email is None:
1296
+ verbose_proxy_logger.error(
1297
+ "Trying to send email alert to no recipient", extra=webhook_event.dict()
1298
+ )
1299
+
1300
+ if webhook_event.event == "budget_crossed":
1301
+ email_html_content = f"""
1302
+ <img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
1303
+
1304
+ <p> Hi {user_name}, <br/>
1305
+
1306
+ Your LLM API usage this month has reached your account's <b> monthly budget of ${max_budget} </b> <br /> <br />
1307
+
1308
+ API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month. <br /> <br />
1309
+
1310
+ If you have any questions, please send an email to {email_support_contact} <br /> <br />
1311
+
1312
+ Best, <br />
1313
+ The LiteLLM team <br />
1314
+ """
1315
+
1316
+ webhook_event.model_dump_json()
1317
+ email_event = {
1318
+ "to": recipient_email,
1319
+ "subject": f"LiteLLM: {event_name}",
1320
+ "html": email_html_content,
1321
+ }
1322
+
1323
+ await send_email(
1324
+ receiver_email=email_event["to"],
1325
+ subject=email_event["subject"],
1326
+ html=email_event["html"],
1327
+ )
1328
+ if webhook_event.event_group == "team":
1329
+ from litellm.integrations.email_alerting import send_team_budget_alert
1330
+
1331
+ await send_team_budget_alert(webhook_event=webhook_event)
1332
+
1333
+ return False
1334
+
1335
+ async def send_alert(
1336
+ self,
1337
+ message: str,
1338
+ level: Literal["Low", "Medium", "High"],
1339
+ alert_type: AlertType,
1340
+ alerting_metadata: dict,
1341
+ user_info: Optional[WebhookEvent] = None,
1342
+ **kwargs,
1343
+ ):
1344
+ """
1345
+ Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
1346
+
1347
+ - Responses taking too long
1348
+ - Requests are hanging
1349
+ - Calls are failing
1350
+ - DB Read/Writes are failing
1351
+ - Proxy Close to max budget
1352
+ - Key Close to max budget
1353
+
1354
+ Parameters:
1355
+ level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
1356
+ message: str - what is the alert about
1357
+ """
1358
+ if self.alerting is None:
1359
+ return
1360
+
1361
+ if (
1362
+ "webhook" in self.alerting
1363
+ and alert_type == "budget_alerts"
1364
+ and user_info is not None
1365
+ ):
1366
+ await self.send_webhook_alert(webhook_event=user_info)
1367
+
1368
+ if (
1369
+ "email" in self.alerting
1370
+ and alert_type == "budget_alerts"
1371
+ and user_info is not None
1372
+ ):
1373
+ # only send budget alerts over Email
1374
+ await self.send_email_alert_using_smtp(
1375
+ webhook_event=user_info, alert_type=alert_type
1376
+ )
1377
+
1378
+ if "slack" not in self.alerting:
1379
+ return
1380
+ if alert_type not in self.alert_types:
1381
+ return
1382
+
1383
+ from datetime import datetime
1384
+
1385
+ # Get the current timestamp
1386
+ current_time = datetime.now().strftime("%H:%M:%S")
1387
+ _proxy_base_url = os.getenv("PROXY_BASE_URL", None)
1388
+ if alert_type == "daily_reports" or alert_type == "new_model_added":
1389
+ formatted_message = message
1390
+ else:
1391
+ formatted_message = (
1392
+ f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
1393
+ )
1394
+
1395
+ if kwargs:
1396
+ for key, value in kwargs.items():
1397
+ formatted_message += f"\n\n{key}: `{value}`\n\n"
1398
+ if alerting_metadata:
1399
+ for key, value in alerting_metadata.items():
1400
+ formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n"
1401
+ if _proxy_base_url is not None:
1402
+ formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
1403
+
1404
+ # check if we find the slack webhook url in self.alert_to_webhook_url
1405
+ if (
1406
+ self.alert_to_webhook_url is not None
1407
+ and alert_type in self.alert_to_webhook_url
1408
+ ):
1409
+ slack_webhook_url: Optional[
1410
+ Union[str, List[str]]
1411
+ ] = self.alert_to_webhook_url[alert_type]
1412
+ elif self.default_webhook_url is not None:
1413
+ slack_webhook_url = self.default_webhook_url
1414
+ else:
1415
+ slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
1416
+
1417
+ if slack_webhook_url is None:
1418
+ raise ValueError("Missing SLACK_WEBHOOK_URL from environment")
1419
+ payload = {"text": formatted_message}
1420
+ headers = {"Content-type": "application/json"}
1421
+
1422
+ if isinstance(slack_webhook_url, list):
1423
+ for url in slack_webhook_url:
1424
+ self.log_queue.append(
1425
+ {
1426
+ "url": url,
1427
+ "headers": headers,
1428
+ "payload": payload,
1429
+ "alert_type": alert_type,
1430
+ }
1431
+ )
1432
+ else:
1433
+ self.log_queue.append(
1434
+ {
1435
+ "url": slack_webhook_url,
1436
+ "headers": headers,
1437
+ "payload": payload,
1438
+ "alert_type": alert_type,
1439
+ }
1440
+ )
1441
+
1442
+ if len(self.log_queue) >= self.batch_size:
1443
+ await self.flush_queue()
1444
+
1445
+ async def async_send_batch(self):
1446
+ if not self.log_queue:
1447
+ return
1448
+
1449
+ squashed_queue = squash_payloads(self.log_queue)
1450
+ tasks = [
1451
+ send_to_webhook(
1452
+ slackAlertingInstance=self, item=item["item"], count=item["count"]
1453
+ )
1454
+ for item in squashed_queue.values()
1455
+ ]
1456
+ await asyncio.gather(*tasks)
1457
+ self.log_queue.clear()
1458
+
1459
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
1460
+ """Log deployment latency"""
1461
+ try:
1462
+ if "daily_reports" in self.alert_types:
1463
+ litellm_params = kwargs.get("litellm_params", {}) or {}
1464
+ model_info = litellm_params.get("model_info", {}) or {}
1465
+ model_id = model_info.get("id", "") or ""
1466
+ response_s: timedelta = end_time - start_time
1467
+
1468
+ final_value = response_s
1469
+
1470
+ if isinstance(response_obj, litellm.ModelResponse) and (
1471
+ hasattr(response_obj, "usage")
1472
+ and response_obj.usage is not None # type: ignore
1473
+ and hasattr(response_obj.usage, "completion_tokens") # type: ignore
1474
+ ):
1475
+ completion_tokens = response_obj.usage.completion_tokens # type: ignore
1476
+ if completion_tokens is not None and completion_tokens > 0:
1477
+ final_value = float(
1478
+ response_s.total_seconds() / completion_tokens
1479
+ )
1480
+ if isinstance(final_value, timedelta):
1481
+ final_value = final_value.total_seconds()
1482
+
1483
+ await self.async_update_daily_reports(
1484
+ DeploymentMetrics(
1485
+ id=model_id,
1486
+ failed_request=False,
1487
+ latency_per_output_token=final_value,
1488
+ updated_at=litellm.utils.get_utc_datetime(),
1489
+ )
1490
+ )
1491
+ except Exception as e:
1492
+ verbose_proxy_logger.error(
1493
+ f"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: {str(e)}"
1494
+ )
1495
+ pass
1496
+
1497
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
1498
+ """Log failure + deployment latency"""
1499
+ _litellm_params = kwargs.get("litellm_params", {})
1500
+ _model_info = _litellm_params.get("model_info", {}) or {}
1501
+ model_id = _model_info.get("id", "")
1502
+ try:
1503
+ if "daily_reports" in self.alert_types:
1504
+ try:
1505
+ await self.async_update_daily_reports(
1506
+ DeploymentMetrics(
1507
+ id=model_id,
1508
+ failed_request=True,
1509
+ latency_per_output_token=None,
1510
+ updated_at=litellm.utils.get_utc_datetime(),
1511
+ )
1512
+ )
1513
+ except Exception as e:
1514
+ verbose_logger.debug(f"Exception raises -{str(e)}")
1515
+
1516
+ if isinstance(kwargs.get("exception", ""), APIError):
1517
+ if "outage_alerts" in self.alert_types:
1518
+ await self.outage_alerts(
1519
+ exception=kwargs["exception"],
1520
+ deployment_id=model_id,
1521
+ )
1522
+
1523
+ if "region_outage_alerts" in self.alert_types:
1524
+ await self.region_outage_alerts(
1525
+ exception=kwargs["exception"], deployment_id=model_id
1526
+ )
1527
+ except Exception:
1528
+ pass
1529
+
1530
+ async def _run_scheduler_helper(self, llm_router) -> bool:
1531
+ """
1532
+ Returns:
1533
+ - True -> report sent
1534
+ - False -> report not sent
1535
+ """
1536
+ report_sent_bool = False
1537
+
1538
+ report_sent = await self.internal_usage_cache.async_get_cache(
1539
+ key=SlackAlertingCacheKeys.report_sent_key.value,
1540
+ parent_otel_span=None,
1541
+ ) # None | float
1542
+
1543
+ current_time = time.time()
1544
+
1545
+ if report_sent is None:
1546
+ await self.internal_usage_cache.async_set_cache(
1547
+ key=SlackAlertingCacheKeys.report_sent_key.value,
1548
+ value=current_time,
1549
+ )
1550
+ elif isinstance(report_sent, float):
1551
+ # Check if current time - interval >= time last sent
1552
+ interval_seconds = self.alerting_args.daily_report_frequency
1553
+
1554
+ if current_time - report_sent >= interval_seconds:
1555
+ # Sneak in the reporting logic here
1556
+ await self.send_daily_reports(router=llm_router)
1557
+ # Also, don't forget to update the report_sent time after sending the report!
1558
+ await self.internal_usage_cache.async_set_cache(
1559
+ key=SlackAlertingCacheKeys.report_sent_key.value,
1560
+ value=current_time,
1561
+ )
1562
+ report_sent_bool = True
1563
+
1564
+ return report_sent_bool
1565
+
1566
+ async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
1567
+ """
1568
+ If 'daily_reports' enabled
1569
+
1570
+ Ping redis cache every 5 minutes to check if we should send the report
1571
+
1572
+ If yes -> call send_daily_report()
1573
+ """
1574
+ if llm_router is None or self.alert_types is None:
1575
+ return
1576
+
1577
+ if "daily_reports" in self.alert_types:
1578
+ while True:
1579
+ await self._run_scheduler_helper(llm_router=llm_router)
1580
+ interval = random.randint(
1581
+ self.alerting_args.report_check_interval - 3,
1582
+ self.alerting_args.report_check_interval + 3,
1583
+ ) # shuffle to prevent collisions
1584
+ await asyncio.sleep(interval)
1585
+ return
1586
+
1587
+ async def send_weekly_spend_report(
1588
+ self,
1589
+ time_range: str = "7d",
1590
+ ):
1591
+ """
1592
+ Send a spend report for a configurable time range.
1593
+
1594
+ Args:
1595
+ time_range: A string specifying the time range for the report, e.g., "1d", "7d", "30d"
1596
+ """
1597
+ if self.alerting is None or "spend_reports" not in self.alert_types:
1598
+ return
1599
+
1600
+ try:
1601
+ from litellm.proxy.spend_tracking.spend_management_endpoints import (
1602
+ _get_spend_report_for_time_range,
1603
+ )
1604
+
1605
+ # Parse the time range
1606
+ days = int(time_range[:-1])
1607
+ if time_range[-1].lower() != "d":
1608
+ raise ValueError("Time range must be specified in days, e.g., '7d'")
1609
+
1610
+ todays_date = datetime.datetime.now().date()
1611
+ start_date = todays_date - datetime.timedelta(days=days)
1612
+
1613
+ _event_cache_key = f"weekly_spend_report_sent_{start_date.strftime('%Y-%m-%d')}_{todays_date.strftime('%Y-%m-%d')}"
1614
+ if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
1615
+ return
1616
+
1617
+ _resp = await _get_spend_report_for_time_range(
1618
+ start_date=start_date.strftime("%Y-%m-%d"),
1619
+ end_date=todays_date.strftime("%Y-%m-%d"),
1620
+ )
1621
+ if _resp is None or _resp == ([], []):
1622
+ return
1623
+
1624
+ spend_per_team, spend_per_tag = _resp
1625
+
1626
+ _spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
1627
+
1628
+ if spend_per_team is not None:
1629
+ _spend_message += "\n*Team Spend Report:*\n"
1630
+ for spend in spend_per_team:
1631
+ _team_spend = round(float(spend["total_spend"]), 4)
1632
+ _spend_message += (
1633
+ f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
1634
+ )
1635
+
1636
+ if spend_per_tag is not None:
1637
+ _spend_message += "\n*Tag Spend Report:*\n"
1638
+ for spend in spend_per_tag:
1639
+ _tag_spend = round(float(spend["total_spend"]), 4)
1640
+ _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
1641
+
1642
+ await self.send_alert(
1643
+ message=_spend_message,
1644
+ level="Low",
1645
+ alert_type=AlertType.spend_reports,
1646
+ alerting_metadata={},
1647
+ )
1648
+
1649
+ await self.internal_usage_cache.async_set_cache(
1650
+ key=_event_cache_key,
1651
+ value="SENT",
1652
+ ttl=duration_in_seconds(time_range),
1653
+ )
1654
+
1655
+ except ValueError as ve:
1656
+ verbose_proxy_logger.error(f"Invalid time range format: {ve}")
1657
+ except Exception as e:
1658
+ verbose_proxy_logger.error(f"Error sending spend report: {e}")
1659
+
1660
+ async def send_monthly_spend_report(self):
1661
+ """ """
1662
+ try:
1663
+ from calendar import monthrange
1664
+
1665
+ from litellm.proxy.spend_tracking.spend_management_endpoints import (
1666
+ _get_spend_report_for_time_range,
1667
+ )
1668
+
1669
+ todays_date = datetime.datetime.now().date()
1670
+ first_day_of_month = todays_date.replace(day=1)
1671
+ _, last_day_of_month = monthrange(todays_date.year, todays_date.month)
1672
+ last_day_of_month = first_day_of_month + datetime.timedelta(
1673
+ days=last_day_of_month - 1
1674
+ )
1675
+
1676
+ _event_cache_key = f"monthly_spend_report_sent_{first_day_of_month.strftime('%Y-%m-%d')}_{last_day_of_month.strftime('%Y-%m-%d')}"
1677
+ if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
1678
+ return
1679
+
1680
+ _resp = await _get_spend_report_for_time_range(
1681
+ start_date=first_day_of_month.strftime("%Y-%m-%d"),
1682
+ end_date=last_day_of_month.strftime("%Y-%m-%d"),
1683
+ )
1684
+
1685
+ if _resp is None or _resp == ([], []):
1686
+ return
1687
+
1688
+ monthly_spend_per_team, monthly_spend_per_tag = _resp
1689
+
1690
+ _spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
1691
+
1692
+ if monthly_spend_per_team is not None:
1693
+ _spend_message += "\n*Team Spend Report:*\n"
1694
+ for spend in monthly_spend_per_team:
1695
+ _team_spend = spend["total_spend"]
1696
+ _team_spend = float(_team_spend)
1697
+ # round to 4 decimal places
1698
+ _team_spend = round(_team_spend, 4)
1699
+ _spend_message += (
1700
+ f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
1701
+ )
1702
+
1703
+ if monthly_spend_per_tag is not None:
1704
+ _spend_message += "\n*Tag Spend Report:*\n"
1705
+ for spend in monthly_spend_per_tag:
1706
+ _tag_spend = spend["total_spend"]
1707
+ _tag_spend = float(_tag_spend)
1708
+ # round to 4 decimal places
1709
+ _tag_spend = round(_tag_spend, 4)
1710
+ _spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
1711
+
1712
+ await self.send_alert(
1713
+ message=_spend_message,
1714
+ level="Low",
1715
+ alert_type=AlertType.spend_reports,
1716
+ alerting_metadata={},
1717
+ )
1718
+
1719
+ await self.internal_usage_cache.async_set_cache(
1720
+ key=_event_cache_key,
1721
+ value="SENT",
1722
+ ttl=(30 * HOURS_IN_A_DAY * 60 * 60), # 1 month
1723
+ )
1724
+
1725
+ except Exception as e:
1726
+ verbose_proxy_logger.exception("Error sending weekly spend report %s", e)
1727
+
1728
+ async def send_fallback_stats_from_prometheus(self):
1729
+ """
1730
+ Helper to send fallback statistics from prometheus server -> to slack
1731
+
1732
+ This runs once per day and sends an overview of all the fallback statistics
1733
+ """
1734
+ try:
1735
+ from litellm.integrations.prometheus_helpers.prometheus_api import (
1736
+ get_fallback_metric_from_prometheus,
1737
+ )
1738
+
1739
+ # call prometheuslogger.
1740
+ falllback_success_info_prometheus = (
1741
+ await get_fallback_metric_from_prometheus()
1742
+ )
1743
+
1744
+ fallback_message = (
1745
+ f"*Fallback Statistics:*\n{falllback_success_info_prometheus}"
1746
+ )
1747
+
1748
+ await self.send_alert(
1749
+ message=fallback_message,
1750
+ level="Low",
1751
+ alert_type=AlertType.fallback_reports,
1752
+ alerting_metadata={},
1753
+ )
1754
+
1755
+ except Exception as e:
1756
+ verbose_proxy_logger.error("Error sending weekly spend report %s", e)
1757
+
1758
+ pass
1759
+
1760
+ async def send_virtual_key_event_slack(
1761
+ self,
1762
+ key_event: VirtualKeyEvent,
1763
+ alert_type: AlertType,
1764
+ event_name: str,
1765
+ ):
1766
+ """
1767
+ Handles sending Virtual Key related alerts
1768
+
1769
+ Example:
1770
+ - New Virtual Key Created
1771
+ - Internal User Updated
1772
+ - Team Created, Updated, Deleted
1773
+ """
1774
+ try:
1775
+ message = f"`{event_name}`\n"
1776
+
1777
+ key_event_dict = key_event.model_dump()
1778
+
1779
+ # Add Created by information first
1780
+ message += "*Action Done by:*\n"
1781
+ for key, value in key_event_dict.items():
1782
+ if "created_by" in key:
1783
+ message += f"{key}: `{value}`\n"
1784
+
1785
+ # Add args sent to function in the alert
1786
+ message += "\n*Arguments passed:*\n"
1787
+ request_kwargs = key_event.request_kwargs
1788
+ for key, value in request_kwargs.items():
1789
+ if key == "user_api_key_dict":
1790
+ continue
1791
+ message += f"{key}: `{value}`\n"
1792
+
1793
+ await self.send_alert(
1794
+ message=message,
1795
+ level="High",
1796
+ alert_type=alert_type,
1797
+ alerting_metadata={},
1798
+ )
1799
+
1800
+ except Exception as e:
1801
+ verbose_proxy_logger.error(
1802
+ "Error sending send_virtual_key_event_slack %s", e
1803
+ )
1804
+
1805
+ return
1806
+
1807
+ async def _request_is_completed(self, request_data: Optional[dict]) -> bool:
1808
+ """
1809
+ Returns True if the request is completed - either as a success or failure
1810
+ """
1811
+ if request_data is None:
1812
+ return False
1813
+
1814
+ if (
1815
+ request_data.get("litellm_status", "") != "success"
1816
+ and request_data.get("litellm_status", "") != "fail"
1817
+ ):
1818
+ ## CHECK IF CACHE IS UPDATED
1819
+ litellm_call_id = request_data.get("litellm_call_id", "")
1820
+ status: Optional[str] = await self.internal_usage_cache.async_get_cache(
1821
+ key="request_status:{}".format(litellm_call_id), local_only=True
1822
+ )
1823
+ if status is not None and (status == "success" or status == "fail"):
1824
+ return True
1825
+ return False
litellm/integrations/SlackAlerting/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utils used for slack alerting
3
+ """
4
+
5
+ import asyncio
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
7
+
8
+ from litellm.proxy._types import AlertType
9
+ from litellm.secret_managers.main import get_secret
10
+
11
+ if TYPE_CHECKING:
12
+ from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
13
+
14
+ Logging = _Logging
15
+ else:
16
+ Logging = Any
17
+
18
+
19
+ def process_slack_alerting_variables(
20
+ alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
21
+ ) -> Optional[Dict[AlertType, Union[List[str], str]]]:
22
+ """
23
+ process alert_to_webhook_url
24
+ - check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
25
+ """
26
+ if alert_to_webhook_url is None:
27
+ return None
28
+
29
+ for alert_type, webhook_urls in alert_to_webhook_url.items():
30
+ if isinstance(webhook_urls, list):
31
+ _webhook_values: List[str] = []
32
+ for webhook_url in webhook_urls:
33
+ if "os.environ/" in webhook_url:
34
+ _env_value = get_secret(secret_name=webhook_url)
35
+ if not isinstance(_env_value, str):
36
+ raise ValueError(
37
+ f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
38
+ )
39
+ _webhook_values.append(_env_value)
40
+ else:
41
+ _webhook_values.append(webhook_url)
42
+
43
+ alert_to_webhook_url[alert_type] = _webhook_values
44
+ else:
45
+ _webhook_value_str: str = webhook_urls
46
+ if "os.environ/" in webhook_urls:
47
+ _env_value = get_secret(secret_name=webhook_urls)
48
+ if not isinstance(_env_value, str):
49
+ raise ValueError(
50
+ f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
51
+ )
52
+ _webhook_value_str = _env_value
53
+ else:
54
+ _webhook_value_str = webhook_urls
55
+
56
+ alert_to_webhook_url[alert_type] = _webhook_value_str
57
+
58
+ return alert_to_webhook_url
59
+
60
+
61
+ async def _add_langfuse_trace_id_to_alert(
62
+ request_data: Optional[dict] = None,
63
+ ) -> Optional[str]:
64
+ """
65
+ Returns langfuse trace url
66
+
67
+ - check:
68
+ -> existing_trace_id
69
+ -> trace_id
70
+ -> litellm_call_id
71
+ """
72
+ # do nothing for now
73
+ if (
74
+ request_data is not None
75
+ and request_data.get("litellm_logging_obj", None) is not None
76
+ ):
77
+ trace_id: Optional[str] = None
78
+ litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
79
+
80
+ for _ in range(3):
81
+ trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
82
+ if trace_id is not None:
83
+ break
84
+ await asyncio.sleep(3) # wait 3s before retrying for trace id
85
+
86
+ _langfuse_object = litellm_logging_obj._get_callback_object(
87
+ service_name="langfuse"
88
+ )
89
+ if _langfuse_object is not None:
90
+ base_url = _langfuse_object.Langfuse.base_url
91
+ return f"{base_url}/trace/{trace_id}"
92
+ return None
litellm/integrations/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
litellm/integrations/_types/open_inference.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+
4
+ class SpanAttributes:
5
+ OUTPUT_VALUE = "output.value"
6
+ OUTPUT_MIME_TYPE = "output.mime_type"
7
+ """
8
+ The type of output.value. If unspecified, the type is plain text by default.
9
+ If type is JSON, the value is a string representing a JSON object.
10
+ """
11
+ INPUT_VALUE = "input.value"
12
+ INPUT_MIME_TYPE = "input.mime_type"
13
+ """
14
+ The type of input.value. If unspecified, the type is plain text by default.
15
+ If type is JSON, the value is a string representing a JSON object.
16
+ """
17
+
18
+ EMBEDDING_EMBEDDINGS = "embedding.embeddings"
19
+ """
20
+ A list of objects containing embedding data, including the vector and represented piece of text.
21
+ """
22
+ EMBEDDING_MODEL_NAME = "embedding.model_name"
23
+ """
24
+ The name of the embedding model.
25
+ """
26
+
27
+ LLM_FUNCTION_CALL = "llm.function_call"
28
+ """
29
+ For models and APIs that support function calling. Records attributes such as the function
30
+ name and arguments to the called function.
31
+ """
32
+ LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
33
+ """
34
+ Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
35
+ """
36
+ LLM_INPUT_MESSAGES = "llm.input_messages"
37
+ """
38
+ Messages provided to a chat API.
39
+ """
40
+ LLM_OUTPUT_MESSAGES = "llm.output_messages"
41
+ """
42
+ Messages received from a chat API.
43
+ """
44
+ LLM_MODEL_NAME = "llm.model_name"
45
+ """
46
+ The name of the model being used.
47
+ """
48
+ LLM_PROVIDER = "llm.provider"
49
+ """
50
+ The provider of the model, such as OpenAI, Azure, Google, etc.
51
+ """
52
+ LLM_SYSTEM = "llm.system"
53
+ """
54
+ The AI product as identified by the client or server
55
+ """
56
+ LLM_PROMPTS = "llm.prompts"
57
+ """
58
+ Prompts provided to a completions API.
59
+ """
60
+ LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
61
+ """
62
+ The prompt template as a Python f-string.
63
+ """
64
+ LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
65
+ """
66
+ A list of input variables to the prompt template.
67
+ """
68
+ LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
69
+ """
70
+ The version of the prompt template being used.
71
+ """
72
+ LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
73
+ """
74
+ Number of tokens in the prompt.
75
+ """
76
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = "llm.token_count.prompt_details.cache_write"
77
+ """
78
+ Number of tokens in the prompt that were written to cache.
79
+ """
80
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = "llm.token_count.prompt_details.cache_read"
81
+ """
82
+ Number of tokens in the prompt that were read from cache.
83
+ """
84
+ LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
85
+ """
86
+ The number of audio input tokens presented in the prompt
87
+ """
88
+ LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
89
+ """
90
+ Number of tokens in the completion.
91
+ """
92
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = "llm.token_count.completion_details.reasoning"
93
+ """
94
+ Number of tokens used for reasoning steps in the completion.
95
+ """
96
+ LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = "llm.token_count.completion_details.audio"
97
+ """
98
+ The number of audio input tokens generated by the model
99
+ """
100
+ LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
101
+ """
102
+ Total number of tokens, including both prompt and completion.
103
+ """
104
+
105
+ LLM_TOOLS = "llm.tools"
106
+ """
107
+ List of tools that are advertised to the LLM to be able to call
108
+ """
109
+
110
+ TOOL_NAME = "tool.name"
111
+ """
112
+ Name of the tool being used.
113
+ """
114
+ TOOL_DESCRIPTION = "tool.description"
115
+ """
116
+ Description of the tool's purpose, typically used to select the tool.
117
+ """
118
+ TOOL_PARAMETERS = "tool.parameters"
119
+ """
120
+ Parameters of the tool represented a dictionary JSON string, e.g.
121
+ see https://platform.openai.com/docs/guides/gpt/function-calling
122
+ """
123
+
124
+ RETRIEVAL_DOCUMENTS = "retrieval.documents"
125
+
126
+ METADATA = "metadata"
127
+ """
128
+ Metadata attributes are used to store user-defined key-value pairs.
129
+ For example, LangChain uses metadata to store user-defined attributes for a chain.
130
+ """
131
+
132
+ TAG_TAGS = "tag.tags"
133
+ """
134
+ Custom categorical tags for the span.
135
+ """
136
+
137
+ OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
138
+
139
+ SESSION_ID = "session.id"
140
+ """
141
+ The id of the session
142
+ """
143
+ USER_ID = "user.id"
144
+ """
145
+ The id of the user
146
+ """
147
+
148
+ PROMPT_VENDOR = "prompt.vendor"
149
+ """
150
+ The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
151
+ """
152
+ PROMPT_ID = "prompt.id"
153
+ """
154
+ A vendor-specific id used to locate the prompt.
155
+ """
156
+ PROMPT_URL = "prompt.url"
157
+ """
158
+ A vendor-specific url used to locate the prompt.
159
+ """
160
+
161
+
162
+ class MessageAttributes:
163
+ """
164
+ Attributes for a message sent to or from an LLM
165
+ """
166
+
167
+ MESSAGE_ROLE = "message.role"
168
+ """
169
+ The role of the message, such as "user", "agent", "function".
170
+ """
171
+ MESSAGE_CONTENT = "message.content"
172
+ """
173
+ The content of the message to or from the llm, must be a string.
174
+ """
175
+ MESSAGE_CONTENTS = "message.contents"
176
+ """
177
+ The message contents to the llm, it is an array of
178
+ `message_content` prefixed attributes.
179
+ """
180
+ MESSAGE_NAME = "message.name"
181
+ """
182
+ The name of the message, often used to identify the function
183
+ that was used to generate the message.
184
+ """
185
+ MESSAGE_TOOL_CALLS = "message.tool_calls"
186
+ """
187
+ The tool calls generated by the model, such as function calls.
188
+ """
189
+ MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
190
+ """
191
+ The function name that is a part of the message list.
192
+ This is populated for role 'function' or 'agent' as a mechanism to identify
193
+ the function that was called during the execution of a tool.
194
+ """
195
+ MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
196
+ """
197
+ The JSON string representing the arguments passed to the function
198
+ during a function call.
199
+ """
200
+ MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
201
+ """
202
+ The id of the tool call.
203
+ """
204
+
205
+
206
+ class MessageContentAttributes:
207
+ """
208
+ Attributes for the contents of user messages sent to an LLM.
209
+ """
210
+
211
+ MESSAGE_CONTENT_TYPE = "message_content.type"
212
+ """
213
+ The type of the content, such as "text" or "image".
214
+ """
215
+ MESSAGE_CONTENT_TEXT = "message_content.text"
216
+ """
217
+ The text content of the message, if the type is "text".
218
+ """
219
+ MESSAGE_CONTENT_IMAGE = "message_content.image"
220
+ """
221
+ The image content of the message, if the type is "image".
222
+ An image can be made available to the model by passing a link to
223
+ the image or by passing the base64 encoded image directly in the
224
+ request.
225
+ """
226
+
227
+
228
+ class ImageAttributes:
229
+ """
230
+ Attributes for images
231
+ """
232
+
233
+ IMAGE_URL = "image.url"
234
+ """
235
+ An http or base64 image url
236
+ """
237
+
238
+
239
+ class AudioAttributes:
240
+ """
241
+ Attributes for audio
242
+ """
243
+
244
+ AUDIO_URL = "audio.url"
245
+ """
246
+ The url to an audio file
247
+ """
248
+ AUDIO_MIME_TYPE = "audio.mime_type"
249
+ """
250
+ The mime type of the audio file
251
+ """
252
+ AUDIO_TRANSCRIPT = "audio.transcript"
253
+ """
254
+ The transcript of the audio file
255
+ """
256
+
257
+
258
+ class DocumentAttributes:
259
+ """
260
+ Attributes for a document.
261
+ """
262
+
263
+ DOCUMENT_ID = "document.id"
264
+ """
265
+ The id of the document.
266
+ """
267
+ DOCUMENT_SCORE = "document.score"
268
+ """
269
+ The score of the document
270
+ """
271
+ DOCUMENT_CONTENT = "document.content"
272
+ """
273
+ The content of the document.
274
+ """
275
+ DOCUMENT_METADATA = "document.metadata"
276
+ """
277
+ The metadata of the document represented as a dictionary
278
+ JSON string, e.g. `"{ 'title': 'foo' }"`
279
+ """
280
+
281
+
282
+ class RerankerAttributes:
283
+ """
284
+ Attributes for a reranker
285
+ """
286
+
287
+ RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
288
+ """
289
+ List of documents as input to the reranker
290
+ """
291
+ RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
292
+ """
293
+ List of documents as output from the reranker
294
+ """
295
+ RERANKER_QUERY = "reranker.query"
296
+ """
297
+ Query string for the reranker
298
+ """
299
+ RERANKER_MODEL_NAME = "reranker.model_name"
300
+ """
301
+ Model name of the reranker
302
+ """
303
+ RERANKER_TOP_K = "reranker.top_k"
304
+ """
305
+ Top K parameter of the reranker
306
+ """
307
+
308
+
309
+ class EmbeddingAttributes:
310
+ """
311
+ Attributes for an embedding
312
+ """
313
+
314
+ EMBEDDING_TEXT = "embedding.text"
315
+ """
316
+ The text represented by the embedding.
317
+ """
318
+ EMBEDDING_VECTOR = "embedding.vector"
319
+ """
320
+ The embedding vector.
321
+ """
322
+
323
+
324
+ class ToolCallAttributes:
325
+ """
326
+ Attributes for a tool call
327
+ """
328
+
329
+ TOOL_CALL_ID = "tool_call.id"
330
+ """
331
+ The id of the tool call.
332
+ """
333
+ TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
334
+ """
335
+ The name of function that is being called during a tool call.
336
+ """
337
+ TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
338
+ """
339
+ The JSON string representing the arguments passed to the function
340
+ during a tool call.
341
+ """
342
+
343
+
344
+ class ToolAttributes:
345
+ """
346
+ Attributes for a tools
347
+ """
348
+
349
+ TOOL_JSON_SCHEMA = "tool.json_schema"
350
+ """
351
+ The json schema of a tool input, It is RECOMMENDED that this be in the
352
+ OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
353
+ """
354
+
355
+
356
+ class OpenInferenceSpanKindValues(Enum):
357
+ TOOL = "TOOL"
358
+ CHAIN = "CHAIN"
359
+ LLM = "LLM"
360
+ RETRIEVER = "RETRIEVER"
361
+ EMBEDDING = "EMBEDDING"
362
+ AGENT = "AGENT"
363
+ RERANKER = "RERANKER"
364
+ UNKNOWN = "UNKNOWN"
365
+ GUARDRAIL = "GUARDRAIL"
366
+ EVALUATOR = "EVALUATOR"
367
+
368
+
369
+ class OpenInferenceMimeTypeValues(Enum):
370
+ TEXT = "text/plain"
371
+ JSON = "application/json"
372
+
373
+
374
+ class OpenInferenceLLMSystemValues(Enum):
375
+ OPENAI = "openai"
376
+ ANTHROPIC = "anthropic"
377
+ COHERE = "cohere"
378
+ MISTRALAI = "mistralai"
379
+ VERTEXAI = "vertexai"
380
+
381
+
382
+ class OpenInferenceLLMProviderValues(Enum):
383
+ OPENAI = "openai"
384
+ ANTHROPIC = "anthropic"
385
+ COHERE = "cohere"
386
+ MISTRALAI = "mistralai"
387
+ GOOGLE = "google"
388
+ AZURE = "azure"
389
+ AWS = "aws"
litellm/integrations/additional_logging_utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for Additional Logging Utils for CustomLoggers
3
+
4
+ - Health Check for the logging util
5
+ - Get Request / Response Payload for the logging util
6
+ """
7
+
8
+ from abc import ABC, abstractmethod
9
+ from datetime import datetime
10
+ from typing import Optional
11
+
12
+ from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
13
+
14
+
15
+ class AdditionalLoggingUtils(ABC):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ @abstractmethod
20
+ async def async_health_check(self) -> IntegrationHealthCheckStatus:
21
+ """
22
+ Check if the service is healthy
23
+ """
24
+ pass
25
+
26
+ @abstractmethod
27
+ async def get_request_response_payload(
28
+ self,
29
+ request_id: str,
30
+ start_time_utc: Optional[datetime],
31
+ end_time_utc: Optional[datetime],
32
+ ) -> Optional[dict]:
33
+ """
34
+ Get the request and response payload for a given `request_id`
35
+ """
36
+ return None
litellm/integrations/agentops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .agentops import AgentOps
2
+
3
+ __all__ = ["AgentOps"]