Spaces:
Running
Running
Push core package and essential files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +26 -0
- litellm/__init__.py +1084 -0
- litellm/_logging.py +167 -0
- litellm/_redis.py +333 -0
- litellm/_service_logger.py +311 -0
- litellm/_version.py +6 -0
- litellm/anthropic_interface/__init__.py +6 -0
- litellm/anthropic_interface/messages/__init__.py +117 -0
- litellm/anthropic_interface/readme.md +116 -0
- litellm/assistants/main.py +1484 -0
- litellm/assistants/utils.py +161 -0
- litellm/batch_completion/Readme.md +11 -0
- litellm/batch_completion/main.py +253 -0
- litellm/batches/batch_utils.py +182 -0
- litellm/batches/main.py +796 -0
- litellm/budget_manager.py +230 -0
- litellm/caching/Readme.md +40 -0
- litellm/caching/__init__.py +9 -0
- litellm/caching/_internal_lru_cache.py +30 -0
- litellm/caching/base_cache.py +55 -0
- litellm/caching/caching.py +818 -0
- litellm/caching/caching_handler.py +938 -0
- litellm/caching/disk_cache.py +88 -0
- litellm/caching/dual_cache.py +434 -0
- litellm/caching/in_memory_cache.py +203 -0
- litellm/caching/llm_caching_handler.py +39 -0
- litellm/caching/qdrant_semantic_cache.py +442 -0
- litellm/caching/redis_cache.py +1162 -0
- litellm/caching/redis_cluster_cache.py +59 -0
- litellm/caching/redis_semantic_cache.py +450 -0
- litellm/caching/s3_cache.py +159 -0
- litellm/constants.py +543 -0
- litellm/cost.json +5 -0
- litellm/cost_calculator.py +1378 -0
- litellm/exceptions.py +809 -0
- litellm/experimental_mcp_client/Readme.md +6 -0
- litellm/experimental_mcp_client/__init__.py +3 -0
- litellm/experimental_mcp_client/client.py +0 -0
- litellm/experimental_mcp_client/tools.py +111 -0
- litellm/files/main.py +891 -0
- litellm/fine_tuning/main.py +761 -0
- litellm/integrations/Readme.md +5 -0
- litellm/integrations/SlackAlerting/Readme.md +13 -0
- litellm/integrations/SlackAlerting/batching_handler.py +81 -0
- litellm/integrations/SlackAlerting/slack_alerting.py +1825 -0
- litellm/integrations/SlackAlerting/utils.py +92 -0
- litellm/integrations/__init__.py +1 -0
- litellm/integrations/_types/open_inference.py +389 -0
- litellm/integrations/additional_logging_utils.py +36 -0
- litellm/integrations/agentops/__init__.py +3 -0
LICENSE
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Portions of this software are licensed as follows:
|
2 |
+
|
3 |
+
* All content that resides under the "enterprise/" directory of this repository, if that directory exists, is licensed under the license defined in "enterprise/LICENSE".
|
4 |
+
* Content outside of the above mentioned directories or restrictions above is available under the MIT license as defined below.
|
5 |
+
---
|
6 |
+
MIT License
|
7 |
+
|
8 |
+
Copyright (c) 2023 Berri AI
|
9 |
+
|
10 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
11 |
+
of this software and associated documentation files (the "Software"), to deal
|
12 |
+
in the Software without restriction, including without limitation the rights
|
13 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
14 |
+
copies of the Software, and to permit persons to whom the Software is
|
15 |
+
furnished to do so, subject to the following conditions:
|
16 |
+
|
17 |
+
The above copyright notice and this permission notice shall be included in all
|
18 |
+
copies or substantial portions of the Software.
|
19 |
+
|
20 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
21 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
22 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
23 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
24 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
25 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
26 |
+
SOFTWARE.
|
litellm/__init__.py
ADDED
@@ -0,0 +1,1084 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Hide pydantic namespace conflict warnings globally ###
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
5 |
+
### INIT VARIABLES ###########
|
6 |
+
import threading
|
7 |
+
import os
|
8 |
+
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
9 |
+
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
10 |
+
from litellm.caching.caching import Cache, DualCache, RedisCache, InMemoryCache
|
11 |
+
from litellm.caching.llm_caching_handler import LLMClientCache
|
12 |
+
from litellm.types.llms.bedrock import COHERE_EMBEDDING_INPUT_TYPES
|
13 |
+
from litellm.types.utils import (
|
14 |
+
ImageObject,
|
15 |
+
BudgetConfig,
|
16 |
+
all_litellm_params,
|
17 |
+
all_litellm_params as _litellm_completion_params,
|
18 |
+
CredentialItem,
|
19 |
+
) # maintain backwards compatibility for root param
|
20 |
+
from litellm._logging import (
|
21 |
+
set_verbose,
|
22 |
+
_turn_on_debug,
|
23 |
+
verbose_logger,
|
24 |
+
json_logs,
|
25 |
+
_turn_on_json,
|
26 |
+
log_level,
|
27 |
+
)
|
28 |
+
import re
|
29 |
+
from litellm.constants import (
|
30 |
+
DEFAULT_BATCH_SIZE,
|
31 |
+
DEFAULT_FLUSH_INTERVAL_SECONDS,
|
32 |
+
ROUTER_MAX_FALLBACKS,
|
33 |
+
DEFAULT_MAX_RETRIES,
|
34 |
+
DEFAULT_REPLICATE_POLLING_RETRIES,
|
35 |
+
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS,
|
36 |
+
LITELLM_CHAT_PROVIDERS,
|
37 |
+
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS,
|
38 |
+
OPENAI_CHAT_COMPLETION_PARAMS,
|
39 |
+
OPENAI_CHAT_COMPLETION_PARAMS as _openai_completion_params, # backwards compatibility
|
40 |
+
OPENAI_FINISH_REASONS,
|
41 |
+
OPENAI_FINISH_REASONS as _openai_finish_reasons, # backwards compatibility
|
42 |
+
openai_compatible_endpoints,
|
43 |
+
openai_compatible_providers,
|
44 |
+
openai_text_completion_compatible_providers,
|
45 |
+
_openai_like_providers,
|
46 |
+
replicate_models,
|
47 |
+
clarifai_models,
|
48 |
+
huggingface_models,
|
49 |
+
empower_models,
|
50 |
+
together_ai_models,
|
51 |
+
baseten_models,
|
52 |
+
REPEATED_STREAMING_CHUNK_LIMIT,
|
53 |
+
request_timeout,
|
54 |
+
open_ai_embedding_models,
|
55 |
+
cohere_embedding_models,
|
56 |
+
bedrock_embedding_models,
|
57 |
+
known_tokenizer_config,
|
58 |
+
BEDROCK_INVOKE_PROVIDERS_LITERAL,
|
59 |
+
DEFAULT_MAX_TOKENS,
|
60 |
+
DEFAULT_SOFT_BUDGET,
|
61 |
+
DEFAULT_ALLOWED_FAILS,
|
62 |
+
)
|
63 |
+
from litellm.types.guardrails import GuardrailItem
|
64 |
+
from litellm.proxy._types import (
|
65 |
+
KeyManagementSystem,
|
66 |
+
KeyManagementSettings,
|
67 |
+
LiteLLM_UpperboundKeyGenerateParams,
|
68 |
+
)
|
69 |
+
from litellm.types.proxy.management_endpoints.ui_sso import DefaultTeamSSOParams
|
70 |
+
from litellm.types.utils import StandardKeyGenerationConfig, LlmProviders
|
71 |
+
from litellm.integrations.custom_logger import CustomLogger
|
72 |
+
from litellm.litellm_core_utils.logging_callback_manager import LoggingCallbackManager
|
73 |
+
import httpx
|
74 |
+
import dotenv
|
75 |
+
|
76 |
+
litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV"
|
77 |
+
if litellm_mode == "DEV":
|
78 |
+
dotenv.load_dotenv()
|
79 |
+
################################################
|
80 |
+
if set_verbose == True:
|
81 |
+
_turn_on_debug()
|
82 |
+
################################################
|
83 |
+
### Callbacks /Logging / Success / Failure Handlers #####
|
84 |
+
CALLBACK_TYPES = Union[str, Callable, CustomLogger]
|
85 |
+
input_callback: List[CALLBACK_TYPES] = []
|
86 |
+
success_callback: List[CALLBACK_TYPES] = []
|
87 |
+
failure_callback: List[CALLBACK_TYPES] = []
|
88 |
+
service_callback: List[CALLBACK_TYPES] = []
|
89 |
+
logging_callback_manager = LoggingCallbackManager()
|
90 |
+
_custom_logger_compatible_callbacks_literal = Literal[
|
91 |
+
"lago",
|
92 |
+
"openmeter",
|
93 |
+
"logfire",
|
94 |
+
"literalai",
|
95 |
+
"dynamic_rate_limiter",
|
96 |
+
"langsmith",
|
97 |
+
"prometheus",
|
98 |
+
"otel",
|
99 |
+
"datadog",
|
100 |
+
"datadog_llm_observability",
|
101 |
+
"galileo",
|
102 |
+
"braintrust",
|
103 |
+
"arize",
|
104 |
+
"arize_phoenix",
|
105 |
+
"langtrace",
|
106 |
+
"gcs_bucket",
|
107 |
+
"azure_storage",
|
108 |
+
"opik",
|
109 |
+
"argilla",
|
110 |
+
"mlflow",
|
111 |
+
"langfuse",
|
112 |
+
"pagerduty",
|
113 |
+
"humanloop",
|
114 |
+
"gcs_pubsub",
|
115 |
+
"agentops",
|
116 |
+
"anthropic_cache_control_hook",
|
117 |
+
"bedrock_knowledgebase_hook",
|
118 |
+
]
|
119 |
+
logged_real_time_event_types: Optional[Union[List[str], Literal["*"]]] = None
|
120 |
+
_known_custom_logger_compatible_callbacks: List = list(
|
121 |
+
get_args(_custom_logger_compatible_callbacks_literal)
|
122 |
+
)
|
123 |
+
callbacks: List[
|
124 |
+
Union[Callable, _custom_logger_compatible_callbacks_literal, CustomLogger]
|
125 |
+
] = []
|
126 |
+
langfuse_default_tags: Optional[List[str]] = None
|
127 |
+
langsmith_batch_size: Optional[int] = None
|
128 |
+
prometheus_initialize_budget_metrics: Optional[bool] = False
|
129 |
+
require_auth_for_metrics_endpoint: Optional[bool] = False
|
130 |
+
argilla_batch_size: Optional[int] = None
|
131 |
+
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
132 |
+
gcs_pub_sub_use_v1: Optional[bool] = (
|
133 |
+
False # if you want to use v1 gcs pubsub logged payload
|
134 |
+
)
|
135 |
+
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
136 |
+
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
137 |
+
[]
|
138 |
+
) # internal variable - async custom callbacks are routed here.
|
139 |
+
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
140 |
+
[]
|
141 |
+
) # internal variable - async custom callbacks are routed here.
|
142 |
+
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
143 |
+
[]
|
144 |
+
) # internal variable - async custom callbacks are routed here.
|
145 |
+
pre_call_rules: List[Callable] = []
|
146 |
+
post_call_rules: List[Callable] = []
|
147 |
+
turn_off_message_logging: Optional[bool] = False
|
148 |
+
log_raw_request_response: bool = False
|
149 |
+
redact_messages_in_exceptions: Optional[bool] = False
|
150 |
+
redact_user_api_key_info: Optional[bool] = False
|
151 |
+
filter_invalid_headers: Optional[bool] = False
|
152 |
+
add_user_information_to_llm_headers: Optional[bool] = (
|
153 |
+
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
154 |
+
)
|
155 |
+
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
156 |
+
### end of callbacks #############
|
157 |
+
|
158 |
+
email: Optional[str] = (
|
159 |
+
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
160 |
+
)
|
161 |
+
token: Optional[str] = (
|
162 |
+
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
163 |
+
)
|
164 |
+
telemetry = True
|
165 |
+
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
166 |
+
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
167 |
+
modify_params = bool(os.getenv("LITELLM_MODIFY_PARAMS", False))
|
168 |
+
retry = True
|
169 |
+
### AUTH ###
|
170 |
+
api_key: Optional[str] = None
|
171 |
+
openai_key: Optional[str] = None
|
172 |
+
groq_key: Optional[str] = None
|
173 |
+
databricks_key: Optional[str] = None
|
174 |
+
openai_like_key: Optional[str] = None
|
175 |
+
azure_key: Optional[str] = None
|
176 |
+
anthropic_key: Optional[str] = None
|
177 |
+
replicate_key: Optional[str] = None
|
178 |
+
cohere_key: Optional[str] = None
|
179 |
+
infinity_key: Optional[str] = None
|
180 |
+
clarifai_key: Optional[str] = None
|
181 |
+
maritalk_key: Optional[str] = None
|
182 |
+
ai21_key: Optional[str] = None
|
183 |
+
ollama_key: Optional[str] = None
|
184 |
+
openrouter_key: Optional[str] = None
|
185 |
+
predibase_key: Optional[str] = None
|
186 |
+
huggingface_key: Optional[str] = None
|
187 |
+
vertex_project: Optional[str] = None
|
188 |
+
vertex_location: Optional[str] = None
|
189 |
+
predibase_tenant_id: Optional[str] = None
|
190 |
+
togetherai_api_key: Optional[str] = None
|
191 |
+
cloudflare_api_key: Optional[str] = None
|
192 |
+
baseten_key: Optional[str] = None
|
193 |
+
aleph_alpha_key: Optional[str] = None
|
194 |
+
nlp_cloud_key: Optional[str] = None
|
195 |
+
snowflake_key: Optional[str] = None
|
196 |
+
common_cloud_provider_auth_params: dict = {
|
197 |
+
"params": ["project", "region_name", "token"],
|
198 |
+
"providers": ["vertex_ai", "bedrock", "watsonx", "azure", "vertex_ai_beta"],
|
199 |
+
}
|
200 |
+
use_client: bool = False
|
201 |
+
ssl_verify: Union[str, bool] = True
|
202 |
+
ssl_certificate: Optional[str] = None
|
203 |
+
disable_streaming_logging: bool = False
|
204 |
+
disable_add_transform_inline_image_block: bool = False
|
205 |
+
in_memory_llm_clients_cache: LLMClientCache = LLMClientCache()
|
206 |
+
safe_memory_mode: bool = False
|
207 |
+
enable_azure_ad_token_refresh: Optional[bool] = False
|
208 |
+
### DEFAULT AZURE API VERSION ###
|
209 |
+
AZURE_DEFAULT_API_VERSION = "2025-02-01-preview" # this is updated to the latest
|
210 |
+
### DEFAULT WATSONX API VERSION ###
|
211 |
+
WATSONX_DEFAULT_API_VERSION = "2024-03-13"
|
212 |
+
### COHERE EMBEDDINGS DEFAULT TYPE ###
|
213 |
+
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE: COHERE_EMBEDDING_INPUT_TYPES = "search_document"
|
214 |
+
### CREDENTIALS ###
|
215 |
+
credential_list: List[CredentialItem] = []
|
216 |
+
### GUARDRAILS ###
|
217 |
+
llamaguard_model_name: Optional[str] = None
|
218 |
+
openai_moderations_model_name: Optional[str] = None
|
219 |
+
presidio_ad_hoc_recognizers: Optional[str] = None
|
220 |
+
google_moderation_confidence_threshold: Optional[float] = None
|
221 |
+
llamaguard_unsafe_content_categories: Optional[str] = None
|
222 |
+
blocked_user_list: Optional[Union[str, List]] = None
|
223 |
+
banned_keywords_list: Optional[Union[str, List]] = None
|
224 |
+
llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all"
|
225 |
+
guardrail_name_config_map: Dict[str, GuardrailItem] = {}
|
226 |
+
##################
|
227 |
+
### PREVIEW FEATURES ###
|
228 |
+
enable_preview_features: bool = False
|
229 |
+
return_response_headers: bool = (
|
230 |
+
False # get response headers from LLM Api providers - example x-remaining-requests,
|
231 |
+
)
|
232 |
+
enable_json_schema_validation: bool = False
|
233 |
+
##################
|
234 |
+
logging: bool = True
|
235 |
+
enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
236 |
+
enable_caching_on_provider_specific_optional_params: bool = (
|
237 |
+
False # feature-flag for caching on optional params - e.g. 'top_k'
|
238 |
+
)
|
239 |
+
caching: bool = (
|
240 |
+
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
241 |
+
)
|
242 |
+
caching_with_models: bool = (
|
243 |
+
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
244 |
+
)
|
245 |
+
cache: Optional[Cache] = (
|
246 |
+
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
247 |
+
)
|
248 |
+
default_in_memory_ttl: Optional[float] = None
|
249 |
+
default_redis_ttl: Optional[float] = None
|
250 |
+
default_redis_batch_cache_expiry: Optional[float] = None
|
251 |
+
model_alias_map: Dict[str, str] = {}
|
252 |
+
model_group_alias_map: Dict[str, str] = {}
|
253 |
+
max_budget: float = 0.0 # set the max budget across all providers
|
254 |
+
budget_duration: Optional[str] = (
|
255 |
+
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
256 |
+
)
|
257 |
+
default_soft_budget: float = (
|
258 |
+
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
259 |
+
)
|
260 |
+
forward_traceparent_to_llm_provider: bool = False
|
261 |
+
|
262 |
+
|
263 |
+
_current_cost = 0.0 # private variable, used if max budget is set
|
264 |
+
error_logs: Dict = {}
|
265 |
+
add_function_to_prompt: bool = (
|
266 |
+
False # if function calling not supported by api, append function call details to system prompt
|
267 |
+
)
|
268 |
+
client_session: Optional[httpx.Client] = None
|
269 |
+
aclient_session: Optional[httpx.AsyncClient] = None
|
270 |
+
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
271 |
+
model_cost_map_url: str = (
|
272 |
+
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
273 |
+
)
|
274 |
+
suppress_debug_info = False
|
275 |
+
dynamodb_table_name: Optional[str] = None
|
276 |
+
s3_callback_params: Optional[Dict] = None
|
277 |
+
generic_logger_headers: Optional[Dict] = None
|
278 |
+
default_key_generate_params: Optional[Dict] = None
|
279 |
+
upperbound_key_generate_params: Optional[LiteLLM_UpperboundKeyGenerateParams] = None
|
280 |
+
key_generation_settings: Optional[StandardKeyGenerationConfig] = None
|
281 |
+
default_internal_user_params: Optional[Dict] = None
|
282 |
+
default_team_params: Optional[Union[DefaultTeamSSOParams, Dict]] = None
|
283 |
+
default_team_settings: Optional[List] = None
|
284 |
+
max_user_budget: Optional[float] = None
|
285 |
+
default_max_internal_user_budget: Optional[float] = None
|
286 |
+
max_internal_user_budget: Optional[float] = None
|
287 |
+
max_ui_session_budget: Optional[float] = 10 # $10 USD budgets for UI Chat sessions
|
288 |
+
internal_user_budget_duration: Optional[str] = None
|
289 |
+
tag_budget_config: Optional[Dict[str, BudgetConfig]] = None
|
290 |
+
max_end_user_budget: Optional[float] = None
|
291 |
+
disable_end_user_cost_tracking: Optional[bool] = None
|
292 |
+
disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
293 |
+
custom_prometheus_metadata_labels: List[str] = []
|
294 |
+
#### REQUEST PRIORITIZATION ####
|
295 |
+
priority_reservation: Optional[Dict[str, float]] = None
|
296 |
+
force_ipv4: bool = (
|
297 |
+
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
298 |
+
)
|
299 |
+
module_level_aclient = AsyncHTTPHandler(
|
300 |
+
timeout=request_timeout, client_alias="module level aclient"
|
301 |
+
)
|
302 |
+
module_level_client = HTTPHandler(timeout=request_timeout)
|
303 |
+
|
304 |
+
#### RETRIES ####
|
305 |
+
num_retries: Optional[int] = None # per model endpoint
|
306 |
+
max_fallbacks: Optional[int] = None
|
307 |
+
default_fallbacks: Optional[List] = None
|
308 |
+
fallbacks: Optional[List] = None
|
309 |
+
context_window_fallbacks: Optional[List] = None
|
310 |
+
content_policy_fallbacks: Optional[List] = None
|
311 |
+
allowed_fails: int = 3
|
312 |
+
num_retries_per_request: Optional[int] = (
|
313 |
+
None # for the request overall (incl. fallbacks + model retries)
|
314 |
+
)
|
315 |
+
####### SECRET MANAGERS #####################
|
316 |
+
secret_manager_client: Optional[Any] = (
|
317 |
+
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
318 |
+
)
|
319 |
+
_google_kms_resource_name: Optional[str] = None
|
320 |
+
_key_management_system: Optional[KeyManagementSystem] = None
|
321 |
+
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
322 |
+
#### PII MASKING ####
|
323 |
+
output_parse_pii: bool = False
|
324 |
+
#############################################
|
325 |
+
from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map
|
326 |
+
|
327 |
+
model_cost = get_model_cost_map(url=model_cost_map_url)
|
328 |
+
custom_prompt_dict: Dict[str, dict] = {}
|
329 |
+
check_provider_endpoint = False
|
330 |
+
|
331 |
+
|
332 |
+
####### THREAD-SPECIFIC DATA ####################
|
333 |
+
class MyLocal(threading.local):
|
334 |
+
def __init__(self):
|
335 |
+
self.user = "Hello World"
|
336 |
+
|
337 |
+
|
338 |
+
_thread_context = MyLocal()
|
339 |
+
|
340 |
+
|
341 |
+
def identify(event_details):
|
342 |
+
# Store user in thread local data
|
343 |
+
if "user" in event_details:
|
344 |
+
_thread_context.user = event_details["user"]
|
345 |
+
|
346 |
+
|
347 |
+
####### ADDITIONAL PARAMS ################### configurable params if you use proxy models like Helicone, map spend to org id, etc.
|
348 |
+
api_base: Optional[str] = None
|
349 |
+
headers = None
|
350 |
+
api_version = None
|
351 |
+
organization = None
|
352 |
+
project = None
|
353 |
+
config_path = None
|
354 |
+
vertex_ai_safety_settings: Optional[dict] = None
|
355 |
+
BEDROCK_CONVERSE_MODELS = [
|
356 |
+
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
357 |
+
"anthropic.claude-3-5-haiku-20241022-v1:0",
|
358 |
+
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
359 |
+
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
360 |
+
"anthropic.claude-3-opus-20240229-v1:0",
|
361 |
+
"anthropic.claude-3-sonnet-20240229-v1:0",
|
362 |
+
"anthropic.claude-3-haiku-20240307-v1:0",
|
363 |
+
"anthropic.claude-v2",
|
364 |
+
"anthropic.claude-v2:1",
|
365 |
+
"anthropic.claude-v1",
|
366 |
+
"anthropic.claude-instant-v1",
|
367 |
+
"ai21.jamba-instruct-v1:0",
|
368 |
+
"meta.llama3-70b-instruct-v1:0",
|
369 |
+
"meta.llama3-8b-instruct-v1:0",
|
370 |
+
"meta.llama3-1-8b-instruct-v1:0",
|
371 |
+
"meta.llama3-1-70b-instruct-v1:0",
|
372 |
+
"meta.llama3-1-405b-instruct-v1:0",
|
373 |
+
"meta.llama3-70b-instruct-v1:0",
|
374 |
+
"mistral.mistral-large-2407-v1:0",
|
375 |
+
"mistral.mistral-large-2402-v1:0",
|
376 |
+
"meta.llama3-2-1b-instruct-v1:0",
|
377 |
+
"meta.llama3-2-3b-instruct-v1:0",
|
378 |
+
"meta.llama3-2-11b-instruct-v1:0",
|
379 |
+
"meta.llama3-2-90b-instruct-v1:0",
|
380 |
+
]
|
381 |
+
|
382 |
+
####### COMPLETION MODELS ###################
|
383 |
+
open_ai_chat_completion_models: List = []
|
384 |
+
open_ai_text_completion_models: List = []
|
385 |
+
cohere_models: List = []
|
386 |
+
cohere_chat_models: List = []
|
387 |
+
mistral_chat_models: List = []
|
388 |
+
text_completion_codestral_models: List = []
|
389 |
+
anthropic_models: List = []
|
390 |
+
openrouter_models: List = []
|
391 |
+
vertex_language_models: List = []
|
392 |
+
vertex_vision_models: List = []
|
393 |
+
vertex_chat_models: List = []
|
394 |
+
vertex_code_chat_models: List = []
|
395 |
+
vertex_ai_image_models: List = []
|
396 |
+
vertex_text_models: List = []
|
397 |
+
vertex_code_text_models: List = []
|
398 |
+
vertex_embedding_models: List = []
|
399 |
+
vertex_anthropic_models: List = []
|
400 |
+
vertex_llama3_models: List = []
|
401 |
+
vertex_ai_ai21_models: List = []
|
402 |
+
vertex_mistral_models: List = []
|
403 |
+
ai21_models: List = []
|
404 |
+
ai21_chat_models: List = []
|
405 |
+
nlp_cloud_models: List = []
|
406 |
+
aleph_alpha_models: List = []
|
407 |
+
bedrock_models: List = []
|
408 |
+
bedrock_converse_models: List = BEDROCK_CONVERSE_MODELS
|
409 |
+
fireworks_ai_models: List = []
|
410 |
+
fireworks_ai_embedding_models: List = []
|
411 |
+
deepinfra_models: List = []
|
412 |
+
perplexity_models: List = []
|
413 |
+
watsonx_models: List = []
|
414 |
+
gemini_models: List = []
|
415 |
+
xai_models: List = []
|
416 |
+
deepseek_models: List = []
|
417 |
+
azure_ai_models: List = []
|
418 |
+
jina_ai_models: List = []
|
419 |
+
voyage_models: List = []
|
420 |
+
infinity_models: List = []
|
421 |
+
databricks_models: List = []
|
422 |
+
cloudflare_models: List = []
|
423 |
+
codestral_models: List = []
|
424 |
+
friendliai_models: List = []
|
425 |
+
palm_models: List = []
|
426 |
+
groq_models: List = []
|
427 |
+
azure_models: List = []
|
428 |
+
azure_text_models: List = []
|
429 |
+
anyscale_models: List = []
|
430 |
+
cerebras_models: List = []
|
431 |
+
galadriel_models: List = []
|
432 |
+
sambanova_models: List = []
|
433 |
+
assemblyai_models: List = []
|
434 |
+
snowflake_models: List = []
|
435 |
+
|
436 |
+
|
437 |
+
def is_bedrock_pricing_only_model(key: str) -> bool:
|
438 |
+
"""
|
439 |
+
Excludes keys with the pattern 'bedrock/<region>/<model>'. These are in the model_prices_and_context_window.json file for pricing purposes only.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
key (str): A key to filter.
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
bool: True if the key matches the Bedrock pattern, False otherwise.
|
446 |
+
"""
|
447 |
+
# Regex to match 'bedrock/<region>/<model>'
|
448 |
+
bedrock_pattern = re.compile(r"^bedrock/[a-zA-Z0-9_-]+/.+$")
|
449 |
+
|
450 |
+
if "month-commitment" in key:
|
451 |
+
return True
|
452 |
+
|
453 |
+
is_match = bedrock_pattern.match(key)
|
454 |
+
return is_match is not None
|
455 |
+
|
456 |
+
|
457 |
+
def is_openai_finetune_model(key: str) -> bool:
|
458 |
+
"""
|
459 |
+
Excludes model cost keys with the pattern 'ft:<model>'. These are in the model_prices_and_context_window.json file for pricing purposes only.
|
460 |
+
|
461 |
+
Args:
|
462 |
+
key (str): A key to filter.
|
463 |
+
|
464 |
+
Returns:
|
465 |
+
bool: True if the key matches the OpenAI finetune pattern, False otherwise.
|
466 |
+
"""
|
467 |
+
return key.startswith("ft:") and not key.count(":") > 1
|
468 |
+
|
469 |
+
|
470 |
+
def add_known_models():
|
471 |
+
for key, value in model_cost.items():
|
472 |
+
if value.get("litellm_provider") == "openai" and not is_openai_finetune_model(
|
473 |
+
key
|
474 |
+
):
|
475 |
+
open_ai_chat_completion_models.append(key)
|
476 |
+
elif value.get("litellm_provider") == "text-completion-openai":
|
477 |
+
open_ai_text_completion_models.append(key)
|
478 |
+
elif value.get("litellm_provider") == "azure_text":
|
479 |
+
azure_text_models.append(key)
|
480 |
+
elif value.get("litellm_provider") == "cohere":
|
481 |
+
cohere_models.append(key)
|
482 |
+
elif value.get("litellm_provider") == "cohere_chat":
|
483 |
+
cohere_chat_models.append(key)
|
484 |
+
elif value.get("litellm_provider") == "mistral":
|
485 |
+
mistral_chat_models.append(key)
|
486 |
+
elif value.get("litellm_provider") == "anthropic":
|
487 |
+
anthropic_models.append(key)
|
488 |
+
elif value.get("litellm_provider") == "empower":
|
489 |
+
empower_models.append(key)
|
490 |
+
elif value.get("litellm_provider") == "openrouter":
|
491 |
+
openrouter_models.append(key)
|
492 |
+
elif value.get("litellm_provider") == "vertex_ai-text-models":
|
493 |
+
vertex_text_models.append(key)
|
494 |
+
elif value.get("litellm_provider") == "vertex_ai-code-text-models":
|
495 |
+
vertex_code_text_models.append(key)
|
496 |
+
elif value.get("litellm_provider") == "vertex_ai-language-models":
|
497 |
+
vertex_language_models.append(key)
|
498 |
+
elif value.get("litellm_provider") == "vertex_ai-vision-models":
|
499 |
+
vertex_vision_models.append(key)
|
500 |
+
elif value.get("litellm_provider") == "vertex_ai-chat-models":
|
501 |
+
vertex_chat_models.append(key)
|
502 |
+
elif value.get("litellm_provider") == "vertex_ai-code-chat-models":
|
503 |
+
vertex_code_chat_models.append(key)
|
504 |
+
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
|
505 |
+
vertex_embedding_models.append(key)
|
506 |
+
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
|
507 |
+
key = key.replace("vertex_ai/", "")
|
508 |
+
vertex_anthropic_models.append(key)
|
509 |
+
elif value.get("litellm_provider") == "vertex_ai-llama_models":
|
510 |
+
key = key.replace("vertex_ai/", "")
|
511 |
+
vertex_llama3_models.append(key)
|
512 |
+
elif value.get("litellm_provider") == "vertex_ai-mistral_models":
|
513 |
+
key = key.replace("vertex_ai/", "")
|
514 |
+
vertex_mistral_models.append(key)
|
515 |
+
elif value.get("litellm_provider") == "vertex_ai-ai21_models":
|
516 |
+
key = key.replace("vertex_ai/", "")
|
517 |
+
vertex_ai_ai21_models.append(key)
|
518 |
+
elif value.get("litellm_provider") == "vertex_ai-image-models":
|
519 |
+
key = key.replace("vertex_ai/", "")
|
520 |
+
vertex_ai_image_models.append(key)
|
521 |
+
elif value.get("litellm_provider") == "ai21":
|
522 |
+
if value.get("mode") == "chat":
|
523 |
+
ai21_chat_models.append(key)
|
524 |
+
else:
|
525 |
+
ai21_models.append(key)
|
526 |
+
elif value.get("litellm_provider") == "nlp_cloud":
|
527 |
+
nlp_cloud_models.append(key)
|
528 |
+
elif value.get("litellm_provider") == "aleph_alpha":
|
529 |
+
aleph_alpha_models.append(key)
|
530 |
+
elif value.get(
|
531 |
+
"litellm_provider"
|
532 |
+
) == "bedrock" and not is_bedrock_pricing_only_model(key):
|
533 |
+
bedrock_models.append(key)
|
534 |
+
elif value.get("litellm_provider") == "bedrock_converse":
|
535 |
+
bedrock_converse_models.append(key)
|
536 |
+
elif value.get("litellm_provider") == "deepinfra":
|
537 |
+
deepinfra_models.append(key)
|
538 |
+
elif value.get("litellm_provider") == "perplexity":
|
539 |
+
perplexity_models.append(key)
|
540 |
+
elif value.get("litellm_provider") == "watsonx":
|
541 |
+
watsonx_models.append(key)
|
542 |
+
elif value.get("litellm_provider") == "gemini":
|
543 |
+
gemini_models.append(key)
|
544 |
+
elif value.get("litellm_provider") == "fireworks_ai":
|
545 |
+
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
546 |
+
if "-to-" not in key and "fireworks-ai-default" not in key:
|
547 |
+
fireworks_ai_models.append(key)
|
548 |
+
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
549 |
+
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
550 |
+
if "-to-" not in key:
|
551 |
+
fireworks_ai_embedding_models.append(key)
|
552 |
+
elif value.get("litellm_provider") == "text-completion-codestral":
|
553 |
+
text_completion_codestral_models.append(key)
|
554 |
+
elif value.get("litellm_provider") == "xai":
|
555 |
+
xai_models.append(key)
|
556 |
+
elif value.get("litellm_provider") == "deepseek":
|
557 |
+
deepseek_models.append(key)
|
558 |
+
elif value.get("litellm_provider") == "azure_ai":
|
559 |
+
azure_ai_models.append(key)
|
560 |
+
elif value.get("litellm_provider") == "voyage":
|
561 |
+
voyage_models.append(key)
|
562 |
+
elif value.get("litellm_provider") == "infinity":
|
563 |
+
infinity_models.append(key)
|
564 |
+
elif value.get("litellm_provider") == "databricks":
|
565 |
+
databricks_models.append(key)
|
566 |
+
elif value.get("litellm_provider") == "cloudflare":
|
567 |
+
cloudflare_models.append(key)
|
568 |
+
elif value.get("litellm_provider") == "codestral":
|
569 |
+
codestral_models.append(key)
|
570 |
+
elif value.get("litellm_provider") == "friendliai":
|
571 |
+
friendliai_models.append(key)
|
572 |
+
elif value.get("litellm_provider") == "palm":
|
573 |
+
palm_models.append(key)
|
574 |
+
elif value.get("litellm_provider") == "groq":
|
575 |
+
groq_models.append(key)
|
576 |
+
elif value.get("litellm_provider") == "azure":
|
577 |
+
azure_models.append(key)
|
578 |
+
elif value.get("litellm_provider") == "anyscale":
|
579 |
+
anyscale_models.append(key)
|
580 |
+
elif value.get("litellm_provider") == "cerebras":
|
581 |
+
cerebras_models.append(key)
|
582 |
+
elif value.get("litellm_provider") == "galadriel":
|
583 |
+
galadriel_models.append(key)
|
584 |
+
elif value.get("litellm_provider") == "sambanova_models":
|
585 |
+
sambanova_models.append(key)
|
586 |
+
elif value.get("litellm_provider") == "assemblyai":
|
587 |
+
assemblyai_models.append(key)
|
588 |
+
elif value.get("litellm_provider") == "jina_ai":
|
589 |
+
jina_ai_models.append(key)
|
590 |
+
elif value.get("litellm_provider") == "snowflake":
|
591 |
+
snowflake_models.append(key)
|
592 |
+
|
593 |
+
|
594 |
+
add_known_models()
|
595 |
+
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
596 |
+
|
597 |
+
# this is maintained for Exception Mapping
|
598 |
+
|
599 |
+
|
600 |
+
# used for Cost Tracking & Token counting
|
601 |
+
# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/
|
602 |
+
# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting
|
603 |
+
azure_llms = {
|
604 |
+
"gpt-35-turbo": "azure/gpt-35-turbo",
|
605 |
+
"gpt-35-turbo-16k": "azure/gpt-35-turbo-16k",
|
606 |
+
"gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct",
|
607 |
+
}
|
608 |
+
|
609 |
+
azure_embedding_models = {
|
610 |
+
"ada": "azure/ada",
|
611 |
+
}
|
612 |
+
|
613 |
+
petals_models = [
|
614 |
+
"petals-team/StableBeluga2",
|
615 |
+
]
|
616 |
+
|
617 |
+
ollama_models = ["llama2"]
|
618 |
+
|
619 |
+
maritalk_models = ["maritalk"]
|
620 |
+
|
621 |
+
|
622 |
+
model_list = (
|
623 |
+
open_ai_chat_completion_models
|
624 |
+
+ open_ai_text_completion_models
|
625 |
+
+ cohere_models
|
626 |
+
+ cohere_chat_models
|
627 |
+
+ anthropic_models
|
628 |
+
+ replicate_models
|
629 |
+
+ openrouter_models
|
630 |
+
+ huggingface_models
|
631 |
+
+ vertex_chat_models
|
632 |
+
+ vertex_text_models
|
633 |
+
+ ai21_models
|
634 |
+
+ ai21_chat_models
|
635 |
+
+ together_ai_models
|
636 |
+
+ baseten_models
|
637 |
+
+ aleph_alpha_models
|
638 |
+
+ nlp_cloud_models
|
639 |
+
+ ollama_models
|
640 |
+
+ bedrock_models
|
641 |
+
+ deepinfra_models
|
642 |
+
+ perplexity_models
|
643 |
+
+ maritalk_models
|
644 |
+
+ vertex_language_models
|
645 |
+
+ watsonx_models
|
646 |
+
+ gemini_models
|
647 |
+
+ text_completion_codestral_models
|
648 |
+
+ xai_models
|
649 |
+
+ deepseek_models
|
650 |
+
+ azure_ai_models
|
651 |
+
+ voyage_models
|
652 |
+
+ infinity_models
|
653 |
+
+ databricks_models
|
654 |
+
+ cloudflare_models
|
655 |
+
+ codestral_models
|
656 |
+
+ friendliai_models
|
657 |
+
+ palm_models
|
658 |
+
+ groq_models
|
659 |
+
+ azure_models
|
660 |
+
+ anyscale_models
|
661 |
+
+ cerebras_models
|
662 |
+
+ galadriel_models
|
663 |
+
+ sambanova_models
|
664 |
+
+ azure_text_models
|
665 |
+
+ assemblyai_models
|
666 |
+
+ jina_ai_models
|
667 |
+
+ snowflake_models
|
668 |
+
)
|
669 |
+
|
670 |
+
model_list_set = set(model_list)
|
671 |
+
|
672 |
+
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)
|
673 |
+
|
674 |
+
|
675 |
+
models_by_provider: dict = {
|
676 |
+
"openai": open_ai_chat_completion_models + open_ai_text_completion_models,
|
677 |
+
"text-completion-openai": open_ai_text_completion_models,
|
678 |
+
"cohere": cohere_models + cohere_chat_models,
|
679 |
+
"cohere_chat": cohere_chat_models,
|
680 |
+
"anthropic": anthropic_models,
|
681 |
+
"replicate": replicate_models,
|
682 |
+
"huggingface": huggingface_models,
|
683 |
+
"together_ai": together_ai_models,
|
684 |
+
"baseten": baseten_models,
|
685 |
+
"openrouter": openrouter_models,
|
686 |
+
"vertex_ai": vertex_chat_models
|
687 |
+
+ vertex_text_models
|
688 |
+
+ vertex_anthropic_models
|
689 |
+
+ vertex_vision_models
|
690 |
+
+ vertex_language_models,
|
691 |
+
"ai21": ai21_models,
|
692 |
+
"bedrock": bedrock_models + bedrock_converse_models,
|
693 |
+
"petals": petals_models,
|
694 |
+
"ollama": ollama_models,
|
695 |
+
"deepinfra": deepinfra_models,
|
696 |
+
"perplexity": perplexity_models,
|
697 |
+
"maritalk": maritalk_models,
|
698 |
+
"watsonx": watsonx_models,
|
699 |
+
"gemini": gemini_models,
|
700 |
+
"fireworks_ai": fireworks_ai_models + fireworks_ai_embedding_models,
|
701 |
+
"aleph_alpha": aleph_alpha_models,
|
702 |
+
"text-completion-codestral": text_completion_codestral_models,
|
703 |
+
"xai": xai_models,
|
704 |
+
"deepseek": deepseek_models,
|
705 |
+
"mistral": mistral_chat_models,
|
706 |
+
"azure_ai": azure_ai_models,
|
707 |
+
"voyage": voyage_models,
|
708 |
+
"infinity": infinity_models,
|
709 |
+
"databricks": databricks_models,
|
710 |
+
"cloudflare": cloudflare_models,
|
711 |
+
"codestral": codestral_models,
|
712 |
+
"nlp_cloud": nlp_cloud_models,
|
713 |
+
"friendliai": friendliai_models,
|
714 |
+
"palm": palm_models,
|
715 |
+
"groq": groq_models,
|
716 |
+
"azure": azure_models + azure_text_models,
|
717 |
+
"azure_text": azure_text_models,
|
718 |
+
"anyscale": anyscale_models,
|
719 |
+
"cerebras": cerebras_models,
|
720 |
+
"galadriel": galadriel_models,
|
721 |
+
"sambanova": sambanova_models,
|
722 |
+
"assemblyai": assemblyai_models,
|
723 |
+
"jina_ai": jina_ai_models,
|
724 |
+
"snowflake": snowflake_models,
|
725 |
+
}
|
726 |
+
|
727 |
+
# mapping for those models which have larger equivalents
|
728 |
+
longer_context_model_fallback_dict: dict = {
|
729 |
+
# openai chat completion models
|
730 |
+
"gpt-3.5-turbo": "gpt-3.5-turbo-16k",
|
731 |
+
"gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301",
|
732 |
+
"gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613",
|
733 |
+
"gpt-4": "gpt-4-32k",
|
734 |
+
"gpt-4-0314": "gpt-4-32k-0314",
|
735 |
+
"gpt-4-0613": "gpt-4-32k-0613",
|
736 |
+
# anthropic
|
737 |
+
"claude-instant-1": "claude-2",
|
738 |
+
"claude-instant-1.2": "claude-2",
|
739 |
+
# vertexai
|
740 |
+
"chat-bison": "chat-bison-32k",
|
741 |
+
"chat-bison@001": "chat-bison-32k",
|
742 |
+
"codechat-bison": "codechat-bison-32k",
|
743 |
+
"codechat-bison@001": "codechat-bison-32k",
|
744 |
+
# openrouter
|
745 |
+
"openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k",
|
746 |
+
"openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2",
|
747 |
+
}
|
748 |
+
|
749 |
+
####### EMBEDDING MODELS ###################
|
750 |
+
|
751 |
+
all_embedding_models = (
|
752 |
+
open_ai_embedding_models
|
753 |
+
+ cohere_embedding_models
|
754 |
+
+ bedrock_embedding_models
|
755 |
+
+ vertex_embedding_models
|
756 |
+
+ fireworks_ai_embedding_models
|
757 |
+
)
|
758 |
+
|
759 |
+
####### IMAGE GENERATION MODELS ###################
|
760 |
+
openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
761 |
+
|
762 |
+
from .timeout import timeout
|
763 |
+
from .cost_calculator import completion_cost
|
764 |
+
from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
|
765 |
+
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
766 |
+
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
|
767 |
+
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
768 |
+
from .utils import (
|
769 |
+
client,
|
770 |
+
exception_type,
|
771 |
+
get_optional_params,
|
772 |
+
get_response_string,
|
773 |
+
token_counter,
|
774 |
+
create_pretrained_tokenizer,
|
775 |
+
create_tokenizer,
|
776 |
+
supports_function_calling,
|
777 |
+
supports_web_search,
|
778 |
+
supports_response_schema,
|
779 |
+
supports_parallel_function_calling,
|
780 |
+
supports_vision,
|
781 |
+
supports_audio_input,
|
782 |
+
supports_audio_output,
|
783 |
+
supports_system_messages,
|
784 |
+
supports_reasoning,
|
785 |
+
get_litellm_params,
|
786 |
+
acreate,
|
787 |
+
get_max_tokens,
|
788 |
+
get_model_info,
|
789 |
+
register_prompt_template,
|
790 |
+
validate_environment,
|
791 |
+
check_valid_key,
|
792 |
+
register_model,
|
793 |
+
encode,
|
794 |
+
decode,
|
795 |
+
_calculate_retry_after,
|
796 |
+
_should_retry,
|
797 |
+
get_supported_openai_params,
|
798 |
+
get_api_base,
|
799 |
+
get_first_chars_messages,
|
800 |
+
ModelResponse,
|
801 |
+
ModelResponseStream,
|
802 |
+
EmbeddingResponse,
|
803 |
+
ImageResponse,
|
804 |
+
TranscriptionResponse,
|
805 |
+
TextCompletionResponse,
|
806 |
+
get_provider_fields,
|
807 |
+
ModelResponseListIterator,
|
808 |
+
)
|
809 |
+
|
810 |
+
ALL_LITELLM_RESPONSE_TYPES = [
|
811 |
+
ModelResponse,
|
812 |
+
EmbeddingResponse,
|
813 |
+
ImageResponse,
|
814 |
+
TranscriptionResponse,
|
815 |
+
TextCompletionResponse,
|
816 |
+
]
|
817 |
+
|
818 |
+
from .llms.custom_llm import CustomLLM
|
819 |
+
from .llms.bedrock.chat.converse_transformation import AmazonConverseConfig
|
820 |
+
from .llms.openai_like.chat.handler import OpenAILikeChatConfig
|
821 |
+
from .llms.aiohttp_openai.chat.transformation import AiohttpOpenAIChatConfig
|
822 |
+
from .llms.galadriel.chat.transformation import GaladrielChatConfig
|
823 |
+
from .llms.github.chat.transformation import GithubChatConfig
|
824 |
+
from .llms.empower.chat.transformation import EmpowerChatConfig
|
825 |
+
from .llms.huggingface.chat.transformation import HuggingFaceChatConfig
|
826 |
+
from .llms.huggingface.embedding.transformation import HuggingFaceEmbeddingConfig
|
827 |
+
from .llms.oobabooga.chat.transformation import OobaboogaConfig
|
828 |
+
from .llms.maritalk import MaritalkConfig
|
829 |
+
from .llms.openrouter.chat.transformation import OpenrouterConfig
|
830 |
+
from .llms.anthropic.chat.transformation import AnthropicConfig
|
831 |
+
from .llms.anthropic.common_utils import AnthropicModelInfo
|
832 |
+
from .llms.groq.stt.transformation import GroqSTTConfig
|
833 |
+
from .llms.anthropic.completion.transformation import AnthropicTextConfig
|
834 |
+
from .llms.triton.completion.transformation import TritonConfig
|
835 |
+
from .llms.triton.completion.transformation import TritonGenerateConfig
|
836 |
+
from .llms.triton.completion.transformation import TritonInferConfig
|
837 |
+
from .llms.triton.embedding.transformation import TritonEmbeddingConfig
|
838 |
+
from .llms.databricks.chat.transformation import DatabricksConfig
|
839 |
+
from .llms.databricks.embed.transformation import DatabricksEmbeddingConfig
|
840 |
+
from .llms.predibase.chat.transformation import PredibaseConfig
|
841 |
+
from .llms.replicate.chat.transformation import ReplicateConfig
|
842 |
+
from .llms.cohere.completion.transformation import CohereTextConfig as CohereConfig
|
843 |
+
from .llms.snowflake.chat.transformation import SnowflakeConfig
|
844 |
+
from .llms.cohere.rerank.transformation import CohereRerankConfig
|
845 |
+
from .llms.cohere.rerank_v2.transformation import CohereRerankV2Config
|
846 |
+
from .llms.azure_ai.rerank.transformation import AzureAIRerankConfig
|
847 |
+
from .llms.infinity.rerank.transformation import InfinityRerankConfig
|
848 |
+
from .llms.jina_ai.rerank.transformation import JinaAIRerankConfig
|
849 |
+
from .llms.clarifai.chat.transformation import ClarifaiConfig
|
850 |
+
from .llms.ai21.chat.transformation import AI21ChatConfig, AI21ChatConfig as AI21Config
|
851 |
+
from .llms.anthropic.experimental_pass_through.messages.transformation import (
|
852 |
+
AnthropicMessagesConfig,
|
853 |
+
)
|
854 |
+
from .llms.together_ai.chat import TogetherAIConfig
|
855 |
+
from .llms.together_ai.completion.transformation import TogetherAITextCompletionConfig
|
856 |
+
from .llms.cloudflare.chat.transformation import CloudflareChatConfig
|
857 |
+
from .llms.deprecated_providers.palm import (
|
858 |
+
PalmConfig,
|
859 |
+
) # here to prevent breaking changes
|
860 |
+
from .llms.nlp_cloud.chat.handler import NLPCloudConfig
|
861 |
+
from .llms.petals.completion.transformation import PetalsConfig
|
862 |
+
from .llms.deprecated_providers.aleph_alpha import AlephAlphaConfig
|
863 |
+
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
864 |
+
VertexGeminiConfig,
|
865 |
+
VertexGeminiConfig as VertexAIConfig,
|
866 |
+
)
|
867 |
+
from .llms.gemini.common_utils import GeminiModelInfo
|
868 |
+
from .llms.gemini.chat.transformation import (
|
869 |
+
GoogleAIStudioGeminiConfig,
|
870 |
+
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
|
871 |
+
)
|
872 |
+
|
873 |
+
|
874 |
+
from .llms.vertex_ai.vertex_embeddings.transformation import (
|
875 |
+
VertexAITextEmbeddingConfig,
|
876 |
+
)
|
877 |
+
|
878 |
+
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
|
879 |
+
|
880 |
+
from .llms.vertex_ai.vertex_ai_partner_models.anthropic.transformation import (
|
881 |
+
VertexAIAnthropicConfig,
|
882 |
+
)
|
883 |
+
from .llms.vertex_ai.vertex_ai_partner_models.llama3.transformation import (
|
884 |
+
VertexAILlama3Config,
|
885 |
+
)
|
886 |
+
from .llms.vertex_ai.vertex_ai_partner_models.ai21.transformation import (
|
887 |
+
VertexAIAi21Config,
|
888 |
+
)
|
889 |
+
|
890 |
+
from .llms.ollama.completion.transformation import OllamaConfig
|
891 |
+
from .llms.sagemaker.completion.transformation import SagemakerConfig
|
892 |
+
from .llms.sagemaker.chat.transformation import SagemakerChatConfig
|
893 |
+
from .llms.ollama_chat import OllamaChatConfig
|
894 |
+
from .llms.bedrock.chat.invoke_handler import (
|
895 |
+
AmazonCohereChatConfig,
|
896 |
+
bedrock_tool_name_mappings,
|
897 |
+
)
|
898 |
+
|
899 |
+
from .llms.bedrock.common_utils import (
|
900 |
+
AmazonBedrockGlobalConfig,
|
901 |
+
)
|
902 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_ai21_transformation import (
|
903 |
+
AmazonAI21Config,
|
904 |
+
)
|
905 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_nova_transformation import (
|
906 |
+
AmazonInvokeNovaConfig,
|
907 |
+
)
|
908 |
+
from .llms.bedrock.chat.invoke_transformations.anthropic_claude2_transformation import (
|
909 |
+
AmazonAnthropicConfig,
|
910 |
+
)
|
911 |
+
from .llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import (
|
912 |
+
AmazonAnthropicClaude3Config,
|
913 |
+
)
|
914 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_cohere_transformation import (
|
915 |
+
AmazonCohereConfig,
|
916 |
+
)
|
917 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_llama_transformation import (
|
918 |
+
AmazonLlamaConfig,
|
919 |
+
)
|
920 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_deepseek_transformation import (
|
921 |
+
AmazonDeepSeekR1Config,
|
922 |
+
)
|
923 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_mistral_transformation import (
|
924 |
+
AmazonMistralConfig,
|
925 |
+
)
|
926 |
+
from .llms.bedrock.chat.invoke_transformations.amazon_titan_transformation import (
|
927 |
+
AmazonTitanConfig,
|
928 |
+
)
|
929 |
+
from .llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
930 |
+
AmazonInvokeConfig,
|
931 |
+
)
|
932 |
+
|
933 |
+
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
|
934 |
+
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
|
935 |
+
from .llms.bedrock.image.amazon_nova_canvas_transformation import AmazonNovaCanvasConfig
|
936 |
+
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
|
937 |
+
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
|
938 |
+
AmazonTitanMultimodalEmbeddingG1Config,
|
939 |
+
)
|
940 |
+
from .llms.bedrock.embed.amazon_titan_v2_transformation import (
|
941 |
+
AmazonTitanV2Config,
|
942 |
+
)
|
943 |
+
from .llms.cohere.chat.transformation import CohereChatConfig
|
944 |
+
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
|
945 |
+
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
|
946 |
+
from .llms.openai.image_variations.transformation import OpenAIImageVariationConfig
|
947 |
+
from .llms.deepinfra.chat.transformation import DeepInfraConfig
|
948 |
+
from .llms.deepgram.audio_transcription.transformation import (
|
949 |
+
DeepgramAudioTranscriptionConfig,
|
950 |
+
)
|
951 |
+
from .llms.topaz.common_utils import TopazModelInfo
|
952 |
+
from .llms.topaz.image_variations.transformation import TopazImageVariationConfig
|
953 |
+
from litellm.llms.openai.completion.transformation import OpenAITextCompletionConfig
|
954 |
+
from .llms.groq.chat.transformation import GroqChatConfig
|
955 |
+
from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
956 |
+
from .llms.infinity.embedding.transformation import InfinityEmbeddingConfig
|
957 |
+
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
958 |
+
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
959 |
+
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
960 |
+
from .llms.azure.responses.transformation import AzureOpenAIResponsesAPIConfig
|
961 |
+
from .llms.openai.chat.o_series_transformation import (
|
962 |
+
OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
|
963 |
+
OpenAIOSeriesConfig,
|
964 |
+
)
|
965 |
+
|
966 |
+
from .llms.snowflake.chat.transformation import SnowflakeConfig
|
967 |
+
|
968 |
+
openaiOSeriesConfig = OpenAIOSeriesConfig()
|
969 |
+
from .llms.openai.chat.gpt_transformation import (
|
970 |
+
OpenAIGPTConfig,
|
971 |
+
)
|
972 |
+
from .llms.openai.transcriptions.whisper_transformation import (
|
973 |
+
OpenAIWhisperAudioTranscriptionConfig,
|
974 |
+
)
|
975 |
+
from .llms.openai.transcriptions.gpt_transformation import (
|
976 |
+
OpenAIGPTAudioTranscriptionConfig,
|
977 |
+
)
|
978 |
+
|
979 |
+
openAIGPTConfig = OpenAIGPTConfig()
|
980 |
+
from .llms.openai.chat.gpt_audio_transformation import (
|
981 |
+
OpenAIGPTAudioConfig,
|
982 |
+
)
|
983 |
+
|
984 |
+
openAIGPTAudioConfig = OpenAIGPTAudioConfig()
|
985 |
+
|
986 |
+
from .llms.nvidia_nim.chat import NvidiaNimConfig
|
987 |
+
from .llms.nvidia_nim.embed import NvidiaNimEmbeddingConfig
|
988 |
+
|
989 |
+
nvidiaNimConfig = NvidiaNimConfig()
|
990 |
+
nvidiaNimEmbeddingConfig = NvidiaNimEmbeddingConfig()
|
991 |
+
|
992 |
+
from .llms.cerebras.chat import CerebrasConfig
|
993 |
+
from .llms.sambanova.chat import SambanovaConfig
|
994 |
+
from .llms.ai21.chat.transformation import AI21ChatConfig
|
995 |
+
from .llms.fireworks_ai.chat.transformation import FireworksAIConfig
|
996 |
+
from .llms.fireworks_ai.completion.transformation import FireworksAITextCompletionConfig
|
997 |
+
from .llms.fireworks_ai.audio_transcription.transformation import (
|
998 |
+
FireworksAIAudioTranscriptionConfig,
|
999 |
+
)
|
1000 |
+
from .llms.fireworks_ai.embed.fireworks_ai_transformation import (
|
1001 |
+
FireworksAIEmbeddingConfig,
|
1002 |
+
)
|
1003 |
+
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
|
1004 |
+
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
|
1005 |
+
from .llms.xai.chat.transformation import XAIChatConfig
|
1006 |
+
from .llms.xai.common_utils import XAIModelInfo
|
1007 |
+
from .llms.volcengine import VolcEngineConfig
|
1008 |
+
from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
|
1009 |
+
from .llms.azure.azure import (
|
1010 |
+
AzureOpenAIError,
|
1011 |
+
AzureOpenAIAssistantsAPIConfig,
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
1015 |
+
from .llms.azure.completion.transformation import AzureOpenAITextConfig
|
1016 |
+
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
1017 |
+
from .llms.llamafile.chat.transformation import LlamafileChatConfig
|
1018 |
+
from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
|
1019 |
+
from .llms.vllm.completion.transformation import VLLMConfig
|
1020 |
+
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
1021 |
+
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
1022 |
+
from .llms.lm_studio.embed.transformation import LmStudioEmbeddingConfig
|
1023 |
+
from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
1024 |
+
from .llms.azure.chat.o_series_transformation import AzureOpenAIO1Config
|
1025 |
+
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
1026 |
+
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
1027 |
+
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
1028 |
+
from .main import * # type: ignore
|
1029 |
+
from .integrations import *
|
1030 |
+
from .exceptions import (
|
1031 |
+
AuthenticationError,
|
1032 |
+
InvalidRequestError,
|
1033 |
+
BadRequestError,
|
1034 |
+
NotFoundError,
|
1035 |
+
RateLimitError,
|
1036 |
+
ServiceUnavailableError,
|
1037 |
+
OpenAIError,
|
1038 |
+
ContextWindowExceededError,
|
1039 |
+
ContentPolicyViolationError,
|
1040 |
+
BudgetExceededError,
|
1041 |
+
APIError,
|
1042 |
+
Timeout,
|
1043 |
+
APIConnectionError,
|
1044 |
+
UnsupportedParamsError,
|
1045 |
+
APIResponseValidationError,
|
1046 |
+
UnprocessableEntityError,
|
1047 |
+
InternalServerError,
|
1048 |
+
JSONSchemaValidationError,
|
1049 |
+
LITELLM_EXCEPTION_TYPES,
|
1050 |
+
MockException,
|
1051 |
+
)
|
1052 |
+
from .budget_manager import BudgetManager
|
1053 |
+
from .proxy.proxy_cli import run_server
|
1054 |
+
from .router import Router
|
1055 |
+
from .assistants.main import *
|
1056 |
+
from .batches.main import *
|
1057 |
+
from .batch_completion.main import * # type: ignore
|
1058 |
+
from .rerank_api.main import *
|
1059 |
+
from .llms.anthropic.experimental_pass_through.messages.handler import *
|
1060 |
+
from .responses.main import *
|
1061 |
+
from .realtime_api.main import _arealtime
|
1062 |
+
from .fine_tuning.main import *
|
1063 |
+
from .files.main import *
|
1064 |
+
from .scheduler import *
|
1065 |
+
from .cost_calculator import response_cost_calculator, cost_per_token
|
1066 |
+
|
1067 |
+
### ADAPTERS ###
|
1068 |
+
from .types.adapter import AdapterItem
|
1069 |
+
import litellm.anthropic_interface as anthropic
|
1070 |
+
|
1071 |
+
adapters: List[AdapterItem] = []
|
1072 |
+
|
1073 |
+
### CUSTOM LLMs ###
|
1074 |
+
from .types.llms.custom_llm import CustomLLMItem
|
1075 |
+
from .types.utils import GenericStreamingChunk
|
1076 |
+
|
1077 |
+
custom_provider_map: List[CustomLLMItem] = []
|
1078 |
+
_custom_providers: List[str] = (
|
1079 |
+
[]
|
1080 |
+
) # internal helper util, used to track names of custom providers
|
1081 |
+
disable_hf_tokenizer_download: Optional[bool] = (
|
1082 |
+
None # disable huggingface tokenizer download. Defaults to openai clk100
|
1083 |
+
)
|
1084 |
+
global_disable_no_log_param: bool = False
|
litellm/_logging.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from datetime import datetime
|
6 |
+
from logging import Formatter
|
7 |
+
|
8 |
+
set_verbose = False
|
9 |
+
|
10 |
+
if set_verbose is True:
|
11 |
+
logging.warning(
|
12 |
+
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
13 |
+
)
|
14 |
+
json_logs = bool(os.getenv("JSON_LOGS", False))
|
15 |
+
# Create a handler for the logger (you may need to adapt this based on your needs)
|
16 |
+
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
17 |
+
numeric_level: str = getattr(logging, log_level.upper())
|
18 |
+
handler = logging.StreamHandler()
|
19 |
+
handler.setLevel(numeric_level)
|
20 |
+
|
21 |
+
|
22 |
+
class JsonFormatter(Formatter):
|
23 |
+
def __init__(self):
|
24 |
+
super(JsonFormatter, self).__init__()
|
25 |
+
|
26 |
+
def formatTime(self, record, datefmt=None):
|
27 |
+
# Use datetime to format the timestamp in ISO 8601 format
|
28 |
+
dt = datetime.fromtimestamp(record.created)
|
29 |
+
return dt.isoformat()
|
30 |
+
|
31 |
+
def format(self, record):
|
32 |
+
json_record = {
|
33 |
+
"message": record.getMessage(),
|
34 |
+
"level": record.levelname,
|
35 |
+
"timestamp": self.formatTime(record),
|
36 |
+
}
|
37 |
+
|
38 |
+
if record.exc_info:
|
39 |
+
json_record["stacktrace"] = self.formatException(record.exc_info)
|
40 |
+
|
41 |
+
return json.dumps(json_record)
|
42 |
+
|
43 |
+
|
44 |
+
# Function to set up exception handlers for JSON logging
|
45 |
+
def _setup_json_exception_handlers(formatter):
|
46 |
+
# Create a handler with JSON formatting for exceptions
|
47 |
+
error_handler = logging.StreamHandler()
|
48 |
+
error_handler.setFormatter(formatter)
|
49 |
+
|
50 |
+
# Setup excepthook for uncaught exceptions
|
51 |
+
def json_excepthook(exc_type, exc_value, exc_traceback):
|
52 |
+
record = logging.LogRecord(
|
53 |
+
name="LiteLLM",
|
54 |
+
level=logging.ERROR,
|
55 |
+
pathname="",
|
56 |
+
lineno=0,
|
57 |
+
msg=str(exc_value),
|
58 |
+
args=(),
|
59 |
+
exc_info=(exc_type, exc_value, exc_traceback),
|
60 |
+
)
|
61 |
+
error_handler.handle(record)
|
62 |
+
|
63 |
+
sys.excepthook = json_excepthook
|
64 |
+
|
65 |
+
# Configure asyncio exception handler if possible
|
66 |
+
try:
|
67 |
+
import asyncio
|
68 |
+
|
69 |
+
def async_json_exception_handler(loop, context):
|
70 |
+
exception = context.get("exception")
|
71 |
+
if exception:
|
72 |
+
record = logging.LogRecord(
|
73 |
+
name="LiteLLM",
|
74 |
+
level=logging.ERROR,
|
75 |
+
pathname="",
|
76 |
+
lineno=0,
|
77 |
+
msg=str(exception),
|
78 |
+
args=(),
|
79 |
+
exc_info=None,
|
80 |
+
)
|
81 |
+
error_handler.handle(record)
|
82 |
+
else:
|
83 |
+
loop.default_exception_handler(context)
|
84 |
+
|
85 |
+
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
|
86 |
+
except Exception:
|
87 |
+
pass
|
88 |
+
|
89 |
+
|
90 |
+
# Create a formatter and set it for the handler
|
91 |
+
if json_logs:
|
92 |
+
handler.setFormatter(JsonFormatter())
|
93 |
+
_setup_json_exception_handlers(JsonFormatter())
|
94 |
+
else:
|
95 |
+
formatter = logging.Formatter(
|
96 |
+
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
97 |
+
datefmt="%H:%M:%S",
|
98 |
+
)
|
99 |
+
|
100 |
+
handler.setFormatter(formatter)
|
101 |
+
|
102 |
+
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
103 |
+
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
104 |
+
verbose_logger = logging.getLogger("LiteLLM")
|
105 |
+
|
106 |
+
# Add the handler to the logger
|
107 |
+
verbose_router_logger.addHandler(handler)
|
108 |
+
verbose_proxy_logger.addHandler(handler)
|
109 |
+
verbose_logger.addHandler(handler)
|
110 |
+
|
111 |
+
|
112 |
+
def _turn_on_json():
|
113 |
+
handler = logging.StreamHandler()
|
114 |
+
handler.setFormatter(JsonFormatter())
|
115 |
+
|
116 |
+
# Define all loggers to update, including root logger
|
117 |
+
loggers = [logging.getLogger()] + [
|
118 |
+
verbose_router_logger,
|
119 |
+
verbose_proxy_logger,
|
120 |
+
verbose_logger,
|
121 |
+
]
|
122 |
+
|
123 |
+
# Iterate through each logger and update its handlers
|
124 |
+
for logger in loggers:
|
125 |
+
# Remove all existing handlers
|
126 |
+
for h in logger.handlers[:]:
|
127 |
+
logger.removeHandler(h)
|
128 |
+
# Add the new handler
|
129 |
+
logger.addHandler(handler)
|
130 |
+
|
131 |
+
# Set up exception handlers
|
132 |
+
_setup_json_exception_handlers(JsonFormatter())
|
133 |
+
|
134 |
+
|
135 |
+
def _turn_on_debug():
|
136 |
+
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
137 |
+
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
138 |
+
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
139 |
+
|
140 |
+
|
141 |
+
def _disable_debugging():
|
142 |
+
verbose_logger.disabled = True
|
143 |
+
verbose_router_logger.disabled = True
|
144 |
+
verbose_proxy_logger.disabled = True
|
145 |
+
|
146 |
+
|
147 |
+
def _enable_debugging():
|
148 |
+
verbose_logger.disabled = False
|
149 |
+
verbose_router_logger.disabled = False
|
150 |
+
verbose_proxy_logger.disabled = False
|
151 |
+
|
152 |
+
|
153 |
+
def print_verbose(print_statement):
|
154 |
+
try:
|
155 |
+
if set_verbose:
|
156 |
+
print(print_statement) # noqa
|
157 |
+
except Exception:
|
158 |
+
pass
|
159 |
+
|
160 |
+
|
161 |
+
def _is_debugging_on() -> bool:
|
162 |
+
"""
|
163 |
+
Returns True if debugging is on
|
164 |
+
"""
|
165 |
+
if verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True:
|
166 |
+
return True
|
167 |
+
return False
|
litellm/_redis.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +-----------------------------------------------+
|
2 |
+
# | |
|
3 |
+
# | Give Feedback / Get Help |
|
4 |
+
# | https://github.com/BerriAI/litellm/issues/new |
|
5 |
+
# | |
|
6 |
+
# +-----------------------------------------------+
|
7 |
+
#
|
8 |
+
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
9 |
+
|
10 |
+
import inspect
|
11 |
+
import json
|
12 |
+
|
13 |
+
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
14 |
+
import os
|
15 |
+
from typing import List, Optional, Union
|
16 |
+
|
17 |
+
import redis # type: ignore
|
18 |
+
import redis.asyncio as async_redis # type: ignore
|
19 |
+
|
20 |
+
from litellm import get_secret, get_secret_str
|
21 |
+
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
|
22 |
+
|
23 |
+
from ._logging import verbose_logger
|
24 |
+
|
25 |
+
|
26 |
+
def _get_redis_kwargs():
|
27 |
+
arg_spec = inspect.getfullargspec(redis.Redis)
|
28 |
+
|
29 |
+
# Only allow primitive arguments
|
30 |
+
exclude_args = {
|
31 |
+
"self",
|
32 |
+
"connection_pool",
|
33 |
+
"retry",
|
34 |
+
}
|
35 |
+
|
36 |
+
include_args = ["url"]
|
37 |
+
|
38 |
+
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
39 |
+
|
40 |
+
return available_args
|
41 |
+
|
42 |
+
|
43 |
+
def _get_redis_url_kwargs(client=None):
|
44 |
+
if client is None:
|
45 |
+
client = redis.Redis.from_url
|
46 |
+
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
|
47 |
+
|
48 |
+
# Only allow primitive arguments
|
49 |
+
exclude_args = {
|
50 |
+
"self",
|
51 |
+
"connection_pool",
|
52 |
+
"retry",
|
53 |
+
}
|
54 |
+
|
55 |
+
include_args = ["url"]
|
56 |
+
|
57 |
+
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
58 |
+
|
59 |
+
return available_args
|
60 |
+
|
61 |
+
|
62 |
+
def _get_redis_cluster_kwargs(client=None):
|
63 |
+
if client is None:
|
64 |
+
client = redis.Redis.from_url
|
65 |
+
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
66 |
+
|
67 |
+
# Only allow primitive arguments
|
68 |
+
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
69 |
+
|
70 |
+
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
71 |
+
available_args.append("password")
|
72 |
+
available_args.append("username")
|
73 |
+
available_args.append("ssl")
|
74 |
+
|
75 |
+
return available_args
|
76 |
+
|
77 |
+
|
78 |
+
def _get_redis_env_kwarg_mapping():
|
79 |
+
PREFIX = "REDIS_"
|
80 |
+
|
81 |
+
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
|
82 |
+
|
83 |
+
|
84 |
+
def _redis_kwargs_from_environment():
|
85 |
+
mapping = _get_redis_env_kwarg_mapping()
|
86 |
+
|
87 |
+
return_dict = {}
|
88 |
+
for k, v in mapping.items():
|
89 |
+
value = get_secret(k, default_value=None) # type: ignore
|
90 |
+
if value is not None:
|
91 |
+
return_dict[v] = value
|
92 |
+
return return_dict
|
93 |
+
|
94 |
+
|
95 |
+
def get_redis_url_from_environment():
|
96 |
+
if "REDIS_URL" in os.environ:
|
97 |
+
return os.environ["REDIS_URL"]
|
98 |
+
|
99 |
+
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
100 |
+
raise ValueError(
|
101 |
+
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
|
102 |
+
)
|
103 |
+
|
104 |
+
if "REDIS_PASSWORD" in os.environ:
|
105 |
+
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
|
106 |
+
else:
|
107 |
+
redis_password = ""
|
108 |
+
|
109 |
+
return (
|
110 |
+
f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def _get_redis_client_logic(**env_overrides):
|
115 |
+
"""
|
116 |
+
Common functionality across sync + async redis client implementations
|
117 |
+
"""
|
118 |
+
### check if "os.environ/<key-name>" passed in
|
119 |
+
for k, v in env_overrides.items():
|
120 |
+
if isinstance(v, str) and v.startswith("os.environ/"):
|
121 |
+
v = v.replace("os.environ/", "")
|
122 |
+
value = get_secret(v) # type: ignore
|
123 |
+
env_overrides[k] = value
|
124 |
+
|
125 |
+
redis_kwargs = {
|
126 |
+
**_redis_kwargs_from_environment(),
|
127 |
+
**env_overrides,
|
128 |
+
}
|
129 |
+
|
130 |
+
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
|
131 |
+
"REDIS_CLUSTER_NODES"
|
132 |
+
)
|
133 |
+
|
134 |
+
if _startup_nodes is not None and isinstance(_startup_nodes, str):
|
135 |
+
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
136 |
+
|
137 |
+
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
|
138 |
+
"REDIS_SENTINEL_NODES"
|
139 |
+
)
|
140 |
+
|
141 |
+
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
|
142 |
+
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
|
143 |
+
|
144 |
+
_sentinel_password: Optional[str] = redis_kwargs.get(
|
145 |
+
"sentinel_password", None
|
146 |
+
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
|
147 |
+
|
148 |
+
if _sentinel_password is not None:
|
149 |
+
redis_kwargs["sentinel_password"] = _sentinel_password
|
150 |
+
|
151 |
+
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
|
152 |
+
"REDIS_SERVICE_NAME"
|
153 |
+
)
|
154 |
+
|
155 |
+
if _service_name is not None:
|
156 |
+
redis_kwargs["service_name"] = _service_name
|
157 |
+
|
158 |
+
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
159 |
+
redis_kwargs.pop("host", None)
|
160 |
+
redis_kwargs.pop("port", None)
|
161 |
+
redis_kwargs.pop("db", None)
|
162 |
+
redis_kwargs.pop("password", None)
|
163 |
+
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
164 |
+
pass
|
165 |
+
elif (
|
166 |
+
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
|
167 |
+
):
|
168 |
+
pass
|
169 |
+
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
170 |
+
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
171 |
+
|
172 |
+
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
173 |
+
return redis_kwargs
|
174 |
+
|
175 |
+
|
176 |
+
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
177 |
+
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
|
178 |
+
if _redis_cluster_nodes_in_env is not None:
|
179 |
+
try:
|
180 |
+
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
181 |
+
except json.JSONDecodeError:
|
182 |
+
raise ValueError(
|
183 |
+
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
184 |
+
)
|
185 |
+
|
186 |
+
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
|
187 |
+
from redis.cluster import ClusterNode
|
188 |
+
|
189 |
+
args = _get_redis_cluster_kwargs()
|
190 |
+
cluster_kwargs = {}
|
191 |
+
for arg in redis_kwargs:
|
192 |
+
if arg in args:
|
193 |
+
cluster_kwargs[arg] = redis_kwargs[arg]
|
194 |
+
|
195 |
+
new_startup_nodes: List[ClusterNode] = []
|
196 |
+
|
197 |
+
for item in redis_kwargs["startup_nodes"]:
|
198 |
+
new_startup_nodes.append(ClusterNode(**item))
|
199 |
+
|
200 |
+
redis_kwargs.pop("startup_nodes")
|
201 |
+
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
|
202 |
+
|
203 |
+
|
204 |
+
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
205 |
+
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
206 |
+
sentinel_password = redis_kwargs.get("sentinel_password")
|
207 |
+
service_name = redis_kwargs.get("service_name")
|
208 |
+
|
209 |
+
if not sentinel_nodes or not service_name:
|
210 |
+
raise ValueError(
|
211 |
+
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
212 |
+
)
|
213 |
+
|
214 |
+
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
215 |
+
|
216 |
+
# Set up the Sentinel client
|
217 |
+
sentinel = redis.Sentinel(
|
218 |
+
sentinel_nodes,
|
219 |
+
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
220 |
+
password=sentinel_password,
|
221 |
+
)
|
222 |
+
|
223 |
+
# Return the master instance for the given service
|
224 |
+
|
225 |
+
return sentinel.master_for(service_name)
|
226 |
+
|
227 |
+
|
228 |
+
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
229 |
+
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
230 |
+
sentinel_password = redis_kwargs.get("sentinel_password")
|
231 |
+
service_name = redis_kwargs.get("service_name")
|
232 |
+
|
233 |
+
if not sentinel_nodes or not service_name:
|
234 |
+
raise ValueError(
|
235 |
+
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
236 |
+
)
|
237 |
+
|
238 |
+
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
239 |
+
|
240 |
+
# Set up the Sentinel client
|
241 |
+
sentinel = async_redis.Sentinel(
|
242 |
+
sentinel_nodes,
|
243 |
+
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
244 |
+
password=sentinel_password,
|
245 |
+
)
|
246 |
+
|
247 |
+
# Return the master instance for the given service
|
248 |
+
|
249 |
+
return sentinel.master_for(service_name)
|
250 |
+
|
251 |
+
|
252 |
+
def get_redis_client(**env_overrides):
|
253 |
+
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
254 |
+
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
255 |
+
args = _get_redis_url_kwargs()
|
256 |
+
url_kwargs = {}
|
257 |
+
for arg in redis_kwargs:
|
258 |
+
if arg in args:
|
259 |
+
url_kwargs[arg] = redis_kwargs[arg]
|
260 |
+
|
261 |
+
return redis.Redis.from_url(**url_kwargs)
|
262 |
+
|
263 |
+
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
|
264 |
+
return init_redis_cluster(redis_kwargs)
|
265 |
+
|
266 |
+
# Check for Redis Sentinel
|
267 |
+
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
268 |
+
return _init_redis_sentinel(redis_kwargs)
|
269 |
+
|
270 |
+
return redis.Redis(**redis_kwargs)
|
271 |
+
|
272 |
+
|
273 |
+
def get_redis_async_client(
|
274 |
+
**env_overrides,
|
275 |
+
) -> async_redis.Redis:
|
276 |
+
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
277 |
+
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
278 |
+
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
279 |
+
url_kwargs = {}
|
280 |
+
for arg in redis_kwargs:
|
281 |
+
if arg in args:
|
282 |
+
url_kwargs[arg] = redis_kwargs[arg]
|
283 |
+
else:
|
284 |
+
verbose_logger.debug(
|
285 |
+
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
286 |
+
arg
|
287 |
+
)
|
288 |
+
)
|
289 |
+
return async_redis.Redis.from_url(**url_kwargs)
|
290 |
+
|
291 |
+
if "startup_nodes" in redis_kwargs:
|
292 |
+
from redis.cluster import ClusterNode
|
293 |
+
|
294 |
+
args = _get_redis_cluster_kwargs()
|
295 |
+
cluster_kwargs = {}
|
296 |
+
for arg in redis_kwargs:
|
297 |
+
if arg in args:
|
298 |
+
cluster_kwargs[arg] = redis_kwargs[arg]
|
299 |
+
|
300 |
+
new_startup_nodes: List[ClusterNode] = []
|
301 |
+
|
302 |
+
for item in redis_kwargs["startup_nodes"]:
|
303 |
+
new_startup_nodes.append(ClusterNode(**item))
|
304 |
+
redis_kwargs.pop("startup_nodes")
|
305 |
+
return async_redis.RedisCluster(
|
306 |
+
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
307 |
+
)
|
308 |
+
|
309 |
+
# Check for Redis Sentinel
|
310 |
+
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
311 |
+
return _init_async_redis_sentinel(redis_kwargs)
|
312 |
+
|
313 |
+
return async_redis.Redis(
|
314 |
+
**redis_kwargs,
|
315 |
+
)
|
316 |
+
|
317 |
+
|
318 |
+
def get_redis_connection_pool(**env_overrides):
|
319 |
+
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
320 |
+
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
321 |
+
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
322 |
+
return async_redis.BlockingConnectionPool.from_url(
|
323 |
+
timeout=REDIS_CONNECTION_POOL_TIMEOUT, url=redis_kwargs["url"]
|
324 |
+
)
|
325 |
+
connection_class = async_redis.Connection
|
326 |
+
if "ssl" in redis_kwargs:
|
327 |
+
connection_class = async_redis.SSLConnection
|
328 |
+
redis_kwargs.pop("ssl", None)
|
329 |
+
redis_kwargs["connection_class"] = connection_class
|
330 |
+
redis_kwargs.pop("startup_nodes", None)
|
331 |
+
return async_redis.BlockingConnectionPool(
|
332 |
+
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
|
333 |
+
)
|
litellm/_service_logger.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from datetime import datetime, timedelta
|
3 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
4 |
+
|
5 |
+
import litellm
|
6 |
+
from litellm._logging import verbose_logger
|
7 |
+
from litellm.proxy._types import UserAPIKeyAuth
|
8 |
+
|
9 |
+
from .integrations.custom_logger import CustomLogger
|
10 |
+
from .integrations.datadog.datadog import DataDogLogger
|
11 |
+
from .integrations.opentelemetry import OpenTelemetry
|
12 |
+
from .integrations.prometheus_services import PrometheusServicesLogger
|
13 |
+
from .types.services import ServiceLoggerPayload, ServiceTypes
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from opentelemetry.trace import Span as _Span
|
17 |
+
|
18 |
+
Span = Union[_Span, Any]
|
19 |
+
OTELClass = OpenTelemetry
|
20 |
+
else:
|
21 |
+
Span = Any
|
22 |
+
OTELClass = Any
|
23 |
+
|
24 |
+
|
25 |
+
class ServiceLogging(CustomLogger):
|
26 |
+
"""
|
27 |
+
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, mock_testing: bool = False) -> None:
|
31 |
+
self.mock_testing = mock_testing
|
32 |
+
self.mock_testing_sync_success_hook = 0
|
33 |
+
self.mock_testing_async_success_hook = 0
|
34 |
+
self.mock_testing_sync_failure_hook = 0
|
35 |
+
self.mock_testing_async_failure_hook = 0
|
36 |
+
if "prometheus_system" in litellm.service_callback:
|
37 |
+
self.prometheusServicesLogger = PrometheusServicesLogger()
|
38 |
+
|
39 |
+
def service_success_hook(
|
40 |
+
self,
|
41 |
+
service: ServiceTypes,
|
42 |
+
duration: float,
|
43 |
+
call_type: str,
|
44 |
+
parent_otel_span: Optional[Span] = None,
|
45 |
+
start_time: Optional[Union[datetime, float]] = None,
|
46 |
+
end_time: Optional[Union[float, datetime]] = None,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Handles both sync and async monitoring by checking for existing event loop.
|
50 |
+
"""
|
51 |
+
|
52 |
+
if self.mock_testing:
|
53 |
+
self.mock_testing_sync_success_hook += 1
|
54 |
+
|
55 |
+
try:
|
56 |
+
# Try to get the current event loop
|
57 |
+
loop = asyncio.get_event_loop()
|
58 |
+
# Check if the loop is running
|
59 |
+
if loop.is_running():
|
60 |
+
# If we're in a running loop, create a task
|
61 |
+
loop.create_task(
|
62 |
+
self.async_service_success_hook(
|
63 |
+
service=service,
|
64 |
+
duration=duration,
|
65 |
+
call_type=call_type,
|
66 |
+
parent_otel_span=parent_otel_span,
|
67 |
+
start_time=start_time,
|
68 |
+
end_time=end_time,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
# Loop exists but not running, we can use run_until_complete
|
73 |
+
loop.run_until_complete(
|
74 |
+
self.async_service_success_hook(
|
75 |
+
service=service,
|
76 |
+
duration=duration,
|
77 |
+
call_type=call_type,
|
78 |
+
parent_otel_span=parent_otel_span,
|
79 |
+
start_time=start_time,
|
80 |
+
end_time=end_time,
|
81 |
+
)
|
82 |
+
)
|
83 |
+
except RuntimeError:
|
84 |
+
# No event loop exists, create a new one and run
|
85 |
+
asyncio.run(
|
86 |
+
self.async_service_success_hook(
|
87 |
+
service=service,
|
88 |
+
duration=duration,
|
89 |
+
call_type=call_type,
|
90 |
+
parent_otel_span=parent_otel_span,
|
91 |
+
start_time=start_time,
|
92 |
+
end_time=end_time,
|
93 |
+
)
|
94 |
+
)
|
95 |
+
|
96 |
+
def service_failure_hook(
|
97 |
+
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
98 |
+
):
|
99 |
+
"""
|
100 |
+
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
101 |
+
"""
|
102 |
+
if self.mock_testing:
|
103 |
+
self.mock_testing_sync_failure_hook += 1
|
104 |
+
|
105 |
+
async def async_service_success_hook(
|
106 |
+
self,
|
107 |
+
service: ServiceTypes,
|
108 |
+
call_type: str,
|
109 |
+
duration: float,
|
110 |
+
parent_otel_span: Optional[Span] = None,
|
111 |
+
start_time: Optional[Union[datetime, float]] = None,
|
112 |
+
end_time: Optional[Union[datetime, float]] = None,
|
113 |
+
event_metadata: Optional[dict] = None,
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
- For counting if the redis, postgres call is successful
|
117 |
+
"""
|
118 |
+
if self.mock_testing:
|
119 |
+
self.mock_testing_async_success_hook += 1
|
120 |
+
|
121 |
+
payload = ServiceLoggerPayload(
|
122 |
+
is_error=False,
|
123 |
+
error=None,
|
124 |
+
service=service,
|
125 |
+
duration=duration,
|
126 |
+
call_type=call_type,
|
127 |
+
event_metadata=event_metadata,
|
128 |
+
)
|
129 |
+
|
130 |
+
for callback in litellm.service_callback:
|
131 |
+
if callback == "prometheus_system":
|
132 |
+
await self.init_prometheus_services_logger_if_none()
|
133 |
+
await self.prometheusServicesLogger.async_service_success_hook(
|
134 |
+
payload=payload
|
135 |
+
)
|
136 |
+
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
137 |
+
await self.init_datadog_logger_if_none()
|
138 |
+
await self.dd_logger.async_service_success_hook(
|
139 |
+
payload=payload,
|
140 |
+
parent_otel_span=parent_otel_span,
|
141 |
+
start_time=start_time,
|
142 |
+
end_time=end_time,
|
143 |
+
event_metadata=event_metadata,
|
144 |
+
)
|
145 |
+
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
146 |
+
from litellm.proxy.proxy_server import open_telemetry_logger
|
147 |
+
|
148 |
+
await self.init_otel_logger_if_none()
|
149 |
+
|
150 |
+
if (
|
151 |
+
parent_otel_span is not None
|
152 |
+
and open_telemetry_logger is not None
|
153 |
+
and isinstance(open_telemetry_logger, OpenTelemetry)
|
154 |
+
):
|
155 |
+
await self.otel_logger.async_service_success_hook(
|
156 |
+
payload=payload,
|
157 |
+
parent_otel_span=parent_otel_span,
|
158 |
+
start_time=start_time,
|
159 |
+
end_time=end_time,
|
160 |
+
event_metadata=event_metadata,
|
161 |
+
)
|
162 |
+
|
163 |
+
async def init_prometheus_services_logger_if_none(self):
|
164 |
+
"""
|
165 |
+
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
|
166 |
+
|
167 |
+
"""
|
168 |
+
if not hasattr(self, "prometheusServicesLogger"):
|
169 |
+
self.prometheusServicesLogger = PrometheusServicesLogger()
|
170 |
+
elif self.prometheusServicesLogger is None:
|
171 |
+
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
172 |
+
return
|
173 |
+
|
174 |
+
async def init_datadog_logger_if_none(self):
|
175 |
+
"""
|
176 |
+
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
|
177 |
+
|
178 |
+
"""
|
179 |
+
from litellm.integrations.datadog.datadog import DataDogLogger
|
180 |
+
|
181 |
+
if not hasattr(self, "dd_logger"):
|
182 |
+
self.dd_logger: DataDogLogger = DataDogLogger()
|
183 |
+
|
184 |
+
return
|
185 |
+
|
186 |
+
async def init_otel_logger_if_none(self):
|
187 |
+
"""
|
188 |
+
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
|
189 |
+
|
190 |
+
"""
|
191 |
+
from litellm.proxy.proxy_server import open_telemetry_logger
|
192 |
+
|
193 |
+
if not hasattr(self, "otel_logger"):
|
194 |
+
if open_telemetry_logger is not None and isinstance(
|
195 |
+
open_telemetry_logger, OpenTelemetry
|
196 |
+
):
|
197 |
+
self.otel_logger: OpenTelemetry = open_telemetry_logger
|
198 |
+
else:
|
199 |
+
verbose_logger.warning(
|
200 |
+
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
async def async_service_failure_hook(
|
205 |
+
self,
|
206 |
+
service: ServiceTypes,
|
207 |
+
duration: float,
|
208 |
+
error: Union[str, Exception],
|
209 |
+
call_type: str,
|
210 |
+
parent_otel_span: Optional[Span] = None,
|
211 |
+
start_time: Optional[Union[datetime, float]] = None,
|
212 |
+
end_time: Optional[Union[float, datetime]] = None,
|
213 |
+
event_metadata: Optional[dict] = None,
|
214 |
+
):
|
215 |
+
"""
|
216 |
+
- For counting if the redis, postgres call is unsuccessful
|
217 |
+
"""
|
218 |
+
if self.mock_testing:
|
219 |
+
self.mock_testing_async_failure_hook += 1
|
220 |
+
|
221 |
+
error_message = ""
|
222 |
+
if isinstance(error, Exception):
|
223 |
+
error_message = str(error)
|
224 |
+
elif isinstance(error, str):
|
225 |
+
error_message = error
|
226 |
+
|
227 |
+
payload = ServiceLoggerPayload(
|
228 |
+
is_error=True,
|
229 |
+
error=error_message,
|
230 |
+
service=service,
|
231 |
+
duration=duration,
|
232 |
+
call_type=call_type,
|
233 |
+
event_metadata=event_metadata,
|
234 |
+
)
|
235 |
+
|
236 |
+
for callback in litellm.service_callback:
|
237 |
+
if callback == "prometheus_system":
|
238 |
+
await self.init_prometheus_services_logger_if_none()
|
239 |
+
await self.prometheusServicesLogger.async_service_failure_hook(
|
240 |
+
payload=payload,
|
241 |
+
error=error,
|
242 |
+
)
|
243 |
+
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
244 |
+
await self.init_datadog_logger_if_none()
|
245 |
+
await self.dd_logger.async_service_failure_hook(
|
246 |
+
payload=payload,
|
247 |
+
error=error_message,
|
248 |
+
parent_otel_span=parent_otel_span,
|
249 |
+
start_time=start_time,
|
250 |
+
end_time=end_time,
|
251 |
+
event_metadata=event_metadata,
|
252 |
+
)
|
253 |
+
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
254 |
+
from litellm.proxy.proxy_server import open_telemetry_logger
|
255 |
+
|
256 |
+
await self.init_otel_logger_if_none()
|
257 |
+
|
258 |
+
if not isinstance(error, str):
|
259 |
+
error = str(error)
|
260 |
+
|
261 |
+
if (
|
262 |
+
parent_otel_span is not None
|
263 |
+
and open_telemetry_logger is not None
|
264 |
+
and isinstance(open_telemetry_logger, OpenTelemetry)
|
265 |
+
):
|
266 |
+
await self.otel_logger.async_service_success_hook(
|
267 |
+
payload=payload,
|
268 |
+
parent_otel_span=parent_otel_span,
|
269 |
+
start_time=start_time,
|
270 |
+
end_time=end_time,
|
271 |
+
event_metadata=event_metadata,
|
272 |
+
)
|
273 |
+
|
274 |
+
async def async_post_call_failure_hook(
|
275 |
+
self,
|
276 |
+
request_data: dict,
|
277 |
+
original_exception: Exception,
|
278 |
+
user_api_key_dict: UserAPIKeyAuth,
|
279 |
+
):
|
280 |
+
"""
|
281 |
+
Hook to track failed litellm-service calls
|
282 |
+
"""
|
283 |
+
return await super().async_post_call_failure_hook(
|
284 |
+
request_data,
|
285 |
+
original_exception,
|
286 |
+
user_api_key_dict,
|
287 |
+
)
|
288 |
+
|
289 |
+
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
290 |
+
"""
|
291 |
+
Hook to track latency for litellm proxy llm api calls
|
292 |
+
"""
|
293 |
+
try:
|
294 |
+
_duration = end_time - start_time
|
295 |
+
if isinstance(_duration, timedelta):
|
296 |
+
_duration = _duration.total_seconds()
|
297 |
+
elif isinstance(_duration, float):
|
298 |
+
pass
|
299 |
+
else:
|
300 |
+
raise Exception(
|
301 |
+
"Duration={} is not a float or timedelta object. type={}".format(
|
302 |
+
_duration, type(_duration)
|
303 |
+
)
|
304 |
+
) # invalid _duration value
|
305 |
+
await self.async_service_success_hook(
|
306 |
+
service=ServiceTypes.LITELLM,
|
307 |
+
duration=_duration,
|
308 |
+
call_type=kwargs["call_type"],
|
309 |
+
)
|
310 |
+
except Exception as e:
|
311 |
+
raise e
|
litellm/_version.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib_metadata
|
2 |
+
|
3 |
+
try:
|
4 |
+
version = importlib_metadata.version("litellm")
|
5 |
+
except Exception:
|
6 |
+
version = "unknown"
|
litellm/anthropic_interface/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Anthropic module for LiteLLM
|
3 |
+
"""
|
4 |
+
from .messages import acreate, create
|
5 |
+
|
6 |
+
__all__ = ["acreate", "create"]
|
litellm/anthropic_interface/messages/__init__.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Interface for Anthropic's messages API
|
3 |
+
|
4 |
+
Use this to call LLMs in Anthropic /messages Request/Response format
|
5 |
+
|
6 |
+
This is an __init__.py file to allow the following interface
|
7 |
+
|
8 |
+
- litellm.messages.acreate
|
9 |
+
- litellm.messages.create
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
|
14 |
+
|
15 |
+
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
16 |
+
anthropic_messages as _async_anthropic_messages,
|
17 |
+
)
|
18 |
+
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
19 |
+
AnthropicMessagesResponse,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
async def acreate(
|
24 |
+
max_tokens: int,
|
25 |
+
messages: List[Dict],
|
26 |
+
model: str,
|
27 |
+
metadata: Optional[Dict] = None,
|
28 |
+
stop_sequences: Optional[List[str]] = None,
|
29 |
+
stream: Optional[bool] = False,
|
30 |
+
system: Optional[str] = None,
|
31 |
+
temperature: Optional[float] = 1.0,
|
32 |
+
thinking: Optional[Dict] = None,
|
33 |
+
tool_choice: Optional[Dict] = None,
|
34 |
+
tools: Optional[List[Dict]] = None,
|
35 |
+
top_k: Optional[int] = None,
|
36 |
+
top_p: Optional[float] = None,
|
37 |
+
**kwargs
|
38 |
+
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
39 |
+
"""
|
40 |
+
Async wrapper for Anthropic's messages API
|
41 |
+
|
42 |
+
Args:
|
43 |
+
max_tokens (int): Maximum tokens to generate (required)
|
44 |
+
messages (List[Dict]): List of message objects with role and content (required)
|
45 |
+
model (str): Model name to use (required)
|
46 |
+
metadata (Dict, optional): Request metadata
|
47 |
+
stop_sequences (List[str], optional): Custom stop sequences
|
48 |
+
stream (bool, optional): Whether to stream the response
|
49 |
+
system (str, optional): System prompt
|
50 |
+
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
51 |
+
thinking (Dict, optional): Extended thinking configuration
|
52 |
+
tool_choice (Dict, optional): Tool choice configuration
|
53 |
+
tools (List[Dict], optional): List of tool definitions
|
54 |
+
top_k (int, optional): Top K sampling parameter
|
55 |
+
top_p (float, optional): Nucleus sampling parameter
|
56 |
+
**kwargs: Additional arguments
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Dict: Response from the API
|
60 |
+
"""
|
61 |
+
return await _async_anthropic_messages(
|
62 |
+
max_tokens=max_tokens,
|
63 |
+
messages=messages,
|
64 |
+
model=model,
|
65 |
+
metadata=metadata,
|
66 |
+
stop_sequences=stop_sequences,
|
67 |
+
stream=stream,
|
68 |
+
system=system,
|
69 |
+
temperature=temperature,
|
70 |
+
thinking=thinking,
|
71 |
+
tool_choice=tool_choice,
|
72 |
+
tools=tools,
|
73 |
+
top_k=top_k,
|
74 |
+
top_p=top_p,
|
75 |
+
**kwargs,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
async def create(
|
80 |
+
max_tokens: int,
|
81 |
+
messages: List[Dict],
|
82 |
+
model: str,
|
83 |
+
metadata: Optional[Dict] = None,
|
84 |
+
stop_sequences: Optional[List[str]] = None,
|
85 |
+
stream: Optional[bool] = False,
|
86 |
+
system: Optional[str] = None,
|
87 |
+
temperature: Optional[float] = 1.0,
|
88 |
+
thinking: Optional[Dict] = None,
|
89 |
+
tool_choice: Optional[Dict] = None,
|
90 |
+
tools: Optional[List[Dict]] = None,
|
91 |
+
top_k: Optional[int] = None,
|
92 |
+
top_p: Optional[float] = None,
|
93 |
+
**kwargs
|
94 |
+
) -> Union[AnthropicMessagesResponse, Iterator]:
|
95 |
+
"""
|
96 |
+
Async wrapper for Anthropic's messages API
|
97 |
+
|
98 |
+
Args:
|
99 |
+
max_tokens (int): Maximum tokens to generate (required)
|
100 |
+
messages (List[Dict]): List of message objects with role and content (required)
|
101 |
+
model (str): Model name to use (required)
|
102 |
+
metadata (Dict, optional): Request metadata
|
103 |
+
stop_sequences (List[str], optional): Custom stop sequences
|
104 |
+
stream (bool, optional): Whether to stream the response
|
105 |
+
system (str, optional): System prompt
|
106 |
+
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
107 |
+
thinking (Dict, optional): Extended thinking configuration
|
108 |
+
tool_choice (Dict, optional): Tool choice configuration
|
109 |
+
tools (List[Dict], optional): List of tool definitions
|
110 |
+
top_k (int, optional): Top K sampling parameter
|
111 |
+
top_p (float, optional): Nucleus sampling parameter
|
112 |
+
**kwargs: Additional arguments
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
Dict: Response from the API
|
116 |
+
"""
|
117 |
+
raise NotImplementedError("This function is not implemented")
|
litellm/anthropic_interface/readme.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Use LLM API endpoints in Anthropic Interface
|
2 |
+
|
3 |
+
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
|
4 |
+
|
5 |
+
|
6 |
+
## Usage
|
7 |
+
---
|
8 |
+
|
9 |
+
### LiteLLM Python SDK
|
10 |
+
|
11 |
+
#### Non-streaming example
|
12 |
+
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
13 |
+
import litellm
|
14 |
+
response = await litellm.anthropic.messages.acreate(
|
15 |
+
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
16 |
+
api_key=api_key,
|
17 |
+
model="anthropic/claude-3-haiku-20240307",
|
18 |
+
max_tokens=100,
|
19 |
+
)
|
20 |
+
```
|
21 |
+
|
22 |
+
Example response:
|
23 |
+
```json
|
24 |
+
{
|
25 |
+
"content": [
|
26 |
+
{
|
27 |
+
"text": "Hi! this is a very short joke",
|
28 |
+
"type": "text"
|
29 |
+
}
|
30 |
+
],
|
31 |
+
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
32 |
+
"model": "claude-3-7-sonnet-20250219",
|
33 |
+
"role": "assistant",
|
34 |
+
"stop_reason": "end_turn",
|
35 |
+
"stop_sequence": null,
|
36 |
+
"type": "message",
|
37 |
+
"usage": {
|
38 |
+
"input_tokens": 2095,
|
39 |
+
"output_tokens": 503,
|
40 |
+
"cache_creation_input_tokens": 2095,
|
41 |
+
"cache_read_input_tokens": 0
|
42 |
+
}
|
43 |
+
}
|
44 |
+
```
|
45 |
+
|
46 |
+
#### Streaming example
|
47 |
+
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
48 |
+
import litellm
|
49 |
+
response = await litellm.anthropic.messages.acreate(
|
50 |
+
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
51 |
+
api_key=api_key,
|
52 |
+
model="anthropic/claude-3-haiku-20240307",
|
53 |
+
max_tokens=100,
|
54 |
+
stream=True,
|
55 |
+
)
|
56 |
+
async for chunk in response:
|
57 |
+
print(chunk)
|
58 |
+
```
|
59 |
+
|
60 |
+
### LiteLLM Proxy Server
|
61 |
+
|
62 |
+
|
63 |
+
1. Setup config.yaml
|
64 |
+
|
65 |
+
```yaml
|
66 |
+
model_list:
|
67 |
+
- model_name: anthropic-claude
|
68 |
+
litellm_params:
|
69 |
+
model: claude-3-7-sonnet-latest
|
70 |
+
```
|
71 |
+
|
72 |
+
2. Start proxy
|
73 |
+
|
74 |
+
```bash
|
75 |
+
litellm --config /path/to/config.yaml
|
76 |
+
```
|
77 |
+
|
78 |
+
3. Test it!
|
79 |
+
|
80 |
+
<Tabs>
|
81 |
+
<TabItem label="Anthropic Python SDK" value="python">
|
82 |
+
|
83 |
+
```python showLineNumbers title="Example using LiteLLM Proxy Server"
|
84 |
+
import anthropic
|
85 |
+
|
86 |
+
# point anthropic sdk to litellm proxy
|
87 |
+
client = anthropic.Anthropic(
|
88 |
+
base_url="http://0.0.0.0:4000",
|
89 |
+
api_key="sk-1234",
|
90 |
+
)
|
91 |
+
|
92 |
+
response = client.messages.create(
|
93 |
+
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
94 |
+
model="anthropic/claude-3-haiku-20240307",
|
95 |
+
max_tokens=100,
|
96 |
+
)
|
97 |
+
```
|
98 |
+
</TabItem>
|
99 |
+
<TabItem label="curl" value="curl">
|
100 |
+
|
101 |
+
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
|
102 |
+
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
|
103 |
+
-H 'content-type: application/json' \
|
104 |
+
-H 'x-api-key: $LITELLM_API_KEY' \
|
105 |
+
-H 'anthropic-version: 2023-06-01' \
|
106 |
+
-d '{
|
107 |
+
"model": "anthropic-claude",
|
108 |
+
"messages": [
|
109 |
+
{
|
110 |
+
"role": "user",
|
111 |
+
"content": "Hello, can you tell me a short joke?"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"max_tokens": 100
|
115 |
+
}'
|
116 |
+
```
|
litellm/assistants/main.py
ADDED
@@ -0,0 +1,1484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# What is this?
|
2 |
+
## Main file for assistants API logic
|
3 |
+
import asyncio
|
4 |
+
import contextvars
|
5 |
+
import os
|
6 |
+
from functools import partial
|
7 |
+
from typing import Any, Coroutine, Dict, Iterable, List, Literal, Optional, Union
|
8 |
+
|
9 |
+
import httpx
|
10 |
+
from openai import AsyncOpenAI, OpenAI
|
11 |
+
from openai.types.beta.assistant import Assistant
|
12 |
+
from openai.types.beta.assistant_deleted import AssistantDeleted
|
13 |
+
|
14 |
+
import litellm
|
15 |
+
from litellm.types.router import GenericLiteLLMParams
|
16 |
+
from litellm.utils import (
|
17 |
+
exception_type,
|
18 |
+
get_litellm_params,
|
19 |
+
get_llm_provider,
|
20 |
+
get_secret,
|
21 |
+
supports_httpx_timeout,
|
22 |
+
)
|
23 |
+
|
24 |
+
from ..llms.azure.assistants import AzureAssistantsAPI
|
25 |
+
from ..llms.openai.openai import OpenAIAssistantsAPI
|
26 |
+
from ..types.llms.openai import *
|
27 |
+
from ..types.router import *
|
28 |
+
from .utils import get_optional_params_add_message
|
29 |
+
|
30 |
+
####### ENVIRONMENT VARIABLES ###################
|
31 |
+
openai_assistants_api = OpenAIAssistantsAPI()
|
32 |
+
azure_assistants_api = AzureAssistantsAPI()
|
33 |
+
|
34 |
+
### ASSISTANTS ###
|
35 |
+
|
36 |
+
|
37 |
+
async def aget_assistants(
|
38 |
+
custom_llm_provider: Literal["openai", "azure"],
|
39 |
+
client: Optional[AsyncOpenAI] = None,
|
40 |
+
**kwargs,
|
41 |
+
) -> AsyncCursorPage[Assistant]:
|
42 |
+
loop = asyncio.get_event_loop()
|
43 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
44 |
+
kwargs["aget_assistants"] = True
|
45 |
+
try:
|
46 |
+
# Use a partial function to pass your keyword arguments
|
47 |
+
func = partial(get_assistants, custom_llm_provider, client, **kwargs)
|
48 |
+
|
49 |
+
# Add the context to the function
|
50 |
+
ctx = contextvars.copy_context()
|
51 |
+
func_with_context = partial(ctx.run, func)
|
52 |
+
|
53 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
54 |
+
model="", custom_llm_provider=custom_llm_provider
|
55 |
+
) # type: ignore
|
56 |
+
|
57 |
+
# Await normally
|
58 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
59 |
+
if asyncio.iscoroutine(init_response):
|
60 |
+
response = await init_response
|
61 |
+
else:
|
62 |
+
response = init_response
|
63 |
+
return response # type: ignore
|
64 |
+
except Exception as e:
|
65 |
+
raise exception_type(
|
66 |
+
model="",
|
67 |
+
custom_llm_provider=custom_llm_provider,
|
68 |
+
original_exception=e,
|
69 |
+
completion_kwargs={},
|
70 |
+
extra_kwargs=kwargs,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
def get_assistants(
|
75 |
+
custom_llm_provider: Literal["openai", "azure"],
|
76 |
+
client: Optional[Any] = None,
|
77 |
+
api_key: Optional[str] = None,
|
78 |
+
api_base: Optional[str] = None,
|
79 |
+
api_version: Optional[str] = None,
|
80 |
+
**kwargs,
|
81 |
+
) -> SyncCursorPage[Assistant]:
|
82 |
+
aget_assistants: Optional[bool] = kwargs.pop("aget_assistants", None)
|
83 |
+
if aget_assistants is not None and not isinstance(aget_assistants, bool):
|
84 |
+
raise Exception(
|
85 |
+
"Invalid value passed in for aget_assistants. Only bool or None allowed"
|
86 |
+
)
|
87 |
+
optional_params = GenericLiteLLMParams(
|
88 |
+
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
89 |
+
)
|
90 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
91 |
+
|
92 |
+
### TIMEOUT LOGIC ###
|
93 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
94 |
+
# set timeout for 10 minutes by default
|
95 |
+
|
96 |
+
if (
|
97 |
+
timeout is not None
|
98 |
+
and isinstance(timeout, httpx.Timeout)
|
99 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
100 |
+
):
|
101 |
+
read_timeout = timeout.read or 600
|
102 |
+
timeout = read_timeout # default 10 min timeout
|
103 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
104 |
+
timeout = float(timeout) # type: ignore
|
105 |
+
elif timeout is None:
|
106 |
+
timeout = 600.0
|
107 |
+
|
108 |
+
response: Optional[SyncCursorPage[Assistant]] = None
|
109 |
+
if custom_llm_provider == "openai":
|
110 |
+
api_base = (
|
111 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
112 |
+
or litellm.api_base
|
113 |
+
or os.getenv("OPENAI_BASE_URL")
|
114 |
+
or os.getenv("OPENAI_API_BASE")
|
115 |
+
or "https://api.openai.com/v1"
|
116 |
+
)
|
117 |
+
organization = (
|
118 |
+
optional_params.organization
|
119 |
+
or litellm.organization
|
120 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
121 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
122 |
+
)
|
123 |
+
# set API KEY
|
124 |
+
api_key = (
|
125 |
+
optional_params.api_key
|
126 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
127 |
+
or litellm.openai_key
|
128 |
+
or os.getenv("OPENAI_API_KEY")
|
129 |
+
)
|
130 |
+
|
131 |
+
response = openai_assistants_api.get_assistants(
|
132 |
+
api_base=api_base,
|
133 |
+
api_key=api_key,
|
134 |
+
timeout=timeout,
|
135 |
+
max_retries=optional_params.max_retries,
|
136 |
+
organization=organization,
|
137 |
+
client=client,
|
138 |
+
aget_assistants=aget_assistants, # type: ignore
|
139 |
+
) # type: ignore
|
140 |
+
elif custom_llm_provider == "azure":
|
141 |
+
api_base = (
|
142 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
143 |
+
) # type: ignore
|
144 |
+
|
145 |
+
api_version = (
|
146 |
+
optional_params.api_version
|
147 |
+
or litellm.api_version
|
148 |
+
or get_secret("AZURE_API_VERSION")
|
149 |
+
) # type: ignore
|
150 |
+
|
151 |
+
api_key = (
|
152 |
+
optional_params.api_key
|
153 |
+
or litellm.api_key
|
154 |
+
or litellm.azure_key
|
155 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
156 |
+
or get_secret("AZURE_API_KEY")
|
157 |
+
) # type: ignore
|
158 |
+
|
159 |
+
extra_body = optional_params.get("extra_body", {})
|
160 |
+
azure_ad_token: Optional[str] = None
|
161 |
+
if extra_body is not None:
|
162 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
163 |
+
else:
|
164 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
165 |
+
|
166 |
+
response = azure_assistants_api.get_assistants(
|
167 |
+
api_base=api_base,
|
168 |
+
api_key=api_key,
|
169 |
+
api_version=api_version,
|
170 |
+
azure_ad_token=azure_ad_token,
|
171 |
+
timeout=timeout,
|
172 |
+
max_retries=optional_params.max_retries,
|
173 |
+
client=client,
|
174 |
+
aget_assistants=aget_assistants, # type: ignore
|
175 |
+
litellm_params=litellm_params_dict,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
raise litellm.exceptions.BadRequestError(
|
179 |
+
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
|
180 |
+
custom_llm_provider
|
181 |
+
),
|
182 |
+
model="n/a",
|
183 |
+
llm_provider=custom_llm_provider,
|
184 |
+
response=httpx.Response(
|
185 |
+
status_code=400,
|
186 |
+
content="Unsupported provider",
|
187 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
188 |
+
),
|
189 |
+
)
|
190 |
+
|
191 |
+
if response is None:
|
192 |
+
raise litellm.exceptions.BadRequestError(
|
193 |
+
message="LiteLLM doesn't support {} for 'get_assistants'. Only 'openai' is supported.".format(
|
194 |
+
custom_llm_provider
|
195 |
+
),
|
196 |
+
model="n/a",
|
197 |
+
llm_provider=custom_llm_provider,
|
198 |
+
response=httpx.Response(
|
199 |
+
status_code=400,
|
200 |
+
content="Unsupported provider",
|
201 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
202 |
+
),
|
203 |
+
)
|
204 |
+
|
205 |
+
return response
|
206 |
+
|
207 |
+
|
208 |
+
async def acreate_assistants(
|
209 |
+
custom_llm_provider: Literal["openai", "azure"],
|
210 |
+
client: Optional[AsyncOpenAI] = None,
|
211 |
+
**kwargs,
|
212 |
+
) -> Assistant:
|
213 |
+
loop = asyncio.get_event_loop()
|
214 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
215 |
+
kwargs["async_create_assistants"] = True
|
216 |
+
model = kwargs.pop("model", None)
|
217 |
+
try:
|
218 |
+
kwargs["client"] = client
|
219 |
+
# Use a partial function to pass your keyword arguments
|
220 |
+
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
|
221 |
+
|
222 |
+
# Add the context to the function
|
223 |
+
ctx = contextvars.copy_context()
|
224 |
+
func_with_context = partial(ctx.run, func)
|
225 |
+
|
226 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
227 |
+
model=model, custom_llm_provider=custom_llm_provider
|
228 |
+
) # type: ignore
|
229 |
+
|
230 |
+
# Await normally
|
231 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
232 |
+
if asyncio.iscoroutine(init_response):
|
233 |
+
response = await init_response
|
234 |
+
else:
|
235 |
+
response = init_response
|
236 |
+
return response # type: ignore
|
237 |
+
except Exception as e:
|
238 |
+
raise exception_type(
|
239 |
+
model=model,
|
240 |
+
custom_llm_provider=custom_llm_provider,
|
241 |
+
original_exception=e,
|
242 |
+
completion_kwargs={},
|
243 |
+
extra_kwargs=kwargs,
|
244 |
+
)
|
245 |
+
|
246 |
+
|
247 |
+
def create_assistants(
|
248 |
+
custom_llm_provider: Literal["openai", "azure"],
|
249 |
+
model: str,
|
250 |
+
name: Optional[str] = None,
|
251 |
+
description: Optional[str] = None,
|
252 |
+
instructions: Optional[str] = None,
|
253 |
+
tools: Optional[List[Dict[str, Any]]] = None,
|
254 |
+
tool_resources: Optional[Dict[str, Any]] = None,
|
255 |
+
metadata: Optional[Dict[str, str]] = None,
|
256 |
+
temperature: Optional[float] = None,
|
257 |
+
top_p: Optional[float] = None,
|
258 |
+
response_format: Optional[Union[str, Dict[str, str]]] = None,
|
259 |
+
client: Optional[Any] = None,
|
260 |
+
api_key: Optional[str] = None,
|
261 |
+
api_base: Optional[str] = None,
|
262 |
+
api_version: Optional[str] = None,
|
263 |
+
**kwargs,
|
264 |
+
) -> Union[Assistant, Coroutine[Any, Any, Assistant]]:
|
265 |
+
async_create_assistants: Optional[bool] = kwargs.pop(
|
266 |
+
"async_create_assistants", None
|
267 |
+
)
|
268 |
+
if async_create_assistants is not None and not isinstance(
|
269 |
+
async_create_assistants, bool
|
270 |
+
):
|
271 |
+
raise ValueError(
|
272 |
+
"Invalid value passed in for async_create_assistants. Only bool or None allowed"
|
273 |
+
)
|
274 |
+
optional_params = GenericLiteLLMParams(
|
275 |
+
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
276 |
+
)
|
277 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
278 |
+
|
279 |
+
### TIMEOUT LOGIC ###
|
280 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
281 |
+
# set timeout for 10 minutes by default
|
282 |
+
|
283 |
+
if (
|
284 |
+
timeout is not None
|
285 |
+
and isinstance(timeout, httpx.Timeout)
|
286 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
287 |
+
):
|
288 |
+
read_timeout = timeout.read or 600
|
289 |
+
timeout = read_timeout # default 10 min timeout
|
290 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
291 |
+
timeout = float(timeout) # type: ignore
|
292 |
+
elif timeout is None:
|
293 |
+
timeout = 600.0
|
294 |
+
|
295 |
+
create_assistant_data = {
|
296 |
+
"model": model,
|
297 |
+
"name": name,
|
298 |
+
"description": description,
|
299 |
+
"instructions": instructions,
|
300 |
+
"tools": tools,
|
301 |
+
"tool_resources": tool_resources,
|
302 |
+
"metadata": metadata,
|
303 |
+
"temperature": temperature,
|
304 |
+
"top_p": top_p,
|
305 |
+
"response_format": response_format,
|
306 |
+
}
|
307 |
+
|
308 |
+
# only send params that are not None
|
309 |
+
create_assistant_data = {
|
310 |
+
k: v for k, v in create_assistant_data.items() if v is not None
|
311 |
+
}
|
312 |
+
|
313 |
+
response: Optional[Union[Coroutine[Any, Any, Assistant], Assistant]] = None
|
314 |
+
if custom_llm_provider == "openai":
|
315 |
+
api_base = (
|
316 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
317 |
+
or litellm.api_base
|
318 |
+
or os.getenv("OPENAI_BASE_URL")
|
319 |
+
or os.getenv("OPENAI_API_BASE")
|
320 |
+
or "https://api.openai.com/v1"
|
321 |
+
)
|
322 |
+
organization = (
|
323 |
+
optional_params.organization
|
324 |
+
or litellm.organization
|
325 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
326 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
327 |
+
)
|
328 |
+
# set API KEY
|
329 |
+
api_key = (
|
330 |
+
optional_params.api_key
|
331 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
332 |
+
or litellm.openai_key
|
333 |
+
or os.getenv("OPENAI_API_KEY")
|
334 |
+
)
|
335 |
+
|
336 |
+
response = openai_assistants_api.create_assistants(
|
337 |
+
api_base=api_base,
|
338 |
+
api_key=api_key,
|
339 |
+
timeout=timeout,
|
340 |
+
max_retries=optional_params.max_retries,
|
341 |
+
organization=organization,
|
342 |
+
create_assistant_data=create_assistant_data,
|
343 |
+
client=client,
|
344 |
+
async_create_assistants=async_create_assistants, # type: ignore
|
345 |
+
) # type: ignore
|
346 |
+
elif custom_llm_provider == "azure":
|
347 |
+
api_base = (
|
348 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
349 |
+
) # type: ignore
|
350 |
+
|
351 |
+
api_version = (
|
352 |
+
optional_params.api_version
|
353 |
+
or litellm.api_version
|
354 |
+
or get_secret("AZURE_API_VERSION")
|
355 |
+
) # type: ignore
|
356 |
+
|
357 |
+
api_key = (
|
358 |
+
optional_params.api_key
|
359 |
+
or litellm.api_key
|
360 |
+
or litellm.azure_key
|
361 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
362 |
+
or get_secret("AZURE_API_KEY")
|
363 |
+
) # type: ignore
|
364 |
+
|
365 |
+
extra_body = optional_params.get("extra_body", {})
|
366 |
+
azure_ad_token: Optional[str] = None
|
367 |
+
if extra_body is not None:
|
368 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
369 |
+
else:
|
370 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
371 |
+
|
372 |
+
if isinstance(client, OpenAI):
|
373 |
+
client = None # only pass client if it's AzureOpenAI
|
374 |
+
|
375 |
+
response = azure_assistants_api.create_assistants(
|
376 |
+
api_base=api_base,
|
377 |
+
api_key=api_key,
|
378 |
+
azure_ad_token=azure_ad_token,
|
379 |
+
api_version=api_version,
|
380 |
+
timeout=timeout,
|
381 |
+
max_retries=optional_params.max_retries,
|
382 |
+
client=client,
|
383 |
+
async_create_assistants=async_create_assistants,
|
384 |
+
create_assistant_data=create_assistant_data,
|
385 |
+
litellm_params=litellm_params_dict,
|
386 |
+
)
|
387 |
+
else:
|
388 |
+
raise litellm.exceptions.BadRequestError(
|
389 |
+
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
|
390 |
+
custom_llm_provider
|
391 |
+
),
|
392 |
+
model="n/a",
|
393 |
+
llm_provider=custom_llm_provider,
|
394 |
+
response=httpx.Response(
|
395 |
+
status_code=400,
|
396 |
+
content="Unsupported provider",
|
397 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
398 |
+
),
|
399 |
+
)
|
400 |
+
if response is None:
|
401 |
+
raise litellm.exceptions.InternalServerError(
|
402 |
+
message="No response returned from 'create_assistants'",
|
403 |
+
model=model,
|
404 |
+
llm_provider=custom_llm_provider,
|
405 |
+
)
|
406 |
+
return response
|
407 |
+
|
408 |
+
|
409 |
+
async def adelete_assistant(
|
410 |
+
custom_llm_provider: Literal["openai", "azure"],
|
411 |
+
client: Optional[AsyncOpenAI] = None,
|
412 |
+
**kwargs,
|
413 |
+
) -> AssistantDeleted:
|
414 |
+
loop = asyncio.get_event_loop()
|
415 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
416 |
+
kwargs["async_delete_assistants"] = True
|
417 |
+
try:
|
418 |
+
kwargs["client"] = client
|
419 |
+
# Use a partial function to pass your keyword arguments
|
420 |
+
func = partial(delete_assistant, custom_llm_provider, **kwargs)
|
421 |
+
|
422 |
+
# Add the context to the function
|
423 |
+
ctx = contextvars.copy_context()
|
424 |
+
func_with_context = partial(ctx.run, func)
|
425 |
+
|
426 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
427 |
+
model="", custom_llm_provider=custom_llm_provider
|
428 |
+
) # type: ignore
|
429 |
+
|
430 |
+
# Await normally
|
431 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
432 |
+
if asyncio.iscoroutine(init_response):
|
433 |
+
response = await init_response
|
434 |
+
else:
|
435 |
+
response = init_response
|
436 |
+
return response # type: ignore
|
437 |
+
except Exception as e:
|
438 |
+
raise exception_type(
|
439 |
+
model="",
|
440 |
+
custom_llm_provider=custom_llm_provider,
|
441 |
+
original_exception=e,
|
442 |
+
completion_kwargs={},
|
443 |
+
extra_kwargs=kwargs,
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
def delete_assistant(
|
448 |
+
custom_llm_provider: Literal["openai", "azure"],
|
449 |
+
assistant_id: str,
|
450 |
+
client: Optional[Any] = None,
|
451 |
+
api_key: Optional[str] = None,
|
452 |
+
api_base: Optional[str] = None,
|
453 |
+
api_version: Optional[str] = None,
|
454 |
+
**kwargs,
|
455 |
+
) -> Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]:
|
456 |
+
optional_params = GenericLiteLLMParams(
|
457 |
+
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
|
458 |
+
)
|
459 |
+
|
460 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
461 |
+
|
462 |
+
async_delete_assistants: Optional[bool] = kwargs.pop(
|
463 |
+
"async_delete_assistants", None
|
464 |
+
)
|
465 |
+
if async_delete_assistants is not None and not isinstance(
|
466 |
+
async_delete_assistants, bool
|
467 |
+
):
|
468 |
+
raise ValueError(
|
469 |
+
"Invalid value passed in for async_delete_assistants. Only bool or None allowed"
|
470 |
+
)
|
471 |
+
|
472 |
+
### TIMEOUT LOGIC ###
|
473 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
474 |
+
# set timeout for 10 minutes by default
|
475 |
+
|
476 |
+
if (
|
477 |
+
timeout is not None
|
478 |
+
and isinstance(timeout, httpx.Timeout)
|
479 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
480 |
+
):
|
481 |
+
read_timeout = timeout.read or 600
|
482 |
+
timeout = read_timeout # default 10 min timeout
|
483 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
484 |
+
timeout = float(timeout) # type: ignore
|
485 |
+
elif timeout is None:
|
486 |
+
timeout = 600.0
|
487 |
+
|
488 |
+
response: Optional[
|
489 |
+
Union[AssistantDeleted, Coroutine[Any, Any, AssistantDeleted]]
|
490 |
+
] = None
|
491 |
+
if custom_llm_provider == "openai":
|
492 |
+
api_base = (
|
493 |
+
optional_params.api_base
|
494 |
+
or litellm.api_base
|
495 |
+
or os.getenv("OPENAI_BASE_URL")
|
496 |
+
or os.getenv("OPENAI_API_BASE")
|
497 |
+
or "https://api.openai.com/v1"
|
498 |
+
)
|
499 |
+
organization = (
|
500 |
+
optional_params.organization
|
501 |
+
or litellm.organization
|
502 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
503 |
+
or None
|
504 |
+
)
|
505 |
+
# set API KEY
|
506 |
+
api_key = (
|
507 |
+
optional_params.api_key
|
508 |
+
or litellm.api_key
|
509 |
+
or litellm.openai_key
|
510 |
+
or os.getenv("OPENAI_API_KEY")
|
511 |
+
)
|
512 |
+
|
513 |
+
response = openai_assistants_api.delete_assistant(
|
514 |
+
api_base=api_base,
|
515 |
+
api_key=api_key,
|
516 |
+
timeout=timeout,
|
517 |
+
max_retries=optional_params.max_retries,
|
518 |
+
organization=organization,
|
519 |
+
assistant_id=assistant_id,
|
520 |
+
client=client,
|
521 |
+
async_delete_assistants=async_delete_assistants,
|
522 |
+
)
|
523 |
+
elif custom_llm_provider == "azure":
|
524 |
+
api_base = (
|
525 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
526 |
+
) # type: ignore
|
527 |
+
|
528 |
+
api_version = (
|
529 |
+
optional_params.api_version
|
530 |
+
or litellm.api_version
|
531 |
+
or get_secret("AZURE_API_VERSION")
|
532 |
+
) # type: ignore
|
533 |
+
|
534 |
+
api_key = (
|
535 |
+
optional_params.api_key
|
536 |
+
or litellm.api_key
|
537 |
+
or litellm.azure_key
|
538 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
539 |
+
or get_secret("AZURE_API_KEY")
|
540 |
+
) # type: ignore
|
541 |
+
|
542 |
+
extra_body = optional_params.get("extra_body", {})
|
543 |
+
azure_ad_token: Optional[str] = None
|
544 |
+
if extra_body is not None:
|
545 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
546 |
+
else:
|
547 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
548 |
+
|
549 |
+
if isinstance(client, OpenAI):
|
550 |
+
client = None # only pass client if it's AzureOpenAI
|
551 |
+
|
552 |
+
response = azure_assistants_api.delete_assistant(
|
553 |
+
assistant_id=assistant_id,
|
554 |
+
api_base=api_base,
|
555 |
+
api_key=api_key,
|
556 |
+
azure_ad_token=azure_ad_token,
|
557 |
+
api_version=api_version,
|
558 |
+
timeout=timeout,
|
559 |
+
max_retries=optional_params.max_retries,
|
560 |
+
client=client,
|
561 |
+
async_delete_assistants=async_delete_assistants,
|
562 |
+
litellm_params=litellm_params_dict,
|
563 |
+
)
|
564 |
+
else:
|
565 |
+
raise litellm.exceptions.BadRequestError(
|
566 |
+
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
|
567 |
+
custom_llm_provider
|
568 |
+
),
|
569 |
+
model="n/a",
|
570 |
+
llm_provider=custom_llm_provider,
|
571 |
+
response=httpx.Response(
|
572 |
+
status_code=400,
|
573 |
+
content="Unsupported provider",
|
574 |
+
request=httpx.Request(
|
575 |
+
method="delete_assistant", url="https://github.com/BerriAI/litellm"
|
576 |
+
),
|
577 |
+
),
|
578 |
+
)
|
579 |
+
if response is None:
|
580 |
+
raise litellm.exceptions.InternalServerError(
|
581 |
+
message="No response returned from 'delete_assistant'",
|
582 |
+
model="n/a",
|
583 |
+
llm_provider=custom_llm_provider,
|
584 |
+
)
|
585 |
+
return response
|
586 |
+
|
587 |
+
|
588 |
+
### THREADS ###
|
589 |
+
|
590 |
+
|
591 |
+
async def acreate_thread(
|
592 |
+
custom_llm_provider: Literal["openai", "azure"], **kwargs
|
593 |
+
) -> Thread:
|
594 |
+
loop = asyncio.get_event_loop()
|
595 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
596 |
+
kwargs["acreate_thread"] = True
|
597 |
+
try:
|
598 |
+
# Use a partial function to pass your keyword arguments
|
599 |
+
func = partial(create_thread, custom_llm_provider, **kwargs)
|
600 |
+
|
601 |
+
# Add the context to the function
|
602 |
+
ctx = contextvars.copy_context()
|
603 |
+
func_with_context = partial(ctx.run, func)
|
604 |
+
|
605 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
606 |
+
model="", custom_llm_provider=custom_llm_provider
|
607 |
+
) # type: ignore
|
608 |
+
|
609 |
+
# Await normally
|
610 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
611 |
+
if asyncio.iscoroutine(init_response):
|
612 |
+
response = await init_response
|
613 |
+
else:
|
614 |
+
response = init_response
|
615 |
+
return response # type: ignore
|
616 |
+
except Exception as e:
|
617 |
+
raise exception_type(
|
618 |
+
model="",
|
619 |
+
custom_llm_provider=custom_llm_provider,
|
620 |
+
original_exception=e,
|
621 |
+
completion_kwargs={},
|
622 |
+
extra_kwargs=kwargs,
|
623 |
+
)
|
624 |
+
|
625 |
+
|
626 |
+
def create_thread(
|
627 |
+
custom_llm_provider: Literal["openai", "azure"],
|
628 |
+
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
629 |
+
metadata: Optional[dict] = None,
|
630 |
+
tool_resources: Optional[OpenAICreateThreadParamsToolResources] = None,
|
631 |
+
client: Optional[OpenAI] = None,
|
632 |
+
**kwargs,
|
633 |
+
) -> Thread:
|
634 |
+
"""
|
635 |
+
- get the llm provider
|
636 |
+
- if openai - route it there
|
637 |
+
- pass through relevant params
|
638 |
+
|
639 |
+
```
|
640 |
+
from litellm import create_thread
|
641 |
+
|
642 |
+
create_thread(
|
643 |
+
custom_llm_provider="openai",
|
644 |
+
### OPTIONAL ###
|
645 |
+
messages = {
|
646 |
+
"role": "user",
|
647 |
+
"content": "Hello, what is AI?"
|
648 |
+
},
|
649 |
+
{
|
650 |
+
"role": "user",
|
651 |
+
"content": "How does AI work? Explain it in simple terms."
|
652 |
+
}]
|
653 |
+
)
|
654 |
+
```
|
655 |
+
"""
|
656 |
+
acreate_thread = kwargs.get("acreate_thread", None)
|
657 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
658 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
659 |
+
|
660 |
+
### TIMEOUT LOGIC ###
|
661 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
662 |
+
# set timeout for 10 minutes by default
|
663 |
+
|
664 |
+
if (
|
665 |
+
timeout is not None
|
666 |
+
and isinstance(timeout, httpx.Timeout)
|
667 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
668 |
+
):
|
669 |
+
read_timeout = timeout.read or 600
|
670 |
+
timeout = read_timeout # default 10 min timeout
|
671 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
672 |
+
timeout = float(timeout) # type: ignore
|
673 |
+
elif timeout is None:
|
674 |
+
timeout = 600.0
|
675 |
+
|
676 |
+
api_base: Optional[str] = None
|
677 |
+
api_key: Optional[str] = None
|
678 |
+
|
679 |
+
response: Optional[Thread] = None
|
680 |
+
if custom_llm_provider == "openai":
|
681 |
+
api_base = (
|
682 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
683 |
+
or litellm.api_base
|
684 |
+
or os.getenv("OPENAI_BASE_URL")
|
685 |
+
or os.getenv("OPENAI_API_BASE")
|
686 |
+
or "https://api.openai.com/v1"
|
687 |
+
)
|
688 |
+
organization = (
|
689 |
+
optional_params.organization
|
690 |
+
or litellm.organization
|
691 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
692 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
693 |
+
)
|
694 |
+
# set API KEY
|
695 |
+
api_key = (
|
696 |
+
optional_params.api_key
|
697 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
698 |
+
or litellm.openai_key
|
699 |
+
or os.getenv("OPENAI_API_KEY")
|
700 |
+
)
|
701 |
+
response = openai_assistants_api.create_thread(
|
702 |
+
messages=messages,
|
703 |
+
metadata=metadata,
|
704 |
+
api_base=api_base,
|
705 |
+
api_key=api_key,
|
706 |
+
timeout=timeout,
|
707 |
+
max_retries=optional_params.max_retries,
|
708 |
+
organization=organization,
|
709 |
+
client=client,
|
710 |
+
acreate_thread=acreate_thread,
|
711 |
+
)
|
712 |
+
elif custom_llm_provider == "azure":
|
713 |
+
api_base = (
|
714 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
715 |
+
) # type: ignore
|
716 |
+
|
717 |
+
api_key = (
|
718 |
+
optional_params.api_key
|
719 |
+
or litellm.api_key
|
720 |
+
or litellm.azure_key
|
721 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
722 |
+
or get_secret("AZURE_API_KEY")
|
723 |
+
) # type: ignore
|
724 |
+
|
725 |
+
api_version: Optional[str] = (
|
726 |
+
optional_params.api_version
|
727 |
+
or litellm.api_version
|
728 |
+
or get_secret("AZURE_API_VERSION")
|
729 |
+
) # type: ignore
|
730 |
+
|
731 |
+
extra_body = optional_params.get("extra_body", {})
|
732 |
+
azure_ad_token: Optional[str] = None
|
733 |
+
if extra_body is not None:
|
734 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
735 |
+
else:
|
736 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
737 |
+
|
738 |
+
if isinstance(client, OpenAI):
|
739 |
+
client = None # only pass client if it's AzureOpenAI
|
740 |
+
|
741 |
+
response = azure_assistants_api.create_thread(
|
742 |
+
messages=messages,
|
743 |
+
metadata=metadata,
|
744 |
+
api_base=api_base,
|
745 |
+
api_key=api_key,
|
746 |
+
azure_ad_token=azure_ad_token,
|
747 |
+
api_version=api_version,
|
748 |
+
timeout=timeout,
|
749 |
+
max_retries=optional_params.max_retries,
|
750 |
+
client=client,
|
751 |
+
acreate_thread=acreate_thread,
|
752 |
+
litellm_params=litellm_params_dict,
|
753 |
+
)
|
754 |
+
else:
|
755 |
+
raise litellm.exceptions.BadRequestError(
|
756 |
+
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
757 |
+
custom_llm_provider
|
758 |
+
),
|
759 |
+
model="n/a",
|
760 |
+
llm_provider=custom_llm_provider,
|
761 |
+
response=httpx.Response(
|
762 |
+
status_code=400,
|
763 |
+
content="Unsupported provider",
|
764 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
765 |
+
),
|
766 |
+
)
|
767 |
+
return response # type: ignore
|
768 |
+
|
769 |
+
|
770 |
+
async def aget_thread(
|
771 |
+
custom_llm_provider: Literal["openai", "azure"],
|
772 |
+
thread_id: str,
|
773 |
+
client: Optional[AsyncOpenAI] = None,
|
774 |
+
**kwargs,
|
775 |
+
) -> Thread:
|
776 |
+
loop = asyncio.get_event_loop()
|
777 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
778 |
+
kwargs["aget_thread"] = True
|
779 |
+
try:
|
780 |
+
# Use a partial function to pass your keyword arguments
|
781 |
+
func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs)
|
782 |
+
|
783 |
+
# Add the context to the function
|
784 |
+
ctx = contextvars.copy_context()
|
785 |
+
func_with_context = partial(ctx.run, func)
|
786 |
+
|
787 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
788 |
+
model="", custom_llm_provider=custom_llm_provider
|
789 |
+
) # type: ignore
|
790 |
+
|
791 |
+
# Await normally
|
792 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
793 |
+
if asyncio.iscoroutine(init_response):
|
794 |
+
response = await init_response
|
795 |
+
else:
|
796 |
+
response = init_response
|
797 |
+
return response # type: ignore
|
798 |
+
except Exception as e:
|
799 |
+
raise exception_type(
|
800 |
+
model="",
|
801 |
+
custom_llm_provider=custom_llm_provider,
|
802 |
+
original_exception=e,
|
803 |
+
completion_kwargs={},
|
804 |
+
extra_kwargs=kwargs,
|
805 |
+
)
|
806 |
+
|
807 |
+
|
808 |
+
def get_thread(
|
809 |
+
custom_llm_provider: Literal["openai", "azure"],
|
810 |
+
thread_id: str,
|
811 |
+
client=None,
|
812 |
+
**kwargs,
|
813 |
+
) -> Thread:
|
814 |
+
"""Get the thread object, given a thread_id"""
|
815 |
+
aget_thread = kwargs.pop("aget_thread", None)
|
816 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
817 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
818 |
+
### TIMEOUT LOGIC ###
|
819 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
820 |
+
# set timeout for 10 minutes by default
|
821 |
+
|
822 |
+
if (
|
823 |
+
timeout is not None
|
824 |
+
and isinstance(timeout, httpx.Timeout)
|
825 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
826 |
+
):
|
827 |
+
read_timeout = timeout.read or 600
|
828 |
+
timeout = read_timeout # default 10 min timeout
|
829 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
830 |
+
timeout = float(timeout) # type: ignore
|
831 |
+
elif timeout is None:
|
832 |
+
timeout = 600.0
|
833 |
+
api_base: Optional[str] = None
|
834 |
+
api_key: Optional[str] = None
|
835 |
+
response: Optional[Thread] = None
|
836 |
+
if custom_llm_provider == "openai":
|
837 |
+
api_base = (
|
838 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
839 |
+
or litellm.api_base
|
840 |
+
or os.getenv("OPENAI_BASE_URL")
|
841 |
+
or os.getenv("OPENAI_API_BASE")
|
842 |
+
or "https://api.openai.com/v1"
|
843 |
+
)
|
844 |
+
organization = (
|
845 |
+
optional_params.organization
|
846 |
+
or litellm.organization
|
847 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
848 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
849 |
+
)
|
850 |
+
# set API KEY
|
851 |
+
api_key = (
|
852 |
+
optional_params.api_key
|
853 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
854 |
+
or litellm.openai_key
|
855 |
+
or os.getenv("OPENAI_API_KEY")
|
856 |
+
)
|
857 |
+
|
858 |
+
response = openai_assistants_api.get_thread(
|
859 |
+
thread_id=thread_id,
|
860 |
+
api_base=api_base,
|
861 |
+
api_key=api_key,
|
862 |
+
timeout=timeout,
|
863 |
+
max_retries=optional_params.max_retries,
|
864 |
+
organization=organization,
|
865 |
+
client=client,
|
866 |
+
aget_thread=aget_thread,
|
867 |
+
)
|
868 |
+
elif custom_llm_provider == "azure":
|
869 |
+
api_base = (
|
870 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
871 |
+
) # type: ignore
|
872 |
+
|
873 |
+
api_version: Optional[str] = (
|
874 |
+
optional_params.api_version
|
875 |
+
or litellm.api_version
|
876 |
+
or get_secret("AZURE_API_VERSION")
|
877 |
+
) # type: ignore
|
878 |
+
|
879 |
+
api_key = (
|
880 |
+
optional_params.api_key
|
881 |
+
or litellm.api_key
|
882 |
+
or litellm.azure_key
|
883 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
884 |
+
or get_secret("AZURE_API_KEY")
|
885 |
+
) # type: ignore
|
886 |
+
|
887 |
+
extra_body = optional_params.get("extra_body", {})
|
888 |
+
azure_ad_token: Optional[str] = None
|
889 |
+
if extra_body is not None:
|
890 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
891 |
+
else:
|
892 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
893 |
+
|
894 |
+
if isinstance(client, OpenAI):
|
895 |
+
client = None # only pass client if it's AzureOpenAI
|
896 |
+
|
897 |
+
response = azure_assistants_api.get_thread(
|
898 |
+
thread_id=thread_id,
|
899 |
+
api_base=api_base,
|
900 |
+
api_key=api_key,
|
901 |
+
azure_ad_token=azure_ad_token,
|
902 |
+
api_version=api_version,
|
903 |
+
timeout=timeout,
|
904 |
+
max_retries=optional_params.max_retries,
|
905 |
+
client=client,
|
906 |
+
aget_thread=aget_thread,
|
907 |
+
litellm_params=litellm_params_dict,
|
908 |
+
)
|
909 |
+
else:
|
910 |
+
raise litellm.exceptions.BadRequestError(
|
911 |
+
message="LiteLLM doesn't support {} for 'get_thread'. Only 'openai' is supported.".format(
|
912 |
+
custom_llm_provider
|
913 |
+
),
|
914 |
+
model="n/a",
|
915 |
+
llm_provider=custom_llm_provider,
|
916 |
+
response=httpx.Response(
|
917 |
+
status_code=400,
|
918 |
+
content="Unsupported provider",
|
919 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
920 |
+
),
|
921 |
+
)
|
922 |
+
return response # type: ignore
|
923 |
+
|
924 |
+
|
925 |
+
### MESSAGES ###
|
926 |
+
|
927 |
+
|
928 |
+
async def a_add_message(
|
929 |
+
custom_llm_provider: Literal["openai", "azure"],
|
930 |
+
thread_id: str,
|
931 |
+
role: Literal["user", "assistant"],
|
932 |
+
content: str,
|
933 |
+
attachments: Optional[List[Attachment]] = None,
|
934 |
+
metadata: Optional[dict] = None,
|
935 |
+
client=None,
|
936 |
+
**kwargs,
|
937 |
+
) -> OpenAIMessage:
|
938 |
+
loop = asyncio.get_event_loop()
|
939 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
940 |
+
kwargs["a_add_message"] = True
|
941 |
+
try:
|
942 |
+
# Use a partial function to pass your keyword arguments
|
943 |
+
func = partial(
|
944 |
+
add_message,
|
945 |
+
custom_llm_provider,
|
946 |
+
thread_id,
|
947 |
+
role,
|
948 |
+
content,
|
949 |
+
attachments,
|
950 |
+
metadata,
|
951 |
+
client,
|
952 |
+
**kwargs,
|
953 |
+
)
|
954 |
+
|
955 |
+
# Add the context to the function
|
956 |
+
ctx = contextvars.copy_context()
|
957 |
+
func_with_context = partial(ctx.run, func)
|
958 |
+
|
959 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
960 |
+
model="", custom_llm_provider=custom_llm_provider
|
961 |
+
) # type: ignore
|
962 |
+
|
963 |
+
# Await normally
|
964 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
965 |
+
if asyncio.iscoroutine(init_response):
|
966 |
+
response = await init_response
|
967 |
+
else:
|
968 |
+
# Call the synchronous function using run_in_executor
|
969 |
+
response = init_response
|
970 |
+
return response # type: ignore
|
971 |
+
except Exception as e:
|
972 |
+
raise exception_type(
|
973 |
+
model="",
|
974 |
+
custom_llm_provider=custom_llm_provider,
|
975 |
+
original_exception=e,
|
976 |
+
completion_kwargs={},
|
977 |
+
extra_kwargs=kwargs,
|
978 |
+
)
|
979 |
+
|
980 |
+
|
981 |
+
def add_message(
|
982 |
+
custom_llm_provider: Literal["openai", "azure"],
|
983 |
+
thread_id: str,
|
984 |
+
role: Literal["user", "assistant"],
|
985 |
+
content: str,
|
986 |
+
attachments: Optional[List[Attachment]] = None,
|
987 |
+
metadata: Optional[dict] = None,
|
988 |
+
client=None,
|
989 |
+
**kwargs,
|
990 |
+
) -> OpenAIMessage:
|
991 |
+
### COMMON OBJECTS ###
|
992 |
+
a_add_message = kwargs.pop("a_add_message", None)
|
993 |
+
_message_data = MessageData(
|
994 |
+
role=role, content=content, attachments=attachments, metadata=metadata
|
995 |
+
)
|
996 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
997 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
998 |
+
|
999 |
+
message_data = get_optional_params_add_message(
|
1000 |
+
role=_message_data["role"],
|
1001 |
+
content=_message_data["content"],
|
1002 |
+
attachments=_message_data["attachments"],
|
1003 |
+
metadata=_message_data["metadata"],
|
1004 |
+
custom_llm_provider=custom_llm_provider,
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
### TIMEOUT LOGIC ###
|
1008 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
1009 |
+
# set timeout for 10 minutes by default
|
1010 |
+
|
1011 |
+
if (
|
1012 |
+
timeout is not None
|
1013 |
+
and isinstance(timeout, httpx.Timeout)
|
1014 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
1015 |
+
):
|
1016 |
+
read_timeout = timeout.read or 600
|
1017 |
+
timeout = read_timeout # default 10 min timeout
|
1018 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
1019 |
+
timeout = float(timeout) # type: ignore
|
1020 |
+
elif timeout is None:
|
1021 |
+
timeout = 600.0
|
1022 |
+
api_key: Optional[str] = None
|
1023 |
+
api_base: Optional[str] = None
|
1024 |
+
response: Optional[OpenAIMessage] = None
|
1025 |
+
if custom_llm_provider == "openai":
|
1026 |
+
api_base = (
|
1027 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
1028 |
+
or litellm.api_base
|
1029 |
+
or os.getenv("OPENAI_BASE_URL")
|
1030 |
+
or os.getenv("OPENAI_API_BASE")
|
1031 |
+
or "https://api.openai.com/v1"
|
1032 |
+
)
|
1033 |
+
organization = (
|
1034 |
+
optional_params.organization
|
1035 |
+
or litellm.organization
|
1036 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
1037 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
1038 |
+
)
|
1039 |
+
# set API KEY
|
1040 |
+
api_key = (
|
1041 |
+
optional_params.api_key
|
1042 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
1043 |
+
or litellm.openai_key
|
1044 |
+
or os.getenv("OPENAI_API_KEY")
|
1045 |
+
)
|
1046 |
+
response = openai_assistants_api.add_message(
|
1047 |
+
thread_id=thread_id,
|
1048 |
+
message_data=message_data,
|
1049 |
+
api_base=api_base,
|
1050 |
+
api_key=api_key,
|
1051 |
+
timeout=timeout,
|
1052 |
+
max_retries=optional_params.max_retries,
|
1053 |
+
organization=organization,
|
1054 |
+
client=client,
|
1055 |
+
a_add_message=a_add_message,
|
1056 |
+
)
|
1057 |
+
elif custom_llm_provider == "azure":
|
1058 |
+
api_base = (
|
1059 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
1060 |
+
) # type: ignore
|
1061 |
+
|
1062 |
+
api_version: Optional[str] = (
|
1063 |
+
optional_params.api_version
|
1064 |
+
or litellm.api_version
|
1065 |
+
or get_secret("AZURE_API_VERSION")
|
1066 |
+
) # type: ignore
|
1067 |
+
|
1068 |
+
api_key = (
|
1069 |
+
optional_params.api_key
|
1070 |
+
or litellm.api_key
|
1071 |
+
or litellm.azure_key
|
1072 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
1073 |
+
or get_secret("AZURE_API_KEY")
|
1074 |
+
) # type: ignore
|
1075 |
+
|
1076 |
+
extra_body = optional_params.get("extra_body", {})
|
1077 |
+
azure_ad_token: Optional[str] = None
|
1078 |
+
if extra_body is not None:
|
1079 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
1080 |
+
else:
|
1081 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
1082 |
+
|
1083 |
+
response = azure_assistants_api.add_message(
|
1084 |
+
thread_id=thread_id,
|
1085 |
+
message_data=message_data,
|
1086 |
+
api_base=api_base,
|
1087 |
+
api_key=api_key,
|
1088 |
+
api_version=api_version,
|
1089 |
+
azure_ad_token=azure_ad_token,
|
1090 |
+
timeout=timeout,
|
1091 |
+
max_retries=optional_params.max_retries,
|
1092 |
+
client=client,
|
1093 |
+
a_add_message=a_add_message,
|
1094 |
+
litellm_params=litellm_params_dict,
|
1095 |
+
)
|
1096 |
+
else:
|
1097 |
+
raise litellm.exceptions.BadRequestError(
|
1098 |
+
message="LiteLLM doesn't support {} for 'create_thread'. Only 'openai' is supported.".format(
|
1099 |
+
custom_llm_provider
|
1100 |
+
),
|
1101 |
+
model="n/a",
|
1102 |
+
llm_provider=custom_llm_provider,
|
1103 |
+
response=httpx.Response(
|
1104 |
+
status_code=400,
|
1105 |
+
content="Unsupported provider",
|
1106 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
1107 |
+
),
|
1108 |
+
)
|
1109 |
+
|
1110 |
+
return response # type: ignore
|
1111 |
+
|
1112 |
+
|
1113 |
+
async def aget_messages(
|
1114 |
+
custom_llm_provider: Literal["openai", "azure"],
|
1115 |
+
thread_id: str,
|
1116 |
+
client: Optional[AsyncOpenAI] = None,
|
1117 |
+
**kwargs,
|
1118 |
+
) -> AsyncCursorPage[OpenAIMessage]:
|
1119 |
+
loop = asyncio.get_event_loop()
|
1120 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
1121 |
+
kwargs["aget_messages"] = True
|
1122 |
+
try:
|
1123 |
+
# Use a partial function to pass your keyword arguments
|
1124 |
+
func = partial(
|
1125 |
+
get_messages,
|
1126 |
+
custom_llm_provider,
|
1127 |
+
thread_id,
|
1128 |
+
client,
|
1129 |
+
**kwargs,
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
# Add the context to the function
|
1133 |
+
ctx = contextvars.copy_context()
|
1134 |
+
func_with_context = partial(ctx.run, func)
|
1135 |
+
|
1136 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
1137 |
+
model="", custom_llm_provider=custom_llm_provider
|
1138 |
+
) # type: ignore
|
1139 |
+
|
1140 |
+
# Await normally
|
1141 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
1142 |
+
if asyncio.iscoroutine(init_response):
|
1143 |
+
response = await init_response
|
1144 |
+
else:
|
1145 |
+
# Call the synchronous function using run_in_executor
|
1146 |
+
response = init_response
|
1147 |
+
return response # type: ignore
|
1148 |
+
except Exception as e:
|
1149 |
+
raise exception_type(
|
1150 |
+
model="",
|
1151 |
+
custom_llm_provider=custom_llm_provider,
|
1152 |
+
original_exception=e,
|
1153 |
+
completion_kwargs={},
|
1154 |
+
extra_kwargs=kwargs,
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
|
1158 |
+
def get_messages(
|
1159 |
+
custom_llm_provider: Literal["openai", "azure"],
|
1160 |
+
thread_id: str,
|
1161 |
+
client: Optional[Any] = None,
|
1162 |
+
**kwargs,
|
1163 |
+
) -> SyncCursorPage[OpenAIMessage]:
|
1164 |
+
aget_messages = kwargs.pop("aget_messages", None)
|
1165 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
1166 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
1167 |
+
|
1168 |
+
### TIMEOUT LOGIC ###
|
1169 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
1170 |
+
# set timeout for 10 minutes by default
|
1171 |
+
|
1172 |
+
if (
|
1173 |
+
timeout is not None
|
1174 |
+
and isinstance(timeout, httpx.Timeout)
|
1175 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
1176 |
+
):
|
1177 |
+
read_timeout = timeout.read or 600
|
1178 |
+
timeout = read_timeout # default 10 min timeout
|
1179 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
1180 |
+
timeout = float(timeout) # type: ignore
|
1181 |
+
elif timeout is None:
|
1182 |
+
timeout = 600.0
|
1183 |
+
|
1184 |
+
response: Optional[SyncCursorPage[OpenAIMessage]] = None
|
1185 |
+
api_key: Optional[str] = None
|
1186 |
+
api_base: Optional[str] = None
|
1187 |
+
if custom_llm_provider == "openai":
|
1188 |
+
api_base = (
|
1189 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
1190 |
+
or litellm.api_base
|
1191 |
+
or os.getenv("OPENAI_BASE_URL")
|
1192 |
+
or os.getenv("OPENAI_API_BASE")
|
1193 |
+
or "https://api.openai.com/v1"
|
1194 |
+
)
|
1195 |
+
organization = (
|
1196 |
+
optional_params.organization
|
1197 |
+
or litellm.organization
|
1198 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
1199 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
1200 |
+
)
|
1201 |
+
# set API KEY
|
1202 |
+
api_key = (
|
1203 |
+
optional_params.api_key
|
1204 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
1205 |
+
or litellm.openai_key
|
1206 |
+
or os.getenv("OPENAI_API_KEY")
|
1207 |
+
)
|
1208 |
+
response = openai_assistants_api.get_messages(
|
1209 |
+
thread_id=thread_id,
|
1210 |
+
api_base=api_base,
|
1211 |
+
api_key=api_key,
|
1212 |
+
timeout=timeout,
|
1213 |
+
max_retries=optional_params.max_retries,
|
1214 |
+
organization=organization,
|
1215 |
+
client=client,
|
1216 |
+
aget_messages=aget_messages,
|
1217 |
+
)
|
1218 |
+
elif custom_llm_provider == "azure":
|
1219 |
+
api_base = (
|
1220 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
1221 |
+
) # type: ignore
|
1222 |
+
|
1223 |
+
api_version: Optional[str] = (
|
1224 |
+
optional_params.api_version
|
1225 |
+
or litellm.api_version
|
1226 |
+
or get_secret("AZURE_API_VERSION")
|
1227 |
+
) # type: ignore
|
1228 |
+
|
1229 |
+
api_key = (
|
1230 |
+
optional_params.api_key
|
1231 |
+
or litellm.api_key
|
1232 |
+
or litellm.azure_key
|
1233 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
1234 |
+
or get_secret("AZURE_API_KEY")
|
1235 |
+
) # type: ignore
|
1236 |
+
|
1237 |
+
extra_body = optional_params.get("extra_body", {})
|
1238 |
+
azure_ad_token: Optional[str] = None
|
1239 |
+
if extra_body is not None:
|
1240 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
1241 |
+
else:
|
1242 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
1243 |
+
|
1244 |
+
response = azure_assistants_api.get_messages(
|
1245 |
+
thread_id=thread_id,
|
1246 |
+
api_base=api_base,
|
1247 |
+
api_key=api_key,
|
1248 |
+
api_version=api_version,
|
1249 |
+
azure_ad_token=azure_ad_token,
|
1250 |
+
timeout=timeout,
|
1251 |
+
max_retries=optional_params.max_retries,
|
1252 |
+
client=client,
|
1253 |
+
aget_messages=aget_messages,
|
1254 |
+
litellm_params=litellm_params_dict,
|
1255 |
+
)
|
1256 |
+
else:
|
1257 |
+
raise litellm.exceptions.BadRequestError(
|
1258 |
+
message="LiteLLM doesn't support {} for 'get_messages'. Only 'openai' is supported.".format(
|
1259 |
+
custom_llm_provider
|
1260 |
+
),
|
1261 |
+
model="n/a",
|
1262 |
+
llm_provider=custom_llm_provider,
|
1263 |
+
response=httpx.Response(
|
1264 |
+
status_code=400,
|
1265 |
+
content="Unsupported provider",
|
1266 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
1267 |
+
),
|
1268 |
+
)
|
1269 |
+
|
1270 |
+
return response # type: ignore
|
1271 |
+
|
1272 |
+
|
1273 |
+
### RUNS ###
|
1274 |
+
def arun_thread_stream(
|
1275 |
+
*,
|
1276 |
+
event_handler: Optional[AssistantEventHandler] = None,
|
1277 |
+
**kwargs,
|
1278 |
+
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
1279 |
+
kwargs["arun_thread"] = True
|
1280 |
+
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
1281 |
+
|
1282 |
+
|
1283 |
+
async def arun_thread(
|
1284 |
+
custom_llm_provider: Literal["openai", "azure"],
|
1285 |
+
thread_id: str,
|
1286 |
+
assistant_id: str,
|
1287 |
+
additional_instructions: Optional[str] = None,
|
1288 |
+
instructions: Optional[str] = None,
|
1289 |
+
metadata: Optional[dict] = None,
|
1290 |
+
model: Optional[str] = None,
|
1291 |
+
stream: Optional[bool] = None,
|
1292 |
+
tools: Optional[Iterable[AssistantToolParam]] = None,
|
1293 |
+
client: Optional[Any] = None,
|
1294 |
+
**kwargs,
|
1295 |
+
) -> Run:
|
1296 |
+
loop = asyncio.get_event_loop()
|
1297 |
+
### PASS ARGS TO GET ASSISTANTS ###
|
1298 |
+
kwargs["arun_thread"] = True
|
1299 |
+
try:
|
1300 |
+
# Use a partial function to pass your keyword arguments
|
1301 |
+
func = partial(
|
1302 |
+
run_thread,
|
1303 |
+
custom_llm_provider,
|
1304 |
+
thread_id,
|
1305 |
+
assistant_id,
|
1306 |
+
additional_instructions,
|
1307 |
+
instructions,
|
1308 |
+
metadata,
|
1309 |
+
model,
|
1310 |
+
stream,
|
1311 |
+
tools,
|
1312 |
+
client,
|
1313 |
+
**kwargs,
|
1314 |
+
)
|
1315 |
+
|
1316 |
+
# Add the context to the function
|
1317 |
+
ctx = contextvars.copy_context()
|
1318 |
+
func_with_context = partial(ctx.run, func)
|
1319 |
+
|
1320 |
+
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
1321 |
+
model="", custom_llm_provider=custom_llm_provider
|
1322 |
+
) # type: ignore
|
1323 |
+
|
1324 |
+
# Await normally
|
1325 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
1326 |
+
if asyncio.iscoroutine(init_response):
|
1327 |
+
response = await init_response
|
1328 |
+
else:
|
1329 |
+
# Call the synchronous function using run_in_executor
|
1330 |
+
response = init_response
|
1331 |
+
return response # type: ignore
|
1332 |
+
except Exception as e:
|
1333 |
+
raise exception_type(
|
1334 |
+
model="",
|
1335 |
+
custom_llm_provider=custom_llm_provider,
|
1336 |
+
original_exception=e,
|
1337 |
+
completion_kwargs={},
|
1338 |
+
extra_kwargs=kwargs,
|
1339 |
+
)
|
1340 |
+
|
1341 |
+
|
1342 |
+
def run_thread_stream(
|
1343 |
+
*,
|
1344 |
+
event_handler: Optional[AssistantEventHandler] = None,
|
1345 |
+
**kwargs,
|
1346 |
+
) -> AssistantStreamManager[AssistantEventHandler]:
|
1347 |
+
return run_thread(stream=True, event_handler=event_handler, **kwargs) # type: ignore
|
1348 |
+
|
1349 |
+
|
1350 |
+
def run_thread(
|
1351 |
+
custom_llm_provider: Literal["openai", "azure"],
|
1352 |
+
thread_id: str,
|
1353 |
+
assistant_id: str,
|
1354 |
+
additional_instructions: Optional[str] = None,
|
1355 |
+
instructions: Optional[str] = None,
|
1356 |
+
metadata: Optional[dict] = None,
|
1357 |
+
model: Optional[str] = None,
|
1358 |
+
stream: Optional[bool] = None,
|
1359 |
+
tools: Optional[Iterable[AssistantToolParam]] = None,
|
1360 |
+
client: Optional[Any] = None,
|
1361 |
+
event_handler: Optional[AssistantEventHandler] = None, # for stream=True calls
|
1362 |
+
**kwargs,
|
1363 |
+
) -> Run:
|
1364 |
+
"""Run a given thread + assistant."""
|
1365 |
+
arun_thread = kwargs.pop("arun_thread", None)
|
1366 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
1367 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
1368 |
+
|
1369 |
+
### TIMEOUT LOGIC ###
|
1370 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
1371 |
+
# set timeout for 10 minutes by default
|
1372 |
+
|
1373 |
+
if (
|
1374 |
+
timeout is not None
|
1375 |
+
and isinstance(timeout, httpx.Timeout)
|
1376 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
1377 |
+
):
|
1378 |
+
read_timeout = timeout.read or 600
|
1379 |
+
timeout = read_timeout # default 10 min timeout
|
1380 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
1381 |
+
timeout = float(timeout) # type: ignore
|
1382 |
+
elif timeout is None:
|
1383 |
+
timeout = 600.0
|
1384 |
+
|
1385 |
+
response: Optional[Run] = None
|
1386 |
+
if custom_llm_provider == "openai":
|
1387 |
+
api_base = (
|
1388 |
+
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
1389 |
+
or litellm.api_base
|
1390 |
+
or os.getenv("OPENAI_BASE_URL")
|
1391 |
+
or os.getenv("OPENAI_API_BASE")
|
1392 |
+
or "https://api.openai.com/v1"
|
1393 |
+
)
|
1394 |
+
organization = (
|
1395 |
+
optional_params.organization
|
1396 |
+
or litellm.organization
|
1397 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
1398 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
1399 |
+
)
|
1400 |
+
# set API KEY
|
1401 |
+
api_key = (
|
1402 |
+
optional_params.api_key
|
1403 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
1404 |
+
or litellm.openai_key
|
1405 |
+
or os.getenv("OPENAI_API_KEY")
|
1406 |
+
)
|
1407 |
+
|
1408 |
+
response = openai_assistants_api.run_thread(
|
1409 |
+
thread_id=thread_id,
|
1410 |
+
assistant_id=assistant_id,
|
1411 |
+
additional_instructions=additional_instructions,
|
1412 |
+
instructions=instructions,
|
1413 |
+
metadata=metadata,
|
1414 |
+
model=model,
|
1415 |
+
stream=stream,
|
1416 |
+
tools=tools,
|
1417 |
+
api_base=api_base,
|
1418 |
+
api_key=api_key,
|
1419 |
+
timeout=timeout,
|
1420 |
+
max_retries=optional_params.max_retries,
|
1421 |
+
organization=organization,
|
1422 |
+
client=client,
|
1423 |
+
arun_thread=arun_thread,
|
1424 |
+
event_handler=event_handler,
|
1425 |
+
)
|
1426 |
+
elif custom_llm_provider == "azure":
|
1427 |
+
api_base = (
|
1428 |
+
optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
1429 |
+
) # type: ignore
|
1430 |
+
|
1431 |
+
api_version = (
|
1432 |
+
optional_params.api_version
|
1433 |
+
or litellm.api_version
|
1434 |
+
or get_secret("AZURE_API_VERSION")
|
1435 |
+
) # type: ignore
|
1436 |
+
|
1437 |
+
api_key = (
|
1438 |
+
optional_params.api_key
|
1439 |
+
or litellm.api_key
|
1440 |
+
or litellm.azure_key
|
1441 |
+
or get_secret("AZURE_OPENAI_API_KEY")
|
1442 |
+
or get_secret("AZURE_API_KEY")
|
1443 |
+
) # type: ignore
|
1444 |
+
|
1445 |
+
extra_body = optional_params.get("extra_body", {})
|
1446 |
+
azure_ad_token = None
|
1447 |
+
if extra_body is not None:
|
1448 |
+
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
1449 |
+
else:
|
1450 |
+
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
1451 |
+
|
1452 |
+
response = azure_assistants_api.run_thread(
|
1453 |
+
thread_id=thread_id,
|
1454 |
+
assistant_id=assistant_id,
|
1455 |
+
additional_instructions=additional_instructions,
|
1456 |
+
instructions=instructions,
|
1457 |
+
metadata=metadata,
|
1458 |
+
model=model,
|
1459 |
+
stream=stream,
|
1460 |
+
tools=tools,
|
1461 |
+
api_base=str(api_base) if api_base is not None else None,
|
1462 |
+
api_key=str(api_key) if api_key is not None else None,
|
1463 |
+
api_version=str(api_version) if api_version is not None else None,
|
1464 |
+
azure_ad_token=str(azure_ad_token) if azure_ad_token is not None else None,
|
1465 |
+
timeout=timeout,
|
1466 |
+
max_retries=optional_params.max_retries,
|
1467 |
+
client=client,
|
1468 |
+
arun_thread=arun_thread,
|
1469 |
+
litellm_params=litellm_params_dict,
|
1470 |
+
) # type: ignore
|
1471 |
+
else:
|
1472 |
+
raise litellm.exceptions.BadRequestError(
|
1473 |
+
message="LiteLLM doesn't support {} for 'run_thread'. Only 'openai' is supported.".format(
|
1474 |
+
custom_llm_provider
|
1475 |
+
),
|
1476 |
+
model="n/a",
|
1477 |
+
llm_provider=custom_llm_provider,
|
1478 |
+
response=httpx.Response(
|
1479 |
+
status_code=400,
|
1480 |
+
content="Unsupported provider",
|
1481 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
1482 |
+
),
|
1483 |
+
)
|
1484 |
+
return response # type: ignore
|
litellm/assistants/utils.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import litellm
|
4 |
+
|
5 |
+
from ..exceptions import UnsupportedParamsError
|
6 |
+
from ..types.llms.openai import *
|
7 |
+
|
8 |
+
|
9 |
+
def get_optional_params_add_message(
|
10 |
+
role: Optional[str],
|
11 |
+
content: Optional[
|
12 |
+
Union[
|
13 |
+
str,
|
14 |
+
List[
|
15 |
+
Union[
|
16 |
+
MessageContentTextObject,
|
17 |
+
MessageContentImageFileObject,
|
18 |
+
MessageContentImageURLObject,
|
19 |
+
]
|
20 |
+
],
|
21 |
+
]
|
22 |
+
],
|
23 |
+
attachments: Optional[List[Attachment]],
|
24 |
+
metadata: Optional[dict],
|
25 |
+
custom_llm_provider: str,
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Azure doesn't support 'attachments' for creating a message
|
30 |
+
|
31 |
+
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
32 |
+
"""
|
33 |
+
passed_params = locals()
|
34 |
+
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
35 |
+
special_params = passed_params.pop("kwargs")
|
36 |
+
for k, v in special_params.items():
|
37 |
+
passed_params[k] = v
|
38 |
+
|
39 |
+
default_params = {
|
40 |
+
"role": None,
|
41 |
+
"content": None,
|
42 |
+
"attachments": None,
|
43 |
+
"metadata": None,
|
44 |
+
}
|
45 |
+
|
46 |
+
non_default_params = {
|
47 |
+
k: v
|
48 |
+
for k, v in passed_params.items()
|
49 |
+
if (k in default_params and v != default_params[k])
|
50 |
+
}
|
51 |
+
optional_params = {}
|
52 |
+
|
53 |
+
## raise exception if non-default value passed for non-openai/azure embedding calls
|
54 |
+
def _check_valid_arg(supported_params):
|
55 |
+
if len(non_default_params.keys()) > 0:
|
56 |
+
keys = list(non_default_params.keys())
|
57 |
+
for k in keys:
|
58 |
+
if (
|
59 |
+
litellm.drop_params is True and k not in supported_params
|
60 |
+
): # drop the unsupported non-default values
|
61 |
+
non_default_params.pop(k, None)
|
62 |
+
elif k not in supported_params:
|
63 |
+
raise litellm.utils.UnsupportedParamsError(
|
64 |
+
status_code=500,
|
65 |
+
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
|
66 |
+
k, custom_llm_provider, supported_params
|
67 |
+
),
|
68 |
+
)
|
69 |
+
return non_default_params
|
70 |
+
|
71 |
+
if custom_llm_provider == "openai":
|
72 |
+
optional_params = non_default_params
|
73 |
+
elif custom_llm_provider == "azure":
|
74 |
+
supported_params = (
|
75 |
+
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
|
76 |
+
)
|
77 |
+
_check_valid_arg(supported_params=supported_params)
|
78 |
+
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
|
79 |
+
non_default_params=non_default_params, optional_params=optional_params
|
80 |
+
)
|
81 |
+
for k in passed_params.keys():
|
82 |
+
if k not in default_params.keys():
|
83 |
+
optional_params[k] = passed_params[k]
|
84 |
+
return optional_params
|
85 |
+
|
86 |
+
|
87 |
+
def get_optional_params_image_gen(
|
88 |
+
n: Optional[int] = None,
|
89 |
+
quality: Optional[str] = None,
|
90 |
+
response_format: Optional[str] = None,
|
91 |
+
size: Optional[str] = None,
|
92 |
+
style: Optional[str] = None,
|
93 |
+
user: Optional[str] = None,
|
94 |
+
custom_llm_provider: Optional[str] = None,
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
# retrieve all parameters passed to the function
|
98 |
+
passed_params = locals()
|
99 |
+
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
100 |
+
special_params = passed_params.pop("kwargs")
|
101 |
+
for k, v in special_params.items():
|
102 |
+
passed_params[k] = v
|
103 |
+
|
104 |
+
default_params = {
|
105 |
+
"n": None,
|
106 |
+
"quality": None,
|
107 |
+
"response_format": None,
|
108 |
+
"size": None,
|
109 |
+
"style": None,
|
110 |
+
"user": None,
|
111 |
+
}
|
112 |
+
|
113 |
+
non_default_params = {
|
114 |
+
k: v
|
115 |
+
for k, v in passed_params.items()
|
116 |
+
if (k in default_params and v != default_params[k])
|
117 |
+
}
|
118 |
+
optional_params = {}
|
119 |
+
|
120 |
+
## raise exception if non-default value passed for non-openai/azure embedding calls
|
121 |
+
def _check_valid_arg(supported_params):
|
122 |
+
if len(non_default_params.keys()) > 0:
|
123 |
+
keys = list(non_default_params.keys())
|
124 |
+
for k in keys:
|
125 |
+
if (
|
126 |
+
litellm.drop_params is True and k not in supported_params
|
127 |
+
): # drop the unsupported non-default values
|
128 |
+
non_default_params.pop(k, None)
|
129 |
+
elif k not in supported_params:
|
130 |
+
raise UnsupportedParamsError(
|
131 |
+
status_code=500,
|
132 |
+
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
133 |
+
)
|
134 |
+
return non_default_params
|
135 |
+
|
136 |
+
if (
|
137 |
+
custom_llm_provider == "openai"
|
138 |
+
or custom_llm_provider == "azure"
|
139 |
+
or custom_llm_provider in litellm.openai_compatible_providers
|
140 |
+
):
|
141 |
+
optional_params = non_default_params
|
142 |
+
elif custom_llm_provider == "bedrock":
|
143 |
+
supported_params = ["size"]
|
144 |
+
_check_valid_arg(supported_params=supported_params)
|
145 |
+
if size is not None:
|
146 |
+
width, height = size.split("x")
|
147 |
+
optional_params["width"] = int(width)
|
148 |
+
optional_params["height"] = int(height)
|
149 |
+
elif custom_llm_provider == "vertex_ai":
|
150 |
+
supported_params = ["n"]
|
151 |
+
"""
|
152 |
+
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
153 |
+
"""
|
154 |
+
_check_valid_arg(supported_params=supported_params)
|
155 |
+
if n is not None:
|
156 |
+
optional_params["sampleCount"] = int(n)
|
157 |
+
|
158 |
+
for k in passed_params.keys():
|
159 |
+
if k not in default_params.keys():
|
160 |
+
optional_params[k] = passed_params[k]
|
161 |
+
return optional_params
|
litellm/batch_completion/Readme.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
|
2 |
+
|
3 |
+
Doc: https://docs.litellm.ai/docs/completion/batching
|
4 |
+
|
5 |
+
|
6 |
+
LiteLLM Python SDK allows you to:
|
7 |
+
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
|
8 |
+
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
|
9 |
+
as soon as one of the models responds.
|
10 |
+
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
|
11 |
+
from all models that respond.
|
litellm/batch_completion/main.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
2 |
+
from typing import List, Optional
|
3 |
+
|
4 |
+
import litellm
|
5 |
+
from litellm._logging import print_verbose
|
6 |
+
from litellm.utils import get_optional_params
|
7 |
+
|
8 |
+
from ..llms.vllm.completion import handler as vllm_handler
|
9 |
+
|
10 |
+
|
11 |
+
def batch_completion(
|
12 |
+
model: str,
|
13 |
+
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
14 |
+
messages: List = [],
|
15 |
+
functions: Optional[List] = None,
|
16 |
+
function_call: Optional[str] = None,
|
17 |
+
temperature: Optional[float] = None,
|
18 |
+
top_p: Optional[float] = None,
|
19 |
+
n: Optional[int] = None,
|
20 |
+
stream: Optional[bool] = None,
|
21 |
+
stop=None,
|
22 |
+
max_tokens: Optional[int] = None,
|
23 |
+
presence_penalty: Optional[float] = None,
|
24 |
+
frequency_penalty: Optional[float] = None,
|
25 |
+
logit_bias: Optional[dict] = None,
|
26 |
+
user: Optional[str] = None,
|
27 |
+
deployment_id=None,
|
28 |
+
request_timeout: Optional[int] = None,
|
29 |
+
timeout: Optional[int] = 600,
|
30 |
+
max_workers: Optional[int] = 100,
|
31 |
+
# Optional liteLLM function params
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Batch litellm.completion function for a given model.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model (str): The model to use for generating completions.
|
39 |
+
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
40 |
+
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
41 |
+
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
42 |
+
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
43 |
+
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
44 |
+
n (int, optional): The number of completions to generate. Defaults to None.
|
45 |
+
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
46 |
+
stop (optional): The stop parameter for generating completions. Defaults to None.
|
47 |
+
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
48 |
+
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
49 |
+
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
50 |
+
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
51 |
+
user (str, optional): The user string for generating completions. Defaults to "".
|
52 |
+
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
53 |
+
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
54 |
+
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
list: A list of completion results.
|
58 |
+
"""
|
59 |
+
args = locals()
|
60 |
+
|
61 |
+
batch_messages = messages
|
62 |
+
completions = []
|
63 |
+
model = model
|
64 |
+
custom_llm_provider = None
|
65 |
+
if model.split("/", 1)[0] in litellm.provider_list:
|
66 |
+
custom_llm_provider = model.split("/", 1)[0]
|
67 |
+
model = model.split("/", 1)[1]
|
68 |
+
if custom_llm_provider == "vllm":
|
69 |
+
optional_params = get_optional_params(
|
70 |
+
functions=functions,
|
71 |
+
function_call=function_call,
|
72 |
+
temperature=temperature,
|
73 |
+
top_p=top_p,
|
74 |
+
n=n,
|
75 |
+
stream=stream or False,
|
76 |
+
stop=stop,
|
77 |
+
max_tokens=max_tokens,
|
78 |
+
presence_penalty=presence_penalty,
|
79 |
+
frequency_penalty=frequency_penalty,
|
80 |
+
logit_bias=logit_bias,
|
81 |
+
user=user,
|
82 |
+
# params to identify the model
|
83 |
+
model=model,
|
84 |
+
custom_llm_provider=custom_llm_provider,
|
85 |
+
)
|
86 |
+
results = vllm_handler.batch_completions(
|
87 |
+
model=model,
|
88 |
+
messages=batch_messages,
|
89 |
+
custom_prompt_dict=litellm.custom_prompt_dict,
|
90 |
+
optional_params=optional_params,
|
91 |
+
)
|
92 |
+
# all non VLLM models for batch completion models
|
93 |
+
else:
|
94 |
+
|
95 |
+
def chunks(lst, n):
|
96 |
+
"""Yield successive n-sized chunks from lst."""
|
97 |
+
for i in range(0, len(lst), n):
|
98 |
+
yield lst[i : i + n]
|
99 |
+
|
100 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
101 |
+
for sub_batch in chunks(batch_messages, 100):
|
102 |
+
for message_list in sub_batch:
|
103 |
+
kwargs_modified = args.copy()
|
104 |
+
kwargs_modified.pop("max_workers")
|
105 |
+
kwargs_modified["messages"] = message_list
|
106 |
+
original_kwargs = {}
|
107 |
+
if "kwargs" in kwargs_modified:
|
108 |
+
original_kwargs = kwargs_modified.pop("kwargs")
|
109 |
+
future = executor.submit(
|
110 |
+
litellm.completion, **kwargs_modified, **original_kwargs
|
111 |
+
)
|
112 |
+
completions.append(future)
|
113 |
+
|
114 |
+
# Retrieve the results from the futures
|
115 |
+
# results = [future.result() for future in completions]
|
116 |
+
# return exceptions if any
|
117 |
+
results = []
|
118 |
+
for future in completions:
|
119 |
+
try:
|
120 |
+
results.append(future.result())
|
121 |
+
except Exception as exc:
|
122 |
+
results.append(exc)
|
123 |
+
|
124 |
+
return results
|
125 |
+
|
126 |
+
|
127 |
+
# send one request to multiple models
|
128 |
+
# return as soon as one of the llms responds
|
129 |
+
def batch_completion_models(*args, **kwargs):
|
130 |
+
"""
|
131 |
+
Send a request to multiple language models concurrently and return the response
|
132 |
+
as soon as one of the models responds.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
*args: Variable-length positional arguments passed to the completion function.
|
136 |
+
**kwargs: Additional keyword arguments:
|
137 |
+
- models (str or list of str): The language models to send requests to.
|
138 |
+
- Other keyword arguments to be passed to the completion function.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
str or None: The response from one of the language models, or None if no response is received.
|
142 |
+
|
143 |
+
Note:
|
144 |
+
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
145 |
+
It sends requests concurrently and returns the response from the first model that responds.
|
146 |
+
"""
|
147 |
+
|
148 |
+
if "model" in kwargs:
|
149 |
+
kwargs.pop("model")
|
150 |
+
if "models" in kwargs:
|
151 |
+
models = kwargs["models"]
|
152 |
+
kwargs.pop("models")
|
153 |
+
futures = {}
|
154 |
+
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
155 |
+
for model in models:
|
156 |
+
futures[model] = executor.submit(
|
157 |
+
litellm.completion, *args, model=model, **kwargs
|
158 |
+
)
|
159 |
+
|
160 |
+
for model, future in sorted(
|
161 |
+
futures.items(), key=lambda x: models.index(x[0])
|
162 |
+
):
|
163 |
+
if future.result() is not None:
|
164 |
+
return future.result()
|
165 |
+
elif "deployments" in kwargs:
|
166 |
+
deployments = kwargs["deployments"]
|
167 |
+
kwargs.pop("deployments")
|
168 |
+
kwargs.pop("model_list")
|
169 |
+
nested_kwargs = kwargs.pop("kwargs", {})
|
170 |
+
futures = {}
|
171 |
+
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
172 |
+
for deployment in deployments:
|
173 |
+
for key in kwargs.keys():
|
174 |
+
if (
|
175 |
+
key not in deployment
|
176 |
+
): # don't override deployment values e.g. model name, api base, etc.
|
177 |
+
deployment[key] = kwargs[key]
|
178 |
+
kwargs = {**deployment, **nested_kwargs}
|
179 |
+
futures[deployment["model"]] = executor.submit(
|
180 |
+
litellm.completion, **kwargs
|
181 |
+
)
|
182 |
+
|
183 |
+
while futures:
|
184 |
+
# wait for the first returned future
|
185 |
+
print_verbose("\n\n waiting for next result\n\n")
|
186 |
+
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
187 |
+
print_verbose(f"done list\n{done}")
|
188 |
+
for future in done:
|
189 |
+
try:
|
190 |
+
result = future.result()
|
191 |
+
return result
|
192 |
+
except Exception:
|
193 |
+
# if model 1 fails, continue with response from model 2, model3
|
194 |
+
print_verbose(
|
195 |
+
"\n\ngot an exception, ignoring, removing from futures"
|
196 |
+
)
|
197 |
+
print_verbose(futures)
|
198 |
+
new_futures = {}
|
199 |
+
for key, value in futures.items():
|
200 |
+
if future == value:
|
201 |
+
print_verbose(f"removing key{key}")
|
202 |
+
continue
|
203 |
+
else:
|
204 |
+
new_futures[key] = value
|
205 |
+
futures = new_futures
|
206 |
+
print_verbose(f"new futures{futures}")
|
207 |
+
continue
|
208 |
+
|
209 |
+
print_verbose("\n\ndone looping through futures\n\n")
|
210 |
+
print_verbose(futures)
|
211 |
+
|
212 |
+
return None # If no response is received from any model
|
213 |
+
|
214 |
+
|
215 |
+
def batch_completion_models_all_responses(*args, **kwargs):
|
216 |
+
"""
|
217 |
+
Send a request to multiple language models concurrently and return a list of responses
|
218 |
+
from all models that respond.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
*args: Variable-length positional arguments passed to the completion function.
|
222 |
+
**kwargs: Additional keyword arguments:
|
223 |
+
- models (str or list of str): The language models to send requests to.
|
224 |
+
- Other keyword arguments to be passed to the completion function.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
list: A list of responses from the language models that responded.
|
228 |
+
|
229 |
+
Note:
|
230 |
+
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
231 |
+
It sends requests concurrently and collects responses from all models that respond.
|
232 |
+
"""
|
233 |
+
import concurrent.futures
|
234 |
+
|
235 |
+
# ANSI escape codes for colored output
|
236 |
+
|
237 |
+
if "model" in kwargs:
|
238 |
+
kwargs.pop("model")
|
239 |
+
if "models" in kwargs:
|
240 |
+
models = kwargs["models"]
|
241 |
+
kwargs.pop("models")
|
242 |
+
else:
|
243 |
+
raise Exception("'models' param not in kwargs")
|
244 |
+
|
245 |
+
responses = []
|
246 |
+
|
247 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
248 |
+
for idx, model in enumerate(models):
|
249 |
+
future = executor.submit(litellm.completion, *args, model=model, **kwargs)
|
250 |
+
if future.result() is not None:
|
251 |
+
responses.append(future.result())
|
252 |
+
|
253 |
+
return responses
|
litellm/batches/batch_utils.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Any, List, Literal, Tuple
|
3 |
+
|
4 |
+
import litellm
|
5 |
+
from litellm._logging import verbose_logger
|
6 |
+
from litellm.types.llms.openai import Batch
|
7 |
+
from litellm.types.utils import CallTypes, Usage
|
8 |
+
|
9 |
+
|
10 |
+
async def _handle_completed_batch(
|
11 |
+
batch: Batch,
|
12 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
13 |
+
) -> Tuple[float, Usage, List[str]]:
|
14 |
+
"""Helper function to process a completed batch and handle logging"""
|
15 |
+
# Get batch results
|
16 |
+
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
17 |
+
batch, custom_llm_provider
|
18 |
+
)
|
19 |
+
|
20 |
+
# Calculate costs and usage
|
21 |
+
batch_cost = await _batch_cost_calculator(
|
22 |
+
custom_llm_provider=custom_llm_provider,
|
23 |
+
file_content_dictionary=file_content_dictionary,
|
24 |
+
)
|
25 |
+
batch_usage = _get_batch_job_total_usage_from_file_content(
|
26 |
+
file_content_dictionary=file_content_dictionary,
|
27 |
+
custom_llm_provider=custom_llm_provider,
|
28 |
+
)
|
29 |
+
|
30 |
+
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
|
31 |
+
|
32 |
+
return batch_cost, batch_usage, batch_models
|
33 |
+
|
34 |
+
|
35 |
+
def _get_batch_models_from_file_content(
|
36 |
+
file_content_dictionary: List[dict],
|
37 |
+
) -> List[str]:
|
38 |
+
"""
|
39 |
+
Get the models from the file content
|
40 |
+
"""
|
41 |
+
batch_models = []
|
42 |
+
for _item in file_content_dictionary:
|
43 |
+
if _batch_response_was_successful(_item):
|
44 |
+
_response_body = _get_response_from_batch_job_output_file(_item)
|
45 |
+
_model = _response_body.get("model")
|
46 |
+
if _model:
|
47 |
+
batch_models.append(_model)
|
48 |
+
return batch_models
|
49 |
+
|
50 |
+
|
51 |
+
async def _batch_cost_calculator(
|
52 |
+
file_content_dictionary: List[dict],
|
53 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
54 |
+
) -> float:
|
55 |
+
"""
|
56 |
+
Calculate the cost of a batch based on the output file id
|
57 |
+
"""
|
58 |
+
if custom_llm_provider == "vertex_ai":
|
59 |
+
raise ValueError("Vertex AI does not support file content retrieval")
|
60 |
+
total_cost = _get_batch_job_cost_from_file_content(
|
61 |
+
file_content_dictionary=file_content_dictionary,
|
62 |
+
custom_llm_provider=custom_llm_provider,
|
63 |
+
)
|
64 |
+
verbose_logger.debug("total_cost=%s", total_cost)
|
65 |
+
return total_cost
|
66 |
+
|
67 |
+
|
68 |
+
async def _get_batch_output_file_content_as_dictionary(
|
69 |
+
batch: Batch,
|
70 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
71 |
+
) -> List[dict]:
|
72 |
+
"""
|
73 |
+
Get the batch output file content as a list of dictionaries
|
74 |
+
"""
|
75 |
+
from litellm.files.main import afile_content
|
76 |
+
|
77 |
+
if custom_llm_provider == "vertex_ai":
|
78 |
+
raise ValueError("Vertex AI does not support file content retrieval")
|
79 |
+
|
80 |
+
if batch.output_file_id is None:
|
81 |
+
raise ValueError("Output file id is None cannot retrieve file content")
|
82 |
+
|
83 |
+
_file_content = await afile_content(
|
84 |
+
file_id=batch.output_file_id,
|
85 |
+
custom_llm_provider=custom_llm_provider,
|
86 |
+
)
|
87 |
+
return _get_file_content_as_dictionary(_file_content.content)
|
88 |
+
|
89 |
+
|
90 |
+
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
91 |
+
"""
|
92 |
+
Get the file content as a list of dictionaries from JSON Lines format
|
93 |
+
"""
|
94 |
+
try:
|
95 |
+
_file_content_str = file_content.decode("utf-8")
|
96 |
+
# Split by newlines and parse each line as a separate JSON object
|
97 |
+
json_objects = []
|
98 |
+
for line in _file_content_str.strip().split("\n"):
|
99 |
+
if line: # Skip empty lines
|
100 |
+
json_objects.append(json.loads(line))
|
101 |
+
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
102 |
+
return json_objects
|
103 |
+
except Exception as e:
|
104 |
+
raise e
|
105 |
+
|
106 |
+
|
107 |
+
def _get_batch_job_cost_from_file_content(
|
108 |
+
file_content_dictionary: List[dict],
|
109 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
110 |
+
) -> float:
|
111 |
+
"""
|
112 |
+
Get the cost of a batch job from the file content
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
total_cost: float = 0.0
|
116 |
+
# parse the file content as json
|
117 |
+
verbose_logger.debug(
|
118 |
+
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
119 |
+
)
|
120 |
+
for _item in file_content_dictionary:
|
121 |
+
if _batch_response_was_successful(_item):
|
122 |
+
_response_body = _get_response_from_batch_job_output_file(_item)
|
123 |
+
total_cost += litellm.completion_cost(
|
124 |
+
completion_response=_response_body,
|
125 |
+
custom_llm_provider=custom_llm_provider,
|
126 |
+
call_type=CallTypes.aretrieve_batch.value,
|
127 |
+
)
|
128 |
+
verbose_logger.debug("total_cost=%s", total_cost)
|
129 |
+
return total_cost
|
130 |
+
except Exception as e:
|
131 |
+
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
132 |
+
raise e
|
133 |
+
|
134 |
+
|
135 |
+
def _get_batch_job_total_usage_from_file_content(
|
136 |
+
file_content_dictionary: List[dict],
|
137 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
138 |
+
) -> Usage:
|
139 |
+
"""
|
140 |
+
Get the tokens of a batch job from the file content
|
141 |
+
"""
|
142 |
+
total_tokens: int = 0
|
143 |
+
prompt_tokens: int = 0
|
144 |
+
completion_tokens: int = 0
|
145 |
+
for _item in file_content_dictionary:
|
146 |
+
if _batch_response_was_successful(_item):
|
147 |
+
_response_body = _get_response_from_batch_job_output_file(_item)
|
148 |
+
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
149 |
+
total_tokens += usage.total_tokens
|
150 |
+
prompt_tokens += usage.prompt_tokens
|
151 |
+
completion_tokens += usage.completion_tokens
|
152 |
+
return Usage(
|
153 |
+
total_tokens=total_tokens,
|
154 |
+
prompt_tokens=prompt_tokens,
|
155 |
+
completion_tokens=completion_tokens,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
160 |
+
"""
|
161 |
+
Get the tokens of a batch job from the response body
|
162 |
+
"""
|
163 |
+
_usage_dict = response_body.get("usage", None) or {}
|
164 |
+
usage: Usage = Usage(**_usage_dict)
|
165 |
+
return usage
|
166 |
+
|
167 |
+
|
168 |
+
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
169 |
+
"""
|
170 |
+
Get the response from the batch job output file
|
171 |
+
"""
|
172 |
+
_response: dict = batch_job_output_file.get("response", None) or {}
|
173 |
+
_response_body = _response.get("body", None) or {}
|
174 |
+
return _response_body
|
175 |
+
|
176 |
+
|
177 |
+
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
178 |
+
"""
|
179 |
+
Check if the batch job response status == 200
|
180 |
+
"""
|
181 |
+
_response: dict = batch_job_output_file.get("response", None) or {}
|
182 |
+
return _response.get("status_code", None) == 200
|
litellm/batches/main.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main File for Batches API implementation
|
3 |
+
|
4 |
+
https://platform.openai.com/docs/api-reference/batch
|
5 |
+
|
6 |
+
- create_batch()
|
7 |
+
- retrieve_batch()
|
8 |
+
- cancel_batch()
|
9 |
+
- list_batch()
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
import asyncio
|
14 |
+
import contextvars
|
15 |
+
import os
|
16 |
+
from functools import partial
|
17 |
+
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
18 |
+
|
19 |
+
import httpx
|
20 |
+
|
21 |
+
import litellm
|
22 |
+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
23 |
+
from litellm.llms.azure.batches.handler import AzureBatchesAPI
|
24 |
+
from litellm.llms.openai.openai import OpenAIBatchesAPI
|
25 |
+
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
26 |
+
from litellm.secret_managers.main import get_secret_str
|
27 |
+
from litellm.types.llms.openai import (
|
28 |
+
Batch,
|
29 |
+
CancelBatchRequest,
|
30 |
+
CreateBatchRequest,
|
31 |
+
RetrieveBatchRequest,
|
32 |
+
)
|
33 |
+
from litellm.types.router import GenericLiteLLMParams
|
34 |
+
from litellm.types.utils import LiteLLMBatch
|
35 |
+
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
|
36 |
+
|
37 |
+
####### ENVIRONMENT VARIABLES ###################
|
38 |
+
openai_batches_instance = OpenAIBatchesAPI()
|
39 |
+
azure_batches_instance = AzureBatchesAPI()
|
40 |
+
vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
|
41 |
+
#################################################
|
42 |
+
|
43 |
+
|
44 |
+
@client
|
45 |
+
async def acreate_batch(
|
46 |
+
completion_window: Literal["24h"],
|
47 |
+
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
48 |
+
input_file_id: str,
|
49 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
50 |
+
metadata: Optional[Dict[str, str]] = None,
|
51 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
52 |
+
extra_body: Optional[Dict[str, str]] = None,
|
53 |
+
**kwargs,
|
54 |
+
) -> Batch:
|
55 |
+
"""
|
56 |
+
Async: Creates and executes a batch from an uploaded file of request
|
57 |
+
|
58 |
+
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
59 |
+
"""
|
60 |
+
try:
|
61 |
+
loop = asyncio.get_event_loop()
|
62 |
+
kwargs["acreate_batch"] = True
|
63 |
+
|
64 |
+
# Use a partial function to pass your keyword arguments
|
65 |
+
func = partial(
|
66 |
+
create_batch,
|
67 |
+
completion_window,
|
68 |
+
endpoint,
|
69 |
+
input_file_id,
|
70 |
+
custom_llm_provider,
|
71 |
+
metadata,
|
72 |
+
extra_headers,
|
73 |
+
extra_body,
|
74 |
+
**kwargs,
|
75 |
+
)
|
76 |
+
|
77 |
+
# Add the context to the function
|
78 |
+
ctx = contextvars.copy_context()
|
79 |
+
func_with_context = partial(ctx.run, func)
|
80 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
81 |
+
|
82 |
+
if asyncio.iscoroutine(init_response):
|
83 |
+
response = await init_response
|
84 |
+
else:
|
85 |
+
response = init_response
|
86 |
+
|
87 |
+
return response
|
88 |
+
except Exception as e:
|
89 |
+
raise e
|
90 |
+
|
91 |
+
|
92 |
+
@client
|
93 |
+
def create_batch(
|
94 |
+
completion_window: Literal["24h"],
|
95 |
+
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
96 |
+
input_file_id: str,
|
97 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
98 |
+
metadata: Optional[Dict[str, str]] = None,
|
99 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
100 |
+
extra_body: Optional[Dict[str, str]] = None,
|
101 |
+
**kwargs,
|
102 |
+
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
103 |
+
"""
|
104 |
+
Creates and executes a batch from an uploaded file of request
|
105 |
+
|
106 |
+
LiteLLM Equivalent of POST: https://api.openai.com/v1/batches
|
107 |
+
"""
|
108 |
+
try:
|
109 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
110 |
+
litellm_call_id = kwargs.get("litellm_call_id", None)
|
111 |
+
proxy_server_request = kwargs.get("proxy_server_request", None)
|
112 |
+
model_info = kwargs.get("model_info", None)
|
113 |
+
_is_async = kwargs.pop("acreate_batch", False) is True
|
114 |
+
litellm_params = get_litellm_params(**kwargs)
|
115 |
+
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
116 |
+
### TIMEOUT LOGIC ###
|
117 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
118 |
+
litellm_logging_obj.update_environment_variables(
|
119 |
+
model=None,
|
120 |
+
user=None,
|
121 |
+
optional_params=optional_params.model_dump(),
|
122 |
+
litellm_params={
|
123 |
+
"litellm_call_id": litellm_call_id,
|
124 |
+
"proxy_server_request": proxy_server_request,
|
125 |
+
"model_info": model_info,
|
126 |
+
"metadata": metadata,
|
127 |
+
"preset_cache_key": None,
|
128 |
+
"stream_response": {},
|
129 |
+
**optional_params.model_dump(exclude_unset=True),
|
130 |
+
},
|
131 |
+
custom_llm_provider=custom_llm_provider,
|
132 |
+
)
|
133 |
+
|
134 |
+
if (
|
135 |
+
timeout is not None
|
136 |
+
and isinstance(timeout, httpx.Timeout)
|
137 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
138 |
+
):
|
139 |
+
read_timeout = timeout.read or 600
|
140 |
+
timeout = read_timeout # default 10 min timeout
|
141 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
142 |
+
timeout = float(timeout) # type: ignore
|
143 |
+
elif timeout is None:
|
144 |
+
timeout = 600.0
|
145 |
+
|
146 |
+
_create_batch_request = CreateBatchRequest(
|
147 |
+
completion_window=completion_window,
|
148 |
+
endpoint=endpoint,
|
149 |
+
input_file_id=input_file_id,
|
150 |
+
metadata=metadata,
|
151 |
+
extra_headers=extra_headers,
|
152 |
+
extra_body=extra_body,
|
153 |
+
)
|
154 |
+
api_base: Optional[str] = None
|
155 |
+
if custom_llm_provider == "openai":
|
156 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
157 |
+
api_base = (
|
158 |
+
optional_params.api_base
|
159 |
+
or litellm.api_base
|
160 |
+
or os.getenv("OPENAI_BASE_URL")
|
161 |
+
or os.getenv("OPENAI_API_BASE")
|
162 |
+
or "https://api.openai.com/v1"
|
163 |
+
)
|
164 |
+
organization = (
|
165 |
+
optional_params.organization
|
166 |
+
or litellm.organization
|
167 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
168 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
169 |
+
)
|
170 |
+
# set API KEY
|
171 |
+
api_key = (
|
172 |
+
optional_params.api_key
|
173 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
174 |
+
or litellm.openai_key
|
175 |
+
or os.getenv("OPENAI_API_KEY")
|
176 |
+
)
|
177 |
+
|
178 |
+
response = openai_batches_instance.create_batch(
|
179 |
+
api_base=api_base,
|
180 |
+
api_key=api_key,
|
181 |
+
organization=organization,
|
182 |
+
create_batch_data=_create_batch_request,
|
183 |
+
timeout=timeout,
|
184 |
+
max_retries=optional_params.max_retries,
|
185 |
+
_is_async=_is_async,
|
186 |
+
)
|
187 |
+
elif custom_llm_provider == "azure":
|
188 |
+
api_base = (
|
189 |
+
optional_params.api_base
|
190 |
+
or litellm.api_base
|
191 |
+
or get_secret_str("AZURE_API_BASE")
|
192 |
+
)
|
193 |
+
api_version = (
|
194 |
+
optional_params.api_version
|
195 |
+
or litellm.api_version
|
196 |
+
or get_secret_str("AZURE_API_VERSION")
|
197 |
+
)
|
198 |
+
|
199 |
+
api_key = (
|
200 |
+
optional_params.api_key
|
201 |
+
or litellm.api_key
|
202 |
+
or litellm.azure_key
|
203 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
204 |
+
or get_secret_str("AZURE_API_KEY")
|
205 |
+
)
|
206 |
+
|
207 |
+
extra_body = optional_params.get("extra_body", {})
|
208 |
+
if extra_body is not None:
|
209 |
+
extra_body.pop("azure_ad_token", None)
|
210 |
+
else:
|
211 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
212 |
+
|
213 |
+
response = azure_batches_instance.create_batch(
|
214 |
+
_is_async=_is_async,
|
215 |
+
api_base=api_base,
|
216 |
+
api_key=api_key,
|
217 |
+
api_version=api_version,
|
218 |
+
timeout=timeout,
|
219 |
+
max_retries=optional_params.max_retries,
|
220 |
+
create_batch_data=_create_batch_request,
|
221 |
+
litellm_params=litellm_params,
|
222 |
+
)
|
223 |
+
elif custom_llm_provider == "vertex_ai":
|
224 |
+
api_base = optional_params.api_base or ""
|
225 |
+
vertex_ai_project = (
|
226 |
+
optional_params.vertex_project
|
227 |
+
or litellm.vertex_project
|
228 |
+
or get_secret_str("VERTEXAI_PROJECT")
|
229 |
+
)
|
230 |
+
vertex_ai_location = (
|
231 |
+
optional_params.vertex_location
|
232 |
+
or litellm.vertex_location
|
233 |
+
or get_secret_str("VERTEXAI_LOCATION")
|
234 |
+
)
|
235 |
+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
236 |
+
"VERTEXAI_CREDENTIALS"
|
237 |
+
)
|
238 |
+
|
239 |
+
response = vertex_ai_batches_instance.create_batch(
|
240 |
+
_is_async=_is_async,
|
241 |
+
api_base=api_base,
|
242 |
+
vertex_project=vertex_ai_project,
|
243 |
+
vertex_location=vertex_ai_location,
|
244 |
+
vertex_credentials=vertex_credentials,
|
245 |
+
timeout=timeout,
|
246 |
+
max_retries=optional_params.max_retries,
|
247 |
+
create_batch_data=_create_batch_request,
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
raise litellm.exceptions.BadRequestError(
|
251 |
+
message="LiteLLM doesn't support custom_llm_provider={} for 'create_batch'".format(
|
252 |
+
custom_llm_provider
|
253 |
+
),
|
254 |
+
model="n/a",
|
255 |
+
llm_provider=custom_llm_provider,
|
256 |
+
response=httpx.Response(
|
257 |
+
status_code=400,
|
258 |
+
content="Unsupported provider",
|
259 |
+
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
|
260 |
+
),
|
261 |
+
)
|
262 |
+
return response
|
263 |
+
except Exception as e:
|
264 |
+
raise e
|
265 |
+
|
266 |
+
|
267 |
+
@client
|
268 |
+
async def aretrieve_batch(
|
269 |
+
batch_id: str,
|
270 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
271 |
+
metadata: Optional[Dict[str, str]] = None,
|
272 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
273 |
+
extra_body: Optional[Dict[str, str]] = None,
|
274 |
+
**kwargs,
|
275 |
+
) -> LiteLLMBatch:
|
276 |
+
"""
|
277 |
+
Async: Retrieves a batch.
|
278 |
+
|
279 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
280 |
+
"""
|
281 |
+
try:
|
282 |
+
loop = asyncio.get_event_loop()
|
283 |
+
kwargs["aretrieve_batch"] = True
|
284 |
+
|
285 |
+
# Use a partial function to pass your keyword arguments
|
286 |
+
func = partial(
|
287 |
+
retrieve_batch,
|
288 |
+
batch_id,
|
289 |
+
custom_llm_provider,
|
290 |
+
metadata,
|
291 |
+
extra_headers,
|
292 |
+
extra_body,
|
293 |
+
**kwargs,
|
294 |
+
)
|
295 |
+
# Add the context to the function
|
296 |
+
ctx = contextvars.copy_context()
|
297 |
+
func_with_context = partial(ctx.run, func)
|
298 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
299 |
+
if asyncio.iscoroutine(init_response):
|
300 |
+
response = await init_response
|
301 |
+
else:
|
302 |
+
response = init_response # type: ignore
|
303 |
+
|
304 |
+
return response
|
305 |
+
except Exception as e:
|
306 |
+
raise e
|
307 |
+
|
308 |
+
|
309 |
+
@client
|
310 |
+
def retrieve_batch(
|
311 |
+
batch_id: str,
|
312 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
313 |
+
metadata: Optional[Dict[str, str]] = None,
|
314 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
315 |
+
extra_body: Optional[Dict[str, str]] = None,
|
316 |
+
**kwargs,
|
317 |
+
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
318 |
+
"""
|
319 |
+
Retrieves a batch.
|
320 |
+
|
321 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/batches/{batch_id}
|
322 |
+
"""
|
323 |
+
try:
|
324 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
325 |
+
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
326 |
+
### TIMEOUT LOGIC ###
|
327 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
328 |
+
litellm_params = get_litellm_params(
|
329 |
+
custom_llm_provider=custom_llm_provider,
|
330 |
+
**kwargs,
|
331 |
+
)
|
332 |
+
litellm_logging_obj.update_environment_variables(
|
333 |
+
model=None,
|
334 |
+
user=None,
|
335 |
+
optional_params=optional_params.model_dump(),
|
336 |
+
litellm_params=litellm_params,
|
337 |
+
custom_llm_provider=custom_llm_provider,
|
338 |
+
)
|
339 |
+
|
340 |
+
if (
|
341 |
+
timeout is not None
|
342 |
+
and isinstance(timeout, httpx.Timeout)
|
343 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
344 |
+
):
|
345 |
+
read_timeout = timeout.read or 600
|
346 |
+
timeout = read_timeout # default 10 min timeout
|
347 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
348 |
+
timeout = float(timeout) # type: ignore
|
349 |
+
elif timeout is None:
|
350 |
+
timeout = 600.0
|
351 |
+
|
352 |
+
_retrieve_batch_request = RetrieveBatchRequest(
|
353 |
+
batch_id=batch_id,
|
354 |
+
extra_headers=extra_headers,
|
355 |
+
extra_body=extra_body,
|
356 |
+
)
|
357 |
+
|
358 |
+
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
359 |
+
api_base: Optional[str] = None
|
360 |
+
if custom_llm_provider == "openai":
|
361 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
362 |
+
api_base = (
|
363 |
+
optional_params.api_base
|
364 |
+
or litellm.api_base
|
365 |
+
or os.getenv("OPENAI_BASE_URL")
|
366 |
+
or os.getenv("OPENAI_API_BASE")
|
367 |
+
or "https://api.openai.com/v1"
|
368 |
+
)
|
369 |
+
organization = (
|
370 |
+
optional_params.organization
|
371 |
+
or litellm.organization
|
372 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
373 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
374 |
+
)
|
375 |
+
# set API KEY
|
376 |
+
api_key = (
|
377 |
+
optional_params.api_key
|
378 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
379 |
+
or litellm.openai_key
|
380 |
+
or os.getenv("OPENAI_API_KEY")
|
381 |
+
)
|
382 |
+
|
383 |
+
response = openai_batches_instance.retrieve_batch(
|
384 |
+
_is_async=_is_async,
|
385 |
+
retrieve_batch_data=_retrieve_batch_request,
|
386 |
+
api_base=api_base,
|
387 |
+
api_key=api_key,
|
388 |
+
organization=organization,
|
389 |
+
timeout=timeout,
|
390 |
+
max_retries=optional_params.max_retries,
|
391 |
+
)
|
392 |
+
elif custom_llm_provider == "azure":
|
393 |
+
api_base = (
|
394 |
+
optional_params.api_base
|
395 |
+
or litellm.api_base
|
396 |
+
or get_secret_str("AZURE_API_BASE")
|
397 |
+
)
|
398 |
+
api_version = (
|
399 |
+
optional_params.api_version
|
400 |
+
or litellm.api_version
|
401 |
+
or get_secret_str("AZURE_API_VERSION")
|
402 |
+
)
|
403 |
+
|
404 |
+
api_key = (
|
405 |
+
optional_params.api_key
|
406 |
+
or litellm.api_key
|
407 |
+
or litellm.azure_key
|
408 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
409 |
+
or get_secret_str("AZURE_API_KEY")
|
410 |
+
)
|
411 |
+
|
412 |
+
extra_body = optional_params.get("extra_body", {})
|
413 |
+
if extra_body is not None:
|
414 |
+
extra_body.pop("azure_ad_token", None)
|
415 |
+
else:
|
416 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
417 |
+
|
418 |
+
response = azure_batches_instance.retrieve_batch(
|
419 |
+
_is_async=_is_async,
|
420 |
+
api_base=api_base,
|
421 |
+
api_key=api_key,
|
422 |
+
api_version=api_version,
|
423 |
+
timeout=timeout,
|
424 |
+
max_retries=optional_params.max_retries,
|
425 |
+
retrieve_batch_data=_retrieve_batch_request,
|
426 |
+
litellm_params=litellm_params,
|
427 |
+
)
|
428 |
+
elif custom_llm_provider == "vertex_ai":
|
429 |
+
api_base = optional_params.api_base or ""
|
430 |
+
vertex_ai_project = (
|
431 |
+
optional_params.vertex_project
|
432 |
+
or litellm.vertex_project
|
433 |
+
or get_secret_str("VERTEXAI_PROJECT")
|
434 |
+
)
|
435 |
+
vertex_ai_location = (
|
436 |
+
optional_params.vertex_location
|
437 |
+
or litellm.vertex_location
|
438 |
+
or get_secret_str("VERTEXAI_LOCATION")
|
439 |
+
)
|
440 |
+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
441 |
+
"VERTEXAI_CREDENTIALS"
|
442 |
+
)
|
443 |
+
|
444 |
+
response = vertex_ai_batches_instance.retrieve_batch(
|
445 |
+
_is_async=_is_async,
|
446 |
+
batch_id=batch_id,
|
447 |
+
api_base=api_base,
|
448 |
+
vertex_project=vertex_ai_project,
|
449 |
+
vertex_location=vertex_ai_location,
|
450 |
+
vertex_credentials=vertex_credentials,
|
451 |
+
timeout=timeout,
|
452 |
+
max_retries=optional_params.max_retries,
|
453 |
+
)
|
454 |
+
else:
|
455 |
+
raise litellm.exceptions.BadRequestError(
|
456 |
+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
457 |
+
custom_llm_provider
|
458 |
+
),
|
459 |
+
model="n/a",
|
460 |
+
llm_provider=custom_llm_provider,
|
461 |
+
response=httpx.Response(
|
462 |
+
status_code=400,
|
463 |
+
content="Unsupported provider",
|
464 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
465 |
+
),
|
466 |
+
)
|
467 |
+
return response
|
468 |
+
except Exception as e:
|
469 |
+
raise e
|
470 |
+
|
471 |
+
|
472 |
+
async def alist_batches(
|
473 |
+
after: Optional[str] = None,
|
474 |
+
limit: Optional[int] = None,
|
475 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
476 |
+
metadata: Optional[Dict[str, str]] = None,
|
477 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
478 |
+
extra_body: Optional[Dict[str, str]] = None,
|
479 |
+
**kwargs,
|
480 |
+
):
|
481 |
+
"""
|
482 |
+
Async: List your organization's batches.
|
483 |
+
"""
|
484 |
+
try:
|
485 |
+
loop = asyncio.get_event_loop()
|
486 |
+
kwargs["alist_batches"] = True
|
487 |
+
|
488 |
+
# Use a partial function to pass your keyword arguments
|
489 |
+
func = partial(
|
490 |
+
list_batches,
|
491 |
+
after,
|
492 |
+
limit,
|
493 |
+
custom_llm_provider,
|
494 |
+
extra_headers,
|
495 |
+
extra_body,
|
496 |
+
**kwargs,
|
497 |
+
)
|
498 |
+
|
499 |
+
# Add the context to the function
|
500 |
+
ctx = contextvars.copy_context()
|
501 |
+
func_with_context = partial(ctx.run, func)
|
502 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
503 |
+
if asyncio.iscoroutine(init_response):
|
504 |
+
response = await init_response
|
505 |
+
else:
|
506 |
+
response = init_response # type: ignore
|
507 |
+
|
508 |
+
return response
|
509 |
+
except Exception as e:
|
510 |
+
raise e
|
511 |
+
|
512 |
+
|
513 |
+
def list_batches(
|
514 |
+
after: Optional[str] = None,
|
515 |
+
limit: Optional[int] = None,
|
516 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
517 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
518 |
+
extra_body: Optional[Dict[str, str]] = None,
|
519 |
+
**kwargs,
|
520 |
+
):
|
521 |
+
"""
|
522 |
+
Lists batches
|
523 |
+
|
524 |
+
List your organization's batches.
|
525 |
+
"""
|
526 |
+
try:
|
527 |
+
# set API KEY
|
528 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
529 |
+
litellm_params = get_litellm_params(
|
530 |
+
custom_llm_provider=custom_llm_provider,
|
531 |
+
**kwargs,
|
532 |
+
)
|
533 |
+
api_key = (
|
534 |
+
optional_params.api_key
|
535 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
536 |
+
or litellm.openai_key
|
537 |
+
or os.getenv("OPENAI_API_KEY")
|
538 |
+
)
|
539 |
+
### TIMEOUT LOGIC ###
|
540 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
541 |
+
# set timeout for 10 minutes by default
|
542 |
+
|
543 |
+
if (
|
544 |
+
timeout is not None
|
545 |
+
and isinstance(timeout, httpx.Timeout)
|
546 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
547 |
+
):
|
548 |
+
read_timeout = timeout.read or 600
|
549 |
+
timeout = read_timeout # default 10 min timeout
|
550 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
551 |
+
timeout = float(timeout) # type: ignore
|
552 |
+
elif timeout is None:
|
553 |
+
timeout = 600.0
|
554 |
+
|
555 |
+
_is_async = kwargs.pop("alist_batches", False) is True
|
556 |
+
if custom_llm_provider == "openai":
|
557 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
558 |
+
api_base = (
|
559 |
+
optional_params.api_base
|
560 |
+
or litellm.api_base
|
561 |
+
or os.getenv("OPENAI_BASE_URL")
|
562 |
+
or os.getenv("OPENAI_API_BASE")
|
563 |
+
or "https://api.openai.com/v1"
|
564 |
+
)
|
565 |
+
organization = (
|
566 |
+
optional_params.organization
|
567 |
+
or litellm.organization
|
568 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
569 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
570 |
+
)
|
571 |
+
|
572 |
+
response = openai_batches_instance.list_batches(
|
573 |
+
_is_async=_is_async,
|
574 |
+
after=after,
|
575 |
+
limit=limit,
|
576 |
+
api_base=api_base,
|
577 |
+
api_key=api_key,
|
578 |
+
organization=organization,
|
579 |
+
timeout=timeout,
|
580 |
+
max_retries=optional_params.max_retries,
|
581 |
+
)
|
582 |
+
elif custom_llm_provider == "azure":
|
583 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
584 |
+
api_version = (
|
585 |
+
optional_params.api_version
|
586 |
+
or litellm.api_version
|
587 |
+
or get_secret_str("AZURE_API_VERSION")
|
588 |
+
)
|
589 |
+
|
590 |
+
api_key = (
|
591 |
+
optional_params.api_key
|
592 |
+
or litellm.api_key
|
593 |
+
or litellm.azure_key
|
594 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
595 |
+
or get_secret_str("AZURE_API_KEY")
|
596 |
+
)
|
597 |
+
|
598 |
+
extra_body = optional_params.get("extra_body", {})
|
599 |
+
if extra_body is not None:
|
600 |
+
extra_body.pop("azure_ad_token", None)
|
601 |
+
else:
|
602 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
603 |
+
|
604 |
+
response = azure_batches_instance.list_batches(
|
605 |
+
_is_async=_is_async,
|
606 |
+
api_base=api_base,
|
607 |
+
api_key=api_key,
|
608 |
+
api_version=api_version,
|
609 |
+
timeout=timeout,
|
610 |
+
max_retries=optional_params.max_retries,
|
611 |
+
litellm_params=litellm_params,
|
612 |
+
)
|
613 |
+
else:
|
614 |
+
raise litellm.exceptions.BadRequestError(
|
615 |
+
message="LiteLLM doesn't support {} for 'list_batch'. Only 'openai' is supported.".format(
|
616 |
+
custom_llm_provider
|
617 |
+
),
|
618 |
+
model="n/a",
|
619 |
+
llm_provider=custom_llm_provider,
|
620 |
+
response=httpx.Response(
|
621 |
+
status_code=400,
|
622 |
+
content="Unsupported provider",
|
623 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
624 |
+
),
|
625 |
+
)
|
626 |
+
return response
|
627 |
+
except Exception as e:
|
628 |
+
raise e
|
629 |
+
|
630 |
+
|
631 |
+
async def acancel_batch(
|
632 |
+
batch_id: str,
|
633 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
634 |
+
metadata: Optional[Dict[str, str]] = None,
|
635 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
636 |
+
extra_body: Optional[Dict[str, str]] = None,
|
637 |
+
**kwargs,
|
638 |
+
) -> Batch:
|
639 |
+
"""
|
640 |
+
Async: Cancels a batch.
|
641 |
+
|
642 |
+
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
|
643 |
+
"""
|
644 |
+
try:
|
645 |
+
loop = asyncio.get_event_loop()
|
646 |
+
kwargs["acancel_batch"] = True
|
647 |
+
|
648 |
+
# Use a partial function to pass your keyword arguments
|
649 |
+
func = partial(
|
650 |
+
cancel_batch,
|
651 |
+
batch_id,
|
652 |
+
custom_llm_provider,
|
653 |
+
metadata,
|
654 |
+
extra_headers,
|
655 |
+
extra_body,
|
656 |
+
**kwargs,
|
657 |
+
)
|
658 |
+
# Add the context to the function
|
659 |
+
ctx = contextvars.copy_context()
|
660 |
+
func_with_context = partial(ctx.run, func)
|
661 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
662 |
+
if asyncio.iscoroutine(init_response):
|
663 |
+
response = await init_response
|
664 |
+
else:
|
665 |
+
response = init_response
|
666 |
+
|
667 |
+
return response
|
668 |
+
except Exception as e:
|
669 |
+
raise e
|
670 |
+
|
671 |
+
|
672 |
+
def cancel_batch(
|
673 |
+
batch_id: str,
|
674 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
675 |
+
metadata: Optional[Dict[str, str]] = None,
|
676 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
677 |
+
extra_body: Optional[Dict[str, str]] = None,
|
678 |
+
**kwargs,
|
679 |
+
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
680 |
+
"""
|
681 |
+
Cancels a batch.
|
682 |
+
|
683 |
+
LiteLLM Equivalent of POST https://api.openai.com/v1/batches/{batch_id}/cancel
|
684 |
+
"""
|
685 |
+
try:
|
686 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
687 |
+
litellm_params = get_litellm_params(
|
688 |
+
custom_llm_provider=custom_llm_provider,
|
689 |
+
**kwargs,
|
690 |
+
)
|
691 |
+
### TIMEOUT LOGIC ###
|
692 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
693 |
+
# set timeout for 10 minutes by default
|
694 |
+
|
695 |
+
if (
|
696 |
+
timeout is not None
|
697 |
+
and isinstance(timeout, httpx.Timeout)
|
698 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
699 |
+
):
|
700 |
+
read_timeout = timeout.read or 600
|
701 |
+
timeout = read_timeout # default 10 min timeout
|
702 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
703 |
+
timeout = float(timeout) # type: ignore
|
704 |
+
elif timeout is None:
|
705 |
+
timeout = 600.0
|
706 |
+
|
707 |
+
_cancel_batch_request = CancelBatchRequest(
|
708 |
+
batch_id=batch_id,
|
709 |
+
extra_headers=extra_headers,
|
710 |
+
extra_body=extra_body,
|
711 |
+
)
|
712 |
+
|
713 |
+
_is_async = kwargs.pop("acancel_batch", False) is True
|
714 |
+
api_base: Optional[str] = None
|
715 |
+
if custom_llm_provider == "openai":
|
716 |
+
api_base = (
|
717 |
+
optional_params.api_base
|
718 |
+
or litellm.api_base
|
719 |
+
or os.getenv("OPENAI_BASE_URL")
|
720 |
+
or os.getenv("OPENAI_API_BASE")
|
721 |
+
or "https://api.openai.com/v1"
|
722 |
+
)
|
723 |
+
organization = (
|
724 |
+
optional_params.organization
|
725 |
+
or litellm.organization
|
726 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
727 |
+
or None
|
728 |
+
)
|
729 |
+
api_key = (
|
730 |
+
optional_params.api_key
|
731 |
+
or litellm.api_key
|
732 |
+
or litellm.openai_key
|
733 |
+
or os.getenv("OPENAI_API_KEY")
|
734 |
+
)
|
735 |
+
|
736 |
+
response = openai_batches_instance.cancel_batch(
|
737 |
+
_is_async=_is_async,
|
738 |
+
cancel_batch_data=_cancel_batch_request,
|
739 |
+
api_base=api_base,
|
740 |
+
api_key=api_key,
|
741 |
+
organization=organization,
|
742 |
+
timeout=timeout,
|
743 |
+
max_retries=optional_params.max_retries,
|
744 |
+
)
|
745 |
+
elif custom_llm_provider == "azure":
|
746 |
+
api_base = (
|
747 |
+
optional_params.api_base
|
748 |
+
or litellm.api_base
|
749 |
+
or get_secret_str("AZURE_API_BASE")
|
750 |
+
)
|
751 |
+
api_version = (
|
752 |
+
optional_params.api_version
|
753 |
+
or litellm.api_version
|
754 |
+
or get_secret_str("AZURE_API_VERSION")
|
755 |
+
)
|
756 |
+
|
757 |
+
api_key = (
|
758 |
+
optional_params.api_key
|
759 |
+
or litellm.api_key
|
760 |
+
or litellm.azure_key
|
761 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
762 |
+
or get_secret_str("AZURE_API_KEY")
|
763 |
+
)
|
764 |
+
|
765 |
+
extra_body = optional_params.get("extra_body", {})
|
766 |
+
if extra_body is not None:
|
767 |
+
extra_body.pop("azure_ad_token", None)
|
768 |
+
else:
|
769 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
770 |
+
|
771 |
+
response = azure_batches_instance.cancel_batch(
|
772 |
+
_is_async=_is_async,
|
773 |
+
api_base=api_base,
|
774 |
+
api_key=api_key,
|
775 |
+
api_version=api_version,
|
776 |
+
timeout=timeout,
|
777 |
+
max_retries=optional_params.max_retries,
|
778 |
+
cancel_batch_data=_cancel_batch_request,
|
779 |
+
litellm_params=litellm_params,
|
780 |
+
)
|
781 |
+
else:
|
782 |
+
raise litellm.exceptions.BadRequestError(
|
783 |
+
message="LiteLLM doesn't support {} for 'cancel_batch'. Only 'openai' and 'azure' are supported.".format(
|
784 |
+
custom_llm_provider
|
785 |
+
),
|
786 |
+
model="n/a",
|
787 |
+
llm_provider=custom_llm_provider,
|
788 |
+
response=httpx.Response(
|
789 |
+
status_code=400,
|
790 |
+
content="Unsupported provider",
|
791 |
+
request=httpx.Request(method="cancel_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
|
792 |
+
),
|
793 |
+
)
|
794 |
+
return response
|
795 |
+
except Exception as e:
|
796 |
+
raise e
|
litellm/budget_manager.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +-----------------------------------------------+
|
2 |
+
# | |
|
3 |
+
# | NOT PROXY BUDGET MANAGER |
|
4 |
+
# | proxy budget manager is in proxy_server.py |
|
5 |
+
# | |
|
6 |
+
# +-----------------------------------------------+
|
7 |
+
#
|
8 |
+
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
9 |
+
|
10 |
+
import json
|
11 |
+
import os
|
12 |
+
import threading
|
13 |
+
import time
|
14 |
+
from typing import Literal, Optional
|
15 |
+
|
16 |
+
import litellm
|
17 |
+
from litellm.constants import (
|
18 |
+
DAYS_IN_A_MONTH,
|
19 |
+
DAYS_IN_A_WEEK,
|
20 |
+
DAYS_IN_A_YEAR,
|
21 |
+
HOURS_IN_A_DAY,
|
22 |
+
)
|
23 |
+
from litellm.utils import ModelResponse
|
24 |
+
|
25 |
+
|
26 |
+
class BudgetManager:
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
project_name: str,
|
30 |
+
client_type: str = "local",
|
31 |
+
api_base: Optional[str] = None,
|
32 |
+
headers: Optional[dict] = None,
|
33 |
+
):
|
34 |
+
self.client_type = client_type
|
35 |
+
self.project_name = project_name
|
36 |
+
self.api_base = api_base or "https://api.litellm.ai"
|
37 |
+
self.headers = headers or {"Content-Type": "application/json"}
|
38 |
+
## load the data or init the initial dictionaries
|
39 |
+
self.load_data()
|
40 |
+
|
41 |
+
def print_verbose(self, print_statement):
|
42 |
+
try:
|
43 |
+
if litellm.set_verbose:
|
44 |
+
import logging
|
45 |
+
|
46 |
+
logging.info(print_statement)
|
47 |
+
except Exception:
|
48 |
+
pass
|
49 |
+
|
50 |
+
def load_data(self):
|
51 |
+
if self.client_type == "local":
|
52 |
+
# Check if user dict file exists
|
53 |
+
if os.path.isfile("user_cost.json"):
|
54 |
+
# Load the user dict
|
55 |
+
with open("user_cost.json", "r") as json_file:
|
56 |
+
self.user_dict = json.load(json_file)
|
57 |
+
else:
|
58 |
+
self.print_verbose("User Dictionary not found!")
|
59 |
+
self.user_dict = {}
|
60 |
+
self.print_verbose(f"user dict from local: {self.user_dict}")
|
61 |
+
elif self.client_type == "hosted":
|
62 |
+
# Load the user_dict from hosted db
|
63 |
+
url = self.api_base + "/get_budget"
|
64 |
+
data = {"project_name": self.project_name}
|
65 |
+
response = litellm.module_level_client.post(
|
66 |
+
url, headers=self.headers, json=data
|
67 |
+
)
|
68 |
+
response = response.json()
|
69 |
+
if response["status"] == "error":
|
70 |
+
self.user_dict = (
|
71 |
+
{}
|
72 |
+
) # assume this means the user dict hasn't been stored yet
|
73 |
+
else:
|
74 |
+
self.user_dict = response["data"]
|
75 |
+
|
76 |
+
def create_budget(
|
77 |
+
self,
|
78 |
+
total_budget: float,
|
79 |
+
user: str,
|
80 |
+
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
81 |
+
created_at: float = time.time(),
|
82 |
+
):
|
83 |
+
self.user_dict[user] = {"total_budget": total_budget}
|
84 |
+
if duration is None:
|
85 |
+
return self.user_dict[user]
|
86 |
+
|
87 |
+
if duration == "daily":
|
88 |
+
duration_in_days = 1
|
89 |
+
elif duration == "weekly":
|
90 |
+
duration_in_days = DAYS_IN_A_WEEK
|
91 |
+
elif duration == "monthly":
|
92 |
+
duration_in_days = DAYS_IN_A_MONTH
|
93 |
+
elif duration == "yearly":
|
94 |
+
duration_in_days = DAYS_IN_A_YEAR
|
95 |
+
else:
|
96 |
+
raise ValueError(
|
97 |
+
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
98 |
+
)
|
99 |
+
self.user_dict[user] = {
|
100 |
+
"total_budget": total_budget,
|
101 |
+
"duration": duration_in_days,
|
102 |
+
"created_at": created_at,
|
103 |
+
"last_updated_at": created_at,
|
104 |
+
}
|
105 |
+
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
106 |
+
return self.user_dict[user]
|
107 |
+
|
108 |
+
def projected_cost(self, model: str, messages: list, user: str):
|
109 |
+
text = "".join(message["content"] for message in messages)
|
110 |
+
prompt_tokens = litellm.token_counter(model=model, text=text)
|
111 |
+
prompt_cost, _ = litellm.cost_per_token(
|
112 |
+
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
113 |
+
)
|
114 |
+
current_cost = self.user_dict[user].get("current_cost", 0)
|
115 |
+
projected_cost = prompt_cost + current_cost
|
116 |
+
return projected_cost
|
117 |
+
|
118 |
+
def get_total_budget(self, user: str):
|
119 |
+
return self.user_dict[user]["total_budget"]
|
120 |
+
|
121 |
+
def update_cost(
|
122 |
+
self,
|
123 |
+
user: str,
|
124 |
+
completion_obj: Optional[ModelResponse] = None,
|
125 |
+
model: Optional[str] = None,
|
126 |
+
input_text: Optional[str] = None,
|
127 |
+
output_text: Optional[str] = None,
|
128 |
+
):
|
129 |
+
if model and input_text and output_text:
|
130 |
+
prompt_tokens = litellm.token_counter(
|
131 |
+
model=model, messages=[{"role": "user", "content": input_text}]
|
132 |
+
)
|
133 |
+
completion_tokens = litellm.token_counter(
|
134 |
+
model=model, messages=[{"role": "user", "content": output_text}]
|
135 |
+
)
|
136 |
+
(
|
137 |
+
prompt_tokens_cost_usd_dollar,
|
138 |
+
completion_tokens_cost_usd_dollar,
|
139 |
+
) = litellm.cost_per_token(
|
140 |
+
model=model,
|
141 |
+
prompt_tokens=prompt_tokens,
|
142 |
+
completion_tokens=completion_tokens,
|
143 |
+
)
|
144 |
+
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
145 |
+
elif completion_obj:
|
146 |
+
cost = litellm.completion_cost(completion_response=completion_obj)
|
147 |
+
model = completion_obj[
|
148 |
+
"model"
|
149 |
+
] # if this throws an error try, model = completion_obj['model']
|
150 |
+
else:
|
151 |
+
raise ValueError(
|
152 |
+
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
153 |
+
)
|
154 |
+
|
155 |
+
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
156 |
+
"current_cost", 0
|
157 |
+
)
|
158 |
+
if "model_cost" in self.user_dict[user]:
|
159 |
+
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
160 |
+
"model_cost"
|
161 |
+
].get(model, 0)
|
162 |
+
else:
|
163 |
+
self.user_dict[user]["model_cost"] = {model: cost}
|
164 |
+
|
165 |
+
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
166 |
+
return {"user": self.user_dict[user]}
|
167 |
+
|
168 |
+
def get_current_cost(self, user):
|
169 |
+
return self.user_dict[user].get("current_cost", 0)
|
170 |
+
|
171 |
+
def get_model_cost(self, user):
|
172 |
+
return self.user_dict[user].get("model_cost", 0)
|
173 |
+
|
174 |
+
def is_valid_user(self, user: str) -> bool:
|
175 |
+
return user in self.user_dict
|
176 |
+
|
177 |
+
def get_users(self):
|
178 |
+
return list(self.user_dict.keys())
|
179 |
+
|
180 |
+
def reset_cost(self, user):
|
181 |
+
self.user_dict[user]["current_cost"] = 0
|
182 |
+
self.user_dict[user]["model_cost"] = {}
|
183 |
+
return {"user": self.user_dict[user]}
|
184 |
+
|
185 |
+
def reset_on_duration(self, user: str):
|
186 |
+
# Get current and creation time
|
187 |
+
last_updated_at = self.user_dict[user]["last_updated_at"]
|
188 |
+
current_time = time.time()
|
189 |
+
|
190 |
+
# Convert duration from days to seconds
|
191 |
+
duration_in_seconds = (
|
192 |
+
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
|
193 |
+
)
|
194 |
+
|
195 |
+
# Check if duration has elapsed
|
196 |
+
if current_time - last_updated_at >= duration_in_seconds:
|
197 |
+
# Reset cost if duration has elapsed and update the creation time
|
198 |
+
self.reset_cost(user)
|
199 |
+
self.user_dict[user]["last_updated_at"] = current_time
|
200 |
+
self._save_data_thread() # Save the data
|
201 |
+
|
202 |
+
def update_budget_all_users(self):
|
203 |
+
for user in self.get_users():
|
204 |
+
if "duration" in self.user_dict[user]:
|
205 |
+
self.reset_on_duration(user)
|
206 |
+
|
207 |
+
def _save_data_thread(self):
|
208 |
+
thread = threading.Thread(
|
209 |
+
target=self.save_data
|
210 |
+
) # [Non-Blocking]: saves data without blocking execution
|
211 |
+
thread.start()
|
212 |
+
|
213 |
+
def save_data(self):
|
214 |
+
if self.client_type == "local":
|
215 |
+
import json
|
216 |
+
|
217 |
+
# save the user dict
|
218 |
+
with open("user_cost.json", "w") as json_file:
|
219 |
+
json.dump(
|
220 |
+
self.user_dict, json_file, indent=4
|
221 |
+
) # Indent for pretty formatting
|
222 |
+
return {"status": "success"}
|
223 |
+
elif self.client_type == "hosted":
|
224 |
+
url = self.api_base + "/set_budget"
|
225 |
+
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
226 |
+
response = litellm.module_level_client.post(
|
227 |
+
url, headers=self.headers, json=data
|
228 |
+
)
|
229 |
+
response = response.json()
|
230 |
+
return response
|
litellm/caching/Readme.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Caching on LiteLLM
|
2 |
+
|
3 |
+
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
|
4 |
+
|
5 |
+
The following caching mechanisms are supported:
|
6 |
+
|
7 |
+
1. **RedisCache**
|
8 |
+
2. **RedisSemanticCache**
|
9 |
+
3. **QdrantSemanticCache**
|
10 |
+
4. **InMemoryCache**
|
11 |
+
5. **DiskCache**
|
12 |
+
6. **S3Cache**
|
13 |
+
7. **DualCache** (updates both Redis and an in-memory cache simultaneously)
|
14 |
+
|
15 |
+
## Folder Structure
|
16 |
+
|
17 |
+
```
|
18 |
+
litellm/caching/
|
19 |
+
├── base_cache.py
|
20 |
+
├── caching.py
|
21 |
+
├── caching_handler.py
|
22 |
+
├── disk_cache.py
|
23 |
+
├── dual_cache.py
|
24 |
+
├── in_memory_cache.py
|
25 |
+
├── qdrant_semantic_cache.py
|
26 |
+
├── redis_cache.py
|
27 |
+
├── redis_semantic_cache.py
|
28 |
+
├── s3_cache.py
|
29 |
+
```
|
30 |
+
|
31 |
+
## Documentation
|
32 |
+
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
|
33 |
+
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
litellm/caching/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .caching import Cache, LiteLLMCacheType
|
2 |
+
from .disk_cache import DiskCache
|
3 |
+
from .dual_cache import DualCache
|
4 |
+
from .in_memory_cache import InMemoryCache
|
5 |
+
from .qdrant_semantic_cache import QdrantSemanticCache
|
6 |
+
from .redis_cache import RedisCache
|
7 |
+
from .redis_cluster_cache import RedisClusterCache
|
8 |
+
from .redis_semantic_cache import RedisSemanticCache
|
9 |
+
from .s3_cache import S3Cache
|
litellm/caching/_internal_lru_cache.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Callable, Optional, TypeVar
|
3 |
+
|
4 |
+
T = TypeVar("T")
|
5 |
+
|
6 |
+
|
7 |
+
def lru_cache_wrapper(
|
8 |
+
maxsize: Optional[int] = None,
|
9 |
+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
10 |
+
"""
|
11 |
+
Wrapper for lru_cache that caches success and exceptions
|
12 |
+
"""
|
13 |
+
|
14 |
+
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
15 |
+
@lru_cache(maxsize=maxsize)
|
16 |
+
def wrapper(*args, **kwargs):
|
17 |
+
try:
|
18 |
+
return ("success", f(*args, **kwargs))
|
19 |
+
except Exception as e:
|
20 |
+
return ("error", e)
|
21 |
+
|
22 |
+
def wrapped(*args, **kwargs):
|
23 |
+
result = wrapper(*args, **kwargs)
|
24 |
+
if result[0] == "error":
|
25 |
+
raise result[1]
|
26 |
+
return result[1]
|
27 |
+
|
28 |
+
return wrapped
|
29 |
+
|
30 |
+
return decorator
|
litellm/caching/base_cache.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base Cache implementation. All cache implementations should inherit from this class.
|
3 |
+
|
4 |
+
Has 4 methods:
|
5 |
+
- set_cache
|
6 |
+
- get_cache
|
7 |
+
- async_set_cache
|
8 |
+
- async_get_cache
|
9 |
+
"""
|
10 |
+
|
11 |
+
from abc import ABC, abstractmethod
|
12 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from opentelemetry.trace import Span as _Span
|
16 |
+
|
17 |
+
Span = Union[_Span, Any]
|
18 |
+
else:
|
19 |
+
Span = Any
|
20 |
+
|
21 |
+
|
22 |
+
class BaseCache(ABC):
|
23 |
+
def __init__(self, default_ttl: int = 60):
|
24 |
+
self.default_ttl = default_ttl
|
25 |
+
|
26 |
+
def get_ttl(self, **kwargs) -> Optional[int]:
|
27 |
+
kwargs_ttl: Optional[int] = kwargs.get("ttl")
|
28 |
+
if kwargs_ttl is not None:
|
29 |
+
try:
|
30 |
+
return int(kwargs_ttl)
|
31 |
+
except ValueError:
|
32 |
+
return self.default_ttl
|
33 |
+
return self.default_ttl
|
34 |
+
|
35 |
+
def set_cache(self, key, value, **kwargs):
|
36 |
+
raise NotImplementedError
|
37 |
+
|
38 |
+
async def async_set_cache(self, key, value, **kwargs):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
@abstractmethod
|
42 |
+
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
43 |
+
pass
|
44 |
+
|
45 |
+
def get_cache(self, key, **kwargs):
|
46 |
+
raise NotImplementedError
|
47 |
+
|
48 |
+
async def async_get_cache(self, key, **kwargs):
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
async def batch_cache_write(self, key, value, **kwargs):
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
async def disconnect(self):
|
55 |
+
raise NotImplementedError
|
litellm/caching/caching.py
ADDED
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +-----------------------------------------------+
|
2 |
+
# | |
|
3 |
+
# | Give Feedback / Get Help |
|
4 |
+
# | https://github.com/BerriAI/litellm/issues/new |
|
5 |
+
# | |
|
6 |
+
# +-----------------------------------------------+
|
7 |
+
#
|
8 |
+
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
9 |
+
|
10 |
+
import ast
|
11 |
+
import hashlib
|
12 |
+
import json
|
13 |
+
import time
|
14 |
+
import traceback
|
15 |
+
from enum import Enum
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
from pydantic import BaseModel
|
19 |
+
|
20 |
+
import litellm
|
21 |
+
from litellm._logging import verbose_logger
|
22 |
+
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
23 |
+
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
24 |
+
from litellm.types.caching import *
|
25 |
+
from litellm.types.utils import EmbeddingResponse, all_litellm_params
|
26 |
+
|
27 |
+
from .base_cache import BaseCache
|
28 |
+
from .disk_cache import DiskCache
|
29 |
+
from .dual_cache import DualCache # noqa
|
30 |
+
from .in_memory_cache import InMemoryCache
|
31 |
+
from .qdrant_semantic_cache import QdrantSemanticCache
|
32 |
+
from .redis_cache import RedisCache
|
33 |
+
from .redis_cluster_cache import RedisClusterCache
|
34 |
+
from .redis_semantic_cache import RedisSemanticCache
|
35 |
+
from .s3_cache import S3Cache
|
36 |
+
|
37 |
+
|
38 |
+
def print_verbose(print_statement):
|
39 |
+
try:
|
40 |
+
verbose_logger.debug(print_statement)
|
41 |
+
if litellm.set_verbose:
|
42 |
+
print(print_statement) # noqa
|
43 |
+
except Exception:
|
44 |
+
pass
|
45 |
+
|
46 |
+
|
47 |
+
class CacheMode(str, Enum):
|
48 |
+
default_on = "default_on"
|
49 |
+
default_off = "default_off"
|
50 |
+
|
51 |
+
|
52 |
+
#### LiteLLM.Completion / Embedding Cache ####
|
53 |
+
class Cache:
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
57 |
+
mode: Optional[
|
58 |
+
CacheMode
|
59 |
+
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
60 |
+
host: Optional[str] = None,
|
61 |
+
port: Optional[str] = None,
|
62 |
+
password: Optional[str] = None,
|
63 |
+
namespace: Optional[str] = None,
|
64 |
+
ttl: Optional[float] = None,
|
65 |
+
default_in_memory_ttl: Optional[float] = None,
|
66 |
+
default_in_redis_ttl: Optional[float] = None,
|
67 |
+
similarity_threshold: Optional[float] = None,
|
68 |
+
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
69 |
+
"completion",
|
70 |
+
"acompletion",
|
71 |
+
"embedding",
|
72 |
+
"aembedding",
|
73 |
+
"atranscription",
|
74 |
+
"transcription",
|
75 |
+
"atext_completion",
|
76 |
+
"text_completion",
|
77 |
+
"arerank",
|
78 |
+
"rerank",
|
79 |
+
],
|
80 |
+
# s3 Bucket, boto3 configuration
|
81 |
+
s3_bucket_name: Optional[str] = None,
|
82 |
+
s3_region_name: Optional[str] = None,
|
83 |
+
s3_api_version: Optional[str] = None,
|
84 |
+
s3_use_ssl: Optional[bool] = True,
|
85 |
+
s3_verify: Optional[Union[bool, str]] = None,
|
86 |
+
s3_endpoint_url: Optional[str] = None,
|
87 |
+
s3_aws_access_key_id: Optional[str] = None,
|
88 |
+
s3_aws_secret_access_key: Optional[str] = None,
|
89 |
+
s3_aws_session_token: Optional[str] = None,
|
90 |
+
s3_config: Optional[Any] = None,
|
91 |
+
s3_path: Optional[str] = None,
|
92 |
+
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
93 |
+
redis_semantic_cache_index_name: Optional[str] = None,
|
94 |
+
redis_flush_size: Optional[int] = None,
|
95 |
+
redis_startup_nodes: Optional[List] = None,
|
96 |
+
disk_cache_dir: Optional[str] = None,
|
97 |
+
qdrant_api_base: Optional[str] = None,
|
98 |
+
qdrant_api_key: Optional[str] = None,
|
99 |
+
qdrant_collection_name: Optional[str] = None,
|
100 |
+
qdrant_quantization_config: Optional[str] = None,
|
101 |
+
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
"""
|
105 |
+
Initializes the cache based on the given type.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
109 |
+
|
110 |
+
# Redis Cache Args
|
111 |
+
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
112 |
+
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
113 |
+
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
114 |
+
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
115 |
+
ttl (float, optional): The ttl for the Redis cache
|
116 |
+
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
117 |
+
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
118 |
+
|
119 |
+
# Qdrant Cache Args
|
120 |
+
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
121 |
+
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
122 |
+
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
123 |
+
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
124 |
+
|
125 |
+
# Disk Cache Args
|
126 |
+
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
127 |
+
|
128 |
+
# S3 Cache Args
|
129 |
+
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
130 |
+
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
131 |
+
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
132 |
+
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
133 |
+
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
134 |
+
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
135 |
+
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
136 |
+
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
137 |
+
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
138 |
+
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
139 |
+
|
140 |
+
# Common Cache Args
|
141 |
+
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
142 |
+
**kwargs: Additional keyword arguments for redis.Redis() cache
|
143 |
+
|
144 |
+
Raises:
|
145 |
+
ValueError: If an invalid cache type is provided.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
None. Cache is set as a litellm param
|
149 |
+
"""
|
150 |
+
if type == LiteLLMCacheType.REDIS:
|
151 |
+
if redis_startup_nodes:
|
152 |
+
self.cache: BaseCache = RedisClusterCache(
|
153 |
+
host=host,
|
154 |
+
port=port,
|
155 |
+
password=password,
|
156 |
+
redis_flush_size=redis_flush_size,
|
157 |
+
startup_nodes=redis_startup_nodes,
|
158 |
+
**kwargs,
|
159 |
+
)
|
160 |
+
else:
|
161 |
+
self.cache = RedisCache(
|
162 |
+
host=host,
|
163 |
+
port=port,
|
164 |
+
password=password,
|
165 |
+
redis_flush_size=redis_flush_size,
|
166 |
+
**kwargs,
|
167 |
+
)
|
168 |
+
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
169 |
+
self.cache = RedisSemanticCache(
|
170 |
+
host=host,
|
171 |
+
port=port,
|
172 |
+
password=password,
|
173 |
+
similarity_threshold=similarity_threshold,
|
174 |
+
embedding_model=redis_semantic_cache_embedding_model,
|
175 |
+
index_name=redis_semantic_cache_index_name,
|
176 |
+
**kwargs,
|
177 |
+
)
|
178 |
+
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
179 |
+
self.cache = QdrantSemanticCache(
|
180 |
+
qdrant_api_base=qdrant_api_base,
|
181 |
+
qdrant_api_key=qdrant_api_key,
|
182 |
+
collection_name=qdrant_collection_name,
|
183 |
+
similarity_threshold=similarity_threshold,
|
184 |
+
quantization_config=qdrant_quantization_config,
|
185 |
+
embedding_model=qdrant_semantic_cache_embedding_model,
|
186 |
+
)
|
187 |
+
elif type == LiteLLMCacheType.LOCAL:
|
188 |
+
self.cache = InMemoryCache()
|
189 |
+
elif type == LiteLLMCacheType.S3:
|
190 |
+
self.cache = S3Cache(
|
191 |
+
s3_bucket_name=s3_bucket_name,
|
192 |
+
s3_region_name=s3_region_name,
|
193 |
+
s3_api_version=s3_api_version,
|
194 |
+
s3_use_ssl=s3_use_ssl,
|
195 |
+
s3_verify=s3_verify,
|
196 |
+
s3_endpoint_url=s3_endpoint_url,
|
197 |
+
s3_aws_access_key_id=s3_aws_access_key_id,
|
198 |
+
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
199 |
+
s3_aws_session_token=s3_aws_session_token,
|
200 |
+
s3_config=s3_config,
|
201 |
+
s3_path=s3_path,
|
202 |
+
**kwargs,
|
203 |
+
)
|
204 |
+
elif type == LiteLLMCacheType.DISK:
|
205 |
+
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
206 |
+
if "cache" not in litellm.input_callback:
|
207 |
+
litellm.input_callback.append("cache")
|
208 |
+
if "cache" not in litellm.success_callback:
|
209 |
+
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
210 |
+
if "cache" not in litellm._async_success_callback:
|
211 |
+
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
212 |
+
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
213 |
+
self.type = type
|
214 |
+
self.namespace = namespace
|
215 |
+
self.redis_flush_size = redis_flush_size
|
216 |
+
self.ttl = ttl
|
217 |
+
self.mode: CacheMode = mode or CacheMode.default_on
|
218 |
+
|
219 |
+
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
220 |
+
self.ttl = default_in_memory_ttl
|
221 |
+
|
222 |
+
if (
|
223 |
+
self.type == LiteLLMCacheType.REDIS
|
224 |
+
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
225 |
+
) and default_in_redis_ttl is not None:
|
226 |
+
self.ttl = default_in_redis_ttl
|
227 |
+
|
228 |
+
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
229 |
+
self.cache.namespace = self.namespace
|
230 |
+
|
231 |
+
def get_cache_key(self, **kwargs) -> str:
|
232 |
+
"""
|
233 |
+
Get the cache key for the given arguments.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
**kwargs: kwargs to litellm.completion() or embedding()
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
240 |
+
"""
|
241 |
+
cache_key = ""
|
242 |
+
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
243 |
+
|
244 |
+
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
245 |
+
if preset_cache_key is not None:
|
246 |
+
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
247 |
+
return preset_cache_key
|
248 |
+
|
249 |
+
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
250 |
+
litellm_param_kwargs = all_litellm_params
|
251 |
+
for param in kwargs:
|
252 |
+
if param in combined_kwargs:
|
253 |
+
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
254 |
+
if param_value is not None:
|
255 |
+
cache_key += f"{str(param)}: {str(param_value)}"
|
256 |
+
elif (
|
257 |
+
param not in litellm_param_kwargs
|
258 |
+
): # check if user passed in optional param - e.g. top_k
|
259 |
+
if (
|
260 |
+
litellm.enable_caching_on_provider_specific_optional_params is True
|
261 |
+
): # feature flagged for now
|
262 |
+
if kwargs[param] is None:
|
263 |
+
continue # ignore None params
|
264 |
+
param_value = kwargs[param]
|
265 |
+
cache_key += f"{str(param)}: {str(param_value)}"
|
266 |
+
|
267 |
+
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
268 |
+
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
269 |
+
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
270 |
+
self._set_preset_cache_key_in_kwargs(
|
271 |
+
preset_cache_key=hashed_cache_key, **kwargs
|
272 |
+
)
|
273 |
+
return hashed_cache_key
|
274 |
+
|
275 |
+
def _get_param_value(
|
276 |
+
self,
|
277 |
+
param: str,
|
278 |
+
kwargs: dict,
|
279 |
+
) -> Optional[str]:
|
280 |
+
"""
|
281 |
+
Get the value for the given param from kwargs
|
282 |
+
"""
|
283 |
+
if param == "model":
|
284 |
+
return self._get_model_param_value(kwargs)
|
285 |
+
elif param == "file":
|
286 |
+
return self._get_file_param_value(kwargs)
|
287 |
+
return kwargs[param]
|
288 |
+
|
289 |
+
def _get_model_param_value(self, kwargs: dict) -> str:
|
290 |
+
"""
|
291 |
+
Handles getting the value for the 'model' param from kwargs
|
292 |
+
|
293 |
+
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
294 |
+
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
295 |
+
3. Else use the `model` passed in kwargs
|
296 |
+
"""
|
297 |
+
metadata: Dict = kwargs.get("metadata", {}) or {}
|
298 |
+
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
299 |
+
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
300 |
+
model_group: Optional[str] = metadata.get(
|
301 |
+
"model_group"
|
302 |
+
) or metadata_in_litellm_params.get("model_group")
|
303 |
+
caching_group = self._get_caching_group(metadata, model_group)
|
304 |
+
return caching_group or model_group or kwargs["model"]
|
305 |
+
|
306 |
+
def _get_caching_group(
|
307 |
+
self, metadata: dict, model_group: Optional[str]
|
308 |
+
) -> Optional[str]:
|
309 |
+
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
310 |
+
if caching_groups:
|
311 |
+
for group in caching_groups:
|
312 |
+
if model_group in group:
|
313 |
+
return str(group)
|
314 |
+
return None
|
315 |
+
|
316 |
+
def _get_file_param_value(self, kwargs: dict) -> str:
|
317 |
+
"""
|
318 |
+
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
319 |
+
"""
|
320 |
+
file = kwargs.get("file")
|
321 |
+
metadata = kwargs.get("metadata", {})
|
322 |
+
litellm_params = kwargs.get("litellm_params", {})
|
323 |
+
return (
|
324 |
+
metadata.get("file_checksum")
|
325 |
+
or getattr(file, "name", None)
|
326 |
+
or metadata.get("file_name")
|
327 |
+
or litellm_params.get("file_name")
|
328 |
+
)
|
329 |
+
|
330 |
+
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
331 |
+
"""
|
332 |
+
Get the preset cache key from kwargs["litellm_params"]
|
333 |
+
|
334 |
+
We use _get_preset_cache_keys for two reasons
|
335 |
+
|
336 |
+
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
337 |
+
2. avoid doing duplicate / repeated work
|
338 |
+
"""
|
339 |
+
if kwargs:
|
340 |
+
if "litellm_params" in kwargs:
|
341 |
+
return kwargs["litellm_params"].get("preset_cache_key", None)
|
342 |
+
return None
|
343 |
+
|
344 |
+
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
345 |
+
"""
|
346 |
+
Set the calculated cache key in kwargs
|
347 |
+
|
348 |
+
This is used to avoid doing duplicate / repeated work
|
349 |
+
|
350 |
+
Placed in kwargs["litellm_params"]
|
351 |
+
"""
|
352 |
+
if kwargs:
|
353 |
+
if "litellm_params" in kwargs:
|
354 |
+
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
355 |
+
|
356 |
+
@staticmethod
|
357 |
+
def _get_hashed_cache_key(cache_key: str) -> str:
|
358 |
+
"""
|
359 |
+
Get the hashed cache key for the given cache key.
|
360 |
+
|
361 |
+
Use hashlib to create a sha256 hash of the cache key
|
362 |
+
|
363 |
+
Args:
|
364 |
+
cache_key (str): The cache key to hash.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
str: The hashed cache key.
|
368 |
+
"""
|
369 |
+
hash_object = hashlib.sha256(cache_key.encode())
|
370 |
+
# Hexadecimal representation of the hash
|
371 |
+
hash_hex = hash_object.hexdigest()
|
372 |
+
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
373 |
+
return hash_hex
|
374 |
+
|
375 |
+
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
376 |
+
"""
|
377 |
+
If a redis namespace is provided, add it to the cache key
|
378 |
+
|
379 |
+
Args:
|
380 |
+
hash_hex (str): The hashed cache key.
|
381 |
+
**kwargs: Additional keyword arguments.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
str: The final hashed cache key with the redis namespace.
|
385 |
+
"""
|
386 |
+
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
387 |
+
namespace = (
|
388 |
+
dynamic_cache_control.get("namespace")
|
389 |
+
or kwargs.get("metadata", {}).get("redis_namespace")
|
390 |
+
or self.namespace
|
391 |
+
)
|
392 |
+
if namespace:
|
393 |
+
hash_hex = f"{namespace}:{hash_hex}"
|
394 |
+
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
395 |
+
return hash_hex
|
396 |
+
|
397 |
+
def generate_streaming_content(self, content):
|
398 |
+
chunk_size = 5 # Adjust the chunk size as needed
|
399 |
+
for i in range(0, len(content), chunk_size):
|
400 |
+
yield {
|
401 |
+
"choices": [
|
402 |
+
{
|
403 |
+
"delta": {
|
404 |
+
"role": "assistant",
|
405 |
+
"content": content[i : i + chunk_size],
|
406 |
+
}
|
407 |
+
}
|
408 |
+
]
|
409 |
+
}
|
410 |
+
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
411 |
+
|
412 |
+
def _get_cache_logic(
|
413 |
+
self,
|
414 |
+
cached_result: Optional[Any],
|
415 |
+
max_age: Optional[float],
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Common get cache logic across sync + async implementations
|
419 |
+
"""
|
420 |
+
# Check if a timestamp was stored with the cached response
|
421 |
+
if (
|
422 |
+
cached_result is not None
|
423 |
+
and isinstance(cached_result, dict)
|
424 |
+
and "timestamp" in cached_result
|
425 |
+
):
|
426 |
+
timestamp = cached_result["timestamp"]
|
427 |
+
current_time = time.time()
|
428 |
+
|
429 |
+
# Calculate age of the cached response
|
430 |
+
response_age = current_time - timestamp
|
431 |
+
|
432 |
+
# Check if the cached response is older than the max-age
|
433 |
+
if max_age is not None and response_age > max_age:
|
434 |
+
return None # Cached response is too old
|
435 |
+
|
436 |
+
# If the response is fresh, or there's no max-age requirement, return the cached response
|
437 |
+
# cached_response is in `b{} convert it to ModelResponse
|
438 |
+
cached_response = cached_result.get("response")
|
439 |
+
try:
|
440 |
+
if isinstance(cached_response, dict):
|
441 |
+
pass
|
442 |
+
else:
|
443 |
+
cached_response = json.loads(
|
444 |
+
cached_response # type: ignore
|
445 |
+
) # Convert string to dictionary
|
446 |
+
except Exception:
|
447 |
+
cached_response = ast.literal_eval(cached_response) # type: ignore
|
448 |
+
return cached_response
|
449 |
+
return cached_result
|
450 |
+
|
451 |
+
def get_cache(self, **kwargs):
|
452 |
+
"""
|
453 |
+
Retrieves the cached result for the given arguments.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
*args: args to litellm.completion() or embedding()
|
457 |
+
**kwargs: kwargs to litellm.completion() or embedding()
|
458 |
+
|
459 |
+
Returns:
|
460 |
+
The cached result if it exists, otherwise None.
|
461 |
+
"""
|
462 |
+
try: # never block execution
|
463 |
+
if self.should_use_cache(**kwargs) is not True:
|
464 |
+
return
|
465 |
+
messages = kwargs.get("messages", [])
|
466 |
+
if "cache_key" in kwargs:
|
467 |
+
cache_key = kwargs["cache_key"]
|
468 |
+
else:
|
469 |
+
cache_key = self.get_cache_key(**kwargs)
|
470 |
+
if cache_key is not None:
|
471 |
+
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
472 |
+
max_age = (
|
473 |
+
cache_control_args.get("s-maxage")
|
474 |
+
or cache_control_args.get("s-max-age")
|
475 |
+
or float("inf")
|
476 |
+
)
|
477 |
+
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
478 |
+
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
479 |
+
return self._get_cache_logic(
|
480 |
+
cached_result=cached_result, max_age=max_age
|
481 |
+
)
|
482 |
+
except Exception:
|
483 |
+
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
484 |
+
return None
|
485 |
+
|
486 |
+
async def async_get_cache(self, **kwargs):
|
487 |
+
"""
|
488 |
+
Async get cache implementation.
|
489 |
+
|
490 |
+
Used for embedding calls in async wrapper
|
491 |
+
"""
|
492 |
+
|
493 |
+
try: # never block execution
|
494 |
+
if self.should_use_cache(**kwargs) is not True:
|
495 |
+
return
|
496 |
+
|
497 |
+
kwargs.get("messages", [])
|
498 |
+
if "cache_key" in kwargs:
|
499 |
+
cache_key = kwargs["cache_key"]
|
500 |
+
else:
|
501 |
+
cache_key = self.get_cache_key(**kwargs)
|
502 |
+
if cache_key is not None:
|
503 |
+
cache_control_args = kwargs.get("cache", {})
|
504 |
+
max_age = cache_control_args.get(
|
505 |
+
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
506 |
+
)
|
507 |
+
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
|
508 |
+
return self._get_cache_logic(
|
509 |
+
cached_result=cached_result, max_age=max_age
|
510 |
+
)
|
511 |
+
except Exception:
|
512 |
+
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
513 |
+
return None
|
514 |
+
|
515 |
+
def _add_cache_logic(self, result, **kwargs):
|
516 |
+
"""
|
517 |
+
Common implementation across sync + async add_cache functions
|
518 |
+
"""
|
519 |
+
try:
|
520 |
+
if "cache_key" in kwargs:
|
521 |
+
cache_key = kwargs["cache_key"]
|
522 |
+
else:
|
523 |
+
cache_key = self.get_cache_key(**kwargs)
|
524 |
+
if cache_key is not None:
|
525 |
+
if isinstance(result, BaseModel):
|
526 |
+
result = result.model_dump_json()
|
527 |
+
|
528 |
+
## DEFAULT TTL ##
|
529 |
+
if self.ttl is not None:
|
530 |
+
kwargs["ttl"] = self.ttl
|
531 |
+
## Get Cache-Controls ##
|
532 |
+
_cache_kwargs = kwargs.get("cache", None)
|
533 |
+
if isinstance(_cache_kwargs, dict):
|
534 |
+
for k, v in _cache_kwargs.items():
|
535 |
+
if k == "ttl":
|
536 |
+
kwargs["ttl"] = v
|
537 |
+
|
538 |
+
cached_data = {"timestamp": time.time(), "response": result}
|
539 |
+
return cache_key, cached_data, kwargs
|
540 |
+
else:
|
541 |
+
raise Exception("cache key is None")
|
542 |
+
except Exception as e:
|
543 |
+
raise e
|
544 |
+
|
545 |
+
def add_cache(self, result, **kwargs):
|
546 |
+
"""
|
547 |
+
Adds a result to the cache.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
*args: args to litellm.completion() or embedding()
|
551 |
+
**kwargs: kwargs to litellm.completion() or embedding()
|
552 |
+
|
553 |
+
Returns:
|
554 |
+
None
|
555 |
+
"""
|
556 |
+
try:
|
557 |
+
if self.should_use_cache(**kwargs) is not True:
|
558 |
+
return
|
559 |
+
cache_key, cached_data, kwargs = self._add_cache_logic(
|
560 |
+
result=result, **kwargs
|
561 |
+
)
|
562 |
+
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
563 |
+
except Exception as e:
|
564 |
+
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
565 |
+
|
566 |
+
async def async_add_cache(self, result, **kwargs):
|
567 |
+
"""
|
568 |
+
Async implementation of add_cache
|
569 |
+
"""
|
570 |
+
try:
|
571 |
+
if self.should_use_cache(**kwargs) is not True:
|
572 |
+
return
|
573 |
+
if self.type == "redis" and self.redis_flush_size is not None:
|
574 |
+
# high traffic - fill in results in memory and then flush
|
575 |
+
await self.batch_cache_write(result, **kwargs)
|
576 |
+
else:
|
577 |
+
cache_key, cached_data, kwargs = self._add_cache_logic(
|
578 |
+
result=result, **kwargs
|
579 |
+
)
|
580 |
+
|
581 |
+
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
582 |
+
except Exception as e:
|
583 |
+
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
584 |
+
|
585 |
+
def add_embedding_response_to_cache(
|
586 |
+
self,
|
587 |
+
result: EmbeddingResponse,
|
588 |
+
input: str,
|
589 |
+
kwargs: dict,
|
590 |
+
idx_in_result_data: int = 0,
|
591 |
+
) -> Tuple[str, dict, dict]:
|
592 |
+
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
|
593 |
+
kwargs["cache_key"] = preset_cache_key
|
594 |
+
embedding_response = result.data[idx_in_result_data]
|
595 |
+
cache_key, cached_data, kwargs = self._add_cache_logic(
|
596 |
+
result=embedding_response,
|
597 |
+
**kwargs,
|
598 |
+
)
|
599 |
+
return cache_key, cached_data, kwargs
|
600 |
+
|
601 |
+
async def async_add_cache_pipeline(self, result, **kwargs):
|
602 |
+
"""
|
603 |
+
Async implementation of add_cache for Embedding calls
|
604 |
+
|
605 |
+
Does a bulk write, to prevent using too many clients
|
606 |
+
"""
|
607 |
+
try:
|
608 |
+
if self.should_use_cache(**kwargs) is not True:
|
609 |
+
return
|
610 |
+
|
611 |
+
# set default ttl if not set
|
612 |
+
if self.ttl is not None:
|
613 |
+
kwargs["ttl"] = self.ttl
|
614 |
+
|
615 |
+
cache_list = []
|
616 |
+
if isinstance(kwargs["input"], list):
|
617 |
+
for idx, i in enumerate(kwargs["input"]):
|
618 |
+
(
|
619 |
+
cache_key,
|
620 |
+
cached_data,
|
621 |
+
kwargs,
|
622 |
+
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
|
623 |
+
cache_list.append((cache_key, cached_data))
|
624 |
+
elif isinstance(kwargs["input"], str):
|
625 |
+
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
|
626 |
+
result, kwargs["input"], kwargs
|
627 |
+
)
|
628 |
+
cache_list.append((cache_key, cached_data))
|
629 |
+
|
630 |
+
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
631 |
+
# if async_set_cache_pipeline:
|
632 |
+
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
|
633 |
+
# else:
|
634 |
+
# tasks = []
|
635 |
+
# for val in cache_list:
|
636 |
+
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
637 |
+
# await asyncio.gather(*tasks)
|
638 |
+
except Exception as e:
|
639 |
+
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
640 |
+
|
641 |
+
def should_use_cache(self, **kwargs):
|
642 |
+
"""
|
643 |
+
Returns true if we should use the cache for LLM API calls
|
644 |
+
|
645 |
+
If cache is default_on then this is True
|
646 |
+
If cache is default_off then this is only true when user has opted in to use cache
|
647 |
+
"""
|
648 |
+
if self.mode == CacheMode.default_on:
|
649 |
+
return True
|
650 |
+
|
651 |
+
# when mode == default_off -> Cache is opt in only
|
652 |
+
_cache = kwargs.get("cache", None)
|
653 |
+
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
654 |
+
if _cache and isinstance(_cache, dict):
|
655 |
+
if _cache.get("use-cache", False) is True:
|
656 |
+
return True
|
657 |
+
return False
|
658 |
+
|
659 |
+
async def batch_cache_write(self, result, **kwargs):
|
660 |
+
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
661 |
+
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
662 |
+
|
663 |
+
async def ping(self):
|
664 |
+
cache_ping = getattr(self.cache, "ping")
|
665 |
+
if cache_ping:
|
666 |
+
return await cache_ping()
|
667 |
+
return None
|
668 |
+
|
669 |
+
async def delete_cache_keys(self, keys):
|
670 |
+
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
671 |
+
if cache_delete_cache_keys:
|
672 |
+
return await cache_delete_cache_keys(keys)
|
673 |
+
return None
|
674 |
+
|
675 |
+
async def disconnect(self):
|
676 |
+
if hasattr(self.cache, "disconnect"):
|
677 |
+
await self.cache.disconnect()
|
678 |
+
|
679 |
+
def _supports_async(self) -> bool:
|
680 |
+
"""
|
681 |
+
Internal method to check if the cache type supports async get/set operations
|
682 |
+
|
683 |
+
Only S3 Cache Does NOT support async operations
|
684 |
+
|
685 |
+
"""
|
686 |
+
if self.type and self.type == LiteLLMCacheType.S3:
|
687 |
+
return False
|
688 |
+
return True
|
689 |
+
|
690 |
+
|
691 |
+
def enable_cache(
|
692 |
+
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
693 |
+
host: Optional[str] = None,
|
694 |
+
port: Optional[str] = None,
|
695 |
+
password: Optional[str] = None,
|
696 |
+
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
697 |
+
"completion",
|
698 |
+
"acompletion",
|
699 |
+
"embedding",
|
700 |
+
"aembedding",
|
701 |
+
"atranscription",
|
702 |
+
"transcription",
|
703 |
+
"atext_completion",
|
704 |
+
"text_completion",
|
705 |
+
"arerank",
|
706 |
+
"rerank",
|
707 |
+
],
|
708 |
+
**kwargs,
|
709 |
+
):
|
710 |
+
"""
|
711 |
+
Enable cache with the specified configuration.
|
712 |
+
|
713 |
+
Args:
|
714 |
+
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
715 |
+
host (Optional[str]): The host address of the cache server. Defaults to None.
|
716 |
+
port (Optional[str]): The port number of the cache server. Defaults to None.
|
717 |
+
password (Optional[str]): The password for the cache server. Defaults to None.
|
718 |
+
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
719 |
+
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
720 |
+
**kwargs: Additional keyword arguments.
|
721 |
+
|
722 |
+
Returns:
|
723 |
+
None
|
724 |
+
|
725 |
+
Raises:
|
726 |
+
None
|
727 |
+
"""
|
728 |
+
print_verbose("LiteLLM: Enabling Cache")
|
729 |
+
if "cache" not in litellm.input_callback:
|
730 |
+
litellm.input_callback.append("cache")
|
731 |
+
if "cache" not in litellm.success_callback:
|
732 |
+
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
733 |
+
if "cache" not in litellm._async_success_callback:
|
734 |
+
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
735 |
+
|
736 |
+
if litellm.cache is None:
|
737 |
+
litellm.cache = Cache(
|
738 |
+
type=type,
|
739 |
+
host=host,
|
740 |
+
port=port,
|
741 |
+
password=password,
|
742 |
+
supported_call_types=supported_call_types,
|
743 |
+
**kwargs,
|
744 |
+
)
|
745 |
+
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
746 |
+
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
747 |
+
|
748 |
+
|
749 |
+
def update_cache(
|
750 |
+
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
751 |
+
host: Optional[str] = None,
|
752 |
+
port: Optional[str] = None,
|
753 |
+
password: Optional[str] = None,
|
754 |
+
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
755 |
+
"completion",
|
756 |
+
"acompletion",
|
757 |
+
"embedding",
|
758 |
+
"aembedding",
|
759 |
+
"atranscription",
|
760 |
+
"transcription",
|
761 |
+
"atext_completion",
|
762 |
+
"text_completion",
|
763 |
+
"arerank",
|
764 |
+
"rerank",
|
765 |
+
],
|
766 |
+
**kwargs,
|
767 |
+
):
|
768 |
+
"""
|
769 |
+
Update the cache for LiteLLM.
|
770 |
+
|
771 |
+
Args:
|
772 |
+
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
773 |
+
host (Optional[str]): The host of the cache. Defaults to None.
|
774 |
+
port (Optional[str]): The port of the cache. Defaults to None.
|
775 |
+
password (Optional[str]): The password for the cache. Defaults to None.
|
776 |
+
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
777 |
+
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
778 |
+
**kwargs: Additional keyword arguments for the cache.
|
779 |
+
|
780 |
+
Returns:
|
781 |
+
None
|
782 |
+
|
783 |
+
"""
|
784 |
+
print_verbose("LiteLLM: Updating Cache")
|
785 |
+
litellm.cache = Cache(
|
786 |
+
type=type,
|
787 |
+
host=host,
|
788 |
+
port=port,
|
789 |
+
password=password,
|
790 |
+
supported_call_types=supported_call_types,
|
791 |
+
**kwargs,
|
792 |
+
)
|
793 |
+
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
794 |
+
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
795 |
+
|
796 |
+
|
797 |
+
def disable_cache():
|
798 |
+
"""
|
799 |
+
Disable the cache used by LiteLLM.
|
800 |
+
|
801 |
+
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
802 |
+
|
803 |
+
Parameters:
|
804 |
+
None
|
805 |
+
|
806 |
+
Returns:
|
807 |
+
None
|
808 |
+
"""
|
809 |
+
from contextlib import suppress
|
810 |
+
|
811 |
+
print_verbose("LiteLLM: Disabling Cache")
|
812 |
+
with suppress(ValueError):
|
813 |
+
litellm.input_callback.remove("cache")
|
814 |
+
litellm.success_callback.remove("cache")
|
815 |
+
litellm._async_success_callback.remove("cache")
|
816 |
+
|
817 |
+
litellm.cache = None
|
818 |
+
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
litellm/caching/caching_handler.py
ADDED
@@ -0,0 +1,938 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This contains LLMCachingHandler
|
3 |
+
|
4 |
+
This exposes two methods:
|
5 |
+
- async_get_cache
|
6 |
+
- async_set_cache
|
7 |
+
|
8 |
+
This file is a wrapper around caching.py
|
9 |
+
|
10 |
+
This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc)
|
11 |
+
|
12 |
+
It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup
|
13 |
+
|
14 |
+
In each method it will call the appropriate method from caching.py
|
15 |
+
"""
|
16 |
+
|
17 |
+
import asyncio
|
18 |
+
import datetime
|
19 |
+
import inspect
|
20 |
+
import threading
|
21 |
+
from typing import (
|
22 |
+
TYPE_CHECKING,
|
23 |
+
Any,
|
24 |
+
AsyncGenerator,
|
25 |
+
Callable,
|
26 |
+
Dict,
|
27 |
+
Generator,
|
28 |
+
List,
|
29 |
+
Optional,
|
30 |
+
Tuple,
|
31 |
+
Union,
|
32 |
+
)
|
33 |
+
|
34 |
+
from pydantic import BaseModel
|
35 |
+
|
36 |
+
import litellm
|
37 |
+
from litellm._logging import print_verbose, verbose_logger
|
38 |
+
from litellm.caching.caching import S3Cache
|
39 |
+
from litellm.litellm_core_utils.logging_utils import (
|
40 |
+
_assemble_complete_response_from_streaming_chunks,
|
41 |
+
)
|
42 |
+
from litellm.types.rerank import RerankResponse
|
43 |
+
from litellm.types.utils import (
|
44 |
+
CallTypes,
|
45 |
+
Embedding,
|
46 |
+
EmbeddingResponse,
|
47 |
+
ModelResponse,
|
48 |
+
TextCompletionResponse,
|
49 |
+
TranscriptionResponse,
|
50 |
+
Usage,
|
51 |
+
)
|
52 |
+
|
53 |
+
if TYPE_CHECKING:
|
54 |
+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
55 |
+
from litellm.utils import CustomStreamWrapper
|
56 |
+
else:
|
57 |
+
LiteLLMLoggingObj = Any
|
58 |
+
CustomStreamWrapper = Any
|
59 |
+
|
60 |
+
|
61 |
+
class CachingHandlerResponse(BaseModel):
|
62 |
+
"""
|
63 |
+
This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses
|
64 |
+
|
65 |
+
For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others
|
66 |
+
"""
|
67 |
+
|
68 |
+
cached_result: Optional[Any] = None
|
69 |
+
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
70 |
+
embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call
|
71 |
+
|
72 |
+
|
73 |
+
class LLMCachingHandler:
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
original_function: Callable,
|
77 |
+
request_kwargs: Dict[str, Any],
|
78 |
+
start_time: datetime.datetime,
|
79 |
+
):
|
80 |
+
self.async_streaming_chunks: List[ModelResponse] = []
|
81 |
+
self.sync_streaming_chunks: List[ModelResponse] = []
|
82 |
+
self.request_kwargs = request_kwargs
|
83 |
+
self.original_function = original_function
|
84 |
+
self.start_time = start_time
|
85 |
+
pass
|
86 |
+
|
87 |
+
async def _async_get_cache(
|
88 |
+
self,
|
89 |
+
model: str,
|
90 |
+
original_function: Callable,
|
91 |
+
logging_obj: LiteLLMLoggingObj,
|
92 |
+
start_time: datetime.datetime,
|
93 |
+
call_type: str,
|
94 |
+
kwargs: Dict[str, Any],
|
95 |
+
args: Optional[Tuple[Any, ...]] = None,
|
96 |
+
) -> CachingHandlerResponse:
|
97 |
+
"""
|
98 |
+
Internal method to get from the cache.
|
99 |
+
Handles different call types (embeddings, chat/completions, text_completion, transcription)
|
100 |
+
and accordingly returns the cached response
|
101 |
+
|
102 |
+
Args:
|
103 |
+
model: str:
|
104 |
+
original_function: Callable:
|
105 |
+
logging_obj: LiteLLMLoggingObj:
|
106 |
+
start_time: datetime.datetime:
|
107 |
+
call_type: str:
|
108 |
+
kwargs: Dict[str, Any]:
|
109 |
+
args: Optional[Tuple[Any, ...]] = None:
|
110 |
+
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
CachingHandlerResponse:
|
114 |
+
Raises:
|
115 |
+
None
|
116 |
+
"""
|
117 |
+
from litellm.utils import CustomStreamWrapper
|
118 |
+
|
119 |
+
args = args or ()
|
120 |
+
|
121 |
+
final_embedding_cached_response: Optional[EmbeddingResponse] = None
|
122 |
+
embedding_all_elements_cache_hit: bool = False
|
123 |
+
cached_result: Optional[Any] = None
|
124 |
+
if (
|
125 |
+
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
126 |
+
or kwargs.get("caching", False) is True
|
127 |
+
) and (
|
128 |
+
kwargs.get("cache", {}).get("no-cache", False) is not True
|
129 |
+
): # allow users to control returning cached responses from the completion function
|
130 |
+
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
131 |
+
original_function=original_function
|
132 |
+
):
|
133 |
+
verbose_logger.debug("Checking Async Cache")
|
134 |
+
cached_result = await self._retrieve_from_cache(
|
135 |
+
call_type=call_type,
|
136 |
+
kwargs=kwargs,
|
137 |
+
args=args,
|
138 |
+
)
|
139 |
+
|
140 |
+
if cached_result is not None and not isinstance(cached_result, list):
|
141 |
+
verbose_logger.debug("Cache Hit!")
|
142 |
+
cache_hit = True
|
143 |
+
end_time = datetime.datetime.now()
|
144 |
+
model, _, _, _ = litellm.get_llm_provider(
|
145 |
+
model=model,
|
146 |
+
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
147 |
+
api_base=kwargs.get("api_base", None),
|
148 |
+
api_key=kwargs.get("api_key", None),
|
149 |
+
)
|
150 |
+
self._update_litellm_logging_obj_environment(
|
151 |
+
logging_obj=logging_obj,
|
152 |
+
model=model,
|
153 |
+
kwargs=kwargs,
|
154 |
+
cached_result=cached_result,
|
155 |
+
is_async=True,
|
156 |
+
)
|
157 |
+
|
158 |
+
call_type = original_function.__name__
|
159 |
+
|
160 |
+
cached_result = self._convert_cached_result_to_model_response(
|
161 |
+
cached_result=cached_result,
|
162 |
+
call_type=call_type,
|
163 |
+
kwargs=kwargs,
|
164 |
+
logging_obj=logging_obj,
|
165 |
+
model=model,
|
166 |
+
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
167 |
+
args=args,
|
168 |
+
)
|
169 |
+
if kwargs.get("stream", False) is False:
|
170 |
+
# LOG SUCCESS
|
171 |
+
self._async_log_cache_hit_on_callbacks(
|
172 |
+
logging_obj=logging_obj,
|
173 |
+
cached_result=cached_result,
|
174 |
+
start_time=start_time,
|
175 |
+
end_time=end_time,
|
176 |
+
cache_hit=cache_hit,
|
177 |
+
)
|
178 |
+
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
179 |
+
**kwargs
|
180 |
+
)
|
181 |
+
if (
|
182 |
+
isinstance(cached_result, BaseModel)
|
183 |
+
or isinstance(cached_result, CustomStreamWrapper)
|
184 |
+
) and hasattr(cached_result, "_hidden_params"):
|
185 |
+
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
186 |
+
return CachingHandlerResponse(cached_result=cached_result)
|
187 |
+
elif (
|
188 |
+
call_type == CallTypes.aembedding.value
|
189 |
+
and cached_result is not None
|
190 |
+
and isinstance(cached_result, list)
|
191 |
+
and litellm.cache is not None
|
192 |
+
and not isinstance(
|
193 |
+
litellm.cache.cache, S3Cache
|
194 |
+
) # s3 doesn't support bulk writing. Exclude.
|
195 |
+
):
|
196 |
+
(
|
197 |
+
final_embedding_cached_response,
|
198 |
+
embedding_all_elements_cache_hit,
|
199 |
+
) = self._process_async_embedding_cached_response(
|
200 |
+
final_embedding_cached_response=final_embedding_cached_response,
|
201 |
+
cached_result=cached_result,
|
202 |
+
kwargs=kwargs,
|
203 |
+
logging_obj=logging_obj,
|
204 |
+
start_time=start_time,
|
205 |
+
model=model,
|
206 |
+
)
|
207 |
+
return CachingHandlerResponse(
|
208 |
+
final_embedding_cached_response=final_embedding_cached_response,
|
209 |
+
embedding_all_elements_cache_hit=embedding_all_elements_cache_hit,
|
210 |
+
)
|
211 |
+
verbose_logger.debug(f"CACHE RESULT: {cached_result}")
|
212 |
+
return CachingHandlerResponse(
|
213 |
+
cached_result=cached_result,
|
214 |
+
final_embedding_cached_response=final_embedding_cached_response,
|
215 |
+
)
|
216 |
+
|
217 |
+
def _sync_get_cache(
|
218 |
+
self,
|
219 |
+
model: str,
|
220 |
+
original_function: Callable,
|
221 |
+
logging_obj: LiteLLMLoggingObj,
|
222 |
+
start_time: datetime.datetime,
|
223 |
+
call_type: str,
|
224 |
+
kwargs: Dict[str, Any],
|
225 |
+
args: Optional[Tuple[Any, ...]] = None,
|
226 |
+
) -> CachingHandlerResponse:
|
227 |
+
from litellm.utils import CustomStreamWrapper
|
228 |
+
|
229 |
+
args = args or ()
|
230 |
+
new_kwargs = kwargs.copy()
|
231 |
+
new_kwargs.update(
|
232 |
+
convert_args_to_kwargs(
|
233 |
+
self.original_function,
|
234 |
+
args,
|
235 |
+
)
|
236 |
+
)
|
237 |
+
cached_result: Optional[Any] = None
|
238 |
+
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
239 |
+
original_function=original_function
|
240 |
+
):
|
241 |
+
print_verbose("Checking Sync Cache")
|
242 |
+
cached_result = litellm.cache.get_cache(**new_kwargs)
|
243 |
+
if cached_result is not None:
|
244 |
+
if "detail" in cached_result:
|
245 |
+
# implies an error occurred
|
246 |
+
pass
|
247 |
+
else:
|
248 |
+
call_type = original_function.__name__
|
249 |
+
cached_result = self._convert_cached_result_to_model_response(
|
250 |
+
cached_result=cached_result,
|
251 |
+
call_type=call_type,
|
252 |
+
kwargs=kwargs,
|
253 |
+
logging_obj=logging_obj,
|
254 |
+
model=model,
|
255 |
+
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
256 |
+
args=args,
|
257 |
+
)
|
258 |
+
|
259 |
+
# LOG SUCCESS
|
260 |
+
cache_hit = True
|
261 |
+
end_time = datetime.datetime.now()
|
262 |
+
(
|
263 |
+
model,
|
264 |
+
custom_llm_provider,
|
265 |
+
dynamic_api_key,
|
266 |
+
api_base,
|
267 |
+
) = litellm.get_llm_provider(
|
268 |
+
model=model or "",
|
269 |
+
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
270 |
+
api_base=kwargs.get("api_base", None),
|
271 |
+
api_key=kwargs.get("api_key", None),
|
272 |
+
)
|
273 |
+
self._update_litellm_logging_obj_environment(
|
274 |
+
logging_obj=logging_obj,
|
275 |
+
model=model,
|
276 |
+
kwargs=kwargs,
|
277 |
+
cached_result=cached_result,
|
278 |
+
is_async=False,
|
279 |
+
)
|
280 |
+
|
281 |
+
threading.Thread(
|
282 |
+
target=logging_obj.success_handler,
|
283 |
+
args=(cached_result, start_time, end_time, cache_hit),
|
284 |
+
).start()
|
285 |
+
cache_key = litellm.cache._get_preset_cache_key_from_kwargs(
|
286 |
+
**kwargs
|
287 |
+
)
|
288 |
+
if (
|
289 |
+
isinstance(cached_result, BaseModel)
|
290 |
+
or isinstance(cached_result, CustomStreamWrapper)
|
291 |
+
) and hasattr(cached_result, "_hidden_params"):
|
292 |
+
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
293 |
+
return CachingHandlerResponse(cached_result=cached_result)
|
294 |
+
return CachingHandlerResponse(cached_result=cached_result)
|
295 |
+
|
296 |
+
def _process_async_embedding_cached_response(
|
297 |
+
self,
|
298 |
+
final_embedding_cached_response: Optional[EmbeddingResponse],
|
299 |
+
cached_result: List[Optional[Dict[str, Any]]],
|
300 |
+
kwargs: Dict[str, Any],
|
301 |
+
logging_obj: LiteLLMLoggingObj,
|
302 |
+
start_time: datetime.datetime,
|
303 |
+
model: str,
|
304 |
+
) -> Tuple[Optional[EmbeddingResponse], bool]:
|
305 |
+
"""
|
306 |
+
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
307 |
+
|
308 |
+
For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others
|
309 |
+
This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
310 |
+
|
311 |
+
Args:
|
312 |
+
final_embedding_cached_response: Optional[EmbeddingResponse]:
|
313 |
+
cached_result: List[Optional[Dict[str, Any]]]:
|
314 |
+
kwargs: Dict[str, Any]:
|
315 |
+
logging_obj: LiteLLMLoggingObj:
|
316 |
+
start_time: datetime.datetime:
|
317 |
+
model: str:
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
Tuple[Optional[EmbeddingResponse], bool]:
|
321 |
+
Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit
|
322 |
+
|
323 |
+
|
324 |
+
"""
|
325 |
+
embedding_all_elements_cache_hit: bool = False
|
326 |
+
remaining_list = []
|
327 |
+
non_null_list = []
|
328 |
+
for idx, cr in enumerate(cached_result):
|
329 |
+
if cr is None:
|
330 |
+
remaining_list.append(kwargs["input"][idx])
|
331 |
+
else:
|
332 |
+
non_null_list.append((idx, cr))
|
333 |
+
original_kwargs_input = kwargs["input"]
|
334 |
+
kwargs["input"] = remaining_list
|
335 |
+
if len(non_null_list) > 0:
|
336 |
+
print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}")
|
337 |
+
final_embedding_cached_response = EmbeddingResponse(
|
338 |
+
model=kwargs.get("model"),
|
339 |
+
data=[None] * len(original_kwargs_input),
|
340 |
+
)
|
341 |
+
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
342 |
+
|
343 |
+
prompt_tokens = 0
|
344 |
+
for val in non_null_list:
|
345 |
+
idx, cr = val # (idx, cr) tuple
|
346 |
+
if cr is not None:
|
347 |
+
final_embedding_cached_response.data[idx] = Embedding(
|
348 |
+
embedding=cr["embedding"],
|
349 |
+
index=idx,
|
350 |
+
object="embedding",
|
351 |
+
)
|
352 |
+
if isinstance(original_kwargs_input[idx], str):
|
353 |
+
from litellm.utils import token_counter
|
354 |
+
|
355 |
+
prompt_tokens += token_counter(
|
356 |
+
text=original_kwargs_input[idx], count_response_tokens=True
|
357 |
+
)
|
358 |
+
## USAGE
|
359 |
+
usage = Usage(
|
360 |
+
prompt_tokens=prompt_tokens,
|
361 |
+
completion_tokens=0,
|
362 |
+
total_tokens=prompt_tokens,
|
363 |
+
)
|
364 |
+
final_embedding_cached_response.usage = usage
|
365 |
+
if len(remaining_list) == 0:
|
366 |
+
# LOG SUCCESS
|
367 |
+
cache_hit = True
|
368 |
+
embedding_all_elements_cache_hit = True
|
369 |
+
end_time = datetime.datetime.now()
|
370 |
+
(
|
371 |
+
model,
|
372 |
+
custom_llm_provider,
|
373 |
+
dynamic_api_key,
|
374 |
+
api_base,
|
375 |
+
) = litellm.get_llm_provider(
|
376 |
+
model=model,
|
377 |
+
custom_llm_provider=kwargs.get("custom_llm_provider", None),
|
378 |
+
api_base=kwargs.get("api_base", None),
|
379 |
+
api_key=kwargs.get("api_key", None),
|
380 |
+
)
|
381 |
+
|
382 |
+
self._update_litellm_logging_obj_environment(
|
383 |
+
logging_obj=logging_obj,
|
384 |
+
model=model,
|
385 |
+
kwargs=kwargs,
|
386 |
+
cached_result=final_embedding_cached_response,
|
387 |
+
is_async=True,
|
388 |
+
is_embedding=True,
|
389 |
+
)
|
390 |
+
self._async_log_cache_hit_on_callbacks(
|
391 |
+
logging_obj=logging_obj,
|
392 |
+
cached_result=final_embedding_cached_response,
|
393 |
+
start_time=start_time,
|
394 |
+
end_time=end_time,
|
395 |
+
cache_hit=cache_hit,
|
396 |
+
)
|
397 |
+
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
398 |
+
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
399 |
+
|
400 |
+
def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage:
|
401 |
+
return Usage(
|
402 |
+
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
403 |
+
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
|
404 |
+
total_tokens=usage1.total_tokens + usage2.total_tokens,
|
405 |
+
)
|
406 |
+
|
407 |
+
def _combine_cached_embedding_response_with_api_result(
|
408 |
+
self,
|
409 |
+
_caching_handler_response: CachingHandlerResponse,
|
410 |
+
embedding_response: EmbeddingResponse,
|
411 |
+
start_time: datetime.datetime,
|
412 |
+
end_time: datetime.datetime,
|
413 |
+
) -> EmbeddingResponse:
|
414 |
+
"""
|
415 |
+
Combines the cached embedding response with the API EmbeddingResponse
|
416 |
+
|
417 |
+
For caching there can be a cache hit for some of the inputs in the list and a cache miss for others
|
418 |
+
This function combines the cached embedding response with the API EmbeddingResponse
|
419 |
+
|
420 |
+
Args:
|
421 |
+
caching_handler_response: CachingHandlerResponse:
|
422 |
+
embedding_response: EmbeddingResponse:
|
423 |
+
|
424 |
+
Returns:
|
425 |
+
EmbeddingResponse:
|
426 |
+
"""
|
427 |
+
if _caching_handler_response.final_embedding_cached_response is None:
|
428 |
+
return embedding_response
|
429 |
+
|
430 |
+
idx = 0
|
431 |
+
final_data_list = []
|
432 |
+
for item in _caching_handler_response.final_embedding_cached_response.data:
|
433 |
+
if item is None and embedding_response.data is not None:
|
434 |
+
final_data_list.append(embedding_response.data[idx])
|
435 |
+
idx += 1
|
436 |
+
else:
|
437 |
+
final_data_list.append(item)
|
438 |
+
|
439 |
+
_caching_handler_response.final_embedding_cached_response.data = final_data_list
|
440 |
+
_caching_handler_response.final_embedding_cached_response._hidden_params[
|
441 |
+
"cache_hit"
|
442 |
+
] = True
|
443 |
+
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
444 |
+
end_time - start_time
|
445 |
+
).total_seconds() * 1000
|
446 |
+
|
447 |
+
## USAGE
|
448 |
+
if (
|
449 |
+
_caching_handler_response.final_embedding_cached_response.usage is not None
|
450 |
+
and embedding_response.usage is not None
|
451 |
+
):
|
452 |
+
_caching_handler_response.final_embedding_cached_response.usage = self.combine_usage(
|
453 |
+
usage1=_caching_handler_response.final_embedding_cached_response.usage,
|
454 |
+
usage2=embedding_response.usage,
|
455 |
+
)
|
456 |
+
|
457 |
+
return _caching_handler_response.final_embedding_cached_response
|
458 |
+
|
459 |
+
def _async_log_cache_hit_on_callbacks(
|
460 |
+
self,
|
461 |
+
logging_obj: LiteLLMLoggingObj,
|
462 |
+
cached_result: Any,
|
463 |
+
start_time: datetime.datetime,
|
464 |
+
end_time: datetime.datetime,
|
465 |
+
cache_hit: bool,
|
466 |
+
):
|
467 |
+
"""
|
468 |
+
Helper function to log the success of a cached result on callbacks
|
469 |
+
|
470 |
+
Args:
|
471 |
+
logging_obj (LiteLLMLoggingObj): The logging object.
|
472 |
+
cached_result: The cached result.
|
473 |
+
start_time (datetime): The start time of the operation.
|
474 |
+
end_time (datetime): The end time of the operation.
|
475 |
+
cache_hit (bool): Whether it was a cache hit.
|
476 |
+
"""
|
477 |
+
asyncio.create_task(
|
478 |
+
logging_obj.async_success_handler(
|
479 |
+
cached_result, start_time, end_time, cache_hit
|
480 |
+
)
|
481 |
+
)
|
482 |
+
threading.Thread(
|
483 |
+
target=logging_obj.success_handler,
|
484 |
+
args=(cached_result, start_time, end_time, cache_hit),
|
485 |
+
).start()
|
486 |
+
|
487 |
+
async def _retrieve_from_cache(
|
488 |
+
self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...]
|
489 |
+
) -> Optional[Any]:
|
490 |
+
"""
|
491 |
+
Internal method to
|
492 |
+
- get cache key
|
493 |
+
- check what type of cache is used - Redis, RedisSemantic, Qdrant, S3
|
494 |
+
- async get cache value
|
495 |
+
- return the cached value
|
496 |
+
|
497 |
+
Args:
|
498 |
+
call_type: str:
|
499 |
+
kwargs: Dict[str, Any]:
|
500 |
+
args: Optional[Tuple[Any, ...]] = None:
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
Optional[Any]:
|
504 |
+
Raises:
|
505 |
+
None
|
506 |
+
"""
|
507 |
+
if litellm.cache is None:
|
508 |
+
return None
|
509 |
+
|
510 |
+
new_kwargs = kwargs.copy()
|
511 |
+
new_kwargs.update(
|
512 |
+
convert_args_to_kwargs(
|
513 |
+
self.original_function,
|
514 |
+
args,
|
515 |
+
)
|
516 |
+
)
|
517 |
+
cached_result: Optional[Any] = None
|
518 |
+
if call_type == CallTypes.aembedding.value and isinstance(
|
519 |
+
new_kwargs["input"], list
|
520 |
+
):
|
521 |
+
tasks = []
|
522 |
+
for idx, i in enumerate(new_kwargs["input"]):
|
523 |
+
preset_cache_key = litellm.cache.get_cache_key(
|
524 |
+
**{**new_kwargs, "input": i}
|
525 |
+
)
|
526 |
+
tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key))
|
527 |
+
cached_result = await asyncio.gather(*tasks)
|
528 |
+
## check if cached result is None ##
|
529 |
+
if cached_result is not None and isinstance(cached_result, list):
|
530 |
+
# set cached_result to None if all elements are None
|
531 |
+
if all(result is None for result in cached_result):
|
532 |
+
cached_result = None
|
533 |
+
else:
|
534 |
+
if litellm.cache._supports_async() is True:
|
535 |
+
cached_result = await litellm.cache.async_get_cache(**new_kwargs)
|
536 |
+
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
|
537 |
+
cached_result = litellm.cache.get_cache(**new_kwargs)
|
538 |
+
return cached_result
|
539 |
+
|
540 |
+
def _convert_cached_result_to_model_response(
|
541 |
+
self,
|
542 |
+
cached_result: Any,
|
543 |
+
call_type: str,
|
544 |
+
kwargs: Dict[str, Any],
|
545 |
+
logging_obj: LiteLLMLoggingObj,
|
546 |
+
model: str,
|
547 |
+
args: Tuple[Any, ...],
|
548 |
+
custom_llm_provider: Optional[str] = None,
|
549 |
+
) -> Optional[
|
550 |
+
Union[
|
551 |
+
ModelResponse,
|
552 |
+
TextCompletionResponse,
|
553 |
+
EmbeddingResponse,
|
554 |
+
RerankResponse,
|
555 |
+
TranscriptionResponse,
|
556 |
+
CustomStreamWrapper,
|
557 |
+
]
|
558 |
+
]:
|
559 |
+
"""
|
560 |
+
Internal method to process the cached result
|
561 |
+
|
562 |
+
Checks the call type and converts the cached result to the appropriate model response object
|
563 |
+
example if call type is text_completion -> returns TextCompletionResponse object
|
564 |
+
|
565 |
+
Args:
|
566 |
+
cached_result: Any:
|
567 |
+
call_type: str:
|
568 |
+
kwargs: Dict[str, Any]:
|
569 |
+
logging_obj: LiteLLMLoggingObj:
|
570 |
+
model: str:
|
571 |
+
custom_llm_provider: Optional[str] = None:
|
572 |
+
args: Optional[Tuple[Any, ...]] = None:
|
573 |
+
|
574 |
+
Returns:
|
575 |
+
Optional[Any]:
|
576 |
+
"""
|
577 |
+
from litellm.utils import convert_to_model_response_object
|
578 |
+
|
579 |
+
if (
|
580 |
+
call_type == CallTypes.acompletion.value
|
581 |
+
or call_type == CallTypes.completion.value
|
582 |
+
) and isinstance(cached_result, dict):
|
583 |
+
if kwargs.get("stream", False) is True:
|
584 |
+
cached_result = self._convert_cached_stream_response(
|
585 |
+
cached_result=cached_result,
|
586 |
+
call_type=call_type,
|
587 |
+
logging_obj=logging_obj,
|
588 |
+
model=model,
|
589 |
+
)
|
590 |
+
else:
|
591 |
+
cached_result = convert_to_model_response_object(
|
592 |
+
response_object=cached_result,
|
593 |
+
model_response_object=ModelResponse(),
|
594 |
+
)
|
595 |
+
if (
|
596 |
+
call_type == CallTypes.atext_completion.value
|
597 |
+
or call_type == CallTypes.text_completion.value
|
598 |
+
) and isinstance(cached_result, dict):
|
599 |
+
if kwargs.get("stream", False) is True:
|
600 |
+
cached_result = self._convert_cached_stream_response(
|
601 |
+
cached_result=cached_result,
|
602 |
+
call_type=call_type,
|
603 |
+
logging_obj=logging_obj,
|
604 |
+
model=model,
|
605 |
+
)
|
606 |
+
else:
|
607 |
+
cached_result = TextCompletionResponse(**cached_result)
|
608 |
+
elif (
|
609 |
+
call_type == CallTypes.aembedding.value
|
610 |
+
or call_type == CallTypes.embedding.value
|
611 |
+
) and isinstance(cached_result, dict):
|
612 |
+
cached_result = convert_to_model_response_object(
|
613 |
+
response_object=cached_result,
|
614 |
+
model_response_object=EmbeddingResponse(),
|
615 |
+
response_type="embedding",
|
616 |
+
)
|
617 |
+
|
618 |
+
elif (
|
619 |
+
call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value
|
620 |
+
) and isinstance(cached_result, dict):
|
621 |
+
cached_result = convert_to_model_response_object(
|
622 |
+
response_object=cached_result,
|
623 |
+
model_response_object=None,
|
624 |
+
response_type="rerank",
|
625 |
+
)
|
626 |
+
elif (
|
627 |
+
call_type == CallTypes.atranscription.value
|
628 |
+
or call_type == CallTypes.transcription.value
|
629 |
+
) and isinstance(cached_result, dict):
|
630 |
+
hidden_params = {
|
631 |
+
"model": "whisper-1",
|
632 |
+
"custom_llm_provider": custom_llm_provider,
|
633 |
+
"cache_hit": True,
|
634 |
+
}
|
635 |
+
cached_result = convert_to_model_response_object(
|
636 |
+
response_object=cached_result,
|
637 |
+
model_response_object=TranscriptionResponse(),
|
638 |
+
response_type="audio_transcription",
|
639 |
+
hidden_params=hidden_params,
|
640 |
+
)
|
641 |
+
|
642 |
+
if (
|
643 |
+
hasattr(cached_result, "_hidden_params")
|
644 |
+
and cached_result._hidden_params is not None
|
645 |
+
and isinstance(cached_result._hidden_params, dict)
|
646 |
+
):
|
647 |
+
cached_result._hidden_params["cache_hit"] = True
|
648 |
+
return cached_result
|
649 |
+
|
650 |
+
def _convert_cached_stream_response(
|
651 |
+
self,
|
652 |
+
cached_result: Any,
|
653 |
+
call_type: str,
|
654 |
+
logging_obj: LiteLLMLoggingObj,
|
655 |
+
model: str,
|
656 |
+
) -> CustomStreamWrapper:
|
657 |
+
from litellm.utils import (
|
658 |
+
CustomStreamWrapper,
|
659 |
+
convert_to_streaming_response,
|
660 |
+
convert_to_streaming_response_async,
|
661 |
+
)
|
662 |
+
|
663 |
+
_stream_cached_result: Union[AsyncGenerator, Generator]
|
664 |
+
if (
|
665 |
+
call_type == CallTypes.acompletion.value
|
666 |
+
or call_type == CallTypes.atext_completion.value
|
667 |
+
):
|
668 |
+
_stream_cached_result = convert_to_streaming_response_async(
|
669 |
+
response_object=cached_result,
|
670 |
+
)
|
671 |
+
else:
|
672 |
+
_stream_cached_result = convert_to_streaming_response(
|
673 |
+
response_object=cached_result,
|
674 |
+
)
|
675 |
+
return CustomStreamWrapper(
|
676 |
+
completion_stream=_stream_cached_result,
|
677 |
+
model=model,
|
678 |
+
custom_llm_provider="cached_response",
|
679 |
+
logging_obj=logging_obj,
|
680 |
+
)
|
681 |
+
|
682 |
+
async def async_set_cache(
|
683 |
+
self,
|
684 |
+
result: Any,
|
685 |
+
original_function: Callable,
|
686 |
+
kwargs: Dict[str, Any],
|
687 |
+
args: Optional[Tuple[Any, ...]] = None,
|
688 |
+
):
|
689 |
+
"""
|
690 |
+
Internal method to check the type of the result & cache used and adds the result to the cache accordingly
|
691 |
+
|
692 |
+
Args:
|
693 |
+
result: Any:
|
694 |
+
original_function: Callable:
|
695 |
+
kwargs: Dict[str, Any]:
|
696 |
+
args: Optional[Tuple[Any, ...]] = None:
|
697 |
+
|
698 |
+
Returns:
|
699 |
+
None
|
700 |
+
Raises:
|
701 |
+
None
|
702 |
+
"""
|
703 |
+
if litellm.cache is None:
|
704 |
+
return
|
705 |
+
|
706 |
+
new_kwargs = kwargs.copy()
|
707 |
+
new_kwargs.update(
|
708 |
+
convert_args_to_kwargs(
|
709 |
+
original_function,
|
710 |
+
args,
|
711 |
+
)
|
712 |
+
)
|
713 |
+
# [OPTIONAL] ADD TO CACHE
|
714 |
+
if self._should_store_result_in_cache(
|
715 |
+
original_function=original_function, kwargs=new_kwargs
|
716 |
+
):
|
717 |
+
if (
|
718 |
+
isinstance(result, litellm.ModelResponse)
|
719 |
+
or isinstance(result, litellm.EmbeddingResponse)
|
720 |
+
or isinstance(result, TranscriptionResponse)
|
721 |
+
or isinstance(result, RerankResponse)
|
722 |
+
):
|
723 |
+
if (
|
724 |
+
isinstance(result, EmbeddingResponse)
|
725 |
+
and litellm.cache is not None
|
726 |
+
and not isinstance(
|
727 |
+
litellm.cache.cache, S3Cache
|
728 |
+
) # s3 doesn't support bulk writing. Exclude.
|
729 |
+
):
|
730 |
+
asyncio.create_task(
|
731 |
+
litellm.cache.async_add_cache_pipeline(result, **new_kwargs)
|
732 |
+
)
|
733 |
+
elif isinstance(litellm.cache.cache, S3Cache):
|
734 |
+
threading.Thread(
|
735 |
+
target=litellm.cache.add_cache,
|
736 |
+
args=(result,),
|
737 |
+
kwargs=new_kwargs,
|
738 |
+
).start()
|
739 |
+
else:
|
740 |
+
asyncio.create_task(
|
741 |
+
litellm.cache.async_add_cache(
|
742 |
+
result.model_dump_json(), **new_kwargs
|
743 |
+
)
|
744 |
+
)
|
745 |
+
else:
|
746 |
+
asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs))
|
747 |
+
|
748 |
+
def sync_set_cache(
|
749 |
+
self,
|
750 |
+
result: Any,
|
751 |
+
kwargs: Dict[str, Any],
|
752 |
+
args: Optional[Tuple[Any, ...]] = None,
|
753 |
+
):
|
754 |
+
"""
|
755 |
+
Sync internal method to add the result to the cache
|
756 |
+
"""
|
757 |
+
|
758 |
+
new_kwargs = kwargs.copy()
|
759 |
+
new_kwargs.update(
|
760 |
+
convert_args_to_kwargs(
|
761 |
+
self.original_function,
|
762 |
+
args,
|
763 |
+
)
|
764 |
+
)
|
765 |
+
if litellm.cache is None:
|
766 |
+
return
|
767 |
+
|
768 |
+
if self._should_store_result_in_cache(
|
769 |
+
original_function=self.original_function, kwargs=new_kwargs
|
770 |
+
):
|
771 |
+
litellm.cache.add_cache(result, **new_kwargs)
|
772 |
+
|
773 |
+
return
|
774 |
+
|
775 |
+
def _should_store_result_in_cache(
|
776 |
+
self, original_function: Callable, kwargs: Dict[str, Any]
|
777 |
+
) -> bool:
|
778 |
+
"""
|
779 |
+
Helper function to determine if the result should be stored in the cache.
|
780 |
+
|
781 |
+
Returns:
|
782 |
+
bool: True if the result should be stored in the cache, False otherwise.
|
783 |
+
"""
|
784 |
+
return (
|
785 |
+
(litellm.cache is not None)
|
786 |
+
and litellm.cache.supported_call_types is not None
|
787 |
+
and (str(original_function.__name__) in litellm.cache.supported_call_types)
|
788 |
+
and (kwargs.get("cache", {}).get("no-store", False) is not True)
|
789 |
+
)
|
790 |
+
|
791 |
+
def _is_call_type_supported_by_cache(
|
792 |
+
self,
|
793 |
+
original_function: Callable,
|
794 |
+
) -> bool:
|
795 |
+
"""
|
796 |
+
Helper function to determine if the call type is supported by the cache.
|
797 |
+
|
798 |
+
call types are acompletion, aembedding, atext_completion, atranscription, arerank
|
799 |
+
|
800 |
+
Defined on `litellm.types.utils.CallTypes`
|
801 |
+
|
802 |
+
Returns:
|
803 |
+
bool: True if the call type is supported by the cache, False otherwise.
|
804 |
+
"""
|
805 |
+
if (
|
806 |
+
litellm.cache is not None
|
807 |
+
and litellm.cache.supported_call_types is not None
|
808 |
+
and str(original_function.__name__) in litellm.cache.supported_call_types
|
809 |
+
):
|
810 |
+
return True
|
811 |
+
return False
|
812 |
+
|
813 |
+
async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
814 |
+
"""
|
815 |
+
Internal method to add the streaming response to the cache
|
816 |
+
|
817 |
+
|
818 |
+
- If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object
|
819 |
+
- Else append the chunk to self.async_streaming_chunks
|
820 |
+
|
821 |
+
"""
|
822 |
+
|
823 |
+
complete_streaming_response: Optional[
|
824 |
+
Union[ModelResponse, TextCompletionResponse]
|
825 |
+
] = _assemble_complete_response_from_streaming_chunks(
|
826 |
+
result=processed_chunk,
|
827 |
+
start_time=self.start_time,
|
828 |
+
end_time=datetime.datetime.now(),
|
829 |
+
request_kwargs=self.request_kwargs,
|
830 |
+
streaming_chunks=self.async_streaming_chunks,
|
831 |
+
is_async=True,
|
832 |
+
)
|
833 |
+
# if a complete_streaming_response is assembled, add it to the cache
|
834 |
+
if complete_streaming_response is not None:
|
835 |
+
await self.async_set_cache(
|
836 |
+
result=complete_streaming_response,
|
837 |
+
original_function=self.original_function,
|
838 |
+
kwargs=self.request_kwargs,
|
839 |
+
)
|
840 |
+
|
841 |
+
def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse):
|
842 |
+
"""
|
843 |
+
Sync internal method to add the streaming response to the cache
|
844 |
+
"""
|
845 |
+
complete_streaming_response: Optional[
|
846 |
+
Union[ModelResponse, TextCompletionResponse]
|
847 |
+
] = _assemble_complete_response_from_streaming_chunks(
|
848 |
+
result=processed_chunk,
|
849 |
+
start_time=self.start_time,
|
850 |
+
end_time=datetime.datetime.now(),
|
851 |
+
request_kwargs=self.request_kwargs,
|
852 |
+
streaming_chunks=self.sync_streaming_chunks,
|
853 |
+
is_async=False,
|
854 |
+
)
|
855 |
+
|
856 |
+
# if a complete_streaming_response is assembled, add it to the cache
|
857 |
+
if complete_streaming_response is not None:
|
858 |
+
self.sync_set_cache(
|
859 |
+
result=complete_streaming_response,
|
860 |
+
kwargs=self.request_kwargs,
|
861 |
+
)
|
862 |
+
|
863 |
+
def _update_litellm_logging_obj_environment(
|
864 |
+
self,
|
865 |
+
logging_obj: LiteLLMLoggingObj,
|
866 |
+
model: str,
|
867 |
+
kwargs: Dict[str, Any],
|
868 |
+
cached_result: Any,
|
869 |
+
is_async: bool,
|
870 |
+
is_embedding: bool = False,
|
871 |
+
):
|
872 |
+
"""
|
873 |
+
Helper function to update the LiteLLMLoggingObj environment variables.
|
874 |
+
|
875 |
+
Args:
|
876 |
+
logging_obj (LiteLLMLoggingObj): The logging object to update.
|
877 |
+
model (str): The model being used.
|
878 |
+
kwargs (Dict[str, Any]): The keyword arguments from the original function call.
|
879 |
+
cached_result (Any): The cached result to log.
|
880 |
+
is_async (bool): Whether the call is asynchronous or not.
|
881 |
+
is_embedding (bool): Whether the call is for embeddings or not.
|
882 |
+
|
883 |
+
Returns:
|
884 |
+
None
|
885 |
+
"""
|
886 |
+
litellm_params = {
|
887 |
+
"logger_fn": kwargs.get("logger_fn", None),
|
888 |
+
"acompletion": is_async,
|
889 |
+
"api_base": kwargs.get("api_base", ""),
|
890 |
+
"metadata": kwargs.get("metadata", {}),
|
891 |
+
"model_info": kwargs.get("model_info", {}),
|
892 |
+
"proxy_server_request": kwargs.get("proxy_server_request", None),
|
893 |
+
"stream_response": kwargs.get("stream_response", {}),
|
894 |
+
}
|
895 |
+
|
896 |
+
if litellm.cache is not None:
|
897 |
+
litellm_params[
|
898 |
+
"preset_cache_key"
|
899 |
+
] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs)
|
900 |
+
else:
|
901 |
+
litellm_params["preset_cache_key"] = None
|
902 |
+
|
903 |
+
logging_obj.update_environment_variables(
|
904 |
+
model=model,
|
905 |
+
user=kwargs.get("user", None),
|
906 |
+
optional_params={},
|
907 |
+
litellm_params=litellm_params,
|
908 |
+
input=(
|
909 |
+
kwargs.get("messages", "")
|
910 |
+
if not is_embedding
|
911 |
+
else kwargs.get("input", "")
|
912 |
+
),
|
913 |
+
api_key=kwargs.get("api_key", None),
|
914 |
+
original_response=str(cached_result),
|
915 |
+
additional_args=None,
|
916 |
+
stream=kwargs.get("stream", False),
|
917 |
+
)
|
918 |
+
|
919 |
+
|
920 |
+
def convert_args_to_kwargs(
|
921 |
+
original_function: Callable,
|
922 |
+
args: Optional[Tuple[Any, ...]] = None,
|
923 |
+
) -> Dict[str, Any]:
|
924 |
+
# Get the signature of the original function
|
925 |
+
signature = inspect.signature(original_function)
|
926 |
+
|
927 |
+
# Get parameter names in the order they appear in the original function
|
928 |
+
param_names = list(signature.parameters.keys())
|
929 |
+
|
930 |
+
# Create a mapping of positional arguments to parameter names
|
931 |
+
args_to_kwargs = {}
|
932 |
+
if args:
|
933 |
+
for index, arg in enumerate(args):
|
934 |
+
if index < len(param_names):
|
935 |
+
param_name = param_names[index]
|
936 |
+
args_to_kwargs[param_name] = arg
|
937 |
+
|
938 |
+
return args_to_kwargs
|
litellm/caching/disk_cache.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
3 |
+
|
4 |
+
from .base_cache import BaseCache
|
5 |
+
|
6 |
+
if TYPE_CHECKING:
|
7 |
+
from opentelemetry.trace import Span as _Span
|
8 |
+
|
9 |
+
Span = Union[_Span, Any]
|
10 |
+
else:
|
11 |
+
Span = Any
|
12 |
+
|
13 |
+
|
14 |
+
class DiskCache(BaseCache):
|
15 |
+
def __init__(self, disk_cache_dir: Optional[str] = None):
|
16 |
+
import diskcache as dc
|
17 |
+
|
18 |
+
# if users don't provider one, use the default litellm cache
|
19 |
+
if disk_cache_dir is None:
|
20 |
+
self.disk_cache = dc.Cache(".litellm_cache")
|
21 |
+
else:
|
22 |
+
self.disk_cache = dc.Cache(disk_cache_dir)
|
23 |
+
|
24 |
+
def set_cache(self, key, value, **kwargs):
|
25 |
+
if "ttl" in kwargs:
|
26 |
+
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
27 |
+
else:
|
28 |
+
self.disk_cache.set(key, value)
|
29 |
+
|
30 |
+
async def async_set_cache(self, key, value, **kwargs):
|
31 |
+
self.set_cache(key=key, value=value, **kwargs)
|
32 |
+
|
33 |
+
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
34 |
+
for cache_key, cache_value in cache_list:
|
35 |
+
if "ttl" in kwargs:
|
36 |
+
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
37 |
+
else:
|
38 |
+
self.set_cache(key=cache_key, value=cache_value)
|
39 |
+
|
40 |
+
def get_cache(self, key, **kwargs):
|
41 |
+
original_cached_response = self.disk_cache.get(key)
|
42 |
+
if original_cached_response:
|
43 |
+
try:
|
44 |
+
cached_response = json.loads(original_cached_response) # type: ignore
|
45 |
+
except Exception:
|
46 |
+
cached_response = original_cached_response
|
47 |
+
return cached_response
|
48 |
+
return None
|
49 |
+
|
50 |
+
def batch_get_cache(self, keys: list, **kwargs):
|
51 |
+
return_val = []
|
52 |
+
for k in keys:
|
53 |
+
val = self.get_cache(key=k, **kwargs)
|
54 |
+
return_val.append(val)
|
55 |
+
return return_val
|
56 |
+
|
57 |
+
def increment_cache(self, key, value: int, **kwargs) -> int:
|
58 |
+
# get the value
|
59 |
+
init_value = self.get_cache(key=key) or 0
|
60 |
+
value = init_value + value # type: ignore
|
61 |
+
self.set_cache(key, value, **kwargs)
|
62 |
+
return value
|
63 |
+
|
64 |
+
async def async_get_cache(self, key, **kwargs):
|
65 |
+
return self.get_cache(key=key, **kwargs)
|
66 |
+
|
67 |
+
async def async_batch_get_cache(self, keys: list, **kwargs):
|
68 |
+
return_val = []
|
69 |
+
for k in keys:
|
70 |
+
val = self.get_cache(key=k, **kwargs)
|
71 |
+
return_val.append(val)
|
72 |
+
return return_val
|
73 |
+
|
74 |
+
async def async_increment(self, key, value: int, **kwargs) -> int:
|
75 |
+
# get the value
|
76 |
+
init_value = await self.async_get_cache(key=key) or 0
|
77 |
+
value = init_value + value # type: ignore
|
78 |
+
await self.async_set_cache(key, value, **kwargs)
|
79 |
+
return value
|
80 |
+
|
81 |
+
def flush_cache(self):
|
82 |
+
self.disk_cache.clear()
|
83 |
+
|
84 |
+
async def disconnect(self):
|
85 |
+
pass
|
86 |
+
|
87 |
+
def delete_cache(self, key):
|
88 |
+
self.disk_cache.pop(key)
|
litellm/caching/dual_cache.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
|
3 |
+
|
4 |
+
Has 4 primary methods:
|
5 |
+
- set_cache
|
6 |
+
- get_cache
|
7 |
+
- async_set_cache
|
8 |
+
- async_get_cache
|
9 |
+
"""
|
10 |
+
|
11 |
+
import asyncio
|
12 |
+
import time
|
13 |
+
import traceback
|
14 |
+
from concurrent.futures import ThreadPoolExecutor
|
15 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
16 |
+
|
17 |
+
import litellm
|
18 |
+
from litellm._logging import print_verbose, verbose_logger
|
19 |
+
|
20 |
+
from .base_cache import BaseCache
|
21 |
+
from .in_memory_cache import InMemoryCache
|
22 |
+
from .redis_cache import RedisCache
|
23 |
+
|
24 |
+
if TYPE_CHECKING:
|
25 |
+
from opentelemetry.trace import Span as _Span
|
26 |
+
|
27 |
+
Span = Union[_Span, Any]
|
28 |
+
else:
|
29 |
+
Span = Any
|
30 |
+
|
31 |
+
from collections import OrderedDict
|
32 |
+
|
33 |
+
|
34 |
+
class LimitedSizeOrderedDict(OrderedDict):
|
35 |
+
def __init__(self, *args, max_size=100, **kwargs):
|
36 |
+
super().__init__(*args, **kwargs)
|
37 |
+
self.max_size = max_size
|
38 |
+
|
39 |
+
def __setitem__(self, key, value):
|
40 |
+
# If inserting a new key exceeds max size, remove the oldest item
|
41 |
+
if len(self) >= self.max_size:
|
42 |
+
self.popitem(last=False)
|
43 |
+
super().__setitem__(key, value)
|
44 |
+
|
45 |
+
|
46 |
+
class DualCache(BaseCache):
|
47 |
+
"""
|
48 |
+
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
49 |
+
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
50 |
+
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
in_memory_cache: Optional[InMemoryCache] = None,
|
56 |
+
redis_cache: Optional[RedisCache] = None,
|
57 |
+
default_in_memory_ttl: Optional[float] = None,
|
58 |
+
default_redis_ttl: Optional[float] = None,
|
59 |
+
default_redis_batch_cache_expiry: Optional[float] = None,
|
60 |
+
default_max_redis_batch_cache_size: int = 100,
|
61 |
+
) -> None:
|
62 |
+
super().__init__()
|
63 |
+
# If in_memory_cache is not provided, use the default InMemoryCache
|
64 |
+
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
65 |
+
# If redis_cache is not provided, use the default RedisCache
|
66 |
+
self.redis_cache = redis_cache
|
67 |
+
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
68 |
+
max_size=default_max_redis_batch_cache_size
|
69 |
+
)
|
70 |
+
self.redis_batch_cache_expiry = (
|
71 |
+
default_redis_batch_cache_expiry
|
72 |
+
or litellm.default_redis_batch_cache_expiry
|
73 |
+
or 10
|
74 |
+
)
|
75 |
+
self.default_in_memory_ttl = (
|
76 |
+
default_in_memory_ttl or litellm.default_in_memory_ttl
|
77 |
+
)
|
78 |
+
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
79 |
+
|
80 |
+
def update_cache_ttl(
|
81 |
+
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
|
82 |
+
):
|
83 |
+
if default_in_memory_ttl is not None:
|
84 |
+
self.default_in_memory_ttl = default_in_memory_ttl
|
85 |
+
|
86 |
+
if default_redis_ttl is not None:
|
87 |
+
self.default_redis_ttl = default_redis_ttl
|
88 |
+
|
89 |
+
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
90 |
+
# Update both Redis and in-memory cache
|
91 |
+
try:
|
92 |
+
if self.in_memory_cache is not None:
|
93 |
+
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
94 |
+
kwargs["ttl"] = self.default_in_memory_ttl
|
95 |
+
|
96 |
+
self.in_memory_cache.set_cache(key, value, **kwargs)
|
97 |
+
|
98 |
+
if self.redis_cache is not None and local_only is False:
|
99 |
+
self.redis_cache.set_cache(key, value, **kwargs)
|
100 |
+
except Exception as e:
|
101 |
+
print_verbose(e)
|
102 |
+
|
103 |
+
def increment_cache(
|
104 |
+
self, key, value: int, local_only: bool = False, **kwargs
|
105 |
+
) -> int:
|
106 |
+
"""
|
107 |
+
Key - the key in cache
|
108 |
+
|
109 |
+
Value - int - the value you want to increment by
|
110 |
+
|
111 |
+
Returns - int - the incremented value
|
112 |
+
"""
|
113 |
+
try:
|
114 |
+
result: int = value
|
115 |
+
if self.in_memory_cache is not None:
|
116 |
+
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
117 |
+
|
118 |
+
if self.redis_cache is not None and local_only is False:
|
119 |
+
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
120 |
+
|
121 |
+
return result
|
122 |
+
except Exception as e:
|
123 |
+
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
124 |
+
raise e
|
125 |
+
|
126 |
+
def get_cache(
|
127 |
+
self,
|
128 |
+
key,
|
129 |
+
parent_otel_span: Optional[Span] = None,
|
130 |
+
local_only: bool = False,
|
131 |
+
**kwargs,
|
132 |
+
):
|
133 |
+
# Try to fetch from in-memory cache first
|
134 |
+
try:
|
135 |
+
result = None
|
136 |
+
if self.in_memory_cache is not None:
|
137 |
+
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
138 |
+
|
139 |
+
if in_memory_result is not None:
|
140 |
+
result = in_memory_result
|
141 |
+
|
142 |
+
if result is None and self.redis_cache is not None and local_only is False:
|
143 |
+
# If not found in in-memory cache, try fetching from Redis
|
144 |
+
redis_result = self.redis_cache.get_cache(
|
145 |
+
key, parent_otel_span=parent_otel_span
|
146 |
+
)
|
147 |
+
|
148 |
+
if redis_result is not None:
|
149 |
+
# Update in-memory cache with the value from Redis
|
150 |
+
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
151 |
+
|
152 |
+
result = redis_result
|
153 |
+
|
154 |
+
print_verbose(f"get cache: cache result: {result}")
|
155 |
+
return result
|
156 |
+
except Exception:
|
157 |
+
verbose_logger.error(traceback.format_exc())
|
158 |
+
|
159 |
+
def batch_get_cache(
|
160 |
+
self,
|
161 |
+
keys: list,
|
162 |
+
parent_otel_span: Optional[Span] = None,
|
163 |
+
local_only: bool = False,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
received_args = locals()
|
167 |
+
received_args.pop("self")
|
168 |
+
|
169 |
+
def run_in_new_loop():
|
170 |
+
"""Run the coroutine in a new event loop within this thread."""
|
171 |
+
new_loop = asyncio.new_event_loop()
|
172 |
+
try:
|
173 |
+
asyncio.set_event_loop(new_loop)
|
174 |
+
return new_loop.run_until_complete(
|
175 |
+
self.async_batch_get_cache(**received_args)
|
176 |
+
)
|
177 |
+
finally:
|
178 |
+
new_loop.close()
|
179 |
+
asyncio.set_event_loop(None)
|
180 |
+
|
181 |
+
try:
|
182 |
+
# First, try to get the current event loop
|
183 |
+
_ = asyncio.get_running_loop()
|
184 |
+
# If we're already in an event loop, run in a separate thread
|
185 |
+
# to avoid nested event loop issues
|
186 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
187 |
+
future = executor.submit(run_in_new_loop)
|
188 |
+
return future.result()
|
189 |
+
|
190 |
+
except RuntimeError:
|
191 |
+
# No running event loop, we can safely run in this thread
|
192 |
+
return run_in_new_loop()
|
193 |
+
|
194 |
+
async def async_get_cache(
|
195 |
+
self,
|
196 |
+
key,
|
197 |
+
parent_otel_span: Optional[Span] = None,
|
198 |
+
local_only: bool = False,
|
199 |
+
**kwargs,
|
200 |
+
):
|
201 |
+
# Try to fetch from in-memory cache first
|
202 |
+
try:
|
203 |
+
print_verbose(
|
204 |
+
f"async get cache: cache key: {key}; local_only: {local_only}"
|
205 |
+
)
|
206 |
+
result = None
|
207 |
+
if self.in_memory_cache is not None:
|
208 |
+
in_memory_result = await self.in_memory_cache.async_get_cache(
|
209 |
+
key, **kwargs
|
210 |
+
)
|
211 |
+
|
212 |
+
print_verbose(f"in_memory_result: {in_memory_result}")
|
213 |
+
if in_memory_result is not None:
|
214 |
+
result = in_memory_result
|
215 |
+
|
216 |
+
if result is None and self.redis_cache is not None and local_only is False:
|
217 |
+
# If not found in in-memory cache, try fetching from Redis
|
218 |
+
redis_result = await self.redis_cache.async_get_cache(
|
219 |
+
key, parent_otel_span=parent_otel_span
|
220 |
+
)
|
221 |
+
|
222 |
+
if redis_result is not None:
|
223 |
+
# Update in-memory cache with the value from Redis
|
224 |
+
await self.in_memory_cache.async_set_cache(
|
225 |
+
key, redis_result, **kwargs
|
226 |
+
)
|
227 |
+
|
228 |
+
result = redis_result
|
229 |
+
|
230 |
+
print_verbose(f"get cache: cache result: {result}")
|
231 |
+
return result
|
232 |
+
except Exception:
|
233 |
+
verbose_logger.error(traceback.format_exc())
|
234 |
+
|
235 |
+
def get_redis_batch_keys(
|
236 |
+
self,
|
237 |
+
current_time: float,
|
238 |
+
keys: List[str],
|
239 |
+
result: List[Any],
|
240 |
+
) -> List[str]:
|
241 |
+
sublist_keys = []
|
242 |
+
for key, value in zip(keys, result):
|
243 |
+
if value is None:
|
244 |
+
if (
|
245 |
+
key not in self.last_redis_batch_access_time
|
246 |
+
or current_time - self.last_redis_batch_access_time[key]
|
247 |
+
>= self.redis_batch_cache_expiry
|
248 |
+
):
|
249 |
+
sublist_keys.append(key)
|
250 |
+
return sublist_keys
|
251 |
+
|
252 |
+
async def async_batch_get_cache(
|
253 |
+
self,
|
254 |
+
keys: list,
|
255 |
+
parent_otel_span: Optional[Span] = None,
|
256 |
+
local_only: bool = False,
|
257 |
+
**kwargs,
|
258 |
+
):
|
259 |
+
try:
|
260 |
+
result = [None for _ in range(len(keys))]
|
261 |
+
if self.in_memory_cache is not None:
|
262 |
+
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
263 |
+
keys, **kwargs
|
264 |
+
)
|
265 |
+
|
266 |
+
if in_memory_result is not None:
|
267 |
+
result = in_memory_result
|
268 |
+
|
269 |
+
if None in result and self.redis_cache is not None and local_only is False:
|
270 |
+
"""
|
271 |
+
- for the none values in the result
|
272 |
+
- check the redis cache
|
273 |
+
"""
|
274 |
+
current_time = time.time()
|
275 |
+
sublist_keys = self.get_redis_batch_keys(current_time, keys, result)
|
276 |
+
|
277 |
+
# Only hit Redis if the last access time was more than 5 seconds ago
|
278 |
+
if len(sublist_keys) > 0:
|
279 |
+
# If not found in in-memory cache, try fetching from Redis
|
280 |
+
redis_result = await self.redis_cache.async_batch_get_cache(
|
281 |
+
sublist_keys, parent_otel_span=parent_otel_span
|
282 |
+
)
|
283 |
+
|
284 |
+
if redis_result is not None:
|
285 |
+
# Update in-memory cache with the value from Redis
|
286 |
+
for key, value in redis_result.items():
|
287 |
+
if value is not None:
|
288 |
+
await self.in_memory_cache.async_set_cache(
|
289 |
+
key, redis_result[key], **kwargs
|
290 |
+
)
|
291 |
+
# Update the last access time for each key fetched from Redis
|
292 |
+
self.last_redis_batch_access_time[key] = current_time
|
293 |
+
|
294 |
+
for key, value in redis_result.items():
|
295 |
+
index = keys.index(key)
|
296 |
+
result[index] = value
|
297 |
+
|
298 |
+
return result
|
299 |
+
except Exception:
|
300 |
+
verbose_logger.error(traceback.format_exc())
|
301 |
+
|
302 |
+
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
303 |
+
print_verbose(
|
304 |
+
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
305 |
+
)
|
306 |
+
try:
|
307 |
+
if self.in_memory_cache is not None:
|
308 |
+
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
309 |
+
|
310 |
+
if self.redis_cache is not None and local_only is False:
|
311 |
+
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
312 |
+
except Exception as e:
|
313 |
+
verbose_logger.exception(
|
314 |
+
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
315 |
+
)
|
316 |
+
|
317 |
+
# async_batch_set_cache
|
318 |
+
async def async_set_cache_pipeline(
|
319 |
+
self, cache_list: list, local_only: bool = False, **kwargs
|
320 |
+
):
|
321 |
+
"""
|
322 |
+
Batch write values to the cache
|
323 |
+
"""
|
324 |
+
print_verbose(
|
325 |
+
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
326 |
+
)
|
327 |
+
try:
|
328 |
+
if self.in_memory_cache is not None:
|
329 |
+
await self.in_memory_cache.async_set_cache_pipeline(
|
330 |
+
cache_list=cache_list, **kwargs
|
331 |
+
)
|
332 |
+
|
333 |
+
if self.redis_cache is not None and local_only is False:
|
334 |
+
await self.redis_cache.async_set_cache_pipeline(
|
335 |
+
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
336 |
+
)
|
337 |
+
except Exception as e:
|
338 |
+
verbose_logger.exception(
|
339 |
+
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
340 |
+
)
|
341 |
+
|
342 |
+
async def async_increment_cache(
|
343 |
+
self,
|
344 |
+
key,
|
345 |
+
value: float,
|
346 |
+
parent_otel_span: Optional[Span] = None,
|
347 |
+
local_only: bool = False,
|
348 |
+
**kwargs,
|
349 |
+
) -> float:
|
350 |
+
"""
|
351 |
+
Key - the key in cache
|
352 |
+
|
353 |
+
Value - float - the value you want to increment by
|
354 |
+
|
355 |
+
Returns - float - the incremented value
|
356 |
+
"""
|
357 |
+
try:
|
358 |
+
result: float = value
|
359 |
+
if self.in_memory_cache is not None:
|
360 |
+
result = await self.in_memory_cache.async_increment(
|
361 |
+
key, value, **kwargs
|
362 |
+
)
|
363 |
+
|
364 |
+
if self.redis_cache is not None and local_only is False:
|
365 |
+
result = await self.redis_cache.async_increment(
|
366 |
+
key,
|
367 |
+
value,
|
368 |
+
parent_otel_span=parent_otel_span,
|
369 |
+
ttl=kwargs.get("ttl", None),
|
370 |
+
)
|
371 |
+
|
372 |
+
return result
|
373 |
+
except Exception as e:
|
374 |
+
raise e # don't log if exception is raised
|
375 |
+
|
376 |
+
async def async_set_cache_sadd(
|
377 |
+
self, key, value: List, local_only: bool = False, **kwargs
|
378 |
+
) -> None:
|
379 |
+
"""
|
380 |
+
Add value to a set
|
381 |
+
|
382 |
+
Key - the key in cache
|
383 |
+
|
384 |
+
Value - str - the value you want to add to the set
|
385 |
+
|
386 |
+
Returns - None
|
387 |
+
"""
|
388 |
+
try:
|
389 |
+
if self.in_memory_cache is not None:
|
390 |
+
_ = await self.in_memory_cache.async_set_cache_sadd(
|
391 |
+
key, value, ttl=kwargs.get("ttl", None)
|
392 |
+
)
|
393 |
+
|
394 |
+
if self.redis_cache is not None and local_only is False:
|
395 |
+
_ = await self.redis_cache.async_set_cache_sadd(
|
396 |
+
key, value, ttl=kwargs.get("ttl", None)
|
397 |
+
)
|
398 |
+
|
399 |
+
return None
|
400 |
+
except Exception as e:
|
401 |
+
raise e # don't log, if exception is raised
|
402 |
+
|
403 |
+
def flush_cache(self):
|
404 |
+
if self.in_memory_cache is not None:
|
405 |
+
self.in_memory_cache.flush_cache()
|
406 |
+
if self.redis_cache is not None:
|
407 |
+
self.redis_cache.flush_cache()
|
408 |
+
|
409 |
+
def delete_cache(self, key):
|
410 |
+
"""
|
411 |
+
Delete a key from the cache
|
412 |
+
"""
|
413 |
+
if self.in_memory_cache is not None:
|
414 |
+
self.in_memory_cache.delete_cache(key)
|
415 |
+
if self.redis_cache is not None:
|
416 |
+
self.redis_cache.delete_cache(key)
|
417 |
+
|
418 |
+
async def async_delete_cache(self, key: str):
|
419 |
+
"""
|
420 |
+
Delete a key from the cache
|
421 |
+
"""
|
422 |
+
if self.in_memory_cache is not None:
|
423 |
+
self.in_memory_cache.delete_cache(key)
|
424 |
+
if self.redis_cache is not None:
|
425 |
+
await self.redis_cache.async_delete_cache(key)
|
426 |
+
|
427 |
+
async def async_get_ttl(self, key: str) -> Optional[int]:
|
428 |
+
"""
|
429 |
+
Get the remaining TTL of a key in in-memory cache or redis
|
430 |
+
"""
|
431 |
+
ttl = await self.in_memory_cache.async_get_ttl(key)
|
432 |
+
if ttl is None and self.redis_cache is not None:
|
433 |
+
ttl = await self.redis_cache.async_get_ttl(key)
|
434 |
+
return ttl
|
litellm/caching/in_memory_cache.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
In-Memory Cache implementation
|
3 |
+
|
4 |
+
Has 4 methods:
|
5 |
+
- set_cache
|
6 |
+
- get_cache
|
7 |
+
- async_set_cache
|
8 |
+
- async_get_cache
|
9 |
+
"""
|
10 |
+
|
11 |
+
import json
|
12 |
+
import sys
|
13 |
+
import time
|
14 |
+
from typing import Any, List, Optional
|
15 |
+
|
16 |
+
from pydantic import BaseModel
|
17 |
+
|
18 |
+
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
19 |
+
|
20 |
+
from .base_cache import BaseCache
|
21 |
+
|
22 |
+
|
23 |
+
class InMemoryCache(BaseCache):
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
max_size_in_memory: Optional[int] = 200,
|
27 |
+
default_ttl: Optional[
|
28 |
+
int
|
29 |
+
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
|
30 |
+
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
|
34 |
+
"""
|
35 |
+
self.max_size_in_memory = (
|
36 |
+
max_size_in_memory or 200
|
37 |
+
) # set an upper bound of 200 items in-memory
|
38 |
+
self.default_ttl = default_ttl or 600
|
39 |
+
self.max_size_per_item = (
|
40 |
+
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
41 |
+
) # 1MB = 1024KB
|
42 |
+
|
43 |
+
# in-memory cache
|
44 |
+
self.cache_dict: dict = {}
|
45 |
+
self.ttl_dict: dict = {}
|
46 |
+
|
47 |
+
def check_value_size(self, value: Any):
|
48 |
+
"""
|
49 |
+
Check if value size exceeds max_size_per_item (1MB)
|
50 |
+
Returns True if value size is acceptable, False otherwise
|
51 |
+
"""
|
52 |
+
try:
|
53 |
+
# Fast path for common primitive types that are typically small
|
54 |
+
if (
|
55 |
+
isinstance(value, (bool, int, float, str))
|
56 |
+
and len(str(value))
|
57 |
+
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
58 |
+
): # Conservative estimate
|
59 |
+
return True
|
60 |
+
|
61 |
+
# Direct size check for bytes objects
|
62 |
+
if isinstance(value, bytes):
|
63 |
+
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
64 |
+
|
65 |
+
# Handle special types without full conversion when possible
|
66 |
+
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
|
67 |
+
size = value.__sizeof__() / 1024
|
68 |
+
return size <= self.max_size_per_item
|
69 |
+
|
70 |
+
# Fallback for complex types
|
71 |
+
if isinstance(value, BaseModel) and hasattr(
|
72 |
+
value, "model_dump"
|
73 |
+
): # Pydantic v2
|
74 |
+
value = value.model_dump()
|
75 |
+
elif hasattr(value, "isoformat"): # datetime objects
|
76 |
+
return True # datetime strings are always small
|
77 |
+
|
78 |
+
# Only convert to JSON if absolutely necessary
|
79 |
+
if not isinstance(value, (str, bytes)):
|
80 |
+
value = json.dumps(value, default=str)
|
81 |
+
|
82 |
+
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
83 |
+
|
84 |
+
except Exception:
|
85 |
+
return False
|
86 |
+
|
87 |
+
def evict_cache(self):
|
88 |
+
"""
|
89 |
+
Eviction policy:
|
90 |
+
- check if any items in ttl_dict are expired -> remove them from ttl_dict and cache_dict
|
91 |
+
|
92 |
+
|
93 |
+
This guarantees the following:
|
94 |
+
- 1. When item ttl not set: At minimumm each item will remain in memory for 5 minutes
|
95 |
+
- 2. When ttl is set: the item will remain in memory for at least that amount of time
|
96 |
+
- 3. the size of in-memory cache is bounded
|
97 |
+
|
98 |
+
"""
|
99 |
+
for key in list(self.ttl_dict.keys()):
|
100 |
+
if time.time() > self.ttl_dict[key]:
|
101 |
+
self.cache_dict.pop(key, None)
|
102 |
+
self.ttl_dict.pop(key, None)
|
103 |
+
|
104 |
+
# de-reference the removed item
|
105 |
+
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
106 |
+
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
107 |
+
# This can occur when an object is referenced by another object, but the reference is never removed.
|
108 |
+
|
109 |
+
def set_cache(self, key, value, **kwargs):
|
110 |
+
if len(self.cache_dict) >= self.max_size_in_memory:
|
111 |
+
# only evict when cache is full
|
112 |
+
self.evict_cache()
|
113 |
+
if not self.check_value_size(value):
|
114 |
+
return
|
115 |
+
|
116 |
+
self.cache_dict[key] = value
|
117 |
+
if "ttl" in kwargs and kwargs["ttl"] is not None:
|
118 |
+
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
119 |
+
else:
|
120 |
+
self.ttl_dict[key] = time.time() + self.default_ttl
|
121 |
+
|
122 |
+
async def async_set_cache(self, key, value, **kwargs):
|
123 |
+
self.set_cache(key=key, value=value, **kwargs)
|
124 |
+
|
125 |
+
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
126 |
+
for cache_key, cache_value in cache_list:
|
127 |
+
if ttl is not None:
|
128 |
+
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
129 |
+
else:
|
130 |
+
self.set_cache(key=cache_key, value=cache_value)
|
131 |
+
|
132 |
+
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
133 |
+
"""
|
134 |
+
Add value to set
|
135 |
+
"""
|
136 |
+
# get the value
|
137 |
+
init_value = self.get_cache(key=key) or set()
|
138 |
+
for val in value:
|
139 |
+
init_value.add(val)
|
140 |
+
self.set_cache(key, init_value, ttl=ttl)
|
141 |
+
return value
|
142 |
+
|
143 |
+
def get_cache(self, key, **kwargs):
|
144 |
+
if key in self.cache_dict:
|
145 |
+
if key in self.ttl_dict:
|
146 |
+
if time.time() > self.ttl_dict[key]:
|
147 |
+
self.cache_dict.pop(key, None)
|
148 |
+
return None
|
149 |
+
original_cached_response = self.cache_dict[key]
|
150 |
+
try:
|
151 |
+
cached_response = json.loads(original_cached_response)
|
152 |
+
except Exception:
|
153 |
+
cached_response = original_cached_response
|
154 |
+
return cached_response
|
155 |
+
return None
|
156 |
+
|
157 |
+
def batch_get_cache(self, keys: list, **kwargs):
|
158 |
+
return_val = []
|
159 |
+
for k in keys:
|
160 |
+
val = self.get_cache(key=k, **kwargs)
|
161 |
+
return_val.append(val)
|
162 |
+
return return_val
|
163 |
+
|
164 |
+
def increment_cache(self, key, value: int, **kwargs) -> int:
|
165 |
+
# get the value
|
166 |
+
init_value = self.get_cache(key=key) or 0
|
167 |
+
value = init_value + value
|
168 |
+
self.set_cache(key, value, **kwargs)
|
169 |
+
return value
|
170 |
+
|
171 |
+
async def async_get_cache(self, key, **kwargs):
|
172 |
+
return self.get_cache(key=key, **kwargs)
|
173 |
+
|
174 |
+
async def async_batch_get_cache(self, keys: list, **kwargs):
|
175 |
+
return_val = []
|
176 |
+
for k in keys:
|
177 |
+
val = self.get_cache(key=k, **kwargs)
|
178 |
+
return_val.append(val)
|
179 |
+
return return_val
|
180 |
+
|
181 |
+
async def async_increment(self, key, value: float, **kwargs) -> float:
|
182 |
+
# get the value
|
183 |
+
init_value = await self.async_get_cache(key=key) or 0
|
184 |
+
value = init_value + value
|
185 |
+
await self.async_set_cache(key, value, **kwargs)
|
186 |
+
return value
|
187 |
+
|
188 |
+
def flush_cache(self):
|
189 |
+
self.cache_dict.clear()
|
190 |
+
self.ttl_dict.clear()
|
191 |
+
|
192 |
+
async def disconnect(self):
|
193 |
+
pass
|
194 |
+
|
195 |
+
def delete_cache(self, key):
|
196 |
+
self.cache_dict.pop(key, None)
|
197 |
+
self.ttl_dict.pop(key, None)
|
198 |
+
|
199 |
+
async def async_get_ttl(self, key: str) -> Optional[int]:
|
200 |
+
"""
|
201 |
+
Get the remaining TTL of a key in in-memory cache
|
202 |
+
"""
|
203 |
+
return self.ttl_dict.get(key, None)
|
litellm/caching/llm_caching_handler.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Add the event loop to the cache key, to prevent event loop closed errors.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import asyncio
|
6 |
+
|
7 |
+
from .in_memory_cache import InMemoryCache
|
8 |
+
|
9 |
+
|
10 |
+
class LLMClientCache(InMemoryCache):
|
11 |
+
def update_cache_key_with_event_loop(self, key):
|
12 |
+
"""
|
13 |
+
Add the event loop to the cache key, to prevent event loop closed errors.
|
14 |
+
If none, use the key as is.
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
event_loop = asyncio.get_event_loop()
|
18 |
+
stringified_event_loop = str(id(event_loop))
|
19 |
+
return f"{key}-{stringified_event_loop}"
|
20 |
+
except Exception: # handle no current event loop
|
21 |
+
return key
|
22 |
+
|
23 |
+
def set_cache(self, key, value, **kwargs):
|
24 |
+
key = self.update_cache_key_with_event_loop(key)
|
25 |
+
return super().set_cache(key, value, **kwargs)
|
26 |
+
|
27 |
+
async def async_set_cache(self, key, value, **kwargs):
|
28 |
+
key = self.update_cache_key_with_event_loop(key)
|
29 |
+
return await super().async_set_cache(key, value, **kwargs)
|
30 |
+
|
31 |
+
def get_cache(self, key, **kwargs):
|
32 |
+
key = self.update_cache_key_with_event_loop(key)
|
33 |
+
|
34 |
+
return super().get_cache(key, **kwargs)
|
35 |
+
|
36 |
+
async def async_get_cache(self, key, **kwargs):
|
37 |
+
key = self.update_cache_key_with_event_loop(key)
|
38 |
+
|
39 |
+
return await super().async_get_cache(key, **kwargs)
|
litellm/caching/qdrant_semantic_cache.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Qdrant Semantic Cache implementation
|
3 |
+
|
4 |
+
Has 4 methods:
|
5 |
+
- set_cache
|
6 |
+
- get_cache
|
7 |
+
- async_set_cache
|
8 |
+
- async_get_cache
|
9 |
+
"""
|
10 |
+
|
11 |
+
import ast
|
12 |
+
import asyncio
|
13 |
+
import json
|
14 |
+
from typing import Any, cast
|
15 |
+
|
16 |
+
import litellm
|
17 |
+
from litellm._logging import print_verbose
|
18 |
+
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
19 |
+
from litellm.types.utils import EmbeddingResponse
|
20 |
+
|
21 |
+
from .base_cache import BaseCache
|
22 |
+
|
23 |
+
|
24 |
+
class QdrantSemanticCache(BaseCache):
|
25 |
+
def __init__( # noqa: PLR0915
|
26 |
+
self,
|
27 |
+
qdrant_api_base=None,
|
28 |
+
qdrant_api_key=None,
|
29 |
+
collection_name=None,
|
30 |
+
similarity_threshold=None,
|
31 |
+
quantization_config=None,
|
32 |
+
embedding_model="text-embedding-ada-002",
|
33 |
+
host_type=None,
|
34 |
+
):
|
35 |
+
import os
|
36 |
+
|
37 |
+
from litellm.llms.custom_httpx.http_handler import (
|
38 |
+
_get_httpx_client,
|
39 |
+
get_async_httpx_client,
|
40 |
+
httpxSpecialProvider,
|
41 |
+
)
|
42 |
+
from litellm.secret_managers.main import get_secret_str
|
43 |
+
|
44 |
+
if collection_name is None:
|
45 |
+
raise Exception("collection_name must be provided, passed None")
|
46 |
+
|
47 |
+
self.collection_name = collection_name
|
48 |
+
print_verbose(
|
49 |
+
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
50 |
+
)
|
51 |
+
|
52 |
+
if similarity_threshold is None:
|
53 |
+
raise Exception("similarity_threshold must be provided, passed None")
|
54 |
+
self.similarity_threshold = similarity_threshold
|
55 |
+
self.embedding_model = embedding_model
|
56 |
+
headers = {}
|
57 |
+
|
58 |
+
# check if defined as os.environ/ variable
|
59 |
+
if qdrant_api_base:
|
60 |
+
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
61 |
+
"os.environ/"
|
62 |
+
):
|
63 |
+
qdrant_api_base = get_secret_str(qdrant_api_base)
|
64 |
+
if qdrant_api_key:
|
65 |
+
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
66 |
+
"os.environ/"
|
67 |
+
):
|
68 |
+
qdrant_api_key = get_secret_str(qdrant_api_key)
|
69 |
+
|
70 |
+
qdrant_api_base = (
|
71 |
+
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
72 |
+
)
|
73 |
+
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
74 |
+
headers = {"Content-Type": "application/json"}
|
75 |
+
if qdrant_api_key:
|
76 |
+
headers["api-key"] = qdrant_api_key
|
77 |
+
|
78 |
+
if qdrant_api_base is None:
|
79 |
+
raise ValueError("Qdrant url must be provided")
|
80 |
+
|
81 |
+
self.qdrant_api_base = qdrant_api_base
|
82 |
+
self.qdrant_api_key = qdrant_api_key
|
83 |
+
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
84 |
+
|
85 |
+
self.headers = headers
|
86 |
+
|
87 |
+
self.sync_client = _get_httpx_client()
|
88 |
+
self.async_client = get_async_httpx_client(
|
89 |
+
llm_provider=httpxSpecialProvider.Caching
|
90 |
+
)
|
91 |
+
|
92 |
+
if quantization_config is None:
|
93 |
+
print_verbose(
|
94 |
+
"Quantization config is not provided. Default binary quantization will be used."
|
95 |
+
)
|
96 |
+
collection_exists = self.sync_client.get(
|
97 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
98 |
+
headers=self.headers,
|
99 |
+
)
|
100 |
+
if collection_exists.status_code != 200:
|
101 |
+
raise ValueError(
|
102 |
+
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
103 |
+
)
|
104 |
+
|
105 |
+
if collection_exists.json()["result"]["exists"]:
|
106 |
+
collection_details = self.sync_client.get(
|
107 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
108 |
+
headers=self.headers,
|
109 |
+
)
|
110 |
+
self.collection_info = collection_details.json()
|
111 |
+
print_verbose(
|
112 |
+
f"Collection already exists.\nCollection details:{self.collection_info}"
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
if quantization_config is None or quantization_config == "binary":
|
116 |
+
quantization_params = {
|
117 |
+
"binary": {
|
118 |
+
"always_ram": False,
|
119 |
+
}
|
120 |
+
}
|
121 |
+
elif quantization_config == "scalar":
|
122 |
+
quantization_params = {
|
123 |
+
"scalar": {
|
124 |
+
"type": "int8",
|
125 |
+
"quantile": QDRANT_SCALAR_QUANTILE,
|
126 |
+
"always_ram": False,
|
127 |
+
}
|
128 |
+
}
|
129 |
+
elif quantization_config == "product":
|
130 |
+
quantization_params = {
|
131 |
+
"product": {"compression": "x16", "always_ram": False}
|
132 |
+
}
|
133 |
+
else:
|
134 |
+
raise Exception(
|
135 |
+
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
136 |
+
)
|
137 |
+
|
138 |
+
new_collection_status = self.sync_client.put(
|
139 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
140 |
+
json={
|
141 |
+
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"},
|
142 |
+
"quantization_config": quantization_params,
|
143 |
+
},
|
144 |
+
headers=self.headers,
|
145 |
+
)
|
146 |
+
if new_collection_status.json()["result"]:
|
147 |
+
collection_details = self.sync_client.get(
|
148 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
149 |
+
headers=self.headers,
|
150 |
+
)
|
151 |
+
self.collection_info = collection_details.json()
|
152 |
+
print_verbose(
|
153 |
+
f"New collection created.\nCollection details:{self.collection_info}"
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
raise Exception("Error while creating new collection")
|
157 |
+
|
158 |
+
def _get_cache_logic(self, cached_response: Any):
|
159 |
+
if cached_response is None:
|
160 |
+
return cached_response
|
161 |
+
try:
|
162 |
+
cached_response = json.loads(
|
163 |
+
cached_response
|
164 |
+
) # Convert string to dictionary
|
165 |
+
except Exception:
|
166 |
+
cached_response = ast.literal_eval(cached_response)
|
167 |
+
return cached_response
|
168 |
+
|
169 |
+
def set_cache(self, key, value, **kwargs):
|
170 |
+
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
171 |
+
import uuid
|
172 |
+
|
173 |
+
# get the prompt
|
174 |
+
messages = kwargs["messages"]
|
175 |
+
prompt = ""
|
176 |
+
for message in messages:
|
177 |
+
prompt += message["content"]
|
178 |
+
|
179 |
+
# create an embedding for prompt
|
180 |
+
embedding_response = cast(
|
181 |
+
EmbeddingResponse,
|
182 |
+
litellm.embedding(
|
183 |
+
model=self.embedding_model,
|
184 |
+
input=prompt,
|
185 |
+
cache={"no-store": True, "no-cache": True},
|
186 |
+
),
|
187 |
+
)
|
188 |
+
|
189 |
+
# get the embedding
|
190 |
+
embedding = embedding_response["data"][0]["embedding"]
|
191 |
+
|
192 |
+
value = str(value)
|
193 |
+
assert isinstance(value, str)
|
194 |
+
|
195 |
+
data = {
|
196 |
+
"points": [
|
197 |
+
{
|
198 |
+
"id": str(uuid.uuid4()),
|
199 |
+
"vector": embedding,
|
200 |
+
"payload": {
|
201 |
+
"text": prompt,
|
202 |
+
"response": value,
|
203 |
+
},
|
204 |
+
},
|
205 |
+
]
|
206 |
+
}
|
207 |
+
self.sync_client.put(
|
208 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
209 |
+
headers=self.headers,
|
210 |
+
json=data,
|
211 |
+
)
|
212 |
+
return
|
213 |
+
|
214 |
+
def get_cache(self, key, **kwargs):
|
215 |
+
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
216 |
+
|
217 |
+
# get the messages
|
218 |
+
messages = kwargs["messages"]
|
219 |
+
prompt = ""
|
220 |
+
for message in messages:
|
221 |
+
prompt += message["content"]
|
222 |
+
|
223 |
+
# convert to embedding
|
224 |
+
embedding_response = cast(
|
225 |
+
EmbeddingResponse,
|
226 |
+
litellm.embedding(
|
227 |
+
model=self.embedding_model,
|
228 |
+
input=prompt,
|
229 |
+
cache={"no-store": True, "no-cache": True},
|
230 |
+
),
|
231 |
+
)
|
232 |
+
|
233 |
+
# get the embedding
|
234 |
+
embedding = embedding_response["data"][0]["embedding"]
|
235 |
+
|
236 |
+
data = {
|
237 |
+
"vector": embedding,
|
238 |
+
"params": {
|
239 |
+
"quantization": {
|
240 |
+
"ignore": False,
|
241 |
+
"rescore": True,
|
242 |
+
"oversampling": 3.0,
|
243 |
+
}
|
244 |
+
},
|
245 |
+
"limit": 1,
|
246 |
+
"with_payload": True,
|
247 |
+
}
|
248 |
+
|
249 |
+
search_response = self.sync_client.post(
|
250 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
251 |
+
headers=self.headers,
|
252 |
+
json=data,
|
253 |
+
)
|
254 |
+
results = search_response.json()["result"]
|
255 |
+
|
256 |
+
if results is None:
|
257 |
+
return None
|
258 |
+
if isinstance(results, list):
|
259 |
+
if len(results) == 0:
|
260 |
+
return None
|
261 |
+
|
262 |
+
similarity = results[0]["score"]
|
263 |
+
cached_prompt = results[0]["payload"]["text"]
|
264 |
+
|
265 |
+
# check similarity, if more than self.similarity_threshold, return results
|
266 |
+
print_verbose(
|
267 |
+
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
268 |
+
)
|
269 |
+
if similarity >= self.similarity_threshold:
|
270 |
+
# cache hit !
|
271 |
+
cached_value = results[0]["payload"]["response"]
|
272 |
+
print_verbose(
|
273 |
+
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
274 |
+
)
|
275 |
+
return self._get_cache_logic(cached_response=cached_value)
|
276 |
+
else:
|
277 |
+
# cache miss !
|
278 |
+
return None
|
279 |
+
pass
|
280 |
+
|
281 |
+
async def async_set_cache(self, key, value, **kwargs):
|
282 |
+
import uuid
|
283 |
+
|
284 |
+
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
285 |
+
|
286 |
+
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
287 |
+
|
288 |
+
# get the prompt
|
289 |
+
messages = kwargs["messages"]
|
290 |
+
prompt = ""
|
291 |
+
for message in messages:
|
292 |
+
prompt += message["content"]
|
293 |
+
# create an embedding for prompt
|
294 |
+
router_model_names = (
|
295 |
+
[m["model_name"] for m in llm_model_list]
|
296 |
+
if llm_model_list is not None
|
297 |
+
else []
|
298 |
+
)
|
299 |
+
if llm_router is not None and self.embedding_model in router_model_names:
|
300 |
+
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
301 |
+
embedding_response = await llm_router.aembedding(
|
302 |
+
model=self.embedding_model,
|
303 |
+
input=prompt,
|
304 |
+
cache={"no-store": True, "no-cache": True},
|
305 |
+
metadata={
|
306 |
+
"user_api_key": user_api_key,
|
307 |
+
"semantic-cache-embedding": True,
|
308 |
+
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
309 |
+
},
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
# convert to embedding
|
313 |
+
embedding_response = await litellm.aembedding(
|
314 |
+
model=self.embedding_model,
|
315 |
+
input=prompt,
|
316 |
+
cache={"no-store": True, "no-cache": True},
|
317 |
+
)
|
318 |
+
|
319 |
+
# get the embedding
|
320 |
+
embedding = embedding_response["data"][0]["embedding"]
|
321 |
+
|
322 |
+
value = str(value)
|
323 |
+
assert isinstance(value, str)
|
324 |
+
|
325 |
+
data = {
|
326 |
+
"points": [
|
327 |
+
{
|
328 |
+
"id": str(uuid.uuid4()),
|
329 |
+
"vector": embedding,
|
330 |
+
"payload": {
|
331 |
+
"text": prompt,
|
332 |
+
"response": value,
|
333 |
+
},
|
334 |
+
},
|
335 |
+
]
|
336 |
+
}
|
337 |
+
|
338 |
+
await self.async_client.put(
|
339 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
340 |
+
headers=self.headers,
|
341 |
+
json=data,
|
342 |
+
)
|
343 |
+
return
|
344 |
+
|
345 |
+
async def async_get_cache(self, key, **kwargs):
|
346 |
+
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
347 |
+
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
348 |
+
|
349 |
+
# get the messages
|
350 |
+
messages = kwargs["messages"]
|
351 |
+
prompt = ""
|
352 |
+
for message in messages:
|
353 |
+
prompt += message["content"]
|
354 |
+
|
355 |
+
router_model_names = (
|
356 |
+
[m["model_name"] for m in llm_model_list]
|
357 |
+
if llm_model_list is not None
|
358 |
+
else []
|
359 |
+
)
|
360 |
+
if llm_router is not None and self.embedding_model in router_model_names:
|
361 |
+
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
362 |
+
embedding_response = await llm_router.aembedding(
|
363 |
+
model=self.embedding_model,
|
364 |
+
input=prompt,
|
365 |
+
cache={"no-store": True, "no-cache": True},
|
366 |
+
metadata={
|
367 |
+
"user_api_key": user_api_key,
|
368 |
+
"semantic-cache-embedding": True,
|
369 |
+
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
370 |
+
},
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
# convert to embedding
|
374 |
+
embedding_response = await litellm.aembedding(
|
375 |
+
model=self.embedding_model,
|
376 |
+
input=prompt,
|
377 |
+
cache={"no-store": True, "no-cache": True},
|
378 |
+
)
|
379 |
+
|
380 |
+
# get the embedding
|
381 |
+
embedding = embedding_response["data"][0]["embedding"]
|
382 |
+
|
383 |
+
data = {
|
384 |
+
"vector": embedding,
|
385 |
+
"params": {
|
386 |
+
"quantization": {
|
387 |
+
"ignore": False,
|
388 |
+
"rescore": True,
|
389 |
+
"oversampling": 3.0,
|
390 |
+
}
|
391 |
+
},
|
392 |
+
"limit": 1,
|
393 |
+
"with_payload": True,
|
394 |
+
}
|
395 |
+
|
396 |
+
search_response = await self.async_client.post(
|
397 |
+
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
398 |
+
headers=self.headers,
|
399 |
+
json=data,
|
400 |
+
)
|
401 |
+
|
402 |
+
results = search_response.json()["result"]
|
403 |
+
|
404 |
+
if results is None:
|
405 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
406 |
+
return None
|
407 |
+
if isinstance(results, list):
|
408 |
+
if len(results) == 0:
|
409 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
410 |
+
return None
|
411 |
+
|
412 |
+
similarity = results[0]["score"]
|
413 |
+
cached_prompt = results[0]["payload"]["text"]
|
414 |
+
|
415 |
+
# check similarity, if more than self.similarity_threshold, return results
|
416 |
+
print_verbose(
|
417 |
+
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
418 |
+
)
|
419 |
+
|
420 |
+
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
421 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
422 |
+
|
423 |
+
if similarity >= self.similarity_threshold:
|
424 |
+
# cache hit !
|
425 |
+
cached_value = results[0]["payload"]["response"]
|
426 |
+
print_verbose(
|
427 |
+
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
428 |
+
)
|
429 |
+
return self._get_cache_logic(cached_response=cached_value)
|
430 |
+
else:
|
431 |
+
# cache miss !
|
432 |
+
return None
|
433 |
+
pass
|
434 |
+
|
435 |
+
async def _collection_info(self):
|
436 |
+
return self.collection_info
|
437 |
+
|
438 |
+
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
439 |
+
tasks = []
|
440 |
+
for val in cache_list:
|
441 |
+
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
442 |
+
await asyncio.gather(*tasks)
|
litellm/caching/redis_cache.py
ADDED
@@ -0,0 +1,1162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Redis Cache implementation
|
3 |
+
|
4 |
+
Has 4 primary methods:
|
5 |
+
- set_cache
|
6 |
+
- get_cache
|
7 |
+
- async_set_cache
|
8 |
+
- async_get_cache
|
9 |
+
"""
|
10 |
+
|
11 |
+
import ast
|
12 |
+
import asyncio
|
13 |
+
import inspect
|
14 |
+
import json
|
15 |
+
import time
|
16 |
+
from datetime import timedelta
|
17 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import litellm
|
20 |
+
from litellm._logging import print_verbose, verbose_logger
|
21 |
+
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
22 |
+
from litellm.types.caching import RedisPipelineIncrementOperation
|
23 |
+
from litellm.types.services import ServiceTypes
|
24 |
+
|
25 |
+
from .base_cache import BaseCache
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from opentelemetry.trace import Span as _Span
|
29 |
+
from redis.asyncio import Redis, RedisCluster
|
30 |
+
from redis.asyncio.client import Pipeline
|
31 |
+
from redis.asyncio.cluster import ClusterPipeline
|
32 |
+
|
33 |
+
pipeline = Pipeline
|
34 |
+
cluster_pipeline = ClusterPipeline
|
35 |
+
async_redis_client = Redis
|
36 |
+
async_redis_cluster_client = RedisCluster
|
37 |
+
Span = Union[_Span, Any]
|
38 |
+
else:
|
39 |
+
pipeline = Any
|
40 |
+
cluster_pipeline = Any
|
41 |
+
async_redis_client = Any
|
42 |
+
async_redis_cluster_client = Any
|
43 |
+
Span = Any
|
44 |
+
|
45 |
+
|
46 |
+
class RedisCache(BaseCache):
|
47 |
+
# if users don't provider one, use the default litellm cache
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
host=None,
|
52 |
+
port=None,
|
53 |
+
password=None,
|
54 |
+
redis_flush_size: Optional[int] = 100,
|
55 |
+
namespace: Optional[str] = None,
|
56 |
+
startup_nodes: Optional[List] = None, # for redis-cluster
|
57 |
+
socket_timeout: Optional[float] = 5.0, # default 5 second timeout
|
58 |
+
**kwargs,
|
59 |
+
):
|
60 |
+
from litellm._service_logger import ServiceLogging
|
61 |
+
|
62 |
+
from .._redis import get_redis_client, get_redis_connection_pool
|
63 |
+
|
64 |
+
redis_kwargs = {}
|
65 |
+
if host is not None:
|
66 |
+
redis_kwargs["host"] = host
|
67 |
+
if port is not None:
|
68 |
+
redis_kwargs["port"] = port
|
69 |
+
if password is not None:
|
70 |
+
redis_kwargs["password"] = password
|
71 |
+
if startup_nodes is not None:
|
72 |
+
redis_kwargs["startup_nodes"] = startup_nodes
|
73 |
+
if socket_timeout is not None:
|
74 |
+
redis_kwargs["socket_timeout"] = socket_timeout
|
75 |
+
|
76 |
+
### HEALTH MONITORING OBJECT ###
|
77 |
+
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
78 |
+
kwargs["service_logger_obj"], ServiceLogging
|
79 |
+
):
|
80 |
+
self.service_logger_obj = kwargs.pop("service_logger_obj")
|
81 |
+
else:
|
82 |
+
self.service_logger_obj = ServiceLogging()
|
83 |
+
|
84 |
+
redis_kwargs.update(kwargs)
|
85 |
+
self.redis_client = get_redis_client(**redis_kwargs)
|
86 |
+
self.redis_async_client: Optional[async_redis_client] = None
|
87 |
+
self.redis_kwargs = redis_kwargs
|
88 |
+
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
|
89 |
+
|
90 |
+
# redis namespaces
|
91 |
+
self.namespace = namespace
|
92 |
+
# for high traffic, we store the redis results in memory and then batch write to redis
|
93 |
+
self.redis_batch_writing_buffer: list = []
|
94 |
+
if redis_flush_size is None:
|
95 |
+
self.redis_flush_size: int = 100
|
96 |
+
else:
|
97 |
+
self.redis_flush_size = redis_flush_size
|
98 |
+
self.redis_version = "Unknown"
|
99 |
+
try:
|
100 |
+
if not inspect.iscoroutinefunction(self.redis_client):
|
101 |
+
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
|
102 |
+
except Exception:
|
103 |
+
pass
|
104 |
+
|
105 |
+
### ASYNC HEALTH PING ###
|
106 |
+
try:
|
107 |
+
# asyncio.get_running_loop().create_task(self.ping())
|
108 |
+
_ = asyncio.get_running_loop().create_task(self.ping())
|
109 |
+
except Exception as e:
|
110 |
+
if "no running event loop" in str(e):
|
111 |
+
verbose_logger.debug(
|
112 |
+
"Ignoring async redis ping. No running event loop."
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
verbose_logger.error(
|
116 |
+
"Error connecting to Async Redis client - {}".format(str(e)),
|
117 |
+
extra={"error": str(e)},
|
118 |
+
)
|
119 |
+
|
120 |
+
### SYNC HEALTH PING ###
|
121 |
+
try:
|
122 |
+
if hasattr(self.redis_client, "ping"):
|
123 |
+
self.redis_client.ping() # type: ignore
|
124 |
+
except Exception as e:
|
125 |
+
verbose_logger.error(
|
126 |
+
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
127 |
+
)
|
128 |
+
|
129 |
+
if litellm.default_redis_ttl is not None:
|
130 |
+
super().__init__(default_ttl=int(litellm.default_redis_ttl))
|
131 |
+
else:
|
132 |
+
super().__init__() # defaults to 60s
|
133 |
+
|
134 |
+
def init_async_client(
|
135 |
+
self,
|
136 |
+
) -> Union[async_redis_client, async_redis_cluster_client]:
|
137 |
+
from .._redis import get_redis_async_client
|
138 |
+
|
139 |
+
if self.redis_async_client is None:
|
140 |
+
self.redis_async_client = get_redis_async_client(
|
141 |
+
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
142 |
+
)
|
143 |
+
return self.redis_async_client
|
144 |
+
|
145 |
+
def check_and_fix_namespace(self, key: str) -> str:
|
146 |
+
"""
|
147 |
+
Make sure each key starts with the given namespace
|
148 |
+
"""
|
149 |
+
if self.namespace is not None and not key.startswith(self.namespace):
|
150 |
+
key = self.namespace + ":" + key
|
151 |
+
|
152 |
+
return key
|
153 |
+
|
154 |
+
def set_cache(self, key, value, **kwargs):
|
155 |
+
ttl = self.get_ttl(**kwargs)
|
156 |
+
print_verbose(
|
157 |
+
f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}, redis_version={self.redis_version}"
|
158 |
+
)
|
159 |
+
key = self.check_and_fix_namespace(key=key)
|
160 |
+
try:
|
161 |
+
start_time = time.time()
|
162 |
+
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
163 |
+
end_time = time.time()
|
164 |
+
_duration = end_time - start_time
|
165 |
+
self.service_logger_obj.service_success_hook(
|
166 |
+
service=ServiceTypes.REDIS,
|
167 |
+
duration=_duration,
|
168 |
+
call_type="set_cache",
|
169 |
+
start_time=start_time,
|
170 |
+
end_time=end_time,
|
171 |
+
)
|
172 |
+
except Exception as e:
|
173 |
+
# NON blocking - notify users Redis is throwing an exception
|
174 |
+
print_verbose(
|
175 |
+
f"litellm.caching.caching: set() - Got exception from REDIS : {str(e)}"
|
176 |
+
)
|
177 |
+
|
178 |
+
def increment_cache(
|
179 |
+
self, key, value: int, ttl: Optional[float] = None, **kwargs
|
180 |
+
) -> int:
|
181 |
+
_redis_client = self.redis_client
|
182 |
+
start_time = time.time()
|
183 |
+
set_ttl = self.get_ttl(ttl=ttl)
|
184 |
+
try:
|
185 |
+
start_time = time.time()
|
186 |
+
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
|
187 |
+
end_time = time.time()
|
188 |
+
_duration = end_time - start_time
|
189 |
+
self.service_logger_obj.service_success_hook(
|
190 |
+
service=ServiceTypes.REDIS,
|
191 |
+
duration=_duration,
|
192 |
+
call_type="increment_cache",
|
193 |
+
start_time=start_time,
|
194 |
+
end_time=end_time,
|
195 |
+
)
|
196 |
+
|
197 |
+
if set_ttl is not None:
|
198 |
+
# check if key already has ttl, if not -> set ttl
|
199 |
+
start_time = time.time()
|
200 |
+
current_ttl = _redis_client.ttl(key)
|
201 |
+
end_time = time.time()
|
202 |
+
_duration = end_time - start_time
|
203 |
+
self.service_logger_obj.service_success_hook(
|
204 |
+
service=ServiceTypes.REDIS,
|
205 |
+
duration=_duration,
|
206 |
+
call_type="increment_cache_ttl",
|
207 |
+
start_time=start_time,
|
208 |
+
end_time=end_time,
|
209 |
+
)
|
210 |
+
if current_ttl == -1:
|
211 |
+
# Key has no expiration
|
212 |
+
start_time = time.time()
|
213 |
+
_redis_client.expire(key, set_ttl) # type: ignore
|
214 |
+
end_time = time.time()
|
215 |
+
_duration = end_time - start_time
|
216 |
+
self.service_logger_obj.service_success_hook(
|
217 |
+
service=ServiceTypes.REDIS,
|
218 |
+
duration=_duration,
|
219 |
+
call_type="increment_cache_expire",
|
220 |
+
start_time=start_time,
|
221 |
+
end_time=end_time,
|
222 |
+
)
|
223 |
+
return result
|
224 |
+
except Exception as e:
|
225 |
+
## LOGGING ##
|
226 |
+
end_time = time.time()
|
227 |
+
_duration = end_time - start_time
|
228 |
+
verbose_logger.error(
|
229 |
+
"LiteLLM Redis Caching: increment_cache() - Got exception from REDIS %s, Writing value=%s",
|
230 |
+
str(e),
|
231 |
+
value,
|
232 |
+
)
|
233 |
+
raise e
|
234 |
+
|
235 |
+
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
236 |
+
start_time = time.time()
|
237 |
+
try:
|
238 |
+
keys = []
|
239 |
+
_redis_client = self.init_async_client()
|
240 |
+
if not hasattr(_redis_client, "scan_iter"):
|
241 |
+
verbose_logger.debug(
|
242 |
+
"Redis client does not support scan_iter, potentially using Redis Cluster. Returning empty list."
|
243 |
+
)
|
244 |
+
return []
|
245 |
+
|
246 |
+
async for key in _redis_client.scan_iter(match=pattern + "*", count=count): # type: ignore
|
247 |
+
keys.append(key)
|
248 |
+
if len(keys) >= count:
|
249 |
+
break
|
250 |
+
|
251 |
+
## LOGGING ##
|
252 |
+
end_time = time.time()
|
253 |
+
_duration = end_time - start_time
|
254 |
+
asyncio.create_task(
|
255 |
+
self.service_logger_obj.async_service_success_hook(
|
256 |
+
service=ServiceTypes.REDIS,
|
257 |
+
duration=_duration,
|
258 |
+
call_type="async_scan_iter",
|
259 |
+
start_time=start_time,
|
260 |
+
end_time=end_time,
|
261 |
+
)
|
262 |
+
) # DO NOT SLOW DOWN CALL B/C OF THIS
|
263 |
+
return keys
|
264 |
+
except Exception as e:
|
265 |
+
# NON blocking - notify users Redis is throwing an exception
|
266 |
+
## LOGGING ##
|
267 |
+
end_time = time.time()
|
268 |
+
_duration = end_time - start_time
|
269 |
+
asyncio.create_task(
|
270 |
+
self.service_logger_obj.async_service_failure_hook(
|
271 |
+
service=ServiceTypes.REDIS,
|
272 |
+
duration=_duration,
|
273 |
+
error=e,
|
274 |
+
call_type="async_scan_iter",
|
275 |
+
start_time=start_time,
|
276 |
+
end_time=end_time,
|
277 |
+
)
|
278 |
+
)
|
279 |
+
raise e
|
280 |
+
|
281 |
+
async def async_set_cache(self, key, value, **kwargs):
|
282 |
+
from redis.asyncio import Redis
|
283 |
+
|
284 |
+
start_time = time.time()
|
285 |
+
try:
|
286 |
+
_redis_client: Redis = self.init_async_client() # type: ignore
|
287 |
+
except Exception as e:
|
288 |
+
end_time = time.time()
|
289 |
+
_duration = end_time - start_time
|
290 |
+
asyncio.create_task(
|
291 |
+
self.service_logger_obj.async_service_failure_hook(
|
292 |
+
service=ServiceTypes.REDIS,
|
293 |
+
duration=_duration,
|
294 |
+
error=e,
|
295 |
+
start_time=start_time,
|
296 |
+
end_time=end_time,
|
297 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
298 |
+
call_type="async_set_cache",
|
299 |
+
)
|
300 |
+
)
|
301 |
+
verbose_logger.error(
|
302 |
+
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
303 |
+
str(e),
|
304 |
+
value,
|
305 |
+
)
|
306 |
+
raise e
|
307 |
+
|
308 |
+
key = self.check_and_fix_namespace(key=key)
|
309 |
+
ttl = self.get_ttl(**kwargs)
|
310 |
+
nx = kwargs.get("nx", False)
|
311 |
+
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
312 |
+
|
313 |
+
try:
|
314 |
+
if not hasattr(_redis_client, "set"):
|
315 |
+
raise Exception("Redis client cannot set cache. Attribute not found.")
|
316 |
+
result = await _redis_client.set(
|
317 |
+
name=key,
|
318 |
+
value=json.dumps(value),
|
319 |
+
nx=nx,
|
320 |
+
ex=ttl,
|
321 |
+
)
|
322 |
+
print_verbose(
|
323 |
+
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
324 |
+
)
|
325 |
+
end_time = time.time()
|
326 |
+
_duration = end_time - start_time
|
327 |
+
asyncio.create_task(
|
328 |
+
self.service_logger_obj.async_service_success_hook(
|
329 |
+
service=ServiceTypes.REDIS,
|
330 |
+
duration=_duration,
|
331 |
+
call_type="async_set_cache",
|
332 |
+
start_time=start_time,
|
333 |
+
end_time=end_time,
|
334 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
335 |
+
event_metadata={"key": key},
|
336 |
+
)
|
337 |
+
)
|
338 |
+
return result
|
339 |
+
except Exception as e:
|
340 |
+
end_time = time.time()
|
341 |
+
_duration = end_time - start_time
|
342 |
+
asyncio.create_task(
|
343 |
+
self.service_logger_obj.async_service_failure_hook(
|
344 |
+
service=ServiceTypes.REDIS,
|
345 |
+
duration=_duration,
|
346 |
+
error=e,
|
347 |
+
call_type="async_set_cache",
|
348 |
+
start_time=start_time,
|
349 |
+
end_time=end_time,
|
350 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
351 |
+
event_metadata={"key": key},
|
352 |
+
)
|
353 |
+
)
|
354 |
+
verbose_logger.error(
|
355 |
+
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
356 |
+
str(e),
|
357 |
+
value,
|
358 |
+
)
|
359 |
+
|
360 |
+
async def _pipeline_helper(
|
361 |
+
self,
|
362 |
+
pipe: Union[pipeline, cluster_pipeline],
|
363 |
+
cache_list: List[Tuple[Any, Any]],
|
364 |
+
ttl: Optional[float],
|
365 |
+
) -> List:
|
366 |
+
"""
|
367 |
+
Helper function for executing a pipeline of set operations on Redis
|
368 |
+
"""
|
369 |
+
ttl = self.get_ttl(ttl=ttl)
|
370 |
+
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
371 |
+
for cache_key, cache_value in cache_list:
|
372 |
+
cache_key = self.check_and_fix_namespace(key=cache_key)
|
373 |
+
print_verbose(
|
374 |
+
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
375 |
+
)
|
376 |
+
json_cache_value = json.dumps(cache_value)
|
377 |
+
# Set the value with a TTL if it's provided.
|
378 |
+
_td: Optional[timedelta] = None
|
379 |
+
if ttl is not None:
|
380 |
+
_td = timedelta(seconds=ttl)
|
381 |
+
pipe.set( # type: ignore
|
382 |
+
name=cache_key,
|
383 |
+
value=json_cache_value,
|
384 |
+
ex=_td,
|
385 |
+
)
|
386 |
+
# Execute the pipeline and return the results.
|
387 |
+
results = await pipe.execute()
|
388 |
+
return results
|
389 |
+
|
390 |
+
async def async_set_cache_pipeline(
|
391 |
+
self, cache_list: List[Tuple[Any, Any]], ttl: Optional[float] = None, **kwargs
|
392 |
+
):
|
393 |
+
"""
|
394 |
+
Use Redis Pipelines for bulk write operations
|
395 |
+
"""
|
396 |
+
# don't waste a network request if there's nothing to set
|
397 |
+
if len(cache_list) == 0:
|
398 |
+
return
|
399 |
+
|
400 |
+
_redis_client = self.init_async_client()
|
401 |
+
start_time = time.time()
|
402 |
+
|
403 |
+
print_verbose(
|
404 |
+
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
405 |
+
)
|
406 |
+
cache_value: Any = None
|
407 |
+
try:
|
408 |
+
async with _redis_client.pipeline(transaction=False) as pipe:
|
409 |
+
results = await self._pipeline_helper(pipe, cache_list, ttl)
|
410 |
+
|
411 |
+
print_verbose(f"pipeline results: {results}")
|
412 |
+
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
413 |
+
## LOGGING ##
|
414 |
+
end_time = time.time()
|
415 |
+
_duration = end_time - start_time
|
416 |
+
asyncio.create_task(
|
417 |
+
self.service_logger_obj.async_service_success_hook(
|
418 |
+
service=ServiceTypes.REDIS,
|
419 |
+
duration=_duration,
|
420 |
+
call_type="async_set_cache_pipeline",
|
421 |
+
start_time=start_time,
|
422 |
+
end_time=end_time,
|
423 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
424 |
+
)
|
425 |
+
)
|
426 |
+
return None
|
427 |
+
except Exception as e:
|
428 |
+
## LOGGING ##
|
429 |
+
end_time = time.time()
|
430 |
+
_duration = end_time - start_time
|
431 |
+
asyncio.create_task(
|
432 |
+
self.service_logger_obj.async_service_failure_hook(
|
433 |
+
service=ServiceTypes.REDIS,
|
434 |
+
duration=_duration,
|
435 |
+
error=e,
|
436 |
+
call_type="async_set_cache_pipeline",
|
437 |
+
start_time=start_time,
|
438 |
+
end_time=end_time,
|
439 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
440 |
+
)
|
441 |
+
)
|
442 |
+
|
443 |
+
verbose_logger.error(
|
444 |
+
"LiteLLM Redis Caching: async set_cache_pipeline() - Got exception from REDIS %s, Writing value=%s",
|
445 |
+
str(e),
|
446 |
+
cache_value,
|
447 |
+
)
|
448 |
+
|
449 |
+
async def _set_cache_sadd_helper(
|
450 |
+
self,
|
451 |
+
redis_client: async_redis_client,
|
452 |
+
key: str,
|
453 |
+
value: List,
|
454 |
+
ttl: Optional[float],
|
455 |
+
) -> None:
|
456 |
+
"""Helper function for async_set_cache_sadd. Separated for testing."""
|
457 |
+
ttl = self.get_ttl(ttl=ttl)
|
458 |
+
try:
|
459 |
+
await redis_client.sadd(key, *value) # type: ignore
|
460 |
+
if ttl is not None:
|
461 |
+
_td = timedelta(seconds=ttl)
|
462 |
+
await redis_client.expire(key, _td)
|
463 |
+
except Exception:
|
464 |
+
raise
|
465 |
+
|
466 |
+
async def async_set_cache_sadd(
|
467 |
+
self, key, value: List, ttl: Optional[float], **kwargs
|
468 |
+
):
|
469 |
+
from redis.asyncio import Redis
|
470 |
+
|
471 |
+
start_time = time.time()
|
472 |
+
try:
|
473 |
+
_redis_client: Redis = self.init_async_client() # type: ignore
|
474 |
+
except Exception as e:
|
475 |
+
end_time = time.time()
|
476 |
+
_duration = end_time - start_time
|
477 |
+
asyncio.create_task(
|
478 |
+
self.service_logger_obj.async_service_failure_hook(
|
479 |
+
service=ServiceTypes.REDIS,
|
480 |
+
duration=_duration,
|
481 |
+
error=e,
|
482 |
+
start_time=start_time,
|
483 |
+
end_time=end_time,
|
484 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
485 |
+
call_type="async_set_cache_sadd",
|
486 |
+
)
|
487 |
+
)
|
488 |
+
# NON blocking - notify users Redis is throwing an exception
|
489 |
+
verbose_logger.error(
|
490 |
+
"LiteLLM Redis Caching: async set() - Got exception from REDIS %s, Writing value=%s",
|
491 |
+
str(e),
|
492 |
+
value,
|
493 |
+
)
|
494 |
+
raise e
|
495 |
+
|
496 |
+
key = self.check_and_fix_namespace(key=key)
|
497 |
+
print_verbose(f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
498 |
+
try:
|
499 |
+
await self._set_cache_sadd_helper(
|
500 |
+
redis_client=_redis_client, key=key, value=value, ttl=ttl
|
501 |
+
)
|
502 |
+
print_verbose(
|
503 |
+
f"Successfully Set ASYNC Redis Cache SADD: key: {key}\nValue {value}\nttl={ttl}"
|
504 |
+
)
|
505 |
+
end_time = time.time()
|
506 |
+
_duration = end_time - start_time
|
507 |
+
asyncio.create_task(
|
508 |
+
self.service_logger_obj.async_service_success_hook(
|
509 |
+
service=ServiceTypes.REDIS,
|
510 |
+
duration=_duration,
|
511 |
+
call_type="async_set_cache_sadd",
|
512 |
+
start_time=start_time,
|
513 |
+
end_time=end_time,
|
514 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
515 |
+
)
|
516 |
+
)
|
517 |
+
except Exception as e:
|
518 |
+
end_time = time.time()
|
519 |
+
_duration = end_time - start_time
|
520 |
+
asyncio.create_task(
|
521 |
+
self.service_logger_obj.async_service_failure_hook(
|
522 |
+
service=ServiceTypes.REDIS,
|
523 |
+
duration=_duration,
|
524 |
+
error=e,
|
525 |
+
call_type="async_set_cache_sadd",
|
526 |
+
start_time=start_time,
|
527 |
+
end_time=end_time,
|
528 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
529 |
+
)
|
530 |
+
)
|
531 |
+
# NON blocking - notify users Redis is throwing an exception
|
532 |
+
verbose_logger.error(
|
533 |
+
"LiteLLM Redis Caching: async set_cache_sadd() - Got exception from REDIS %s, Writing value=%s",
|
534 |
+
str(e),
|
535 |
+
value,
|
536 |
+
)
|
537 |
+
|
538 |
+
async def batch_cache_write(self, key, value, **kwargs):
|
539 |
+
print_verbose(
|
540 |
+
f"in batch cache writing for redis buffer size={len(self.redis_batch_writing_buffer)}",
|
541 |
+
)
|
542 |
+
key = self.check_and_fix_namespace(key=key)
|
543 |
+
self.redis_batch_writing_buffer.append((key, value))
|
544 |
+
if len(self.redis_batch_writing_buffer) >= self.redis_flush_size:
|
545 |
+
await self.flush_cache_buffer() # logging done in here
|
546 |
+
|
547 |
+
async def async_increment(
|
548 |
+
self,
|
549 |
+
key,
|
550 |
+
value: float,
|
551 |
+
ttl: Optional[int] = None,
|
552 |
+
parent_otel_span: Optional[Span] = None,
|
553 |
+
) -> float:
|
554 |
+
from redis.asyncio import Redis
|
555 |
+
|
556 |
+
_redis_client: Redis = self.init_async_client() # type: ignore
|
557 |
+
start_time = time.time()
|
558 |
+
_used_ttl = self.get_ttl(ttl=ttl)
|
559 |
+
key = self.check_and_fix_namespace(key=key)
|
560 |
+
try:
|
561 |
+
result = await _redis_client.incrbyfloat(name=key, amount=value)
|
562 |
+
if _used_ttl is not None:
|
563 |
+
# check if key already has ttl, if not -> set ttl
|
564 |
+
current_ttl = await _redis_client.ttl(key)
|
565 |
+
if current_ttl == -1:
|
566 |
+
# Key has no expiration
|
567 |
+
await _redis_client.expire(key, _used_ttl)
|
568 |
+
|
569 |
+
## LOGGING ##
|
570 |
+
end_time = time.time()
|
571 |
+
_duration = end_time - start_time
|
572 |
+
|
573 |
+
asyncio.create_task(
|
574 |
+
self.service_logger_obj.async_service_success_hook(
|
575 |
+
service=ServiceTypes.REDIS,
|
576 |
+
duration=_duration,
|
577 |
+
call_type="async_increment",
|
578 |
+
start_time=start_time,
|
579 |
+
end_time=end_time,
|
580 |
+
parent_otel_span=parent_otel_span,
|
581 |
+
)
|
582 |
+
)
|
583 |
+
return result
|
584 |
+
except Exception as e:
|
585 |
+
## LOGGING ##
|
586 |
+
end_time = time.time()
|
587 |
+
_duration = end_time - start_time
|
588 |
+
asyncio.create_task(
|
589 |
+
self.service_logger_obj.async_service_failure_hook(
|
590 |
+
service=ServiceTypes.REDIS,
|
591 |
+
duration=_duration,
|
592 |
+
error=e,
|
593 |
+
call_type="async_increment",
|
594 |
+
start_time=start_time,
|
595 |
+
end_time=end_time,
|
596 |
+
parent_otel_span=parent_otel_span,
|
597 |
+
)
|
598 |
+
)
|
599 |
+
verbose_logger.error(
|
600 |
+
"LiteLLM Redis Caching: async async_increment() - Got exception from REDIS %s, Writing value=%s",
|
601 |
+
str(e),
|
602 |
+
value,
|
603 |
+
)
|
604 |
+
raise e
|
605 |
+
|
606 |
+
async def flush_cache_buffer(self):
|
607 |
+
print_verbose(
|
608 |
+
f"flushing to redis....reached size of buffer {len(self.redis_batch_writing_buffer)}"
|
609 |
+
)
|
610 |
+
await self.async_set_cache_pipeline(self.redis_batch_writing_buffer)
|
611 |
+
self.redis_batch_writing_buffer = []
|
612 |
+
|
613 |
+
def _get_cache_logic(self, cached_response: Any):
|
614 |
+
"""
|
615 |
+
Common 'get_cache_logic' across sync + async redis client implementations
|
616 |
+
"""
|
617 |
+
if cached_response is None:
|
618 |
+
return cached_response
|
619 |
+
# cached_response is in `b{} convert it to ModelResponse
|
620 |
+
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
621 |
+
try:
|
622 |
+
cached_response = json.loads(
|
623 |
+
cached_response
|
624 |
+
) # Convert string to dictionary
|
625 |
+
except Exception:
|
626 |
+
cached_response = ast.literal_eval(cached_response)
|
627 |
+
return cached_response
|
628 |
+
|
629 |
+
def get_cache(self, key, parent_otel_span: Optional[Span] = None, **kwargs):
|
630 |
+
try:
|
631 |
+
key = self.check_and_fix_namespace(key=key)
|
632 |
+
print_verbose(f"Get Redis Cache: key: {key}")
|
633 |
+
start_time = time.time()
|
634 |
+
cached_response = self.redis_client.get(key)
|
635 |
+
end_time = time.time()
|
636 |
+
_duration = end_time - start_time
|
637 |
+
self.service_logger_obj.service_success_hook(
|
638 |
+
service=ServiceTypes.REDIS,
|
639 |
+
duration=_duration,
|
640 |
+
call_type="get_cache",
|
641 |
+
start_time=start_time,
|
642 |
+
end_time=end_time,
|
643 |
+
parent_otel_span=parent_otel_span,
|
644 |
+
)
|
645 |
+
print_verbose(
|
646 |
+
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
647 |
+
)
|
648 |
+
return self._get_cache_logic(cached_response=cached_response)
|
649 |
+
except Exception as e:
|
650 |
+
# NON blocking - notify users Redis is throwing an exception
|
651 |
+
verbose_logger.error(
|
652 |
+
"litellm.caching.caching: get() - Got exception from REDIS: ", e
|
653 |
+
)
|
654 |
+
|
655 |
+
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
656 |
+
"""
|
657 |
+
Wrapper to call `mget` on the redis client
|
658 |
+
|
659 |
+
We use a wrapper so RedisCluster can override this method
|
660 |
+
"""
|
661 |
+
return self.redis_client.mget(keys=keys) # type: ignore
|
662 |
+
|
663 |
+
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
664 |
+
"""
|
665 |
+
Wrapper to call `mget` on the redis client
|
666 |
+
|
667 |
+
We use a wrapper so RedisCluster can override this method
|
668 |
+
"""
|
669 |
+
async_redis_client = self.init_async_client()
|
670 |
+
return await async_redis_client.mget(keys=keys) # type: ignore
|
671 |
+
|
672 |
+
def batch_get_cache(
|
673 |
+
self,
|
674 |
+
key_list: Union[List[str], List[Optional[str]]],
|
675 |
+
parent_otel_span: Optional[Span] = None,
|
676 |
+
) -> dict:
|
677 |
+
"""
|
678 |
+
Use Redis for bulk read operations
|
679 |
+
|
680 |
+
Args:
|
681 |
+
key_list: List of keys to get from Redis
|
682 |
+
parent_otel_span: Optional parent OpenTelemetry span
|
683 |
+
|
684 |
+
Returns:
|
685 |
+
dict: A dictionary mapping keys to their cached values
|
686 |
+
"""
|
687 |
+
key_value_dict = {}
|
688 |
+
_key_list = [key for key in key_list if key is not None]
|
689 |
+
|
690 |
+
try:
|
691 |
+
_keys = []
|
692 |
+
for cache_key in _key_list:
|
693 |
+
cache_key = self.check_and_fix_namespace(key=cache_key or "")
|
694 |
+
_keys.append(cache_key)
|
695 |
+
start_time = time.time()
|
696 |
+
results: List = self._run_redis_mget_operation(keys=_keys)
|
697 |
+
end_time = time.time()
|
698 |
+
_duration = end_time - start_time
|
699 |
+
self.service_logger_obj.service_success_hook(
|
700 |
+
service=ServiceTypes.REDIS,
|
701 |
+
duration=_duration,
|
702 |
+
call_type="batch_get_cache",
|
703 |
+
start_time=start_time,
|
704 |
+
end_time=end_time,
|
705 |
+
parent_otel_span=parent_otel_span,
|
706 |
+
)
|
707 |
+
|
708 |
+
# Associate the results back with their keys.
|
709 |
+
# 'results' is a list of values corresponding to the order of keys in '_key_list'.
|
710 |
+
key_value_dict = dict(zip(_key_list, results))
|
711 |
+
|
712 |
+
decoded_results = {}
|
713 |
+
for k, v in key_value_dict.items():
|
714 |
+
if isinstance(k, bytes):
|
715 |
+
k = k.decode("utf-8")
|
716 |
+
v = self._get_cache_logic(v)
|
717 |
+
decoded_results[k] = v
|
718 |
+
|
719 |
+
return decoded_results
|
720 |
+
except Exception as e:
|
721 |
+
verbose_logger.error(f"Error occurred in batch get cache - {str(e)}")
|
722 |
+
return key_value_dict
|
723 |
+
|
724 |
+
async def async_get_cache(
|
725 |
+
self, key, parent_otel_span: Optional[Span] = None, **kwargs
|
726 |
+
):
|
727 |
+
from redis.asyncio import Redis
|
728 |
+
|
729 |
+
_redis_client: Redis = self.init_async_client() # type: ignore
|
730 |
+
key = self.check_and_fix_namespace(key=key)
|
731 |
+
start_time = time.time()
|
732 |
+
|
733 |
+
try:
|
734 |
+
print_verbose(f"Get Async Redis Cache: key: {key}")
|
735 |
+
cached_response = await _redis_client.get(key)
|
736 |
+
print_verbose(
|
737 |
+
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
738 |
+
)
|
739 |
+
response = self._get_cache_logic(cached_response=cached_response)
|
740 |
+
|
741 |
+
end_time = time.time()
|
742 |
+
_duration = end_time - start_time
|
743 |
+
asyncio.create_task(
|
744 |
+
self.service_logger_obj.async_service_success_hook(
|
745 |
+
service=ServiceTypes.REDIS,
|
746 |
+
duration=_duration,
|
747 |
+
call_type="async_get_cache",
|
748 |
+
start_time=start_time,
|
749 |
+
end_time=end_time,
|
750 |
+
parent_otel_span=parent_otel_span,
|
751 |
+
event_metadata={"key": key},
|
752 |
+
)
|
753 |
+
)
|
754 |
+
return response
|
755 |
+
except Exception as e:
|
756 |
+
end_time = time.time()
|
757 |
+
_duration = end_time - start_time
|
758 |
+
asyncio.create_task(
|
759 |
+
self.service_logger_obj.async_service_failure_hook(
|
760 |
+
service=ServiceTypes.REDIS,
|
761 |
+
duration=_duration,
|
762 |
+
error=e,
|
763 |
+
call_type="async_get_cache",
|
764 |
+
start_time=start_time,
|
765 |
+
end_time=end_time,
|
766 |
+
parent_otel_span=parent_otel_span,
|
767 |
+
event_metadata={"key": key},
|
768 |
+
)
|
769 |
+
)
|
770 |
+
print_verbose(
|
771 |
+
f"litellm.caching.caching: async get() - Got exception from REDIS: {str(e)}"
|
772 |
+
)
|
773 |
+
|
774 |
+
async def async_batch_get_cache(
|
775 |
+
self,
|
776 |
+
key_list: Union[List[str], List[Optional[str]]],
|
777 |
+
parent_otel_span: Optional[Span] = None,
|
778 |
+
) -> dict:
|
779 |
+
"""
|
780 |
+
Use Redis for bulk read operations
|
781 |
+
|
782 |
+
Args:
|
783 |
+
key_list: List of keys to get from Redis
|
784 |
+
parent_otel_span: Optional parent OpenTelemetry span
|
785 |
+
|
786 |
+
Returns:
|
787 |
+
dict: A dictionary mapping keys to their cached values
|
788 |
+
|
789 |
+
`.mget` does not support None keys. This will filter out None keys.
|
790 |
+
"""
|
791 |
+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
|
792 |
+
key_value_dict = {}
|
793 |
+
start_time = time.time()
|
794 |
+
_key_list = [key for key in key_list if key is not None]
|
795 |
+
try:
|
796 |
+
_keys = []
|
797 |
+
for cache_key in _key_list:
|
798 |
+
cache_key = self.check_and_fix_namespace(key=cache_key)
|
799 |
+
_keys.append(cache_key)
|
800 |
+
results = await self._async_run_redis_mget_operation(keys=_keys)
|
801 |
+
## LOGGING ##
|
802 |
+
end_time = time.time()
|
803 |
+
_duration = end_time - start_time
|
804 |
+
asyncio.create_task(
|
805 |
+
self.service_logger_obj.async_service_success_hook(
|
806 |
+
service=ServiceTypes.REDIS,
|
807 |
+
duration=_duration,
|
808 |
+
call_type="async_batch_get_cache",
|
809 |
+
start_time=start_time,
|
810 |
+
end_time=end_time,
|
811 |
+
parent_otel_span=parent_otel_span,
|
812 |
+
)
|
813 |
+
)
|
814 |
+
|
815 |
+
# Associate the results back with their keys.
|
816 |
+
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
817 |
+
key_value_dict = dict(zip(_key_list, results))
|
818 |
+
|
819 |
+
decoded_results = {}
|
820 |
+
for k, v in key_value_dict.items():
|
821 |
+
if isinstance(k, bytes):
|
822 |
+
k = k.decode("utf-8")
|
823 |
+
v = self._get_cache_logic(v)
|
824 |
+
decoded_results[k] = v
|
825 |
+
|
826 |
+
return decoded_results
|
827 |
+
except Exception as e:
|
828 |
+
## LOGGING ##
|
829 |
+
end_time = time.time()
|
830 |
+
_duration = end_time - start_time
|
831 |
+
asyncio.create_task(
|
832 |
+
self.service_logger_obj.async_service_failure_hook(
|
833 |
+
service=ServiceTypes.REDIS,
|
834 |
+
duration=_duration,
|
835 |
+
error=e,
|
836 |
+
call_type="async_batch_get_cache",
|
837 |
+
start_time=start_time,
|
838 |
+
end_time=end_time,
|
839 |
+
parent_otel_span=parent_otel_span,
|
840 |
+
)
|
841 |
+
)
|
842 |
+
verbose_logger.error(f"Error occurred in async batch get cache - {str(e)}")
|
843 |
+
return key_value_dict
|
844 |
+
|
845 |
+
def sync_ping(self) -> bool:
|
846 |
+
"""
|
847 |
+
Tests if the sync redis client is correctly setup.
|
848 |
+
"""
|
849 |
+
print_verbose("Pinging Sync Redis Cache")
|
850 |
+
start_time = time.time()
|
851 |
+
try:
|
852 |
+
response: bool = self.redis_client.ping() # type: ignore
|
853 |
+
print_verbose(f"Redis Cache PING: {response}")
|
854 |
+
## LOGGING ##
|
855 |
+
end_time = time.time()
|
856 |
+
_duration = end_time - start_time
|
857 |
+
self.service_logger_obj.service_success_hook(
|
858 |
+
service=ServiceTypes.REDIS,
|
859 |
+
duration=_duration,
|
860 |
+
call_type="sync_ping",
|
861 |
+
start_time=start_time,
|
862 |
+
end_time=end_time,
|
863 |
+
)
|
864 |
+
return response
|
865 |
+
except Exception as e:
|
866 |
+
# NON blocking - notify users Redis is throwing an exception
|
867 |
+
## LOGGING ##
|
868 |
+
end_time = time.time()
|
869 |
+
_duration = end_time - start_time
|
870 |
+
self.service_logger_obj.service_failure_hook(
|
871 |
+
service=ServiceTypes.REDIS,
|
872 |
+
duration=_duration,
|
873 |
+
error=e,
|
874 |
+
call_type="sync_ping",
|
875 |
+
)
|
876 |
+
verbose_logger.error(
|
877 |
+
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
878 |
+
)
|
879 |
+
raise e
|
880 |
+
|
881 |
+
async def ping(self) -> bool:
|
882 |
+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
|
883 |
+
_redis_client: Any = self.init_async_client()
|
884 |
+
start_time = time.time()
|
885 |
+
print_verbose("Pinging Async Redis Cache")
|
886 |
+
try:
|
887 |
+
response = await _redis_client.ping()
|
888 |
+
## LOGGING ##
|
889 |
+
end_time = time.time()
|
890 |
+
_duration = end_time - start_time
|
891 |
+
asyncio.create_task(
|
892 |
+
self.service_logger_obj.async_service_success_hook(
|
893 |
+
service=ServiceTypes.REDIS,
|
894 |
+
duration=_duration,
|
895 |
+
call_type="async_ping",
|
896 |
+
)
|
897 |
+
)
|
898 |
+
return response
|
899 |
+
except Exception as e:
|
900 |
+
# NON blocking - notify users Redis is throwing an exception
|
901 |
+
## LOGGING ##
|
902 |
+
end_time = time.time()
|
903 |
+
_duration = end_time - start_time
|
904 |
+
asyncio.create_task(
|
905 |
+
self.service_logger_obj.async_service_failure_hook(
|
906 |
+
service=ServiceTypes.REDIS,
|
907 |
+
duration=_duration,
|
908 |
+
error=e,
|
909 |
+
call_type="async_ping",
|
910 |
+
)
|
911 |
+
)
|
912 |
+
verbose_logger.error(
|
913 |
+
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
|
914 |
+
)
|
915 |
+
raise e
|
916 |
+
|
917 |
+
async def delete_cache_keys(self, keys):
|
918 |
+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
919 |
+
_redis_client: Any = self.init_async_client()
|
920 |
+
# keys is a list, unpack it so it gets passed as individual elements to delete
|
921 |
+
await _redis_client.delete(*keys)
|
922 |
+
|
923 |
+
def client_list(self) -> List:
|
924 |
+
client_list: List = self.redis_client.client_list() # type: ignore
|
925 |
+
return client_list
|
926 |
+
|
927 |
+
def info(self):
|
928 |
+
info = self.redis_client.info()
|
929 |
+
return info
|
930 |
+
|
931 |
+
def flush_cache(self):
|
932 |
+
self.redis_client.flushall()
|
933 |
+
|
934 |
+
def flushall(self):
|
935 |
+
self.redis_client.flushall()
|
936 |
+
|
937 |
+
async def disconnect(self):
|
938 |
+
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
939 |
+
|
940 |
+
async def async_delete_cache(self, key: str):
|
941 |
+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
|
942 |
+
_redis_client: Any = self.init_async_client()
|
943 |
+
# keys is str
|
944 |
+
return await _redis_client.delete(key)
|
945 |
+
|
946 |
+
def delete_cache(self, key):
|
947 |
+
self.redis_client.delete(key)
|
948 |
+
|
949 |
+
async def _pipeline_increment_helper(
|
950 |
+
self,
|
951 |
+
pipe: pipeline,
|
952 |
+
increment_list: List[RedisPipelineIncrementOperation],
|
953 |
+
) -> Optional[List[float]]:
|
954 |
+
"""Helper function for pipeline increment operations"""
|
955 |
+
# Iterate through each increment operation and add commands to pipeline
|
956 |
+
for increment_op in increment_list:
|
957 |
+
cache_key = self.check_and_fix_namespace(key=increment_op["key"])
|
958 |
+
print_verbose(
|
959 |
+
f"Increment ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {increment_op['increment_value']}\nttl={increment_op['ttl']}"
|
960 |
+
)
|
961 |
+
pipe.incrbyfloat(cache_key, increment_op["increment_value"])
|
962 |
+
if increment_op["ttl"] is not None:
|
963 |
+
_td = timedelta(seconds=increment_op["ttl"])
|
964 |
+
pipe.expire(cache_key, _td)
|
965 |
+
# Execute the pipeline and return results
|
966 |
+
results = await pipe.execute()
|
967 |
+
print_verbose(f"Increment ASYNC Redis Cache PIPELINE: results: {results}")
|
968 |
+
return results
|
969 |
+
|
970 |
+
async def async_increment_pipeline(
|
971 |
+
self, increment_list: List[RedisPipelineIncrementOperation], **kwargs
|
972 |
+
) -> Optional[List[float]]:
|
973 |
+
"""
|
974 |
+
Use Redis Pipelines for bulk increment operations
|
975 |
+
Args:
|
976 |
+
increment_list: List of RedisPipelineIncrementOperation dicts containing:
|
977 |
+
- key: str
|
978 |
+
- increment_value: float
|
979 |
+
- ttl_seconds: int
|
980 |
+
"""
|
981 |
+
# don't waste a network request if there's nothing to increment
|
982 |
+
if len(increment_list) == 0:
|
983 |
+
return None
|
984 |
+
|
985 |
+
from redis.asyncio import Redis
|
986 |
+
|
987 |
+
_redis_client: Redis = self.init_async_client() # type: ignore
|
988 |
+
start_time = time.time()
|
989 |
+
|
990 |
+
print_verbose(
|
991 |
+
f"Increment Async Redis Cache Pipeline: increment list: {increment_list}"
|
992 |
+
)
|
993 |
+
|
994 |
+
try:
|
995 |
+
async with _redis_client.pipeline(transaction=False) as pipe:
|
996 |
+
results = await self._pipeline_increment_helper(pipe, increment_list)
|
997 |
+
|
998 |
+
print_verbose(f"pipeline increment results: {results}")
|
999 |
+
|
1000 |
+
## LOGGING ##
|
1001 |
+
end_time = time.time()
|
1002 |
+
_duration = end_time - start_time
|
1003 |
+
asyncio.create_task(
|
1004 |
+
self.service_logger_obj.async_service_success_hook(
|
1005 |
+
service=ServiceTypes.REDIS,
|
1006 |
+
duration=_duration,
|
1007 |
+
call_type="async_increment_pipeline",
|
1008 |
+
start_time=start_time,
|
1009 |
+
end_time=end_time,
|
1010 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
1011 |
+
)
|
1012 |
+
)
|
1013 |
+
return results
|
1014 |
+
except Exception as e:
|
1015 |
+
## LOGGING ##
|
1016 |
+
end_time = time.time()
|
1017 |
+
_duration = end_time - start_time
|
1018 |
+
asyncio.create_task(
|
1019 |
+
self.service_logger_obj.async_service_failure_hook(
|
1020 |
+
service=ServiceTypes.REDIS,
|
1021 |
+
duration=_duration,
|
1022 |
+
error=e,
|
1023 |
+
call_type="async_increment_pipeline",
|
1024 |
+
start_time=start_time,
|
1025 |
+
end_time=end_time,
|
1026 |
+
parent_otel_span=_get_parent_otel_span_from_kwargs(kwargs),
|
1027 |
+
)
|
1028 |
+
)
|
1029 |
+
verbose_logger.error(
|
1030 |
+
"LiteLLM Redis Caching: async increment_pipeline() - Got exception from REDIS %s",
|
1031 |
+
str(e),
|
1032 |
+
)
|
1033 |
+
raise e
|
1034 |
+
|
1035 |
+
async def async_get_ttl(self, key: str) -> Optional[int]:
|
1036 |
+
"""
|
1037 |
+
Get the remaining TTL of a key in Redis
|
1038 |
+
|
1039 |
+
Args:
|
1040 |
+
key (str): The key to get TTL for
|
1041 |
+
|
1042 |
+
Returns:
|
1043 |
+
Optional[int]: The remaining TTL in seconds, or None if key doesn't exist
|
1044 |
+
|
1045 |
+
Redis ref: https://redis.io/docs/latest/commands/ttl/
|
1046 |
+
"""
|
1047 |
+
try:
|
1048 |
+
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
|
1049 |
+
_redis_client: Any = self.init_async_client()
|
1050 |
+
ttl = await _redis_client.ttl(key)
|
1051 |
+
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
|
1052 |
+
return None
|
1053 |
+
return ttl
|
1054 |
+
except Exception as e:
|
1055 |
+
verbose_logger.debug(f"Redis TTL Error: {e}")
|
1056 |
+
return None
|
1057 |
+
|
1058 |
+
async def async_rpush(
|
1059 |
+
self,
|
1060 |
+
key: str,
|
1061 |
+
values: List[Any],
|
1062 |
+
parent_otel_span: Optional[Span] = None,
|
1063 |
+
**kwargs,
|
1064 |
+
) -> int:
|
1065 |
+
"""
|
1066 |
+
Append one or multiple values to a list stored at key
|
1067 |
+
|
1068 |
+
Args:
|
1069 |
+
key: The Redis key of the list
|
1070 |
+
values: One or more values to append to the list
|
1071 |
+
parent_otel_span: Optional parent OpenTelemetry span
|
1072 |
+
|
1073 |
+
Returns:
|
1074 |
+
int: The length of the list after the push operation
|
1075 |
+
"""
|
1076 |
+
_redis_client: Any = self.init_async_client()
|
1077 |
+
start_time = time.time()
|
1078 |
+
try:
|
1079 |
+
response = await _redis_client.rpush(key, *values)
|
1080 |
+
## LOGGING ##
|
1081 |
+
end_time = time.time()
|
1082 |
+
_duration = end_time - start_time
|
1083 |
+
asyncio.create_task(
|
1084 |
+
self.service_logger_obj.async_service_success_hook(
|
1085 |
+
service=ServiceTypes.REDIS,
|
1086 |
+
duration=_duration,
|
1087 |
+
call_type="async_rpush",
|
1088 |
+
)
|
1089 |
+
)
|
1090 |
+
return response
|
1091 |
+
except Exception as e:
|
1092 |
+
# NON blocking - notify users Redis is throwing an exception
|
1093 |
+
## LOGGING ##
|
1094 |
+
end_time = time.time()
|
1095 |
+
_duration = end_time - start_time
|
1096 |
+
asyncio.create_task(
|
1097 |
+
self.service_logger_obj.async_service_failure_hook(
|
1098 |
+
service=ServiceTypes.REDIS,
|
1099 |
+
duration=_duration,
|
1100 |
+
error=e,
|
1101 |
+
call_type="async_rpush",
|
1102 |
+
)
|
1103 |
+
)
|
1104 |
+
verbose_logger.error(
|
1105 |
+
f"LiteLLM Redis Cache RPUSH: - Got exception from REDIS : {str(e)}"
|
1106 |
+
)
|
1107 |
+
raise e
|
1108 |
+
|
1109 |
+
async def async_lpop(
|
1110 |
+
self,
|
1111 |
+
key: str,
|
1112 |
+
count: Optional[int] = None,
|
1113 |
+
parent_otel_span: Optional[Span] = None,
|
1114 |
+
**kwargs,
|
1115 |
+
) -> Union[Any, List[Any]]:
|
1116 |
+
_redis_client: Any = self.init_async_client()
|
1117 |
+
start_time = time.time()
|
1118 |
+
print_verbose(f"LPOP from Redis list: key: {key}, count: {count}")
|
1119 |
+
try:
|
1120 |
+
result = await _redis_client.lpop(key, count)
|
1121 |
+
## LOGGING ##
|
1122 |
+
end_time = time.time()
|
1123 |
+
_duration = end_time - start_time
|
1124 |
+
asyncio.create_task(
|
1125 |
+
self.service_logger_obj.async_service_success_hook(
|
1126 |
+
service=ServiceTypes.REDIS,
|
1127 |
+
duration=_duration,
|
1128 |
+
call_type="async_lpop",
|
1129 |
+
)
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
# Handle result parsing if needed
|
1133 |
+
if isinstance(result, bytes):
|
1134 |
+
try:
|
1135 |
+
return result.decode("utf-8")
|
1136 |
+
except Exception:
|
1137 |
+
return result
|
1138 |
+
elif isinstance(result, list) and all(
|
1139 |
+
isinstance(item, bytes) for item in result
|
1140 |
+
):
|
1141 |
+
try:
|
1142 |
+
return [item.decode("utf-8") for item in result]
|
1143 |
+
except Exception:
|
1144 |
+
return result
|
1145 |
+
return result
|
1146 |
+
except Exception as e:
|
1147 |
+
# NON blocking - notify users Redis is throwing an exception
|
1148 |
+
## LOGGING ##
|
1149 |
+
end_time = time.time()
|
1150 |
+
_duration = end_time - start_time
|
1151 |
+
asyncio.create_task(
|
1152 |
+
self.service_logger_obj.async_service_failure_hook(
|
1153 |
+
service=ServiceTypes.REDIS,
|
1154 |
+
duration=_duration,
|
1155 |
+
error=e,
|
1156 |
+
call_type="async_lpop",
|
1157 |
+
)
|
1158 |
+
)
|
1159 |
+
verbose_logger.error(
|
1160 |
+
f"LiteLLM Redis Cache LPOP: - Got exception from REDIS : {str(e)}"
|
1161 |
+
)
|
1162 |
+
raise e
|
litellm/caching/redis_cluster_cache.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Redis Cluster Cache implementation
|
3 |
+
|
4 |
+
Key differences:
|
5 |
+
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
6 |
+
"""
|
7 |
+
|
8 |
+
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
9 |
+
|
10 |
+
from litellm.caching.redis_cache import RedisCache
|
11 |
+
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from opentelemetry.trace import Span as _Span
|
14 |
+
from redis.asyncio import Redis, RedisCluster
|
15 |
+
from redis.asyncio.client import Pipeline
|
16 |
+
|
17 |
+
pipeline = Pipeline
|
18 |
+
async_redis_client = Redis
|
19 |
+
Span = Union[_Span, Any]
|
20 |
+
else:
|
21 |
+
pipeline = Any
|
22 |
+
async_redis_client = Any
|
23 |
+
Span = Any
|
24 |
+
|
25 |
+
|
26 |
+
class RedisClusterCache(RedisCache):
|
27 |
+
def __init__(self, *args, **kwargs):
|
28 |
+
super().__init__(*args, **kwargs)
|
29 |
+
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
30 |
+
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
31 |
+
|
32 |
+
def init_async_client(self):
|
33 |
+
from redis.asyncio import RedisCluster
|
34 |
+
|
35 |
+
from .._redis import get_redis_async_client
|
36 |
+
|
37 |
+
if self.redis_async_redis_cluster_client:
|
38 |
+
return self.redis_async_redis_cluster_client
|
39 |
+
|
40 |
+
_redis_client = get_redis_async_client(
|
41 |
+
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
42 |
+
)
|
43 |
+
if isinstance(_redis_client, RedisCluster):
|
44 |
+
self.redis_async_redis_cluster_client = _redis_client
|
45 |
+
|
46 |
+
return _redis_client
|
47 |
+
|
48 |
+
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
49 |
+
"""
|
50 |
+
Overrides `_run_redis_mget_operation` in redis_cache.py
|
51 |
+
"""
|
52 |
+
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
53 |
+
|
54 |
+
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
55 |
+
"""
|
56 |
+
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
57 |
+
"""
|
58 |
+
async_redis_cluster_client = self.init_async_client()
|
59 |
+
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
litellm/caching/redis_semantic_cache.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Redis Semantic Cache implementation for LiteLLM
|
3 |
+
|
4 |
+
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
|
5 |
+
This cache stores responses based on the semantic similarity of prompts rather than
|
6 |
+
exact matching, allowing for more flexible caching of LLM responses.
|
7 |
+
|
8 |
+
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
|
9 |
+
and their cached responses.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import ast
|
13 |
+
import asyncio
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, cast
|
17 |
+
|
18 |
+
import litellm
|
19 |
+
from litellm._logging import print_verbose
|
20 |
+
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
21 |
+
get_str_from_messages,
|
22 |
+
)
|
23 |
+
from litellm.types.utils import EmbeddingResponse
|
24 |
+
|
25 |
+
from .base_cache import BaseCache
|
26 |
+
|
27 |
+
|
28 |
+
class RedisSemanticCache(BaseCache):
|
29 |
+
"""
|
30 |
+
Redis-backed semantic cache for LLM responses.
|
31 |
+
|
32 |
+
This cache uses vector similarity to find semantically similar prompts that have been
|
33 |
+
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
34 |
+
but carry similar meaning.
|
35 |
+
"""
|
36 |
+
|
37 |
+
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
host: Optional[str] = None,
|
42 |
+
port: Optional[str] = None,
|
43 |
+
password: Optional[str] = None,
|
44 |
+
redis_url: Optional[str] = None,
|
45 |
+
similarity_threshold: Optional[float] = None,
|
46 |
+
embedding_model: str = "text-embedding-ada-002",
|
47 |
+
index_name: Optional[str] = None,
|
48 |
+
**kwargs,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Initialize the Redis Semantic Cache.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
host: Redis host address
|
55 |
+
port: Redis port
|
56 |
+
password: Redis password
|
57 |
+
redis_url: Full Redis URL (alternative to separate host/port/password)
|
58 |
+
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
|
59 |
+
where 1.0 requires exact matches and 0.0 accepts any match
|
60 |
+
embedding_model: Model to use for generating embeddings
|
61 |
+
index_name: Name for the Redis index
|
62 |
+
ttl: Default time-to-live for cache entries in seconds
|
63 |
+
**kwargs: Additional arguments passed to the Redis client
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
Exception: If similarity_threshold is not provided or required Redis
|
67 |
+
connection information is missing
|
68 |
+
"""
|
69 |
+
from redisvl.extensions.llmcache import SemanticCache
|
70 |
+
from redisvl.utils.vectorize import CustomTextVectorizer
|
71 |
+
|
72 |
+
if index_name is None:
|
73 |
+
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
74 |
+
|
75 |
+
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
76 |
+
|
77 |
+
# Validate similarity threshold
|
78 |
+
if similarity_threshold is None:
|
79 |
+
raise ValueError("similarity_threshold must be provided, passed None")
|
80 |
+
|
81 |
+
# Store configuration
|
82 |
+
self.similarity_threshold = similarity_threshold
|
83 |
+
|
84 |
+
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
85 |
+
# For cosine distance: 0 = most similar, 2 = least similar
|
86 |
+
# While similarity: 1 = most similar, 0 = least similar
|
87 |
+
self.distance_threshold = 1 - similarity_threshold
|
88 |
+
self.embedding_model = embedding_model
|
89 |
+
|
90 |
+
# Set up Redis connection
|
91 |
+
if redis_url is None:
|
92 |
+
try:
|
93 |
+
# Attempt to use provided parameters or fallback to environment variables
|
94 |
+
host = host or os.environ["REDIS_HOST"]
|
95 |
+
port = port or os.environ["REDIS_PORT"]
|
96 |
+
password = password or os.environ["REDIS_PASSWORD"]
|
97 |
+
except KeyError as e:
|
98 |
+
# Raise a more informative exception if any of the required keys are missing
|
99 |
+
missing_var = e.args[0]
|
100 |
+
raise ValueError(
|
101 |
+
f"Missing required Redis configuration: {missing_var}. "
|
102 |
+
f"Provide {missing_var} or redis_url."
|
103 |
+
) from e
|
104 |
+
|
105 |
+
redis_url = f"redis://:{password}@{host}:{port}"
|
106 |
+
|
107 |
+
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
|
108 |
+
|
109 |
+
# Initialize the Redis vectorizer and cache
|
110 |
+
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
111 |
+
|
112 |
+
self.llmcache = SemanticCache(
|
113 |
+
name=index_name,
|
114 |
+
redis_url=redis_url,
|
115 |
+
vectorizer=cache_vectorizer,
|
116 |
+
distance_threshold=self.distance_threshold,
|
117 |
+
overwrite=False,
|
118 |
+
)
|
119 |
+
|
120 |
+
def _get_ttl(self, **kwargs) -> Optional[int]:
|
121 |
+
"""
|
122 |
+
Get the TTL (time-to-live) value for cache entries.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
**kwargs: Keyword arguments that may contain a custom TTL
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
|
129 |
+
"""
|
130 |
+
ttl = kwargs.get("ttl")
|
131 |
+
if ttl is not None:
|
132 |
+
ttl = int(ttl)
|
133 |
+
return ttl
|
134 |
+
|
135 |
+
def _get_embedding(self, prompt: str) -> List[float]:
|
136 |
+
"""
|
137 |
+
Generate an embedding vector for the given prompt using the configured embedding model.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
prompt: The text to generate an embedding for
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
List[float]: The embedding vector
|
144 |
+
"""
|
145 |
+
# Create an embedding from prompt
|
146 |
+
embedding_response = cast(
|
147 |
+
EmbeddingResponse,
|
148 |
+
litellm.embedding(
|
149 |
+
model=self.embedding_model,
|
150 |
+
input=prompt,
|
151 |
+
cache={"no-store": True, "no-cache": True},
|
152 |
+
),
|
153 |
+
)
|
154 |
+
embedding = embedding_response["data"][0]["embedding"]
|
155 |
+
return embedding
|
156 |
+
|
157 |
+
def _get_cache_logic(self, cached_response: Any) -> Any:
|
158 |
+
"""
|
159 |
+
Process the cached response to prepare it for use.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
cached_response: The raw cached response
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
The processed cache response, or None if input was None
|
166 |
+
"""
|
167 |
+
if cached_response is None:
|
168 |
+
return cached_response
|
169 |
+
|
170 |
+
# Convert bytes to string if needed
|
171 |
+
if isinstance(cached_response, bytes):
|
172 |
+
cached_response = cached_response.decode("utf-8")
|
173 |
+
|
174 |
+
# Convert string representation to Python object
|
175 |
+
try:
|
176 |
+
cached_response = json.loads(cached_response)
|
177 |
+
except json.JSONDecodeError:
|
178 |
+
try:
|
179 |
+
cached_response = ast.literal_eval(cached_response)
|
180 |
+
except (ValueError, SyntaxError) as e:
|
181 |
+
print_verbose(f"Error parsing cached response: {str(e)}")
|
182 |
+
return None
|
183 |
+
|
184 |
+
return cached_response
|
185 |
+
|
186 |
+
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
187 |
+
"""
|
188 |
+
Store a value in the semantic cache.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
key: The cache key (not directly used in semantic caching)
|
192 |
+
value: The response value to cache
|
193 |
+
**kwargs: Additional arguments including 'messages' for the prompt
|
194 |
+
and optional 'ttl' for time-to-live
|
195 |
+
"""
|
196 |
+
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
197 |
+
|
198 |
+
value_str: Optional[str] = None
|
199 |
+
try:
|
200 |
+
# Extract the prompt from messages
|
201 |
+
messages = kwargs.get("messages", [])
|
202 |
+
if not messages:
|
203 |
+
print_verbose("No messages provided for semantic caching")
|
204 |
+
return
|
205 |
+
|
206 |
+
prompt = get_str_from_messages(messages)
|
207 |
+
value_str = str(value)
|
208 |
+
|
209 |
+
# Get TTL and store in Redis semantic cache
|
210 |
+
ttl = self._get_ttl(**kwargs)
|
211 |
+
if ttl is not None:
|
212 |
+
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
213 |
+
else:
|
214 |
+
self.llmcache.store(prompt, value_str)
|
215 |
+
except Exception as e:
|
216 |
+
print_verbose(
|
217 |
+
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
218 |
+
)
|
219 |
+
|
220 |
+
def get_cache(self, key: str, **kwargs) -> Any:
|
221 |
+
"""
|
222 |
+
Retrieve a semantically similar cached response.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
key: The cache key (not directly used in semantic caching)
|
226 |
+
**kwargs: Additional arguments including 'messages' for the prompt
|
227 |
+
|
228 |
+
Returns:
|
229 |
+
The cached response if a semantically similar prompt is found, else None
|
230 |
+
"""
|
231 |
+
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
|
232 |
+
|
233 |
+
try:
|
234 |
+
# Extract the prompt from messages
|
235 |
+
messages = kwargs.get("messages", [])
|
236 |
+
if not messages:
|
237 |
+
print_verbose("No messages provided for semantic cache lookup")
|
238 |
+
return None
|
239 |
+
|
240 |
+
prompt = get_str_from_messages(messages)
|
241 |
+
# Check the cache for semantically similar prompts
|
242 |
+
results = self.llmcache.check(prompt=prompt)
|
243 |
+
|
244 |
+
# Return None if no similar prompts found
|
245 |
+
if not results:
|
246 |
+
return None
|
247 |
+
|
248 |
+
# Process the best matching result
|
249 |
+
cache_hit = results[0]
|
250 |
+
vector_distance = float(cache_hit["vector_distance"])
|
251 |
+
|
252 |
+
# Convert vector distance back to similarity score
|
253 |
+
# For cosine distance: 0 = most similar, 2 = least similar
|
254 |
+
# While similarity: 1 = most similar, 0 = least similar
|
255 |
+
similarity = 1 - vector_distance
|
256 |
+
|
257 |
+
cached_prompt = cache_hit["prompt"]
|
258 |
+
cached_response = cache_hit["response"]
|
259 |
+
|
260 |
+
print_verbose(
|
261 |
+
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
262 |
+
f"actual similarity: {similarity}, "
|
263 |
+
f"current prompt: {prompt}, "
|
264 |
+
f"cached prompt: {cached_prompt}"
|
265 |
+
)
|
266 |
+
|
267 |
+
return self._get_cache_logic(cached_response=cached_response)
|
268 |
+
except Exception as e:
|
269 |
+
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
270 |
+
|
271 |
+
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
272 |
+
"""
|
273 |
+
Asynchronously generate an embedding for the given prompt.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
prompt: The text to generate an embedding for
|
277 |
+
**kwargs: Additional arguments that may contain metadata
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
List[float]: The embedding vector
|
281 |
+
"""
|
282 |
+
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
283 |
+
|
284 |
+
# Route the embedding request through the proxy if appropriate
|
285 |
+
router_model_names = (
|
286 |
+
[m["model_name"] for m in llm_model_list]
|
287 |
+
if llm_model_list is not None
|
288 |
+
else []
|
289 |
+
)
|
290 |
+
|
291 |
+
try:
|
292 |
+
if llm_router is not None and self.embedding_model in router_model_names:
|
293 |
+
# Use the router for embedding generation
|
294 |
+
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
295 |
+
embedding_response = await llm_router.aembedding(
|
296 |
+
model=self.embedding_model,
|
297 |
+
input=prompt,
|
298 |
+
cache={"no-store": True, "no-cache": True},
|
299 |
+
metadata={
|
300 |
+
"user_api_key": user_api_key,
|
301 |
+
"semantic-cache-embedding": True,
|
302 |
+
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
303 |
+
},
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
# Generate embedding directly
|
307 |
+
embedding_response = await litellm.aembedding(
|
308 |
+
model=self.embedding_model,
|
309 |
+
input=prompt,
|
310 |
+
cache={"no-store": True, "no-cache": True},
|
311 |
+
)
|
312 |
+
|
313 |
+
# Extract and return the embedding vector
|
314 |
+
return embedding_response["data"][0]["embedding"]
|
315 |
+
except Exception as e:
|
316 |
+
print_verbose(f"Error generating async embedding: {str(e)}")
|
317 |
+
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
|
318 |
+
|
319 |
+
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
320 |
+
"""
|
321 |
+
Asynchronously store a value in the semantic cache.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
key: The cache key (not directly used in semantic caching)
|
325 |
+
value: The response value to cache
|
326 |
+
**kwargs: Additional arguments including 'messages' for the prompt
|
327 |
+
and optional 'ttl' for time-to-live
|
328 |
+
"""
|
329 |
+
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
|
330 |
+
|
331 |
+
try:
|
332 |
+
# Extract the prompt from messages
|
333 |
+
messages = kwargs.get("messages", [])
|
334 |
+
if not messages:
|
335 |
+
print_verbose("No messages provided for semantic caching")
|
336 |
+
return
|
337 |
+
|
338 |
+
prompt = get_str_from_messages(messages)
|
339 |
+
value_str = str(value)
|
340 |
+
|
341 |
+
# Generate embedding for the value (response) to cache
|
342 |
+
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
343 |
+
|
344 |
+
# Get TTL and store in Redis semantic cache
|
345 |
+
ttl = self._get_ttl(**kwargs)
|
346 |
+
if ttl is not None:
|
347 |
+
await self.llmcache.astore(
|
348 |
+
prompt,
|
349 |
+
value_str,
|
350 |
+
vector=prompt_embedding, # Pass through custom embedding
|
351 |
+
ttl=ttl,
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
await self.llmcache.astore(
|
355 |
+
prompt,
|
356 |
+
value_str,
|
357 |
+
vector=prompt_embedding, # Pass through custom embedding
|
358 |
+
)
|
359 |
+
except Exception as e:
|
360 |
+
print_verbose(f"Error in async_set_cache: {str(e)}")
|
361 |
+
|
362 |
+
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
363 |
+
"""
|
364 |
+
Asynchronously retrieve a semantically similar cached response.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
key: The cache key (not directly used in semantic caching)
|
368 |
+
**kwargs: Additional arguments including 'messages' for the prompt
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
The cached response if a semantically similar prompt is found, else None
|
372 |
+
"""
|
373 |
+
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
|
374 |
+
|
375 |
+
try:
|
376 |
+
# Extract the prompt from messages
|
377 |
+
messages = kwargs.get("messages", [])
|
378 |
+
if not messages:
|
379 |
+
print_verbose("No messages provided for semantic cache lookup")
|
380 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
381 |
+
return None
|
382 |
+
|
383 |
+
prompt = get_str_from_messages(messages)
|
384 |
+
|
385 |
+
# Generate embedding for the prompt
|
386 |
+
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
387 |
+
|
388 |
+
# Check the cache for semantically similar prompts
|
389 |
+
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
390 |
+
|
391 |
+
# handle results / cache hit
|
392 |
+
if not results:
|
393 |
+
kwargs.setdefault("metadata", {})[
|
394 |
+
"semantic-similarity"
|
395 |
+
] = 0.0 # TODO why here but not above??
|
396 |
+
return None
|
397 |
+
|
398 |
+
cache_hit = results[0]
|
399 |
+
vector_distance = float(cache_hit["vector_distance"])
|
400 |
+
|
401 |
+
# Convert vector distance back to similarity
|
402 |
+
# For cosine distance: 0 = most similar, 2 = least similar
|
403 |
+
# While similarity: 1 = most similar, 0 = least similar
|
404 |
+
similarity = 1 - vector_distance
|
405 |
+
|
406 |
+
cached_prompt = cache_hit["prompt"]
|
407 |
+
cached_response = cache_hit["response"]
|
408 |
+
|
409 |
+
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
410 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
411 |
+
|
412 |
+
print_verbose(
|
413 |
+
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
414 |
+
f"actual similarity: {similarity}, "
|
415 |
+
f"current prompt: {prompt}, "
|
416 |
+
f"cached prompt: {cached_prompt}"
|
417 |
+
)
|
418 |
+
|
419 |
+
return self._get_cache_logic(cached_response=cached_response)
|
420 |
+
except Exception as e:
|
421 |
+
print_verbose(f"Error in async_get_cache: {str(e)}")
|
422 |
+
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
423 |
+
|
424 |
+
async def _index_info(self) -> Dict[str, Any]:
|
425 |
+
"""
|
426 |
+
Get information about the Redis index.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
Dict[str, Any]: Information about the Redis index
|
430 |
+
"""
|
431 |
+
aindex = await self.llmcache._get_async_index()
|
432 |
+
return await aindex.info()
|
433 |
+
|
434 |
+
async def async_set_cache_pipeline(
|
435 |
+
self, cache_list: List[Tuple[str, Any]], **kwargs
|
436 |
+
) -> None:
|
437 |
+
"""
|
438 |
+
Asynchronously store multiple values in the semantic cache.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
cache_list: List of (key, value) tuples to cache
|
442 |
+
**kwargs: Additional arguments
|
443 |
+
"""
|
444 |
+
try:
|
445 |
+
tasks = []
|
446 |
+
for val in cache_list:
|
447 |
+
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
448 |
+
await asyncio.gather(*tasks)
|
449 |
+
except Exception as e:
|
450 |
+
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
|
litellm/caching/s3_cache.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
S3 Cache implementation
|
3 |
+
WARNING: DO NOT USE THIS IN PRODUCTION - This is not ASYNC
|
4 |
+
|
5 |
+
Has 4 methods:
|
6 |
+
- set_cache
|
7 |
+
- get_cache
|
8 |
+
- async_set_cache
|
9 |
+
- async_get_cache
|
10 |
+
"""
|
11 |
+
|
12 |
+
import ast
|
13 |
+
import asyncio
|
14 |
+
import json
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
from litellm._logging import print_verbose, verbose_logger
|
18 |
+
|
19 |
+
from .base_cache import BaseCache
|
20 |
+
|
21 |
+
|
22 |
+
class S3Cache(BaseCache):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
s3_bucket_name,
|
26 |
+
s3_region_name=None,
|
27 |
+
s3_api_version=None,
|
28 |
+
s3_use_ssl: Optional[bool] = True,
|
29 |
+
s3_verify=None,
|
30 |
+
s3_endpoint_url=None,
|
31 |
+
s3_aws_access_key_id=None,
|
32 |
+
s3_aws_secret_access_key=None,
|
33 |
+
s3_aws_session_token=None,
|
34 |
+
s3_config=None,
|
35 |
+
s3_path=None,
|
36 |
+
**kwargs,
|
37 |
+
):
|
38 |
+
import boto3
|
39 |
+
|
40 |
+
self.bucket_name = s3_bucket_name
|
41 |
+
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
|
42 |
+
# Create an S3 client with custom endpoint URL
|
43 |
+
|
44 |
+
self.s3_client = boto3.client(
|
45 |
+
"s3",
|
46 |
+
region_name=s3_region_name,
|
47 |
+
endpoint_url=s3_endpoint_url,
|
48 |
+
api_version=s3_api_version,
|
49 |
+
use_ssl=s3_use_ssl,
|
50 |
+
verify=s3_verify,
|
51 |
+
aws_access_key_id=s3_aws_access_key_id,
|
52 |
+
aws_secret_access_key=s3_aws_secret_access_key,
|
53 |
+
aws_session_token=s3_aws_session_token,
|
54 |
+
config=s3_config,
|
55 |
+
**kwargs,
|
56 |
+
)
|
57 |
+
|
58 |
+
def set_cache(self, key, value, **kwargs):
|
59 |
+
try:
|
60 |
+
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
|
61 |
+
ttl = kwargs.get("ttl", None)
|
62 |
+
# Convert value to JSON before storing in S3
|
63 |
+
serialized_value = json.dumps(value)
|
64 |
+
key = self.key_prefix + key
|
65 |
+
|
66 |
+
if ttl is not None:
|
67 |
+
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
|
68 |
+
import datetime
|
69 |
+
|
70 |
+
# Calculate expiration time
|
71 |
+
expiration_time = datetime.datetime.now() + ttl
|
72 |
+
|
73 |
+
# Upload the data to S3 with the calculated expiration time
|
74 |
+
self.s3_client.put_object(
|
75 |
+
Bucket=self.bucket_name,
|
76 |
+
Key=key,
|
77 |
+
Body=serialized_value,
|
78 |
+
Expires=expiration_time,
|
79 |
+
CacheControl=cache_control,
|
80 |
+
ContentType="application/json",
|
81 |
+
ContentLanguage="en",
|
82 |
+
ContentDisposition=f'inline; filename="{key}.json"',
|
83 |
+
)
|
84 |
+
else:
|
85 |
+
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
|
86 |
+
# Upload the data to S3 without specifying Expires
|
87 |
+
self.s3_client.put_object(
|
88 |
+
Bucket=self.bucket_name,
|
89 |
+
Key=key,
|
90 |
+
Body=serialized_value,
|
91 |
+
CacheControl=cache_control,
|
92 |
+
ContentType="application/json",
|
93 |
+
ContentLanguage="en",
|
94 |
+
ContentDisposition=f'inline; filename="{key}.json"',
|
95 |
+
)
|
96 |
+
except Exception as e:
|
97 |
+
# NON blocking - notify users S3 is throwing an exception
|
98 |
+
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
99 |
+
|
100 |
+
async def async_set_cache(self, key, value, **kwargs):
|
101 |
+
self.set_cache(key=key, value=value, **kwargs)
|
102 |
+
|
103 |
+
def get_cache(self, key, **kwargs):
|
104 |
+
import botocore
|
105 |
+
|
106 |
+
try:
|
107 |
+
key = self.key_prefix + key
|
108 |
+
|
109 |
+
print_verbose(f"Get S3 Cache: key: {key}")
|
110 |
+
# Download the data from S3
|
111 |
+
cached_response = self.s3_client.get_object(
|
112 |
+
Bucket=self.bucket_name, Key=key
|
113 |
+
)
|
114 |
+
|
115 |
+
if cached_response is not None:
|
116 |
+
# cached_response is in `b{} convert it to ModelResponse
|
117 |
+
cached_response = (
|
118 |
+
cached_response["Body"].read().decode("utf-8")
|
119 |
+
) # Convert bytes to string
|
120 |
+
try:
|
121 |
+
cached_response = json.loads(
|
122 |
+
cached_response
|
123 |
+
) # Convert string to dictionary
|
124 |
+
except Exception:
|
125 |
+
cached_response = ast.literal_eval(cached_response)
|
126 |
+
if not isinstance(cached_response, dict):
|
127 |
+
cached_response = dict(cached_response)
|
128 |
+
verbose_logger.debug(
|
129 |
+
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
130 |
+
)
|
131 |
+
|
132 |
+
return cached_response
|
133 |
+
except botocore.exceptions.ClientError as e: # type: ignore
|
134 |
+
if e.response["Error"]["Code"] == "NoSuchKey":
|
135 |
+
verbose_logger.debug(
|
136 |
+
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
137 |
+
)
|
138 |
+
return None
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
# NON blocking - notify users S3 is throwing an exception
|
142 |
+
verbose_logger.error(
|
143 |
+
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
144 |
+
)
|
145 |
+
|
146 |
+
async def async_get_cache(self, key, **kwargs):
|
147 |
+
return self.get_cache(key=key, **kwargs)
|
148 |
+
|
149 |
+
def flush_cache(self):
|
150 |
+
pass
|
151 |
+
|
152 |
+
async def disconnect(self):
|
153 |
+
pass
|
154 |
+
|
155 |
+
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
156 |
+
tasks = []
|
157 |
+
for val in cache_list:
|
158 |
+
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
159 |
+
await asyncio.gather(*tasks)
|
litellm/constants.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Literal
|
2 |
+
|
3 |
+
ROUTER_MAX_FALLBACKS = 5
|
4 |
+
DEFAULT_BATCH_SIZE = 512
|
5 |
+
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
|
6 |
+
DEFAULT_MAX_RETRIES = 2
|
7 |
+
DEFAULT_MAX_RECURSE_DEPTH = 10
|
8 |
+
DEFAULT_FAILURE_THRESHOLD_PERCENT = (
|
9 |
+
0.5 # default cooldown a deployment if 50% of requests fail in a given minute
|
10 |
+
)
|
11 |
+
DEFAULT_MAX_TOKENS = 4096
|
12 |
+
DEFAULT_ALLOWED_FAILS = 3
|
13 |
+
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
14 |
+
DEFAULT_COOLDOWN_TIME_SECONDS = 5
|
15 |
+
DEFAULT_REPLICATE_POLLING_RETRIES = 5
|
16 |
+
DEFAULT_REPLICATE_POLLING_DELAY_SECONDS = 1
|
17 |
+
DEFAULT_IMAGE_TOKEN_COUNT = 250
|
18 |
+
DEFAULT_IMAGE_WIDTH = 300
|
19 |
+
DEFAULT_IMAGE_HEIGHT = 300
|
20 |
+
DEFAULT_MAX_TOKENS = 256 # used when providers need a default
|
21 |
+
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 1024 # 1MB = 1024KB
|
22 |
+
SINGLE_DEPLOYMENT_TRAFFIC_FAILURE_THRESHOLD = 1000 # Minimum number of requests to consider "reasonable traffic". Used for single-deployment cooldown logic.
|
23 |
+
|
24 |
+
DEFAULT_REASONING_EFFORT_LOW_THINKING_BUDGET = 1024
|
25 |
+
DEFAULT_REASONING_EFFORT_MEDIUM_THINKING_BUDGET = 2048
|
26 |
+
DEFAULT_REASONING_EFFORT_HIGH_THINKING_BUDGET = 4096
|
27 |
+
|
28 |
+
########## Networking constants ##############################################################
|
29 |
+
_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
|
30 |
+
|
31 |
+
########### v2 Architecture constants for managing writing updates to the database ###########
|
32 |
+
REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
|
33 |
+
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
|
34 |
+
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
|
35 |
+
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
|
36 |
+
MAX_REDIS_BUFFER_DEQUEUE_COUNT = 100
|
37 |
+
MAX_SIZE_IN_MEMORY_QUEUE = 10000
|
38 |
+
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT = 1000
|
39 |
+
###############################################################################################
|
40 |
+
MINIMUM_PROMPT_CACHE_TOKEN_COUNT = (
|
41 |
+
1024 # minimum number of tokens to cache a prompt by Anthropic
|
42 |
+
)
|
43 |
+
DEFAULT_TRIM_RATIO = 0.75 # default ratio of tokens to trim from the end of a prompt
|
44 |
+
HOURS_IN_A_DAY = 24
|
45 |
+
DAYS_IN_A_WEEK = 7
|
46 |
+
DAYS_IN_A_MONTH = 28
|
47 |
+
DAYS_IN_A_YEAR = 365
|
48 |
+
REPLICATE_MODEL_NAME_WITH_ID_LENGTH = 64
|
49 |
+
#### TOKEN COUNTING ####
|
50 |
+
FUNCTION_DEFINITION_TOKEN_COUNT = 9
|
51 |
+
SYSTEM_MESSAGE_TOKEN_COUNT = 4
|
52 |
+
TOOL_CHOICE_OBJECT_TOKEN_COUNT = 4
|
53 |
+
DEFAULT_MOCK_RESPONSE_PROMPT_TOKEN_COUNT = 10
|
54 |
+
DEFAULT_MOCK_RESPONSE_COMPLETION_TOKEN_COUNT = 20
|
55 |
+
MAX_SHORT_SIDE_FOR_IMAGE_HIGH_RES = 768
|
56 |
+
MAX_LONG_SIDE_FOR_IMAGE_HIGH_RES = 2000
|
57 |
+
MAX_TILE_WIDTH = 512
|
58 |
+
MAX_TILE_HEIGHT = 512
|
59 |
+
OPENAI_FILE_SEARCH_COST_PER_1K_CALLS = 2.5 / 1000
|
60 |
+
MIN_NON_ZERO_TEMPERATURE = 0.0001
|
61 |
+
#### RELIABILITY ####
|
62 |
+
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
63 |
+
DEFAULT_MAX_LRU_CACHE_SIZE = 16
|
64 |
+
INITIAL_RETRY_DELAY = 0.5
|
65 |
+
MAX_RETRY_DELAY = 8.0
|
66 |
+
JITTER = 0.75
|
67 |
+
DEFAULT_IN_MEMORY_TTL = 5 # default time to live for the in-memory cache
|
68 |
+
DEFAULT_POLLING_INTERVAL = 0.03 # default polling interval for the scheduler
|
69 |
+
AZURE_OPERATION_POLLING_TIMEOUT = 120
|
70 |
+
REDIS_SOCKET_TIMEOUT = 0.1
|
71 |
+
REDIS_CONNECTION_POOL_TIMEOUT = 5
|
72 |
+
NON_LLM_CONNECTION_TIMEOUT = 15 # timeout for adjacent services (e.g. jwt auth)
|
73 |
+
MAX_EXCEPTION_MESSAGE_LENGTH = 2000
|
74 |
+
BEDROCK_MAX_POLICY_SIZE = 75
|
75 |
+
REPLICATE_POLLING_DELAY_SECONDS = 0.5
|
76 |
+
DEFAULT_ANTHROPIC_CHAT_MAX_TOKENS = 4096
|
77 |
+
TOGETHER_AI_4_B = 4
|
78 |
+
TOGETHER_AI_8_B = 8
|
79 |
+
TOGETHER_AI_21_B = 21
|
80 |
+
TOGETHER_AI_41_B = 41
|
81 |
+
TOGETHER_AI_80_B = 80
|
82 |
+
TOGETHER_AI_110_B = 110
|
83 |
+
TOGETHER_AI_EMBEDDING_150_M = 150
|
84 |
+
TOGETHER_AI_EMBEDDING_350_M = 350
|
85 |
+
QDRANT_SCALAR_QUANTILE = 0.99
|
86 |
+
QDRANT_VECTOR_SIZE = 1536
|
87 |
+
CACHED_STREAMING_CHUNK_DELAY = 0.02
|
88 |
+
MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB = 512
|
89 |
+
DEFAULT_MAX_TOKENS_FOR_TRITON = 2000
|
90 |
+
#### Networking settings ####
|
91 |
+
request_timeout: float = 6000 # time in seconds
|
92 |
+
STREAM_SSE_DONE_STRING: str = "[DONE]"
|
93 |
+
### SPEND TRACKING ###
|
94 |
+
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND = 0.001400 # price per second for a100 80GB
|
95 |
+
FIREWORKS_AI_56_B_MOE = 56
|
96 |
+
FIREWORKS_AI_176_B_MOE = 176
|
97 |
+
FIREWORKS_AI_4_B = 4
|
98 |
+
FIREWORKS_AI_16_B = 16
|
99 |
+
FIREWORKS_AI_80_B = 80
|
100 |
+
|
101 |
+
LITELLM_CHAT_PROVIDERS = [
|
102 |
+
"openai",
|
103 |
+
"openai_like",
|
104 |
+
"xai",
|
105 |
+
"custom_openai",
|
106 |
+
"text-completion-openai",
|
107 |
+
"cohere",
|
108 |
+
"cohere_chat",
|
109 |
+
"clarifai",
|
110 |
+
"anthropic",
|
111 |
+
"anthropic_text",
|
112 |
+
"replicate",
|
113 |
+
"huggingface",
|
114 |
+
"together_ai",
|
115 |
+
"openrouter",
|
116 |
+
"vertex_ai",
|
117 |
+
"vertex_ai_beta",
|
118 |
+
"gemini",
|
119 |
+
"ai21",
|
120 |
+
"baseten",
|
121 |
+
"azure",
|
122 |
+
"azure_text",
|
123 |
+
"azure_ai",
|
124 |
+
"sagemaker",
|
125 |
+
"sagemaker_chat",
|
126 |
+
"bedrock",
|
127 |
+
"vllm",
|
128 |
+
"nlp_cloud",
|
129 |
+
"petals",
|
130 |
+
"oobabooga",
|
131 |
+
"ollama",
|
132 |
+
"ollama_chat",
|
133 |
+
"deepinfra",
|
134 |
+
"perplexity",
|
135 |
+
"mistral",
|
136 |
+
"groq",
|
137 |
+
"nvidia_nim",
|
138 |
+
"cerebras",
|
139 |
+
"ai21_chat",
|
140 |
+
"volcengine",
|
141 |
+
"codestral",
|
142 |
+
"text-completion-codestral",
|
143 |
+
"deepseek",
|
144 |
+
"sambanova",
|
145 |
+
"maritalk",
|
146 |
+
"cloudflare",
|
147 |
+
"fireworks_ai",
|
148 |
+
"friendliai",
|
149 |
+
"watsonx",
|
150 |
+
"watsonx_text",
|
151 |
+
"triton",
|
152 |
+
"predibase",
|
153 |
+
"databricks",
|
154 |
+
"empower",
|
155 |
+
"github",
|
156 |
+
"custom",
|
157 |
+
"litellm_proxy",
|
158 |
+
"hosted_vllm",
|
159 |
+
"llamafile",
|
160 |
+
"lm_studio",
|
161 |
+
"galadriel",
|
162 |
+
]
|
163 |
+
|
164 |
+
|
165 |
+
OPENAI_CHAT_COMPLETION_PARAMS = [
|
166 |
+
"functions",
|
167 |
+
"function_call",
|
168 |
+
"temperature",
|
169 |
+
"temperature",
|
170 |
+
"top_p",
|
171 |
+
"n",
|
172 |
+
"stream",
|
173 |
+
"stream_options",
|
174 |
+
"stop",
|
175 |
+
"max_completion_tokens",
|
176 |
+
"modalities",
|
177 |
+
"prediction",
|
178 |
+
"audio",
|
179 |
+
"max_tokens",
|
180 |
+
"presence_penalty",
|
181 |
+
"frequency_penalty",
|
182 |
+
"logit_bias",
|
183 |
+
"user",
|
184 |
+
"request_timeout",
|
185 |
+
"api_base",
|
186 |
+
"api_version",
|
187 |
+
"api_key",
|
188 |
+
"deployment_id",
|
189 |
+
"organization",
|
190 |
+
"base_url",
|
191 |
+
"default_headers",
|
192 |
+
"timeout",
|
193 |
+
"response_format",
|
194 |
+
"seed",
|
195 |
+
"tools",
|
196 |
+
"tool_choice",
|
197 |
+
"max_retries",
|
198 |
+
"parallel_tool_calls",
|
199 |
+
"logprobs",
|
200 |
+
"top_logprobs",
|
201 |
+
"reasoning_effort",
|
202 |
+
"extra_headers",
|
203 |
+
"thinking",
|
204 |
+
]
|
205 |
+
|
206 |
+
openai_compatible_endpoints: List = [
|
207 |
+
"api.perplexity.ai",
|
208 |
+
"api.endpoints.anyscale.com/v1",
|
209 |
+
"api.deepinfra.com/v1/openai",
|
210 |
+
"api.mistral.ai/v1",
|
211 |
+
"codestral.mistral.ai/v1/chat/completions",
|
212 |
+
"codestral.mistral.ai/v1/fim/completions",
|
213 |
+
"api.groq.com/openai/v1",
|
214 |
+
"https://integrate.api.nvidia.com/v1",
|
215 |
+
"api.deepseek.com/v1",
|
216 |
+
"api.together.xyz/v1",
|
217 |
+
"app.empower.dev/api/v1",
|
218 |
+
"https://api.friendli.ai/serverless/v1",
|
219 |
+
"api.sambanova.ai/v1",
|
220 |
+
"api.x.ai/v1",
|
221 |
+
"api.galadriel.ai/v1",
|
222 |
+
]
|
223 |
+
|
224 |
+
|
225 |
+
openai_compatible_providers: List = [
|
226 |
+
"anyscale",
|
227 |
+
"mistral",
|
228 |
+
"groq",
|
229 |
+
"nvidia_nim",
|
230 |
+
"cerebras",
|
231 |
+
"sambanova",
|
232 |
+
"ai21_chat",
|
233 |
+
"ai21",
|
234 |
+
"volcengine",
|
235 |
+
"codestral",
|
236 |
+
"deepseek",
|
237 |
+
"deepinfra",
|
238 |
+
"perplexity",
|
239 |
+
"xinference",
|
240 |
+
"xai",
|
241 |
+
"together_ai",
|
242 |
+
"fireworks_ai",
|
243 |
+
"empower",
|
244 |
+
"friendliai",
|
245 |
+
"azure_ai",
|
246 |
+
"github",
|
247 |
+
"litellm_proxy",
|
248 |
+
"hosted_vllm",
|
249 |
+
"llamafile",
|
250 |
+
"lm_studio",
|
251 |
+
"galadriel",
|
252 |
+
]
|
253 |
+
openai_text_completion_compatible_providers: List = (
|
254 |
+
[ # providers that support `/v1/completions`
|
255 |
+
"together_ai",
|
256 |
+
"fireworks_ai",
|
257 |
+
"hosted_vllm",
|
258 |
+
"llamafile",
|
259 |
+
]
|
260 |
+
)
|
261 |
+
_openai_like_providers: List = [
|
262 |
+
"predibase",
|
263 |
+
"databricks",
|
264 |
+
"watsonx",
|
265 |
+
] # private helper. similar to openai but require some custom auth / endpoint handling, so can't use the openai sdk
|
266 |
+
# well supported replicate llms
|
267 |
+
replicate_models: List = [
|
268 |
+
# llama replicate supported LLMs
|
269 |
+
"replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf",
|
270 |
+
"a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52",
|
271 |
+
"meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db",
|
272 |
+
# Vicuna
|
273 |
+
"replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b",
|
274 |
+
"joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe",
|
275 |
+
# Flan T-5
|
276 |
+
"daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f",
|
277 |
+
# Others
|
278 |
+
"replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5",
|
279 |
+
"replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad",
|
280 |
+
]
|
281 |
+
|
282 |
+
clarifai_models: List = [
|
283 |
+
"clarifai/meta.Llama-3.Llama-3-8B-Instruct",
|
284 |
+
"clarifai/gcp.generate.gemma-1_1-7b-it",
|
285 |
+
"clarifai/mistralai.completion.mixtral-8x22B",
|
286 |
+
"clarifai/cohere.generate.command-r-plus",
|
287 |
+
"clarifai/databricks.drbx.dbrx-instruct",
|
288 |
+
"clarifai/mistralai.completion.mistral-large",
|
289 |
+
"clarifai/mistralai.completion.mistral-medium",
|
290 |
+
"clarifai/mistralai.completion.mistral-small",
|
291 |
+
"clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1",
|
292 |
+
"clarifai/gcp.generate.gemma-2b-it",
|
293 |
+
"clarifai/gcp.generate.gemma-7b-it",
|
294 |
+
"clarifai/deci.decilm.deciLM-7B-instruct",
|
295 |
+
"clarifai/mistralai.completion.mistral-7B-Instruct",
|
296 |
+
"clarifai/gcp.generate.gemini-pro",
|
297 |
+
"clarifai/anthropic.completion.claude-v1",
|
298 |
+
"clarifai/anthropic.completion.claude-instant-1_2",
|
299 |
+
"clarifai/anthropic.completion.claude-instant",
|
300 |
+
"clarifai/anthropic.completion.claude-v2",
|
301 |
+
"clarifai/anthropic.completion.claude-2_1",
|
302 |
+
"clarifai/meta.Llama-2.codeLlama-70b-Python",
|
303 |
+
"clarifai/meta.Llama-2.codeLlama-70b-Instruct",
|
304 |
+
"clarifai/openai.completion.gpt-3_5-turbo-instruct",
|
305 |
+
"clarifai/meta.Llama-2.llama2-7b-chat",
|
306 |
+
"clarifai/meta.Llama-2.llama2-13b-chat",
|
307 |
+
"clarifai/meta.Llama-2.llama2-70b-chat",
|
308 |
+
"clarifai/openai.chat-completion.gpt-4-turbo",
|
309 |
+
"clarifai/microsoft.text-generation.phi-2",
|
310 |
+
"clarifai/meta.Llama-2.llama2-7b-chat-vllm",
|
311 |
+
"clarifai/upstage.solar.solar-10_7b-instruct",
|
312 |
+
"clarifai/openchat.openchat.openchat-3_5-1210",
|
313 |
+
"clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B",
|
314 |
+
"clarifai/gcp.generate.text-bison",
|
315 |
+
"clarifai/meta.Llama-2.llamaGuard-7b",
|
316 |
+
"clarifai/fblgit.una-cybertron.una-cybertron-7b-v2",
|
317 |
+
"clarifai/openai.chat-completion.GPT-4",
|
318 |
+
"clarifai/openai.chat-completion.GPT-3_5-turbo",
|
319 |
+
"clarifai/ai21.complete.Jurassic2-Grande",
|
320 |
+
"clarifai/ai21.complete.Jurassic2-Grande-Instruct",
|
321 |
+
"clarifai/ai21.complete.Jurassic2-Jumbo-Instruct",
|
322 |
+
"clarifai/ai21.complete.Jurassic2-Jumbo",
|
323 |
+
"clarifai/ai21.complete.Jurassic2-Large",
|
324 |
+
"clarifai/cohere.generate.cohere-generate-command",
|
325 |
+
"clarifai/wizardlm.generate.wizardCoder-Python-34B",
|
326 |
+
"clarifai/wizardlm.generate.wizardLM-70B",
|
327 |
+
"clarifai/tiiuae.falcon.falcon-40b-instruct",
|
328 |
+
"clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat",
|
329 |
+
"clarifai/gcp.generate.code-gecko",
|
330 |
+
"clarifai/gcp.generate.code-bison",
|
331 |
+
"clarifai/mistralai.completion.mistral-7B-OpenOrca",
|
332 |
+
"clarifai/mistralai.completion.openHermes-2-mistral-7B",
|
333 |
+
"clarifai/wizardlm.generate.wizardLM-13B",
|
334 |
+
"clarifai/huggingface-research.zephyr.zephyr-7B-alpha",
|
335 |
+
"clarifai/wizardlm.generate.wizardCoder-15B",
|
336 |
+
"clarifai/microsoft.text-generation.phi-1_5",
|
337 |
+
"clarifai/databricks.Dolly-v2.dolly-v2-12b",
|
338 |
+
"clarifai/bigcode.code.StarCoder",
|
339 |
+
"clarifai/salesforce.xgen.xgen-7b-8k-instruct",
|
340 |
+
"clarifai/mosaicml.mpt.mpt-7b-instruct",
|
341 |
+
"clarifai/anthropic.completion.claude-3-opus",
|
342 |
+
"clarifai/anthropic.completion.claude-3-sonnet",
|
343 |
+
"clarifai/gcp.generate.gemini-1_5-pro",
|
344 |
+
"clarifai/gcp.generate.imagen-2",
|
345 |
+
"clarifai/salesforce.blip.general-english-image-caption-blip-2",
|
346 |
+
]
|
347 |
+
|
348 |
+
|
349 |
+
huggingface_models: List = [
|
350 |
+
"meta-llama/Llama-2-7b-hf",
|
351 |
+
"meta-llama/Llama-2-7b-chat-hf",
|
352 |
+
"meta-llama/Llama-2-13b-hf",
|
353 |
+
"meta-llama/Llama-2-13b-chat-hf",
|
354 |
+
"meta-llama/Llama-2-70b-hf",
|
355 |
+
"meta-llama/Llama-2-70b-chat-hf",
|
356 |
+
"meta-llama/Llama-2-7b",
|
357 |
+
"meta-llama/Llama-2-7b-chat",
|
358 |
+
"meta-llama/Llama-2-13b",
|
359 |
+
"meta-llama/Llama-2-13b-chat",
|
360 |
+
"meta-llama/Llama-2-70b",
|
361 |
+
"meta-llama/Llama-2-70b-chat",
|
362 |
+
] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers
|
363 |
+
empower_models = [
|
364 |
+
"empower/empower-functions",
|
365 |
+
"empower/empower-functions-small",
|
366 |
+
]
|
367 |
+
|
368 |
+
together_ai_models: List = [
|
369 |
+
# llama llms - chat
|
370 |
+
"togethercomputer/llama-2-70b-chat",
|
371 |
+
# llama llms - language / instruct
|
372 |
+
"togethercomputer/llama-2-70b",
|
373 |
+
"togethercomputer/LLaMA-2-7B-32K",
|
374 |
+
"togethercomputer/Llama-2-7B-32K-Instruct",
|
375 |
+
"togethercomputer/llama-2-7b",
|
376 |
+
# falcon llms
|
377 |
+
"togethercomputer/falcon-40b-instruct",
|
378 |
+
"togethercomputer/falcon-7b-instruct",
|
379 |
+
# alpaca
|
380 |
+
"togethercomputer/alpaca-7b",
|
381 |
+
# chat llms
|
382 |
+
"HuggingFaceH4/starchat-alpha",
|
383 |
+
# code llms
|
384 |
+
"togethercomputer/CodeLlama-34b",
|
385 |
+
"togethercomputer/CodeLlama-34b-Instruct",
|
386 |
+
"togethercomputer/CodeLlama-34b-Python",
|
387 |
+
"defog/sqlcoder",
|
388 |
+
"NumbersStation/nsql-llama-2-7B",
|
389 |
+
"WizardLM/WizardCoder-15B-V1.0",
|
390 |
+
"WizardLM/WizardCoder-Python-34B-V1.0",
|
391 |
+
# language llms
|
392 |
+
"NousResearch/Nous-Hermes-Llama2-13b",
|
393 |
+
"Austism/chronos-hermes-13b",
|
394 |
+
"upstage/SOLAR-0-70b-16bit",
|
395 |
+
"WizardLM/WizardLM-70B-V1.0",
|
396 |
+
] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...)
|
397 |
+
|
398 |
+
|
399 |
+
baseten_models: List = [
|
400 |
+
"qvv0xeq",
|
401 |
+
"q841o8w",
|
402 |
+
"31dxrj3",
|
403 |
+
] # FALCON 7B # WizardLM # Mosaic ML
|
404 |
+
|
405 |
+
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
406 |
+
"cohere",
|
407 |
+
"anthropic",
|
408 |
+
"mistral",
|
409 |
+
"amazon",
|
410 |
+
"meta",
|
411 |
+
"llama",
|
412 |
+
"ai21",
|
413 |
+
"nova",
|
414 |
+
"deepseek_r1",
|
415 |
+
]
|
416 |
+
|
417 |
+
open_ai_embedding_models: List = ["text-embedding-ada-002"]
|
418 |
+
cohere_embedding_models: List = [
|
419 |
+
"embed-english-v3.0",
|
420 |
+
"embed-english-light-v3.0",
|
421 |
+
"embed-multilingual-v3.0",
|
422 |
+
"embed-english-v2.0",
|
423 |
+
"embed-english-light-v2.0",
|
424 |
+
"embed-multilingual-v2.0",
|
425 |
+
]
|
426 |
+
bedrock_embedding_models: List = [
|
427 |
+
"amazon.titan-embed-text-v1",
|
428 |
+
"cohere.embed-english-v3",
|
429 |
+
"cohere.embed-multilingual-v3",
|
430 |
+
]
|
431 |
+
|
432 |
+
known_tokenizer_config = {
|
433 |
+
"mistralai/Mistral-7B-Instruct-v0.1": {
|
434 |
+
"tokenizer": {
|
435 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
436 |
+
"bos_token": "<s>",
|
437 |
+
"eos_token": "</s>",
|
438 |
+
},
|
439 |
+
"status": "success",
|
440 |
+
},
|
441 |
+
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
442 |
+
"tokenizer": {
|
443 |
+
"chat_template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
|
444 |
+
"bos_token": "<|begin_of_text|>",
|
445 |
+
"eos_token": "",
|
446 |
+
},
|
447 |
+
"status": "success",
|
448 |
+
},
|
449 |
+
"deepseek-r1/deepseek-r1-7b-instruct": {
|
450 |
+
"tokenizer": {
|
451 |
+
"add_bos_token": True,
|
452 |
+
"add_eos_token": False,
|
453 |
+
"bos_token": {
|
454 |
+
"__type": "AddedToken",
|
455 |
+
"content": "<|begin▁of▁sentence|>",
|
456 |
+
"lstrip": False,
|
457 |
+
"normalized": True,
|
458 |
+
"rstrip": False,
|
459 |
+
"single_word": False,
|
460 |
+
},
|
461 |
+
"clean_up_tokenization_spaces": False,
|
462 |
+
"eos_token": {
|
463 |
+
"__type": "AddedToken",
|
464 |
+
"content": "<|end▁of▁sentence|>",
|
465 |
+
"lstrip": False,
|
466 |
+
"normalized": True,
|
467 |
+
"rstrip": False,
|
468 |
+
"single_word": False,
|
469 |
+
},
|
470 |
+
"legacy": True,
|
471 |
+
"model_max_length": 16384,
|
472 |
+
"pad_token": {
|
473 |
+
"__type": "AddedToken",
|
474 |
+
"content": "<|end▁of▁sentence|>",
|
475 |
+
"lstrip": False,
|
476 |
+
"normalized": True,
|
477 |
+
"rstrip": False,
|
478 |
+
"single_word": False,
|
479 |
+
},
|
480 |
+
"sp_model_kwargs": {},
|
481 |
+
"unk_token": None,
|
482 |
+
"tokenizer_class": "LlamaTokenizerFast",
|
483 |
+
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|><think>\\n'}}{% endif %}",
|
484 |
+
},
|
485 |
+
"status": "success",
|
486 |
+
},
|
487 |
+
}
|
488 |
+
|
489 |
+
|
490 |
+
OPENAI_FINISH_REASONS = ["stop", "length", "function_call", "content_filter", "null"]
|
491 |
+
HUMANLOOP_PROMPT_CACHE_TTL_SECONDS = 60 # 1 minute
|
492 |
+
RESPONSE_FORMAT_TOOL_NAME = "json_tool_call" # default tool name used when converting response format to tool call
|
493 |
+
|
494 |
+
########################### Logging Callback Constants ###########################
|
495 |
+
AZURE_STORAGE_MSFT_VERSION = "2019-07-07"
|
496 |
+
PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES = 5
|
497 |
+
MCP_TOOL_NAME_PREFIX = "mcp_tool"
|
498 |
+
|
499 |
+
########################### LiteLLM Proxy Specific Constants ###########################
|
500 |
+
########################################################################################
|
501 |
+
MAX_SPENDLOG_ROWS_TO_QUERY = (
|
502 |
+
1_000_000 # if spendLogs has more than 1M rows, do not query the DB
|
503 |
+
)
|
504 |
+
DEFAULT_SOFT_BUDGET = (
|
505 |
+
50.0 # by default all litellm proxy keys have a soft budget of 50.0
|
506 |
+
)
|
507 |
+
# makes it clear this is a rate limit error for a litellm virtual key
|
508 |
+
RATE_LIMIT_ERROR_MESSAGE_FOR_VIRTUAL_KEY = "LiteLLM Virtual Key user_api_key_hash"
|
509 |
+
|
510 |
+
# pass through route constansts
|
511 |
+
BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
512 |
+
"agents/",
|
513 |
+
"knowledgebases/",
|
514 |
+
"flows/",
|
515 |
+
"retrieveAndGenerate/",
|
516 |
+
"rerank/",
|
517 |
+
"generateQuery/",
|
518 |
+
"optimize-prompt/",
|
519 |
+
]
|
520 |
+
|
521 |
+
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
|
522 |
+
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
|
523 |
+
|
524 |
+
HEALTH_CHECK_TIMEOUT_SECONDS = 60 # 60 seconds
|
525 |
+
|
526 |
+
UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"
|
527 |
+
LITELLM_PROXY_ADMIN_NAME = "default_user_id"
|
528 |
+
|
529 |
+
########################### DB CRON JOB NAMES ###########################
|
530 |
+
DB_SPEND_UPDATE_JOB_NAME = "db_spend_update_job"
|
531 |
+
PROMETHEUS_EMIT_BUDGET_METRICS_JOB_NAME = "prometheus_emit_budget_metrics_job"
|
532 |
+
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS = 60 # 1 minute
|
533 |
+
PROXY_BUDGET_RESCHEDULER_MIN_TIME = 597
|
534 |
+
PROXY_BUDGET_RESCHEDULER_MAX_TIME = 605
|
535 |
+
PROXY_BATCH_WRITE_AT = 10 # in seconds
|
536 |
+
DEFAULT_HEALTH_CHECK_INTERVAL = 300 # 5 minutes
|
537 |
+
PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = 9
|
538 |
+
DEFAULT_MODEL_CREATED_AT_TIME = 1677610602 # returns on `/models` endpoint
|
539 |
+
DEFAULT_SLACK_ALERTING_THRESHOLD = 300
|
540 |
+
MAX_TEAM_LIST_LIMIT = 20
|
541 |
+
DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD = 0.7
|
542 |
+
LENGTH_OF_LITELLM_GENERATED_KEY = 16
|
543 |
+
SECRET_MANAGER_REFRESH_INTERVAL = 86400
|
litellm/cost.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"gpt-3.5-turbo-0613": 0.00015000000000000001,
|
3 |
+
"claude-2": 0.00016454,
|
4 |
+
"gpt-4-0613": 0.015408
|
5 |
+
}
|
litellm/cost_calculator.py
ADDED
@@ -0,0 +1,1378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# What is this?
|
2 |
+
## File for 'response_cost' calculation in Logging
|
3 |
+
import time
|
4 |
+
from functools import lru_cache
|
5 |
+
from typing import Any, List, Literal, Optional, Tuple, Union, cast
|
6 |
+
|
7 |
+
from pydantic import BaseModel
|
8 |
+
|
9 |
+
import litellm
|
10 |
+
import litellm._logging
|
11 |
+
from litellm import verbose_logger
|
12 |
+
from litellm.constants import (
|
13 |
+
DEFAULT_MAX_LRU_CACHE_SIZE,
|
14 |
+
DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND,
|
15 |
+
)
|
16 |
+
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
17 |
+
StandardBuiltInToolCostTracking,
|
18 |
+
)
|
19 |
+
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
20 |
+
_generic_cost_per_character,
|
21 |
+
generic_cost_per_token,
|
22 |
+
select_cost_metric_for_model,
|
23 |
+
)
|
24 |
+
from litellm.llms.anthropic.cost_calculation import (
|
25 |
+
cost_per_token as anthropic_cost_per_token,
|
26 |
+
)
|
27 |
+
from litellm.llms.azure.cost_calculation import (
|
28 |
+
cost_per_token as azure_openai_cost_per_token,
|
29 |
+
)
|
30 |
+
from litellm.llms.bedrock.image.cost_calculator import (
|
31 |
+
cost_calculator as bedrock_image_cost_calculator,
|
32 |
+
)
|
33 |
+
from litellm.llms.databricks.cost_calculator import (
|
34 |
+
cost_per_token as databricks_cost_per_token,
|
35 |
+
)
|
36 |
+
from litellm.llms.deepseek.cost_calculator import (
|
37 |
+
cost_per_token as deepseek_cost_per_token,
|
38 |
+
)
|
39 |
+
from litellm.llms.fireworks_ai.cost_calculator import (
|
40 |
+
cost_per_token as fireworks_ai_cost_per_token,
|
41 |
+
)
|
42 |
+
from litellm.llms.gemini.cost_calculator import cost_per_token as gemini_cost_per_token
|
43 |
+
from litellm.llms.openai.cost_calculation import (
|
44 |
+
cost_per_second as openai_cost_per_second,
|
45 |
+
)
|
46 |
+
from litellm.llms.openai.cost_calculation import cost_per_token as openai_cost_per_token
|
47 |
+
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
|
48 |
+
from litellm.llms.vertex_ai.cost_calculator import (
|
49 |
+
cost_per_character as google_cost_per_character,
|
50 |
+
)
|
51 |
+
from litellm.llms.vertex_ai.cost_calculator import (
|
52 |
+
cost_per_token as google_cost_per_token,
|
53 |
+
)
|
54 |
+
from litellm.llms.vertex_ai.cost_calculator import cost_router as google_cost_router
|
55 |
+
from litellm.llms.vertex_ai.image_generation.cost_calculator import (
|
56 |
+
cost_calculator as vertex_ai_image_cost_calculator,
|
57 |
+
)
|
58 |
+
from litellm.responses.utils import ResponseAPILoggingUtils
|
59 |
+
from litellm.types.llms.openai import (
|
60 |
+
HttpxBinaryResponseContent,
|
61 |
+
ImageGenerationRequestQuality,
|
62 |
+
OpenAIModerationResponse,
|
63 |
+
OpenAIRealtimeStreamList,
|
64 |
+
OpenAIRealtimeStreamResponseBaseObject,
|
65 |
+
OpenAIRealtimeStreamSessionEvents,
|
66 |
+
ResponseAPIUsage,
|
67 |
+
ResponsesAPIResponse,
|
68 |
+
)
|
69 |
+
from litellm.types.rerank import RerankBilledUnits, RerankResponse
|
70 |
+
from litellm.types.utils import (
|
71 |
+
CallTypesLiteral,
|
72 |
+
LiteLLMRealtimeStreamLoggingObject,
|
73 |
+
LlmProviders,
|
74 |
+
LlmProvidersSet,
|
75 |
+
ModelInfo,
|
76 |
+
PassthroughCallTypes,
|
77 |
+
StandardBuiltInToolsParams,
|
78 |
+
Usage,
|
79 |
+
)
|
80 |
+
from litellm.utils import (
|
81 |
+
CallTypes,
|
82 |
+
CostPerToken,
|
83 |
+
EmbeddingResponse,
|
84 |
+
ImageResponse,
|
85 |
+
ModelResponse,
|
86 |
+
ProviderConfigManager,
|
87 |
+
TextCompletionResponse,
|
88 |
+
TranscriptionResponse,
|
89 |
+
_cached_get_model_info_helper,
|
90 |
+
token_counter,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def _cost_per_token_custom_pricing_helper(
|
95 |
+
prompt_tokens: float = 0,
|
96 |
+
completion_tokens: float = 0,
|
97 |
+
response_time_ms: Optional[float] = 0.0,
|
98 |
+
### CUSTOM PRICING ###
|
99 |
+
custom_cost_per_token: Optional[CostPerToken] = None,
|
100 |
+
custom_cost_per_second: Optional[float] = None,
|
101 |
+
) -> Optional[Tuple[float, float]]:
|
102 |
+
"""Internal helper function for calculating cost, if custom pricing given"""
|
103 |
+
if custom_cost_per_token is None and custom_cost_per_second is None:
|
104 |
+
return None
|
105 |
+
|
106 |
+
if custom_cost_per_token is not None:
|
107 |
+
input_cost = custom_cost_per_token["input_cost_per_token"] * prompt_tokens
|
108 |
+
output_cost = custom_cost_per_token["output_cost_per_token"] * completion_tokens
|
109 |
+
return input_cost, output_cost
|
110 |
+
elif custom_cost_per_second is not None:
|
111 |
+
output_cost = custom_cost_per_second * response_time_ms / 1000 # type: ignore
|
112 |
+
return 0, output_cost
|
113 |
+
|
114 |
+
return None
|
115 |
+
|
116 |
+
|
117 |
+
def cost_per_token( # noqa: PLR0915
|
118 |
+
model: str = "",
|
119 |
+
prompt_tokens: int = 0,
|
120 |
+
completion_tokens: int = 0,
|
121 |
+
response_time_ms: Optional[float] = 0.0,
|
122 |
+
custom_llm_provider: Optional[str] = None,
|
123 |
+
region_name=None,
|
124 |
+
### CHARACTER PRICING ###
|
125 |
+
prompt_characters: Optional[int] = None,
|
126 |
+
completion_characters: Optional[int] = None,
|
127 |
+
### PROMPT CACHING PRICING ### - used for anthropic
|
128 |
+
cache_creation_input_tokens: Optional[int] = 0,
|
129 |
+
cache_read_input_tokens: Optional[int] = 0,
|
130 |
+
### CUSTOM PRICING ###
|
131 |
+
custom_cost_per_token: Optional[CostPerToken] = None,
|
132 |
+
custom_cost_per_second: Optional[float] = None,
|
133 |
+
### NUMBER OF QUERIES ###
|
134 |
+
number_of_queries: Optional[int] = None,
|
135 |
+
### USAGE OBJECT ###
|
136 |
+
usage_object: Optional[Usage] = None, # just read the usage object if provided
|
137 |
+
### BILLED UNITS ###
|
138 |
+
rerank_billed_units: Optional[RerankBilledUnits] = None,
|
139 |
+
### CALL TYPE ###
|
140 |
+
call_type: CallTypesLiteral = "completion",
|
141 |
+
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
|
142 |
+
) -> Tuple[float, float]: # type: ignore
|
143 |
+
"""
|
144 |
+
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
145 |
+
|
146 |
+
Parameters:
|
147 |
+
model (str): The name of the model to use. Default is ""
|
148 |
+
prompt_tokens (int): The number of tokens in the prompt.
|
149 |
+
completion_tokens (int): The number of tokens in the completion.
|
150 |
+
response_time (float): The amount of time, in milliseconds, it took the call to complete.
|
151 |
+
prompt_characters (float): The number of characters in the prompt. Used for vertex ai cost calculation.
|
152 |
+
completion_characters (float): The number of characters in the completion response. Used for vertex ai cost calculation.
|
153 |
+
custom_llm_provider (str): The llm provider to whom the call was made (see init.py for full list)
|
154 |
+
custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
|
155 |
+
custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
|
156 |
+
call_type: Optional[str]: the call type
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
tuple: A tuple containing the cost in USD dollars for prompt tokens and completion tokens, respectively.
|
160 |
+
"""
|
161 |
+
if model is None:
|
162 |
+
raise Exception("Invalid arg. Model cannot be none.")
|
163 |
+
|
164 |
+
## RECONSTRUCT USAGE BLOCK ##
|
165 |
+
if usage_object is not None:
|
166 |
+
usage_block = usage_object
|
167 |
+
else:
|
168 |
+
usage_block = Usage(
|
169 |
+
prompt_tokens=prompt_tokens,
|
170 |
+
completion_tokens=completion_tokens,
|
171 |
+
total_tokens=prompt_tokens + completion_tokens,
|
172 |
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
173 |
+
cache_read_input_tokens=cache_read_input_tokens,
|
174 |
+
)
|
175 |
+
|
176 |
+
## CUSTOM PRICING ##
|
177 |
+
response_cost = _cost_per_token_custom_pricing_helper(
|
178 |
+
prompt_tokens=prompt_tokens,
|
179 |
+
completion_tokens=completion_tokens,
|
180 |
+
response_time_ms=response_time_ms,
|
181 |
+
custom_cost_per_second=custom_cost_per_second,
|
182 |
+
custom_cost_per_token=custom_cost_per_token,
|
183 |
+
)
|
184 |
+
|
185 |
+
if response_cost is not None:
|
186 |
+
return response_cost[0], response_cost[1]
|
187 |
+
|
188 |
+
# given
|
189 |
+
prompt_tokens_cost_usd_dollar: float = 0
|
190 |
+
completion_tokens_cost_usd_dollar: float = 0
|
191 |
+
model_cost_ref = litellm.model_cost
|
192 |
+
model_with_provider = model
|
193 |
+
if custom_llm_provider is not None:
|
194 |
+
model_with_provider = custom_llm_provider + "/" + model
|
195 |
+
if region_name is not None:
|
196 |
+
model_with_provider_and_region = (
|
197 |
+
f"{custom_llm_provider}/{region_name}/{model}"
|
198 |
+
)
|
199 |
+
if (
|
200 |
+
model_with_provider_and_region in model_cost_ref
|
201 |
+
): # use region based pricing, if it's available
|
202 |
+
model_with_provider = model_with_provider_and_region
|
203 |
+
else:
|
204 |
+
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
205 |
+
model_without_prefix = model
|
206 |
+
model_parts = model.split("/", 1)
|
207 |
+
if len(model_parts) > 1:
|
208 |
+
model_without_prefix = model_parts[1]
|
209 |
+
else:
|
210 |
+
model_without_prefix = model
|
211 |
+
"""
|
212 |
+
Code block that formats model to lookup in litellm.model_cost
|
213 |
+
Option1. model = "bedrock/ap-northeast-1/anthropic.claude-instant-v1". This is the most accurate since it is region based. Should always be option 1
|
214 |
+
Option2. model = "openai/gpt-4" - model = provider/model
|
215 |
+
Option3. model = "anthropic.claude-3" - model = model
|
216 |
+
"""
|
217 |
+
if (
|
218 |
+
model_with_provider in model_cost_ref
|
219 |
+
): # Option 2. use model with provider, model = "openai/gpt-4"
|
220 |
+
model = model_with_provider
|
221 |
+
elif model in model_cost_ref: # Option 1. use model passed, model="gpt-4"
|
222 |
+
model = model
|
223 |
+
elif (
|
224 |
+
model_without_prefix in model_cost_ref
|
225 |
+
): # Option 3. if user passed model="bedrock/anthropic.claude-3", use model="anthropic.claude-3"
|
226 |
+
model = model_without_prefix
|
227 |
+
|
228 |
+
# see this https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models
|
229 |
+
if call_type == "speech" or call_type == "aspeech":
|
230 |
+
speech_model_info = litellm.get_model_info(
|
231 |
+
model=model_without_prefix, custom_llm_provider=custom_llm_provider
|
232 |
+
)
|
233 |
+
cost_metric = select_cost_metric_for_model(speech_model_info)
|
234 |
+
prompt_cost: float = 0.0
|
235 |
+
completion_cost: float = 0.0
|
236 |
+
if cost_metric == "cost_per_character":
|
237 |
+
if prompt_characters is None:
|
238 |
+
raise ValueError(
|
239 |
+
"prompt_characters must be provided for tts calls. prompt_characters={}, model={}, custom_llm_provider={}, call_type={}".format(
|
240 |
+
prompt_characters,
|
241 |
+
model,
|
242 |
+
custom_llm_provider,
|
243 |
+
call_type,
|
244 |
+
)
|
245 |
+
)
|
246 |
+
_prompt_cost, _completion_cost = _generic_cost_per_character(
|
247 |
+
model=model_without_prefix,
|
248 |
+
custom_llm_provider=custom_llm_provider,
|
249 |
+
prompt_characters=prompt_characters,
|
250 |
+
completion_characters=0,
|
251 |
+
custom_prompt_cost=None,
|
252 |
+
custom_completion_cost=0,
|
253 |
+
)
|
254 |
+
if _prompt_cost is None or _completion_cost is None:
|
255 |
+
raise ValueError(
|
256 |
+
"cost for tts call is None. prompt_cost={}, completion_cost={}, model={}, custom_llm_provider={}, prompt_characters={}, completion_characters={}".format(
|
257 |
+
_prompt_cost,
|
258 |
+
_completion_cost,
|
259 |
+
model_without_prefix,
|
260 |
+
custom_llm_provider,
|
261 |
+
prompt_characters,
|
262 |
+
completion_characters,
|
263 |
+
)
|
264 |
+
)
|
265 |
+
prompt_cost = _prompt_cost
|
266 |
+
completion_cost = _completion_cost
|
267 |
+
elif cost_metric == "cost_per_token":
|
268 |
+
prompt_cost, completion_cost = generic_cost_per_token(
|
269 |
+
model=model_without_prefix,
|
270 |
+
usage=usage_block,
|
271 |
+
custom_llm_provider=custom_llm_provider,
|
272 |
+
)
|
273 |
+
|
274 |
+
return prompt_cost, completion_cost
|
275 |
+
elif call_type == "arerank" or call_type == "rerank":
|
276 |
+
return rerank_cost(
|
277 |
+
model=model,
|
278 |
+
custom_llm_provider=custom_llm_provider,
|
279 |
+
billed_units=rerank_billed_units,
|
280 |
+
)
|
281 |
+
elif (
|
282 |
+
call_type == "aretrieve_batch"
|
283 |
+
or call_type == "retrieve_batch"
|
284 |
+
or call_type == CallTypes.aretrieve_batch
|
285 |
+
or call_type == CallTypes.retrieve_batch
|
286 |
+
):
|
287 |
+
return batch_cost_calculator(
|
288 |
+
usage=usage_block, model=model, custom_llm_provider=custom_llm_provider
|
289 |
+
)
|
290 |
+
elif call_type == "atranscription" or call_type == "transcription":
|
291 |
+
return openai_cost_per_second(
|
292 |
+
model=model,
|
293 |
+
custom_llm_provider=custom_llm_provider,
|
294 |
+
duration=audio_transcription_file_duration,
|
295 |
+
)
|
296 |
+
elif custom_llm_provider == "vertex_ai":
|
297 |
+
cost_router = google_cost_router(
|
298 |
+
model=model_without_prefix,
|
299 |
+
custom_llm_provider=custom_llm_provider,
|
300 |
+
call_type=call_type,
|
301 |
+
)
|
302 |
+
if cost_router == "cost_per_character":
|
303 |
+
return google_cost_per_character(
|
304 |
+
model=model_without_prefix,
|
305 |
+
custom_llm_provider=custom_llm_provider,
|
306 |
+
prompt_characters=prompt_characters,
|
307 |
+
completion_characters=completion_characters,
|
308 |
+
usage=usage_block,
|
309 |
+
)
|
310 |
+
elif cost_router == "cost_per_token":
|
311 |
+
return google_cost_per_token(
|
312 |
+
model=model_without_prefix,
|
313 |
+
custom_llm_provider=custom_llm_provider,
|
314 |
+
usage=usage_block,
|
315 |
+
)
|
316 |
+
elif custom_llm_provider == "anthropic":
|
317 |
+
return anthropic_cost_per_token(model=model, usage=usage_block)
|
318 |
+
elif custom_llm_provider == "openai":
|
319 |
+
return openai_cost_per_token(model=model, usage=usage_block)
|
320 |
+
elif custom_llm_provider == "databricks":
|
321 |
+
return databricks_cost_per_token(model=model, usage=usage_block)
|
322 |
+
elif custom_llm_provider == "fireworks_ai":
|
323 |
+
return fireworks_ai_cost_per_token(model=model, usage=usage_block)
|
324 |
+
elif custom_llm_provider == "azure":
|
325 |
+
return azure_openai_cost_per_token(
|
326 |
+
model=model, usage=usage_block, response_time_ms=response_time_ms
|
327 |
+
)
|
328 |
+
elif custom_llm_provider == "gemini":
|
329 |
+
return gemini_cost_per_token(model=model, usage=usage_block)
|
330 |
+
elif custom_llm_provider == "deepseek":
|
331 |
+
return deepseek_cost_per_token(model=model, usage=usage_block)
|
332 |
+
else:
|
333 |
+
model_info = _cached_get_model_info_helper(
|
334 |
+
model=model, custom_llm_provider=custom_llm_provider
|
335 |
+
)
|
336 |
+
|
337 |
+
if model_info["input_cost_per_token"] > 0:
|
338 |
+
## COST PER TOKEN ##
|
339 |
+
prompt_tokens_cost_usd_dollar = (
|
340 |
+
model_info["input_cost_per_token"] * prompt_tokens
|
341 |
+
)
|
342 |
+
elif (
|
343 |
+
model_info.get("input_cost_per_second", None) is not None
|
344 |
+
and response_time_ms is not None
|
345 |
+
):
|
346 |
+
verbose_logger.debug(
|
347 |
+
"For model=%s - input_cost_per_second: %s; response time: %s",
|
348 |
+
model,
|
349 |
+
model_info.get("input_cost_per_second", None),
|
350 |
+
response_time_ms,
|
351 |
+
)
|
352 |
+
## COST PER SECOND ##
|
353 |
+
prompt_tokens_cost_usd_dollar = (
|
354 |
+
model_info["input_cost_per_second"] * response_time_ms / 1000 # type: ignore
|
355 |
+
)
|
356 |
+
|
357 |
+
if model_info["output_cost_per_token"] > 0:
|
358 |
+
completion_tokens_cost_usd_dollar = (
|
359 |
+
model_info["output_cost_per_token"] * completion_tokens
|
360 |
+
)
|
361 |
+
elif (
|
362 |
+
model_info.get("output_cost_per_second", None) is not None
|
363 |
+
and response_time_ms is not None
|
364 |
+
):
|
365 |
+
verbose_logger.debug(
|
366 |
+
"For model=%s - output_cost_per_second: %s; response time: %s",
|
367 |
+
model,
|
368 |
+
model_info.get("output_cost_per_second", None),
|
369 |
+
response_time_ms,
|
370 |
+
)
|
371 |
+
## COST PER SECOND ##
|
372 |
+
completion_tokens_cost_usd_dollar = (
|
373 |
+
model_info["output_cost_per_second"] * response_time_ms / 1000 # type: ignore
|
374 |
+
)
|
375 |
+
|
376 |
+
verbose_logger.debug(
|
377 |
+
"Returned custom cost for model=%s - prompt_tokens_cost_usd_dollar: %s, completion_tokens_cost_usd_dollar: %s",
|
378 |
+
model,
|
379 |
+
prompt_tokens_cost_usd_dollar,
|
380 |
+
completion_tokens_cost_usd_dollar,
|
381 |
+
)
|
382 |
+
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
383 |
+
|
384 |
+
|
385 |
+
def get_replicate_completion_pricing(completion_response: dict, total_time=0.0):
|
386 |
+
# see https://replicate.com/pricing
|
387 |
+
# for all litellm currently supported LLMs, almost all requests go to a100_80gb
|
388 |
+
a100_80gb_price_per_second_public = DEFAULT_REPLICATE_GPU_PRICE_PER_SECOND # assume all calls sent to A100 80GB for now
|
389 |
+
if total_time == 0.0: # total time is in ms
|
390 |
+
start_time = completion_response.get("created", time.time())
|
391 |
+
end_time = getattr(completion_response, "ended", time.time())
|
392 |
+
total_time = end_time - start_time
|
393 |
+
|
394 |
+
return a100_80gb_price_per_second_public * total_time / 1000
|
395 |
+
|
396 |
+
|
397 |
+
def has_hidden_params(obj: Any) -> bool:
|
398 |
+
return hasattr(obj, "_hidden_params")
|
399 |
+
|
400 |
+
|
401 |
+
def _get_provider_for_cost_calc(
|
402 |
+
model: Optional[str],
|
403 |
+
custom_llm_provider: Optional[str] = None,
|
404 |
+
) -> Optional[str]:
|
405 |
+
if custom_llm_provider is not None:
|
406 |
+
return custom_llm_provider
|
407 |
+
if model is None:
|
408 |
+
return None
|
409 |
+
try:
|
410 |
+
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
411 |
+
except Exception as e:
|
412 |
+
verbose_logger.debug(
|
413 |
+
f"litellm.cost_calculator.py::_get_provider_for_cost_calc() - Error inferring custom_llm_provider - {str(e)}"
|
414 |
+
)
|
415 |
+
return None
|
416 |
+
|
417 |
+
return custom_llm_provider
|
418 |
+
|
419 |
+
|
420 |
+
def _select_model_name_for_cost_calc(
|
421 |
+
model: Optional[str],
|
422 |
+
completion_response: Optional[Any],
|
423 |
+
base_model: Optional[str] = None,
|
424 |
+
custom_pricing: Optional[bool] = None,
|
425 |
+
custom_llm_provider: Optional[str] = None,
|
426 |
+
router_model_id: Optional[str] = None,
|
427 |
+
) -> Optional[str]:
|
428 |
+
"""
|
429 |
+
1. If custom pricing is true, return received model name
|
430 |
+
2. If base_model is set (e.g. for azure models), return that
|
431 |
+
3. If completion response has model set return that
|
432 |
+
4. Check if model is passed in return that
|
433 |
+
"""
|
434 |
+
|
435 |
+
return_model: Optional[str] = None
|
436 |
+
region_name: Optional[str] = None
|
437 |
+
custom_llm_provider = _get_provider_for_cost_calc(
|
438 |
+
model=model, custom_llm_provider=custom_llm_provider
|
439 |
+
)
|
440 |
+
|
441 |
+
completion_response_model: Optional[str] = None
|
442 |
+
if completion_response is not None:
|
443 |
+
if isinstance(completion_response, BaseModel):
|
444 |
+
completion_response_model = getattr(completion_response, "model", None)
|
445 |
+
elif isinstance(completion_response, dict):
|
446 |
+
completion_response_model = completion_response.get("model", None)
|
447 |
+
hidden_params: Optional[dict] = getattr(completion_response, "_hidden_params", None)
|
448 |
+
|
449 |
+
if custom_pricing is True:
|
450 |
+
if router_model_id is not None and router_model_id in litellm.model_cost:
|
451 |
+
return_model = router_model_id
|
452 |
+
else:
|
453 |
+
return_model = model
|
454 |
+
|
455 |
+
if base_model is not None:
|
456 |
+
return_model = base_model
|
457 |
+
|
458 |
+
if completion_response_model is None and hidden_params is not None:
|
459 |
+
if (
|
460 |
+
hidden_params.get("model", None) is not None
|
461 |
+
and len(hidden_params["model"]) > 0
|
462 |
+
):
|
463 |
+
return_model = hidden_params.get("model", model)
|
464 |
+
if hidden_params is not None and hidden_params.get("region_name", None) is not None:
|
465 |
+
region_name = hidden_params.get("region_name", None)
|
466 |
+
|
467 |
+
if return_model is None and completion_response_model is not None:
|
468 |
+
return_model = completion_response_model
|
469 |
+
|
470 |
+
if return_model is None and model is not None:
|
471 |
+
return_model = model
|
472 |
+
|
473 |
+
if (
|
474 |
+
return_model is not None
|
475 |
+
and custom_llm_provider is not None
|
476 |
+
and not _model_contains_known_llm_provider(return_model)
|
477 |
+
): # add provider prefix if not already present, to match model_cost
|
478 |
+
if region_name is not None:
|
479 |
+
return_model = f"{custom_llm_provider}/{region_name}/{return_model}"
|
480 |
+
else:
|
481 |
+
return_model = f"{custom_llm_provider}/{return_model}"
|
482 |
+
|
483 |
+
return return_model
|
484 |
+
|
485 |
+
|
486 |
+
@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE)
|
487 |
+
def _model_contains_known_llm_provider(model: str) -> bool:
|
488 |
+
"""
|
489 |
+
Check if the model contains a known llm provider
|
490 |
+
"""
|
491 |
+
_provider_prefix = model.split("/")[0]
|
492 |
+
return _provider_prefix in LlmProvidersSet
|
493 |
+
|
494 |
+
|
495 |
+
def _get_usage_object(
|
496 |
+
completion_response: Any,
|
497 |
+
) -> Optional[Usage]:
|
498 |
+
usage_obj = cast(
|
499 |
+
Union[Usage, ResponseAPIUsage, dict, BaseModel],
|
500 |
+
(
|
501 |
+
completion_response.get("usage")
|
502 |
+
if isinstance(completion_response, dict)
|
503 |
+
else getattr(completion_response, "get", lambda x: None)("usage")
|
504 |
+
),
|
505 |
+
)
|
506 |
+
|
507 |
+
if usage_obj is None:
|
508 |
+
return None
|
509 |
+
if isinstance(usage_obj, Usage):
|
510 |
+
return usage_obj
|
511 |
+
elif (
|
512 |
+
usage_obj is not None
|
513 |
+
and (isinstance(usage_obj, dict) or isinstance(usage_obj, ResponseAPIUsage))
|
514 |
+
and ResponseAPILoggingUtils._is_response_api_usage(usage_obj)
|
515 |
+
):
|
516 |
+
return ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
517 |
+
usage_obj
|
518 |
+
)
|
519 |
+
elif isinstance(usage_obj, dict):
|
520 |
+
return Usage(**usage_obj)
|
521 |
+
elif isinstance(usage_obj, BaseModel):
|
522 |
+
return Usage(**usage_obj.model_dump())
|
523 |
+
else:
|
524 |
+
verbose_logger.debug(
|
525 |
+
f"Unknown usage object type: {type(usage_obj)}, usage_obj: {usage_obj}"
|
526 |
+
)
|
527 |
+
return None
|
528 |
+
|
529 |
+
|
530 |
+
def _is_known_usage_objects(usage_obj):
|
531 |
+
"""Returns True if the usage obj is a known Usage type"""
|
532 |
+
return isinstance(usage_obj, litellm.Usage) or isinstance(
|
533 |
+
usage_obj, ResponseAPIUsage
|
534 |
+
)
|
535 |
+
|
536 |
+
|
537 |
+
def _infer_call_type(
|
538 |
+
call_type: Optional[CallTypesLiteral], completion_response: Any
|
539 |
+
) -> Optional[CallTypesLiteral]:
|
540 |
+
if call_type is not None:
|
541 |
+
return call_type
|
542 |
+
|
543 |
+
if completion_response is None:
|
544 |
+
return None
|
545 |
+
|
546 |
+
if isinstance(completion_response, ModelResponse):
|
547 |
+
return "completion"
|
548 |
+
elif isinstance(completion_response, EmbeddingResponse):
|
549 |
+
return "embedding"
|
550 |
+
elif isinstance(completion_response, TranscriptionResponse):
|
551 |
+
return "transcription"
|
552 |
+
elif isinstance(completion_response, HttpxBinaryResponseContent):
|
553 |
+
return "speech"
|
554 |
+
elif isinstance(completion_response, RerankResponse):
|
555 |
+
return "rerank"
|
556 |
+
elif isinstance(completion_response, ImageResponse):
|
557 |
+
return "image_generation"
|
558 |
+
elif isinstance(completion_response, TextCompletionResponse):
|
559 |
+
return "text_completion"
|
560 |
+
|
561 |
+
return call_type
|
562 |
+
|
563 |
+
|
564 |
+
def completion_cost( # noqa: PLR0915
|
565 |
+
completion_response=None,
|
566 |
+
model: Optional[str] = None,
|
567 |
+
prompt="",
|
568 |
+
messages: List = [],
|
569 |
+
completion="",
|
570 |
+
total_time: Optional[float] = 0.0, # used for replicate, sagemaker
|
571 |
+
call_type: Optional[CallTypesLiteral] = None,
|
572 |
+
### REGION ###
|
573 |
+
custom_llm_provider=None,
|
574 |
+
region_name=None, # used for bedrock pricing
|
575 |
+
### IMAGE GEN ###
|
576 |
+
size: Optional[str] = None,
|
577 |
+
quality: Optional[str] = None,
|
578 |
+
n: Optional[int] = None, # number of images
|
579 |
+
### CUSTOM PRICING ###
|
580 |
+
custom_cost_per_token: Optional[CostPerToken] = None,
|
581 |
+
custom_cost_per_second: Optional[float] = None,
|
582 |
+
optional_params: Optional[dict] = None,
|
583 |
+
custom_pricing: Optional[bool] = None,
|
584 |
+
base_model: Optional[str] = None,
|
585 |
+
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
586 |
+
litellm_model_name: Optional[str] = None,
|
587 |
+
router_model_id: Optional[str] = None,
|
588 |
+
) -> float:
|
589 |
+
"""
|
590 |
+
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
|
591 |
+
|
592 |
+
Parameters:
|
593 |
+
completion_response (litellm.ModelResponses): [Required] The response received from a LiteLLM completion request.
|
594 |
+
|
595 |
+
[OPTIONAL PARAMS]
|
596 |
+
model (str): Optional. The name of the language model used in the completion calls
|
597 |
+
prompt (str): Optional. The input prompt passed to the llm
|
598 |
+
completion (str): Optional. The output completion text from the llm
|
599 |
+
total_time (float, int): Optional. (Only used for Replicate LLMs) The total time used for the request in seconds
|
600 |
+
custom_cost_per_token: Optional[CostPerToken]: the cost per input + output token for the llm api call.
|
601 |
+
custom_cost_per_second: Optional[float]: the cost per second for the llm api call.
|
602 |
+
|
603 |
+
Returns:
|
604 |
+
float: The cost in USD dollars for the completion based on the provided parameters.
|
605 |
+
|
606 |
+
Exceptions:
|
607 |
+
Raises exception if model not in the litellm model cost map. Register model, via custom pricing or PR - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
|
608 |
+
|
609 |
+
|
610 |
+
Note:
|
611 |
+
- If completion_response is provided, the function extracts token information and the model name from it.
|
612 |
+
- If completion_response is not provided, the function calculates token counts based on the model and input text.
|
613 |
+
- The cost is calculated based on the model, prompt tokens, and completion tokens.
|
614 |
+
- For certain models containing "togethercomputer" in the name, prices are based on the model size.
|
615 |
+
- For un-mapped Replicate models, the cost is calculated based on the total time used for the request.
|
616 |
+
"""
|
617 |
+
try:
|
618 |
+
call_type = _infer_call_type(call_type, completion_response) or "completion"
|
619 |
+
|
620 |
+
if (
|
621 |
+
(call_type == "aimage_generation" or call_type == "image_generation")
|
622 |
+
and model is not None
|
623 |
+
and isinstance(model, str)
|
624 |
+
and len(model) == 0
|
625 |
+
and custom_llm_provider == "azure"
|
626 |
+
):
|
627 |
+
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
|
628 |
+
# Handle Inputs to completion_cost
|
629 |
+
prompt_tokens = 0
|
630 |
+
prompt_characters: Optional[int] = None
|
631 |
+
completion_tokens = 0
|
632 |
+
completion_characters: Optional[int] = None
|
633 |
+
cache_creation_input_tokens: Optional[int] = None
|
634 |
+
cache_read_input_tokens: Optional[int] = None
|
635 |
+
audio_transcription_file_duration: float = 0.0
|
636 |
+
cost_per_token_usage_object: Optional[Usage] = _get_usage_object(
|
637 |
+
completion_response=completion_response
|
638 |
+
)
|
639 |
+
rerank_billed_units: Optional[RerankBilledUnits] = None
|
640 |
+
|
641 |
+
selected_model = _select_model_name_for_cost_calc(
|
642 |
+
model=model,
|
643 |
+
completion_response=completion_response,
|
644 |
+
custom_llm_provider=custom_llm_provider,
|
645 |
+
custom_pricing=custom_pricing,
|
646 |
+
base_model=base_model,
|
647 |
+
router_model_id=router_model_id,
|
648 |
+
)
|
649 |
+
|
650 |
+
potential_model_names = [selected_model]
|
651 |
+
if model is not None:
|
652 |
+
potential_model_names.append(model)
|
653 |
+
for idx, model in enumerate(potential_model_names):
|
654 |
+
try:
|
655 |
+
verbose_logger.info(
|
656 |
+
f"selected model name for cost calculation: {model}"
|
657 |
+
)
|
658 |
+
|
659 |
+
if completion_response is not None and (
|
660 |
+
isinstance(completion_response, BaseModel)
|
661 |
+
or isinstance(completion_response, dict)
|
662 |
+
): # tts returns a custom class
|
663 |
+
if isinstance(completion_response, dict):
|
664 |
+
usage_obj: Optional[
|
665 |
+
Union[dict, Usage]
|
666 |
+
] = completion_response.get("usage", {})
|
667 |
+
else:
|
668 |
+
usage_obj = getattr(completion_response, "usage", {})
|
669 |
+
if isinstance(usage_obj, BaseModel) and not _is_known_usage_objects(
|
670 |
+
usage_obj=usage_obj
|
671 |
+
):
|
672 |
+
setattr(
|
673 |
+
completion_response,
|
674 |
+
"usage",
|
675 |
+
litellm.Usage(**usage_obj.model_dump()),
|
676 |
+
)
|
677 |
+
if usage_obj is None:
|
678 |
+
_usage = {}
|
679 |
+
elif isinstance(usage_obj, BaseModel):
|
680 |
+
_usage = usage_obj.model_dump()
|
681 |
+
else:
|
682 |
+
_usage = usage_obj
|
683 |
+
|
684 |
+
if ResponseAPILoggingUtils._is_response_api_usage(_usage):
|
685 |
+
_usage = ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
686 |
+
_usage
|
687 |
+
).model_dump()
|
688 |
+
|
689 |
+
# get input/output tokens from completion_response
|
690 |
+
prompt_tokens = _usage.get("prompt_tokens", 0)
|
691 |
+
completion_tokens = _usage.get("completion_tokens", 0)
|
692 |
+
cache_creation_input_tokens = _usage.get(
|
693 |
+
"cache_creation_input_tokens", 0
|
694 |
+
)
|
695 |
+
cache_read_input_tokens = _usage.get("cache_read_input_tokens", 0)
|
696 |
+
if (
|
697 |
+
"prompt_tokens_details" in _usage
|
698 |
+
and _usage["prompt_tokens_details"] != {}
|
699 |
+
and _usage["prompt_tokens_details"]
|
700 |
+
):
|
701 |
+
prompt_tokens_details = _usage.get("prompt_tokens_details", {})
|
702 |
+
cache_read_input_tokens = prompt_tokens_details.get(
|
703 |
+
"cached_tokens", 0
|
704 |
+
)
|
705 |
+
|
706 |
+
total_time = getattr(completion_response, "_response_ms", 0)
|
707 |
+
|
708 |
+
hidden_params = getattr(completion_response, "_hidden_params", None)
|
709 |
+
if hidden_params is not None:
|
710 |
+
custom_llm_provider = hidden_params.get(
|
711 |
+
"custom_llm_provider", custom_llm_provider or None
|
712 |
+
)
|
713 |
+
region_name = hidden_params.get("region_name", region_name)
|
714 |
+
size = hidden_params.get("optional_params", {}).get(
|
715 |
+
"size", "1024-x-1024"
|
716 |
+
) # openai default
|
717 |
+
quality = hidden_params.get("optional_params", {}).get(
|
718 |
+
"quality", "standard"
|
719 |
+
) # openai default
|
720 |
+
n = hidden_params.get("optional_params", {}).get(
|
721 |
+
"n", 1
|
722 |
+
) # openai default
|
723 |
+
else:
|
724 |
+
if model is None:
|
725 |
+
raise ValueError(
|
726 |
+
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
727 |
+
)
|
728 |
+
if len(messages) > 0:
|
729 |
+
prompt_tokens = token_counter(model=model, messages=messages)
|
730 |
+
elif len(prompt) > 0:
|
731 |
+
prompt_tokens = token_counter(model=model, text=prompt)
|
732 |
+
completion_tokens = token_counter(model=model, text=completion)
|
733 |
+
|
734 |
+
if model is None:
|
735 |
+
raise ValueError(
|
736 |
+
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
737 |
+
)
|
738 |
+
if custom_llm_provider is None:
|
739 |
+
try:
|
740 |
+
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
741 |
+
model=model
|
742 |
+
) # strip the llm provider from the model name -> for image gen cost calculation
|
743 |
+
except Exception as e:
|
744 |
+
verbose_logger.debug(
|
745 |
+
"litellm.cost_calculator.py::completion_cost() - Error inferring custom_llm_provider - {}".format(
|
746 |
+
str(e)
|
747 |
+
)
|
748 |
+
)
|
749 |
+
if (
|
750 |
+
call_type == CallTypes.image_generation.value
|
751 |
+
or call_type == CallTypes.aimage_generation.value
|
752 |
+
or call_type
|
753 |
+
== PassthroughCallTypes.passthrough_image_generation.value
|
754 |
+
):
|
755 |
+
### IMAGE GENERATION COST CALCULATION ###
|
756 |
+
if custom_llm_provider == "vertex_ai":
|
757 |
+
if isinstance(completion_response, ImageResponse):
|
758 |
+
return vertex_ai_image_cost_calculator(
|
759 |
+
model=model,
|
760 |
+
image_response=completion_response,
|
761 |
+
)
|
762 |
+
elif custom_llm_provider == "bedrock":
|
763 |
+
if isinstance(completion_response, ImageResponse):
|
764 |
+
return bedrock_image_cost_calculator(
|
765 |
+
model=model,
|
766 |
+
size=size,
|
767 |
+
image_response=completion_response,
|
768 |
+
optional_params=optional_params,
|
769 |
+
)
|
770 |
+
raise TypeError(
|
771 |
+
"completion_response must be of type ImageResponse for bedrock image cost calculation"
|
772 |
+
)
|
773 |
+
else:
|
774 |
+
return default_image_cost_calculator(
|
775 |
+
model=model,
|
776 |
+
quality=quality,
|
777 |
+
custom_llm_provider=custom_llm_provider,
|
778 |
+
n=n,
|
779 |
+
size=size,
|
780 |
+
optional_params=optional_params,
|
781 |
+
)
|
782 |
+
elif (
|
783 |
+
call_type == CallTypes.speech.value
|
784 |
+
or call_type == CallTypes.aspeech.value
|
785 |
+
):
|
786 |
+
prompt_characters = litellm.utils._count_characters(text=prompt)
|
787 |
+
elif (
|
788 |
+
call_type == CallTypes.atranscription.value
|
789 |
+
or call_type == CallTypes.transcription.value
|
790 |
+
):
|
791 |
+
audio_transcription_file_duration = getattr(
|
792 |
+
completion_response, "duration", 0.0
|
793 |
+
)
|
794 |
+
elif (
|
795 |
+
call_type == CallTypes.rerank.value
|
796 |
+
or call_type == CallTypes.arerank.value
|
797 |
+
):
|
798 |
+
if completion_response is not None and isinstance(
|
799 |
+
completion_response, RerankResponse
|
800 |
+
):
|
801 |
+
meta_obj = completion_response.meta
|
802 |
+
if meta_obj is not None:
|
803 |
+
billed_units = meta_obj.get("billed_units", {}) or {}
|
804 |
+
else:
|
805 |
+
billed_units = {}
|
806 |
+
|
807 |
+
rerank_billed_units = RerankBilledUnits(
|
808 |
+
search_units=billed_units.get("search_units"),
|
809 |
+
total_tokens=billed_units.get("total_tokens"),
|
810 |
+
)
|
811 |
+
|
812 |
+
search_units = (
|
813 |
+
billed_units.get("search_units") or 1
|
814 |
+
) # cohere charges per request by default.
|
815 |
+
completion_tokens = search_units
|
816 |
+
elif call_type == CallTypes.arealtime.value and isinstance(
|
817 |
+
completion_response, LiteLLMRealtimeStreamLoggingObject
|
818 |
+
):
|
819 |
+
if (
|
820 |
+
cost_per_token_usage_object is None
|
821 |
+
or custom_llm_provider is None
|
822 |
+
):
|
823 |
+
raise ValueError(
|
824 |
+
"usage object and custom_llm_provider must be provided for realtime stream cost calculation. Got cost_per_token_usage_object={}, custom_llm_provider={}".format(
|
825 |
+
cost_per_token_usage_object,
|
826 |
+
custom_llm_provider,
|
827 |
+
)
|
828 |
+
)
|
829 |
+
return handle_realtime_stream_cost_calculation(
|
830 |
+
results=completion_response.results,
|
831 |
+
combined_usage_object=cost_per_token_usage_object,
|
832 |
+
custom_llm_provider=custom_llm_provider,
|
833 |
+
litellm_model_name=model,
|
834 |
+
)
|
835 |
+
# Calculate cost based on prompt_tokens, completion_tokens
|
836 |
+
if (
|
837 |
+
"togethercomputer" in model
|
838 |
+
or "together_ai" in model
|
839 |
+
or custom_llm_provider == "together_ai"
|
840 |
+
):
|
841 |
+
# together ai prices based on size of llm
|
842 |
+
# get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
|
843 |
+
|
844 |
+
model = get_model_params_and_category(
|
845 |
+
model, call_type=CallTypes(call_type)
|
846 |
+
)
|
847 |
+
|
848 |
+
# replicate llms are calculate based on time for request running
|
849 |
+
# see https://replicate.com/pricing
|
850 |
+
elif (
|
851 |
+
model in litellm.replicate_models or "replicate" in model
|
852 |
+
) and model not in litellm.model_cost:
|
853 |
+
# for unmapped replicate model, default to replicate's time tracking logic
|
854 |
+
return get_replicate_completion_pricing(completion_response, total_time) # type: ignore
|
855 |
+
|
856 |
+
if model is None:
|
857 |
+
raise ValueError(
|
858 |
+
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
859 |
+
)
|
860 |
+
|
861 |
+
if (
|
862 |
+
custom_llm_provider is not None
|
863 |
+
and custom_llm_provider == "vertex_ai"
|
864 |
+
):
|
865 |
+
# Calculate the prompt characters + response characters
|
866 |
+
if len(messages) > 0:
|
867 |
+
prompt_string = litellm.utils.get_formatted_prompt(
|
868 |
+
data={"messages": messages}, call_type="completion"
|
869 |
+
)
|
870 |
+
|
871 |
+
prompt_characters = litellm.utils._count_characters(
|
872 |
+
text=prompt_string
|
873 |
+
)
|
874 |
+
if completion_response is not None and isinstance(
|
875 |
+
completion_response, ModelResponse
|
876 |
+
):
|
877 |
+
completion_string = litellm.utils.get_response_string(
|
878 |
+
response_obj=completion_response
|
879 |
+
)
|
880 |
+
completion_characters = litellm.utils._count_characters(
|
881 |
+
text=completion_string
|
882 |
+
)
|
883 |
+
|
884 |
+
(
|
885 |
+
prompt_tokens_cost_usd_dollar,
|
886 |
+
completion_tokens_cost_usd_dollar,
|
887 |
+
) = cost_per_token(
|
888 |
+
model=model,
|
889 |
+
prompt_tokens=prompt_tokens,
|
890 |
+
completion_tokens=completion_tokens,
|
891 |
+
custom_llm_provider=custom_llm_provider,
|
892 |
+
response_time_ms=total_time,
|
893 |
+
region_name=region_name,
|
894 |
+
custom_cost_per_second=custom_cost_per_second,
|
895 |
+
custom_cost_per_token=custom_cost_per_token,
|
896 |
+
prompt_characters=prompt_characters,
|
897 |
+
completion_characters=completion_characters,
|
898 |
+
cache_creation_input_tokens=cache_creation_input_tokens,
|
899 |
+
cache_read_input_tokens=cache_read_input_tokens,
|
900 |
+
usage_object=cost_per_token_usage_object,
|
901 |
+
call_type=cast(CallTypesLiteral, call_type),
|
902 |
+
audio_transcription_file_duration=audio_transcription_file_duration,
|
903 |
+
rerank_billed_units=rerank_billed_units,
|
904 |
+
)
|
905 |
+
_final_cost = (
|
906 |
+
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
907 |
+
)
|
908 |
+
_final_cost += (
|
909 |
+
StandardBuiltInToolCostTracking.get_cost_for_built_in_tools(
|
910 |
+
model=model,
|
911 |
+
response_object=completion_response,
|
912 |
+
standard_built_in_tools_params=standard_built_in_tools_params,
|
913 |
+
custom_llm_provider=custom_llm_provider,
|
914 |
+
)
|
915 |
+
)
|
916 |
+
return _final_cost
|
917 |
+
except Exception as e:
|
918 |
+
verbose_logger.debug(
|
919 |
+
"litellm.cost_calculator.py::completion_cost() - Error calculating cost for model={} - {}".format(
|
920 |
+
model, str(e)
|
921 |
+
)
|
922 |
+
)
|
923 |
+
if idx == len(potential_model_names) - 1:
|
924 |
+
raise e
|
925 |
+
raise Exception(
|
926 |
+
"Unable to calculat cost for received potential model names - {}".format(
|
927 |
+
potential_model_names
|
928 |
+
)
|
929 |
+
)
|
930 |
+
except Exception as e:
|
931 |
+
raise e
|
932 |
+
|
933 |
+
|
934 |
+
def get_response_cost_from_hidden_params(
|
935 |
+
hidden_params: Union[dict, BaseModel],
|
936 |
+
) -> Optional[float]:
|
937 |
+
if isinstance(hidden_params, BaseModel):
|
938 |
+
_hidden_params_dict = hidden_params.model_dump()
|
939 |
+
else:
|
940 |
+
_hidden_params_dict = hidden_params
|
941 |
+
|
942 |
+
additional_headers = _hidden_params_dict.get("additional_headers", {})
|
943 |
+
if (
|
944 |
+
additional_headers
|
945 |
+
and "llm_provider-x-litellm-response-cost" in additional_headers
|
946 |
+
):
|
947 |
+
response_cost = additional_headers["llm_provider-x-litellm-response-cost"]
|
948 |
+
if response_cost is None:
|
949 |
+
return None
|
950 |
+
return float(additional_headers["llm_provider-x-litellm-response-cost"])
|
951 |
+
return None
|
952 |
+
|
953 |
+
|
954 |
+
def response_cost_calculator(
|
955 |
+
response_object: Union[
|
956 |
+
ModelResponse,
|
957 |
+
EmbeddingResponse,
|
958 |
+
ImageResponse,
|
959 |
+
TranscriptionResponse,
|
960 |
+
TextCompletionResponse,
|
961 |
+
HttpxBinaryResponseContent,
|
962 |
+
RerankResponse,
|
963 |
+
ResponsesAPIResponse,
|
964 |
+
LiteLLMRealtimeStreamLoggingObject,
|
965 |
+
OpenAIModerationResponse,
|
966 |
+
],
|
967 |
+
model: str,
|
968 |
+
custom_llm_provider: Optional[str],
|
969 |
+
call_type: Literal[
|
970 |
+
"embedding",
|
971 |
+
"aembedding",
|
972 |
+
"completion",
|
973 |
+
"acompletion",
|
974 |
+
"atext_completion",
|
975 |
+
"text_completion",
|
976 |
+
"image_generation",
|
977 |
+
"aimage_generation",
|
978 |
+
"moderation",
|
979 |
+
"amoderation",
|
980 |
+
"atranscription",
|
981 |
+
"transcription",
|
982 |
+
"aspeech",
|
983 |
+
"speech",
|
984 |
+
"rerank",
|
985 |
+
"arerank",
|
986 |
+
],
|
987 |
+
optional_params: dict,
|
988 |
+
cache_hit: Optional[bool] = None,
|
989 |
+
base_model: Optional[str] = None,
|
990 |
+
custom_pricing: Optional[bool] = None,
|
991 |
+
prompt: str = "",
|
992 |
+
standard_built_in_tools_params: Optional[StandardBuiltInToolsParams] = None,
|
993 |
+
litellm_model_name: Optional[str] = None,
|
994 |
+
router_model_id: Optional[str] = None,
|
995 |
+
) -> float:
|
996 |
+
"""
|
997 |
+
Returns
|
998 |
+
- float or None: cost of response
|
999 |
+
"""
|
1000 |
+
try:
|
1001 |
+
response_cost: float = 0.0
|
1002 |
+
if cache_hit is not None and cache_hit is True:
|
1003 |
+
response_cost = 0.0
|
1004 |
+
else:
|
1005 |
+
if isinstance(response_object, BaseModel):
|
1006 |
+
response_object._hidden_params["optional_params"] = optional_params
|
1007 |
+
|
1008 |
+
if hasattr(response_object, "_hidden_params"):
|
1009 |
+
provider_response_cost = get_response_cost_from_hidden_params(
|
1010 |
+
response_object._hidden_params
|
1011 |
+
)
|
1012 |
+
if provider_response_cost is not None:
|
1013 |
+
return provider_response_cost
|
1014 |
+
|
1015 |
+
response_cost = completion_cost(
|
1016 |
+
completion_response=response_object,
|
1017 |
+
model=model,
|
1018 |
+
call_type=call_type,
|
1019 |
+
custom_llm_provider=custom_llm_provider,
|
1020 |
+
optional_params=optional_params,
|
1021 |
+
custom_pricing=custom_pricing,
|
1022 |
+
base_model=base_model,
|
1023 |
+
prompt=prompt,
|
1024 |
+
standard_built_in_tools_params=standard_built_in_tools_params,
|
1025 |
+
litellm_model_name=litellm_model_name,
|
1026 |
+
router_model_id=router_model_id,
|
1027 |
+
)
|
1028 |
+
return response_cost
|
1029 |
+
except Exception as e:
|
1030 |
+
raise e
|
1031 |
+
|
1032 |
+
|
1033 |
+
def rerank_cost(
|
1034 |
+
model: str,
|
1035 |
+
custom_llm_provider: Optional[str],
|
1036 |
+
billed_units: Optional[RerankBilledUnits] = None,
|
1037 |
+
) -> Tuple[float, float]:
|
1038 |
+
"""
|
1039 |
+
Returns
|
1040 |
+
- float or None: cost of response OR none if error.
|
1041 |
+
"""
|
1042 |
+
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
1043 |
+
model=model, custom_llm_provider=custom_llm_provider
|
1044 |
+
)
|
1045 |
+
|
1046 |
+
try:
|
1047 |
+
config = ProviderConfigManager.get_provider_rerank_config(
|
1048 |
+
model=model,
|
1049 |
+
api_base=None,
|
1050 |
+
present_version_params=[],
|
1051 |
+
provider=LlmProviders(custom_llm_provider),
|
1052 |
+
)
|
1053 |
+
|
1054 |
+
try:
|
1055 |
+
model_info: Optional[ModelInfo] = litellm.get_model_info(
|
1056 |
+
model=model, custom_llm_provider=custom_llm_provider
|
1057 |
+
)
|
1058 |
+
except Exception:
|
1059 |
+
model_info = None
|
1060 |
+
|
1061 |
+
return config.calculate_rerank_cost(
|
1062 |
+
model=model,
|
1063 |
+
custom_llm_provider=custom_llm_provider,
|
1064 |
+
billed_units=billed_units,
|
1065 |
+
model_info=model_info,
|
1066 |
+
)
|
1067 |
+
except Exception as e:
|
1068 |
+
raise e
|
1069 |
+
|
1070 |
+
|
1071 |
+
def transcription_cost(
|
1072 |
+
model: str, custom_llm_provider: Optional[str], duration: float
|
1073 |
+
) -> Tuple[float, float]:
|
1074 |
+
return openai_cost_per_second(
|
1075 |
+
model=model, custom_llm_provider=custom_llm_provider, duration=duration
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
|
1079 |
+
def default_image_cost_calculator(
|
1080 |
+
model: str,
|
1081 |
+
custom_llm_provider: Optional[str] = None,
|
1082 |
+
quality: Optional[str] = None,
|
1083 |
+
n: Optional[int] = 1, # Default to 1 image
|
1084 |
+
size: Optional[str] = "1024-x-1024", # OpenAI default
|
1085 |
+
optional_params: Optional[dict] = None,
|
1086 |
+
) -> float:
|
1087 |
+
"""
|
1088 |
+
Default image cost calculator for image generation
|
1089 |
+
|
1090 |
+
Args:
|
1091 |
+
model (str): Model name
|
1092 |
+
image_response (ImageResponse): Response from image generation
|
1093 |
+
quality (Optional[str]): Image quality setting
|
1094 |
+
n (Optional[int]): Number of images generated
|
1095 |
+
size (Optional[str]): Image size (e.g. "1024x1024" or "1024-x-1024")
|
1096 |
+
|
1097 |
+
Returns:
|
1098 |
+
float: Cost in USD for the image generation
|
1099 |
+
|
1100 |
+
Raises:
|
1101 |
+
Exception: If model pricing not found in cost map
|
1102 |
+
"""
|
1103 |
+
# Standardize size format to use "-x-"
|
1104 |
+
size_str: str = size or "1024-x-1024"
|
1105 |
+
size_str = (
|
1106 |
+
size_str.replace("x", "-x-")
|
1107 |
+
if "x" in size_str and "-x-" not in size_str
|
1108 |
+
else size_str
|
1109 |
+
)
|
1110 |
+
|
1111 |
+
# Parse dimensions
|
1112 |
+
height, width = map(int, size_str.split("-x-"))
|
1113 |
+
|
1114 |
+
# Build model names for cost lookup
|
1115 |
+
base_model_name = f"{size_str}/{model}"
|
1116 |
+
if custom_llm_provider and model.startswith(custom_llm_provider):
|
1117 |
+
base_model_name = (
|
1118 |
+
f"{custom_llm_provider}/{size_str}/{model.replace(custom_llm_provider, '')}"
|
1119 |
+
)
|
1120 |
+
model_name_with_quality = (
|
1121 |
+
f"{quality}/{base_model_name}" if quality else base_model_name
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
# gpt-image-1 models use low, medium, high quality. If user did not specify quality, use medium fot gpt-image-1 model family
|
1125 |
+
model_name_with_v2_quality = (
|
1126 |
+
f"{ImageGenerationRequestQuality.MEDIUM.value}/{base_model_name}"
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
verbose_logger.debug(
|
1130 |
+
f"Looking up cost for models: {model_name_with_quality}, {base_model_name}"
|
1131 |
+
)
|
1132 |
+
|
1133 |
+
model_without_provider = f"{size_str}/{model.split('/')[-1]}"
|
1134 |
+
model_with_quality_without_provider = (
|
1135 |
+
f"{quality}/{model_without_provider}" if quality else model_without_provider
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
# Try model with quality first, fall back to base model name
|
1139 |
+
cost_info: Optional[dict] = None
|
1140 |
+
models_to_check = [
|
1141 |
+
model_name_with_quality,
|
1142 |
+
base_model_name,
|
1143 |
+
model_name_with_v2_quality,
|
1144 |
+
model_with_quality_without_provider,
|
1145 |
+
model_without_provider,
|
1146 |
+
model,
|
1147 |
+
]
|
1148 |
+
for model in models_to_check:
|
1149 |
+
if model in litellm.model_cost:
|
1150 |
+
cost_info = litellm.model_cost[model]
|
1151 |
+
break
|
1152 |
+
if cost_info is None:
|
1153 |
+
raise Exception(
|
1154 |
+
f"Model not found in cost map. Tried checking {models_to_check}"
|
1155 |
+
)
|
1156 |
+
|
1157 |
+
return cost_info["input_cost_per_pixel"] * height * width * n
|
1158 |
+
|
1159 |
+
|
1160 |
+
def batch_cost_calculator(
|
1161 |
+
usage: Usage,
|
1162 |
+
model: str,
|
1163 |
+
custom_llm_provider: Optional[str] = None,
|
1164 |
+
) -> Tuple[float, float]:
|
1165 |
+
"""
|
1166 |
+
Calculate the cost of a batch job
|
1167 |
+
"""
|
1168 |
+
|
1169 |
+
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
1170 |
+
model=model, custom_llm_provider=custom_llm_provider
|
1171 |
+
)
|
1172 |
+
|
1173 |
+
verbose_logger.info(
|
1174 |
+
"Calculating batch cost per token. model=%s, custom_llm_provider=%s",
|
1175 |
+
model,
|
1176 |
+
custom_llm_provider,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
try:
|
1180 |
+
model_info: Optional[ModelInfo] = litellm.get_model_info(
|
1181 |
+
model=model, custom_llm_provider=custom_llm_provider
|
1182 |
+
)
|
1183 |
+
except Exception:
|
1184 |
+
model_info = None
|
1185 |
+
|
1186 |
+
if not model_info:
|
1187 |
+
return 0.0, 0.0
|
1188 |
+
|
1189 |
+
input_cost_per_token_batches = model_info.get("input_cost_per_token_batches")
|
1190 |
+
input_cost_per_token = model_info.get("input_cost_per_token")
|
1191 |
+
output_cost_per_token_batches = model_info.get("output_cost_per_token_batches")
|
1192 |
+
output_cost_per_token = model_info.get("output_cost_per_token")
|
1193 |
+
total_prompt_cost = 0.0
|
1194 |
+
total_completion_cost = 0.0
|
1195 |
+
if input_cost_per_token_batches:
|
1196 |
+
total_prompt_cost = usage.prompt_tokens * input_cost_per_token_batches
|
1197 |
+
elif input_cost_per_token:
|
1198 |
+
total_prompt_cost = (
|
1199 |
+
usage.prompt_tokens * (input_cost_per_token) / 2
|
1200 |
+
) # batch cost is usually half of the regular token cost
|
1201 |
+
if output_cost_per_token_batches:
|
1202 |
+
total_completion_cost = usage.completion_tokens * output_cost_per_token_batches
|
1203 |
+
elif output_cost_per_token:
|
1204 |
+
total_completion_cost = (
|
1205 |
+
usage.completion_tokens * (output_cost_per_token) / 2
|
1206 |
+
) # batch cost is usually half of the regular token cost
|
1207 |
+
|
1208 |
+
return total_prompt_cost, total_completion_cost
|
1209 |
+
|
1210 |
+
|
1211 |
+
class RealtimeAPITokenUsageProcessor:
|
1212 |
+
@staticmethod
|
1213 |
+
def collect_usage_from_realtime_stream_results(
|
1214 |
+
results: OpenAIRealtimeStreamList,
|
1215 |
+
) -> List[Usage]:
|
1216 |
+
"""
|
1217 |
+
Collect usage from realtime stream results
|
1218 |
+
"""
|
1219 |
+
response_done_events: List[OpenAIRealtimeStreamResponseBaseObject] = cast(
|
1220 |
+
List[OpenAIRealtimeStreamResponseBaseObject],
|
1221 |
+
[result for result in results if result["type"] == "response.done"],
|
1222 |
+
)
|
1223 |
+
usage_objects: List[Usage] = []
|
1224 |
+
for result in response_done_events:
|
1225 |
+
usage_object = (
|
1226 |
+
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(
|
1227 |
+
result["response"].get("usage", {})
|
1228 |
+
)
|
1229 |
+
)
|
1230 |
+
usage_objects.append(usage_object)
|
1231 |
+
return usage_objects
|
1232 |
+
|
1233 |
+
@staticmethod
|
1234 |
+
def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
|
1235 |
+
"""
|
1236 |
+
Combine multiple Usage objects into a single Usage object, checking model keys for nested values.
|
1237 |
+
"""
|
1238 |
+
from litellm.types.utils import (
|
1239 |
+
CompletionTokensDetails,
|
1240 |
+
PromptTokensDetailsWrapper,
|
1241 |
+
Usage,
|
1242 |
+
)
|
1243 |
+
|
1244 |
+
combined = Usage()
|
1245 |
+
|
1246 |
+
# Sum basic token counts
|
1247 |
+
for usage in usage_objects:
|
1248 |
+
# Handle direct attributes by checking what exists in the model
|
1249 |
+
for attr in dir(usage):
|
1250 |
+
if not attr.startswith("_") and not callable(getattr(usage, attr)):
|
1251 |
+
current_val = getattr(combined, attr, 0)
|
1252 |
+
new_val = getattr(usage, attr, 0)
|
1253 |
+
if (
|
1254 |
+
new_val is not None
|
1255 |
+
and isinstance(new_val, (int, float))
|
1256 |
+
and isinstance(current_val, (int, float))
|
1257 |
+
):
|
1258 |
+
setattr(combined, attr, current_val + new_val)
|
1259 |
+
# Handle nested prompt_tokens_details
|
1260 |
+
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
1261 |
+
if (
|
1262 |
+
not hasattr(combined, "prompt_tokens_details")
|
1263 |
+
or not combined.prompt_tokens_details
|
1264 |
+
):
|
1265 |
+
combined.prompt_tokens_details = PromptTokensDetailsWrapper()
|
1266 |
+
|
1267 |
+
# Check what keys exist in the model's prompt_tokens_details
|
1268 |
+
for attr in dir(usage.prompt_tokens_details):
|
1269 |
+
if not attr.startswith("_") and not callable(
|
1270 |
+
getattr(usage.prompt_tokens_details, attr)
|
1271 |
+
):
|
1272 |
+
current_val = getattr(combined.prompt_tokens_details, attr, 0)
|
1273 |
+
new_val = getattr(usage.prompt_tokens_details, attr, 0)
|
1274 |
+
if new_val is not None:
|
1275 |
+
setattr(
|
1276 |
+
combined.prompt_tokens_details,
|
1277 |
+
attr,
|
1278 |
+
current_val + new_val,
|
1279 |
+
)
|
1280 |
+
|
1281 |
+
# Handle nested completion_tokens_details
|
1282 |
+
if (
|
1283 |
+
hasattr(usage, "completion_tokens_details")
|
1284 |
+
and usage.completion_tokens_details
|
1285 |
+
):
|
1286 |
+
if (
|
1287 |
+
not hasattr(combined, "completion_tokens_details")
|
1288 |
+
or not combined.completion_tokens_details
|
1289 |
+
):
|
1290 |
+
combined.completion_tokens_details = CompletionTokensDetails()
|
1291 |
+
|
1292 |
+
# Check what keys exist in the model's completion_tokens_details
|
1293 |
+
for attr in dir(usage.completion_tokens_details):
|
1294 |
+
if not attr.startswith("_") and not callable(
|
1295 |
+
getattr(usage.completion_tokens_details, attr)
|
1296 |
+
):
|
1297 |
+
current_val = getattr(
|
1298 |
+
combined.completion_tokens_details, attr, 0
|
1299 |
+
)
|
1300 |
+
new_val = getattr(usage.completion_tokens_details, attr, 0)
|
1301 |
+
if new_val is not None:
|
1302 |
+
setattr(
|
1303 |
+
combined.completion_tokens_details,
|
1304 |
+
attr,
|
1305 |
+
current_val + new_val,
|
1306 |
+
)
|
1307 |
+
|
1308 |
+
return combined
|
1309 |
+
|
1310 |
+
@staticmethod
|
1311 |
+
def collect_and_combine_usage_from_realtime_stream_results(
|
1312 |
+
results: OpenAIRealtimeStreamList,
|
1313 |
+
) -> Usage:
|
1314 |
+
"""
|
1315 |
+
Collect and combine usage from realtime stream results
|
1316 |
+
"""
|
1317 |
+
collected_usage_objects = (
|
1318 |
+
RealtimeAPITokenUsageProcessor.collect_usage_from_realtime_stream_results(
|
1319 |
+
results
|
1320 |
+
)
|
1321 |
+
)
|
1322 |
+
combined_usage_object = RealtimeAPITokenUsageProcessor.combine_usage_objects(
|
1323 |
+
collected_usage_objects
|
1324 |
+
)
|
1325 |
+
return combined_usage_object
|
1326 |
+
|
1327 |
+
@staticmethod
|
1328 |
+
def create_logging_realtime_object(
|
1329 |
+
usage: Usage, results: OpenAIRealtimeStreamList
|
1330 |
+
) -> LiteLLMRealtimeStreamLoggingObject:
|
1331 |
+
return LiteLLMRealtimeStreamLoggingObject(
|
1332 |
+
usage=usage,
|
1333 |
+
results=results,
|
1334 |
+
)
|
1335 |
+
|
1336 |
+
|
1337 |
+
def handle_realtime_stream_cost_calculation(
|
1338 |
+
results: OpenAIRealtimeStreamList,
|
1339 |
+
combined_usage_object: Usage,
|
1340 |
+
custom_llm_provider: str,
|
1341 |
+
litellm_model_name: str,
|
1342 |
+
) -> float:
|
1343 |
+
"""
|
1344 |
+
Handles the cost calculation for realtime stream responses.
|
1345 |
+
|
1346 |
+
Pick the 'response.done' events. Calculate total cost across all 'response.done' events.
|
1347 |
+
|
1348 |
+
Args:
|
1349 |
+
results: A list of OpenAIRealtimeStreamBaseObject objects
|
1350 |
+
"""
|
1351 |
+
received_model = None
|
1352 |
+
potential_model_names = []
|
1353 |
+
for result in results:
|
1354 |
+
if result["type"] == "session.created":
|
1355 |
+
received_model = cast(OpenAIRealtimeStreamSessionEvents, result)["session"][
|
1356 |
+
"model"
|
1357 |
+
]
|
1358 |
+
potential_model_names.append(received_model)
|
1359 |
+
|
1360 |
+
potential_model_names.append(litellm_model_name)
|
1361 |
+
input_cost_per_token = 0.0
|
1362 |
+
output_cost_per_token = 0.0
|
1363 |
+
|
1364 |
+
for model_name in potential_model_names:
|
1365 |
+
try:
|
1366 |
+
_input_cost_per_token, _output_cost_per_token = generic_cost_per_token(
|
1367 |
+
model=model_name,
|
1368 |
+
usage=combined_usage_object,
|
1369 |
+
custom_llm_provider=custom_llm_provider,
|
1370 |
+
)
|
1371 |
+
except Exception:
|
1372 |
+
continue
|
1373 |
+
input_cost_per_token += _input_cost_per_token
|
1374 |
+
output_cost_per_token += _output_cost_per_token
|
1375 |
+
break # exit if we find a valid model
|
1376 |
+
total_cost = input_cost_per_token + output_cost_per_token
|
1377 |
+
|
1378 |
+
return total_cost
|
litellm/exceptions.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# +-----------------------------------------------+
|
2 |
+
# | |
|
3 |
+
# | Give Feedback / Get Help |
|
4 |
+
# | https://github.com/BerriAI/litellm/issues/new |
|
5 |
+
# | |
|
6 |
+
# +-----------------------------------------------+
|
7 |
+
#
|
8 |
+
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
9 |
+
|
10 |
+
## LiteLLM versions of the OpenAI Exception Types
|
11 |
+
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
import openai
|
16 |
+
|
17 |
+
from litellm.types.utils import LiteLLMCommonStrings
|
18 |
+
|
19 |
+
|
20 |
+
class AuthenticationError(openai.AuthenticationError): # type: ignore
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
message,
|
24 |
+
llm_provider,
|
25 |
+
model,
|
26 |
+
response: Optional[httpx.Response] = None,
|
27 |
+
litellm_debug_info: Optional[str] = None,
|
28 |
+
max_retries: Optional[int] = None,
|
29 |
+
num_retries: Optional[int] = None,
|
30 |
+
):
|
31 |
+
self.status_code = 401
|
32 |
+
self.message = "litellm.AuthenticationError: {}".format(message)
|
33 |
+
self.llm_provider = llm_provider
|
34 |
+
self.model = model
|
35 |
+
self.litellm_debug_info = litellm_debug_info
|
36 |
+
self.max_retries = max_retries
|
37 |
+
self.num_retries = num_retries
|
38 |
+
self.response = response or httpx.Response(
|
39 |
+
status_code=self.status_code,
|
40 |
+
request=httpx.Request(
|
41 |
+
method="GET", url="https://litellm.ai"
|
42 |
+
), # mock request object
|
43 |
+
)
|
44 |
+
super().__init__(
|
45 |
+
self.message, response=self.response, body=None
|
46 |
+
) # Call the base class constructor with the parameters it needs
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
_message = self.message
|
50 |
+
if self.num_retries:
|
51 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
52 |
+
if self.max_retries:
|
53 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
54 |
+
return _message
|
55 |
+
|
56 |
+
def __repr__(self):
|
57 |
+
_message = self.message
|
58 |
+
if self.num_retries:
|
59 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
60 |
+
if self.max_retries:
|
61 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
62 |
+
return _message
|
63 |
+
|
64 |
+
|
65 |
+
# raise when invalid models passed, example gpt-8
|
66 |
+
class NotFoundError(openai.NotFoundError): # type: ignore
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
message,
|
70 |
+
model,
|
71 |
+
llm_provider,
|
72 |
+
response: Optional[httpx.Response] = None,
|
73 |
+
litellm_debug_info: Optional[str] = None,
|
74 |
+
max_retries: Optional[int] = None,
|
75 |
+
num_retries: Optional[int] = None,
|
76 |
+
):
|
77 |
+
self.status_code = 404
|
78 |
+
self.message = "litellm.NotFoundError: {}".format(message)
|
79 |
+
self.model = model
|
80 |
+
self.llm_provider = llm_provider
|
81 |
+
self.litellm_debug_info = litellm_debug_info
|
82 |
+
self.max_retries = max_retries
|
83 |
+
self.num_retries = num_retries
|
84 |
+
self.response = response or httpx.Response(
|
85 |
+
status_code=self.status_code,
|
86 |
+
request=httpx.Request(
|
87 |
+
method="GET", url="https://litellm.ai"
|
88 |
+
), # mock request object
|
89 |
+
)
|
90 |
+
super().__init__(
|
91 |
+
self.message, response=self.response, body=None
|
92 |
+
) # Call the base class constructor with the parameters it needs
|
93 |
+
|
94 |
+
def __str__(self):
|
95 |
+
_message = self.message
|
96 |
+
if self.num_retries:
|
97 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
98 |
+
if self.max_retries:
|
99 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
100 |
+
return _message
|
101 |
+
|
102 |
+
def __repr__(self):
|
103 |
+
_message = self.message
|
104 |
+
if self.num_retries:
|
105 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
106 |
+
if self.max_retries:
|
107 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
108 |
+
return _message
|
109 |
+
|
110 |
+
|
111 |
+
class BadRequestError(openai.BadRequestError): # type: ignore
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
message,
|
115 |
+
model,
|
116 |
+
llm_provider,
|
117 |
+
response: Optional[httpx.Response] = None,
|
118 |
+
litellm_debug_info: Optional[str] = None,
|
119 |
+
max_retries: Optional[int] = None,
|
120 |
+
num_retries: Optional[int] = None,
|
121 |
+
body: Optional[dict] = None,
|
122 |
+
):
|
123 |
+
self.status_code = 400
|
124 |
+
self.message = "litellm.BadRequestError: {}".format(message)
|
125 |
+
self.model = model
|
126 |
+
self.llm_provider = llm_provider
|
127 |
+
self.litellm_debug_info = litellm_debug_info
|
128 |
+
response = httpx.Response(
|
129 |
+
status_code=self.status_code,
|
130 |
+
request=httpx.Request(
|
131 |
+
method="GET", url="https://litellm.ai"
|
132 |
+
), # mock request object
|
133 |
+
)
|
134 |
+
self.max_retries = max_retries
|
135 |
+
self.num_retries = num_retries
|
136 |
+
super().__init__(
|
137 |
+
self.message, response=response, body=body
|
138 |
+
) # Call the base class constructor with the parameters it needs
|
139 |
+
|
140 |
+
def __str__(self):
|
141 |
+
_message = self.message
|
142 |
+
if self.num_retries:
|
143 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
144 |
+
if self.max_retries:
|
145 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
146 |
+
return _message
|
147 |
+
|
148 |
+
def __repr__(self):
|
149 |
+
_message = self.message
|
150 |
+
if self.num_retries:
|
151 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
152 |
+
if self.max_retries:
|
153 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
154 |
+
return _message
|
155 |
+
|
156 |
+
|
157 |
+
class UnprocessableEntityError(openai.UnprocessableEntityError): # type: ignore
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
message,
|
161 |
+
model,
|
162 |
+
llm_provider,
|
163 |
+
response: httpx.Response,
|
164 |
+
litellm_debug_info: Optional[str] = None,
|
165 |
+
max_retries: Optional[int] = None,
|
166 |
+
num_retries: Optional[int] = None,
|
167 |
+
):
|
168 |
+
self.status_code = 422
|
169 |
+
self.message = "litellm.UnprocessableEntityError: {}".format(message)
|
170 |
+
self.model = model
|
171 |
+
self.llm_provider = llm_provider
|
172 |
+
self.litellm_debug_info = litellm_debug_info
|
173 |
+
self.max_retries = max_retries
|
174 |
+
self.num_retries = num_retries
|
175 |
+
super().__init__(
|
176 |
+
self.message, response=response, body=None
|
177 |
+
) # Call the base class constructor with the parameters it needs
|
178 |
+
|
179 |
+
def __str__(self):
|
180 |
+
_message = self.message
|
181 |
+
if self.num_retries:
|
182 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
183 |
+
if self.max_retries:
|
184 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
185 |
+
return _message
|
186 |
+
|
187 |
+
def __repr__(self):
|
188 |
+
_message = self.message
|
189 |
+
if self.num_retries:
|
190 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
191 |
+
if self.max_retries:
|
192 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
193 |
+
return _message
|
194 |
+
|
195 |
+
|
196 |
+
class Timeout(openai.APITimeoutError): # type: ignore
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
message,
|
200 |
+
model,
|
201 |
+
llm_provider,
|
202 |
+
litellm_debug_info: Optional[str] = None,
|
203 |
+
max_retries: Optional[int] = None,
|
204 |
+
num_retries: Optional[int] = None,
|
205 |
+
headers: Optional[dict] = None,
|
206 |
+
exception_status_code: Optional[int] = None,
|
207 |
+
):
|
208 |
+
request = httpx.Request(
|
209 |
+
method="POST",
|
210 |
+
url="https://api.openai.com/v1",
|
211 |
+
)
|
212 |
+
super().__init__(
|
213 |
+
request=request
|
214 |
+
) # Call the base class constructor with the parameters it needs
|
215 |
+
self.status_code = exception_status_code or 408
|
216 |
+
self.message = "litellm.Timeout: {}".format(message)
|
217 |
+
self.model = model
|
218 |
+
self.llm_provider = llm_provider
|
219 |
+
self.litellm_debug_info = litellm_debug_info
|
220 |
+
self.max_retries = max_retries
|
221 |
+
self.num_retries = num_retries
|
222 |
+
self.headers = headers
|
223 |
+
|
224 |
+
# custom function to convert to str
|
225 |
+
def __str__(self):
|
226 |
+
_message = self.message
|
227 |
+
if self.num_retries:
|
228 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
229 |
+
if self.max_retries:
|
230 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
231 |
+
return _message
|
232 |
+
|
233 |
+
def __repr__(self):
|
234 |
+
_message = self.message
|
235 |
+
if self.num_retries:
|
236 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
237 |
+
if self.max_retries:
|
238 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
239 |
+
return _message
|
240 |
+
|
241 |
+
|
242 |
+
class PermissionDeniedError(openai.PermissionDeniedError): # type:ignore
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
message,
|
246 |
+
llm_provider,
|
247 |
+
model,
|
248 |
+
response: httpx.Response,
|
249 |
+
litellm_debug_info: Optional[str] = None,
|
250 |
+
max_retries: Optional[int] = None,
|
251 |
+
num_retries: Optional[int] = None,
|
252 |
+
):
|
253 |
+
self.status_code = 403
|
254 |
+
self.message = "litellm.PermissionDeniedError: {}".format(message)
|
255 |
+
self.llm_provider = llm_provider
|
256 |
+
self.model = model
|
257 |
+
self.litellm_debug_info = litellm_debug_info
|
258 |
+
self.max_retries = max_retries
|
259 |
+
self.num_retries = num_retries
|
260 |
+
super().__init__(
|
261 |
+
self.message, response=response, body=None
|
262 |
+
) # Call the base class constructor with the parameters it needs
|
263 |
+
|
264 |
+
def __str__(self):
|
265 |
+
_message = self.message
|
266 |
+
if self.num_retries:
|
267 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
268 |
+
if self.max_retries:
|
269 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
270 |
+
return _message
|
271 |
+
|
272 |
+
def __repr__(self):
|
273 |
+
_message = self.message
|
274 |
+
if self.num_retries:
|
275 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
276 |
+
if self.max_retries:
|
277 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
278 |
+
return _message
|
279 |
+
|
280 |
+
|
281 |
+
class RateLimitError(openai.RateLimitError): # type: ignore
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
message,
|
285 |
+
llm_provider,
|
286 |
+
model,
|
287 |
+
response: Optional[httpx.Response] = None,
|
288 |
+
litellm_debug_info: Optional[str] = None,
|
289 |
+
max_retries: Optional[int] = None,
|
290 |
+
num_retries: Optional[int] = None,
|
291 |
+
):
|
292 |
+
self.status_code = 429
|
293 |
+
self.message = "litellm.RateLimitError: {}".format(message)
|
294 |
+
self.llm_provider = llm_provider
|
295 |
+
self.model = model
|
296 |
+
self.litellm_debug_info = litellm_debug_info
|
297 |
+
self.max_retries = max_retries
|
298 |
+
self.num_retries = num_retries
|
299 |
+
_response_headers = (
|
300 |
+
getattr(response, "headers", None) if response is not None else None
|
301 |
+
)
|
302 |
+
self.response = httpx.Response(
|
303 |
+
status_code=429,
|
304 |
+
headers=_response_headers,
|
305 |
+
request=httpx.Request(
|
306 |
+
method="POST",
|
307 |
+
url=" https://cloud.google.com/vertex-ai/",
|
308 |
+
),
|
309 |
+
)
|
310 |
+
super().__init__(
|
311 |
+
self.message, response=self.response, body=None
|
312 |
+
) # Call the base class constructor with the parameters it needs
|
313 |
+
self.code = "429"
|
314 |
+
self.type = "throttling_error"
|
315 |
+
|
316 |
+
def __str__(self):
|
317 |
+
_message = self.message
|
318 |
+
if self.num_retries:
|
319 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
320 |
+
if self.max_retries:
|
321 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
322 |
+
return _message
|
323 |
+
|
324 |
+
def __repr__(self):
|
325 |
+
_message = self.message
|
326 |
+
if self.num_retries:
|
327 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
328 |
+
if self.max_retries:
|
329 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
330 |
+
return _message
|
331 |
+
|
332 |
+
|
333 |
+
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
|
334 |
+
class ContextWindowExceededError(BadRequestError): # type: ignore
|
335 |
+
def __init__(
|
336 |
+
self,
|
337 |
+
message,
|
338 |
+
model,
|
339 |
+
llm_provider,
|
340 |
+
response: Optional[httpx.Response] = None,
|
341 |
+
litellm_debug_info: Optional[str] = None,
|
342 |
+
):
|
343 |
+
self.status_code = 400
|
344 |
+
self.model = model
|
345 |
+
self.llm_provider = llm_provider
|
346 |
+
self.litellm_debug_info = litellm_debug_info
|
347 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
348 |
+
self.response = httpx.Response(status_code=400, request=request)
|
349 |
+
super().__init__(
|
350 |
+
message=message,
|
351 |
+
model=self.model, # type: ignore
|
352 |
+
llm_provider=self.llm_provider, # type: ignore
|
353 |
+
response=self.response,
|
354 |
+
litellm_debug_info=self.litellm_debug_info,
|
355 |
+
) # Call the base class constructor with the parameters it needs
|
356 |
+
|
357 |
+
# set after, to make it clear the raised error is a context window exceeded error
|
358 |
+
self.message = "litellm.ContextWindowExceededError: {}".format(self.message)
|
359 |
+
|
360 |
+
def __str__(self):
|
361 |
+
_message = self.message
|
362 |
+
if self.num_retries:
|
363 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
364 |
+
if self.max_retries:
|
365 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
366 |
+
return _message
|
367 |
+
|
368 |
+
def __repr__(self):
|
369 |
+
_message = self.message
|
370 |
+
if self.num_retries:
|
371 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
372 |
+
if self.max_retries:
|
373 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
374 |
+
return _message
|
375 |
+
|
376 |
+
|
377 |
+
# sub class of bad request error - meant to help us catch guardrails-related errors on proxy.
|
378 |
+
class RejectedRequestError(BadRequestError): # type: ignore
|
379 |
+
def __init__(
|
380 |
+
self,
|
381 |
+
message,
|
382 |
+
model,
|
383 |
+
llm_provider,
|
384 |
+
request_data: dict,
|
385 |
+
litellm_debug_info: Optional[str] = None,
|
386 |
+
):
|
387 |
+
self.status_code = 400
|
388 |
+
self.message = "litellm.RejectedRequestError: {}".format(message)
|
389 |
+
self.model = model
|
390 |
+
self.llm_provider = llm_provider
|
391 |
+
self.litellm_debug_info = litellm_debug_info
|
392 |
+
self.request_data = request_data
|
393 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
394 |
+
response = httpx.Response(status_code=400, request=request)
|
395 |
+
super().__init__(
|
396 |
+
message=self.message,
|
397 |
+
model=self.model, # type: ignore
|
398 |
+
llm_provider=self.llm_provider, # type: ignore
|
399 |
+
response=response,
|
400 |
+
litellm_debug_info=self.litellm_debug_info,
|
401 |
+
) # Call the base class constructor with the parameters it needs
|
402 |
+
|
403 |
+
def __str__(self):
|
404 |
+
_message = self.message
|
405 |
+
if self.num_retries:
|
406 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
407 |
+
if self.max_retries:
|
408 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
409 |
+
return _message
|
410 |
+
|
411 |
+
def __repr__(self):
|
412 |
+
_message = self.message
|
413 |
+
if self.num_retries:
|
414 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
415 |
+
if self.max_retries:
|
416 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
417 |
+
return _message
|
418 |
+
|
419 |
+
|
420 |
+
class ContentPolicyViolationError(BadRequestError): # type: ignore
|
421 |
+
# Error code: 400 - {'error': {'code': 'content_policy_violation', 'message': 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', 'param': None, 'type': 'invalid_request_error'}}
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
message,
|
425 |
+
model,
|
426 |
+
llm_provider,
|
427 |
+
response: Optional[httpx.Response] = None,
|
428 |
+
litellm_debug_info: Optional[str] = None,
|
429 |
+
):
|
430 |
+
self.status_code = 400
|
431 |
+
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
|
432 |
+
self.model = model
|
433 |
+
self.llm_provider = llm_provider
|
434 |
+
self.litellm_debug_info = litellm_debug_info
|
435 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
436 |
+
self.response = httpx.Response(status_code=400, request=request)
|
437 |
+
super().__init__(
|
438 |
+
message=self.message,
|
439 |
+
model=self.model, # type: ignore
|
440 |
+
llm_provider=self.llm_provider, # type: ignore
|
441 |
+
response=self.response,
|
442 |
+
litellm_debug_info=self.litellm_debug_info,
|
443 |
+
) # Call the base class constructor with the parameters it needs
|
444 |
+
|
445 |
+
def __str__(self):
|
446 |
+
_message = self.message
|
447 |
+
if self.num_retries:
|
448 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
449 |
+
if self.max_retries:
|
450 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
451 |
+
return _message
|
452 |
+
|
453 |
+
def __repr__(self):
|
454 |
+
_message = self.message
|
455 |
+
if self.num_retries:
|
456 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
457 |
+
if self.max_retries:
|
458 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
459 |
+
return _message
|
460 |
+
|
461 |
+
|
462 |
+
class ServiceUnavailableError(openai.APIStatusError): # type: ignore
|
463 |
+
def __init__(
|
464 |
+
self,
|
465 |
+
message,
|
466 |
+
llm_provider,
|
467 |
+
model,
|
468 |
+
response: Optional[httpx.Response] = None,
|
469 |
+
litellm_debug_info: Optional[str] = None,
|
470 |
+
max_retries: Optional[int] = None,
|
471 |
+
num_retries: Optional[int] = None,
|
472 |
+
):
|
473 |
+
self.status_code = 503
|
474 |
+
self.message = "litellm.ServiceUnavailableError: {}".format(message)
|
475 |
+
self.llm_provider = llm_provider
|
476 |
+
self.model = model
|
477 |
+
self.litellm_debug_info = litellm_debug_info
|
478 |
+
self.max_retries = max_retries
|
479 |
+
self.num_retries = num_retries
|
480 |
+
self.response = httpx.Response(
|
481 |
+
status_code=self.status_code,
|
482 |
+
request=httpx.Request(
|
483 |
+
method="POST",
|
484 |
+
url=" https://cloud.google.com/vertex-ai/",
|
485 |
+
),
|
486 |
+
)
|
487 |
+
super().__init__(
|
488 |
+
self.message, response=self.response, body=None
|
489 |
+
) # Call the base class constructor with the parameters it needs
|
490 |
+
|
491 |
+
def __str__(self):
|
492 |
+
_message = self.message
|
493 |
+
if self.num_retries:
|
494 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
495 |
+
if self.max_retries:
|
496 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
497 |
+
return _message
|
498 |
+
|
499 |
+
def __repr__(self):
|
500 |
+
_message = self.message
|
501 |
+
if self.num_retries:
|
502 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
503 |
+
if self.max_retries:
|
504 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
505 |
+
return _message
|
506 |
+
|
507 |
+
|
508 |
+
class InternalServerError(openai.InternalServerError): # type: ignore
|
509 |
+
def __init__(
|
510 |
+
self,
|
511 |
+
message,
|
512 |
+
llm_provider,
|
513 |
+
model,
|
514 |
+
response: Optional[httpx.Response] = None,
|
515 |
+
litellm_debug_info: Optional[str] = None,
|
516 |
+
max_retries: Optional[int] = None,
|
517 |
+
num_retries: Optional[int] = None,
|
518 |
+
):
|
519 |
+
self.status_code = 500
|
520 |
+
self.message = "litellm.InternalServerError: {}".format(message)
|
521 |
+
self.llm_provider = llm_provider
|
522 |
+
self.model = model
|
523 |
+
self.litellm_debug_info = litellm_debug_info
|
524 |
+
self.max_retries = max_retries
|
525 |
+
self.num_retries = num_retries
|
526 |
+
self.response = httpx.Response(
|
527 |
+
status_code=self.status_code,
|
528 |
+
request=httpx.Request(
|
529 |
+
method="POST",
|
530 |
+
url=" https://cloud.google.com/vertex-ai/",
|
531 |
+
),
|
532 |
+
)
|
533 |
+
super().__init__(
|
534 |
+
self.message, response=self.response, body=None
|
535 |
+
) # Call the base class constructor with the parameters it needs
|
536 |
+
|
537 |
+
def __str__(self):
|
538 |
+
_message = self.message
|
539 |
+
if self.num_retries:
|
540 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
541 |
+
if self.max_retries:
|
542 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
543 |
+
return _message
|
544 |
+
|
545 |
+
def __repr__(self):
|
546 |
+
_message = self.message
|
547 |
+
if self.num_retries:
|
548 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
549 |
+
if self.max_retries:
|
550 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
551 |
+
return _message
|
552 |
+
|
553 |
+
|
554 |
+
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
|
555 |
+
class APIError(openai.APIError): # type: ignore
|
556 |
+
def __init__(
|
557 |
+
self,
|
558 |
+
status_code: int,
|
559 |
+
message,
|
560 |
+
llm_provider,
|
561 |
+
model,
|
562 |
+
request: Optional[httpx.Request] = None,
|
563 |
+
litellm_debug_info: Optional[str] = None,
|
564 |
+
max_retries: Optional[int] = None,
|
565 |
+
num_retries: Optional[int] = None,
|
566 |
+
):
|
567 |
+
self.status_code = status_code
|
568 |
+
self.message = "litellm.APIError: {}".format(message)
|
569 |
+
self.llm_provider = llm_provider
|
570 |
+
self.model = model
|
571 |
+
self.litellm_debug_info = litellm_debug_info
|
572 |
+
self.max_retries = max_retries
|
573 |
+
self.num_retries = num_retries
|
574 |
+
if request is None:
|
575 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
576 |
+
super().__init__(self.message, request=request, body=None) # type: ignore
|
577 |
+
|
578 |
+
def __str__(self):
|
579 |
+
_message = self.message
|
580 |
+
if self.num_retries:
|
581 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
582 |
+
if self.max_retries:
|
583 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
584 |
+
return _message
|
585 |
+
|
586 |
+
def __repr__(self):
|
587 |
+
_message = self.message
|
588 |
+
if self.num_retries:
|
589 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
590 |
+
if self.max_retries:
|
591 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
592 |
+
return _message
|
593 |
+
|
594 |
+
|
595 |
+
# raised if an invalid request (not get, delete, put, post) is made
|
596 |
+
class APIConnectionError(openai.APIConnectionError): # type: ignore
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
message,
|
600 |
+
llm_provider,
|
601 |
+
model,
|
602 |
+
request: Optional[httpx.Request] = None,
|
603 |
+
litellm_debug_info: Optional[str] = None,
|
604 |
+
max_retries: Optional[int] = None,
|
605 |
+
num_retries: Optional[int] = None,
|
606 |
+
):
|
607 |
+
self.message = "litellm.APIConnectionError: {}".format(message)
|
608 |
+
self.llm_provider = llm_provider
|
609 |
+
self.model = model
|
610 |
+
self.status_code = 500
|
611 |
+
self.litellm_debug_info = litellm_debug_info
|
612 |
+
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
613 |
+
self.max_retries = max_retries
|
614 |
+
self.num_retries = num_retries
|
615 |
+
super().__init__(message=self.message, request=self.request)
|
616 |
+
|
617 |
+
def __str__(self):
|
618 |
+
_message = self.message
|
619 |
+
if self.num_retries:
|
620 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
621 |
+
if self.max_retries:
|
622 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
623 |
+
return _message
|
624 |
+
|
625 |
+
def __repr__(self):
|
626 |
+
_message = self.message
|
627 |
+
if self.num_retries:
|
628 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
629 |
+
if self.max_retries:
|
630 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
631 |
+
return _message
|
632 |
+
|
633 |
+
|
634 |
+
# raised if an invalid request (not get, delete, put, post) is made
|
635 |
+
class APIResponseValidationError(openai.APIResponseValidationError): # type: ignore
|
636 |
+
def __init__(
|
637 |
+
self,
|
638 |
+
message,
|
639 |
+
llm_provider,
|
640 |
+
model,
|
641 |
+
litellm_debug_info: Optional[str] = None,
|
642 |
+
max_retries: Optional[int] = None,
|
643 |
+
num_retries: Optional[int] = None,
|
644 |
+
):
|
645 |
+
self.message = "litellm.APIResponseValidationError: {}".format(message)
|
646 |
+
self.llm_provider = llm_provider
|
647 |
+
self.model = model
|
648 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
649 |
+
response = httpx.Response(status_code=500, request=request)
|
650 |
+
self.litellm_debug_info = litellm_debug_info
|
651 |
+
self.max_retries = max_retries
|
652 |
+
self.num_retries = num_retries
|
653 |
+
super().__init__(response=response, body=None, message=message)
|
654 |
+
|
655 |
+
def __str__(self):
|
656 |
+
_message = self.message
|
657 |
+
if self.num_retries:
|
658 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
659 |
+
if self.max_retries:
|
660 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
661 |
+
return _message
|
662 |
+
|
663 |
+
def __repr__(self):
|
664 |
+
_message = self.message
|
665 |
+
if self.num_retries:
|
666 |
+
_message += f" LiteLLM Retried: {self.num_retries} times"
|
667 |
+
if self.max_retries:
|
668 |
+
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
669 |
+
return _message
|
670 |
+
|
671 |
+
|
672 |
+
class JSONSchemaValidationError(APIResponseValidationError):
|
673 |
+
def __init__(
|
674 |
+
self, model: str, llm_provider: str, raw_response: str, schema: str
|
675 |
+
) -> None:
|
676 |
+
self.raw_response = raw_response
|
677 |
+
self.schema = schema
|
678 |
+
self.model = model
|
679 |
+
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
|
680 |
+
model, raw_response, schema
|
681 |
+
)
|
682 |
+
self.message = message
|
683 |
+
super().__init__(model=model, message=message, llm_provider=llm_provider)
|
684 |
+
|
685 |
+
|
686 |
+
class OpenAIError(openai.OpenAIError): # type: ignore
|
687 |
+
def __init__(self, original_exception=None):
|
688 |
+
super().__init__()
|
689 |
+
self.llm_provider = "openai"
|
690 |
+
|
691 |
+
|
692 |
+
class UnsupportedParamsError(BadRequestError):
|
693 |
+
def __init__(
|
694 |
+
self,
|
695 |
+
message,
|
696 |
+
llm_provider: Optional[str] = None,
|
697 |
+
model: Optional[str] = None,
|
698 |
+
status_code: int = 400,
|
699 |
+
response: Optional[httpx.Response] = None,
|
700 |
+
litellm_debug_info: Optional[str] = None,
|
701 |
+
max_retries: Optional[int] = None,
|
702 |
+
num_retries: Optional[int] = None,
|
703 |
+
):
|
704 |
+
self.status_code = 400
|
705 |
+
self.message = "litellm.UnsupportedParamsError: {}".format(message)
|
706 |
+
self.model = model
|
707 |
+
self.llm_provider = llm_provider
|
708 |
+
self.litellm_debug_info = litellm_debug_info
|
709 |
+
response = response or httpx.Response(
|
710 |
+
status_code=self.status_code,
|
711 |
+
request=httpx.Request(
|
712 |
+
method="GET", url="https://litellm.ai"
|
713 |
+
), # mock request object
|
714 |
+
)
|
715 |
+
self.max_retries = max_retries
|
716 |
+
self.num_retries = num_retries
|
717 |
+
|
718 |
+
|
719 |
+
LITELLM_EXCEPTION_TYPES = [
|
720 |
+
AuthenticationError,
|
721 |
+
NotFoundError,
|
722 |
+
BadRequestError,
|
723 |
+
UnprocessableEntityError,
|
724 |
+
UnsupportedParamsError,
|
725 |
+
Timeout,
|
726 |
+
PermissionDeniedError,
|
727 |
+
RateLimitError,
|
728 |
+
ContextWindowExceededError,
|
729 |
+
RejectedRequestError,
|
730 |
+
ContentPolicyViolationError,
|
731 |
+
InternalServerError,
|
732 |
+
ServiceUnavailableError,
|
733 |
+
APIError,
|
734 |
+
APIConnectionError,
|
735 |
+
APIResponseValidationError,
|
736 |
+
OpenAIError,
|
737 |
+
InternalServerError,
|
738 |
+
JSONSchemaValidationError,
|
739 |
+
]
|
740 |
+
|
741 |
+
|
742 |
+
class BudgetExceededError(Exception):
|
743 |
+
def __init__(
|
744 |
+
self, current_cost: float, max_budget: float, message: Optional[str] = None
|
745 |
+
):
|
746 |
+
self.current_cost = current_cost
|
747 |
+
self.max_budget = max_budget
|
748 |
+
message = (
|
749 |
+
message
|
750 |
+
or f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
|
751 |
+
)
|
752 |
+
self.message = message
|
753 |
+
super().__init__(message)
|
754 |
+
|
755 |
+
|
756 |
+
## DEPRECATED ##
|
757 |
+
class InvalidRequestError(openai.BadRequestError): # type: ignore
|
758 |
+
def __init__(self, message, model, llm_provider):
|
759 |
+
self.status_code = 400
|
760 |
+
self.message = message
|
761 |
+
self.model = model
|
762 |
+
self.llm_provider = llm_provider
|
763 |
+
self.response = httpx.Response(
|
764 |
+
status_code=400,
|
765 |
+
request=httpx.Request(
|
766 |
+
method="GET", url="https://litellm.ai"
|
767 |
+
), # mock request object
|
768 |
+
)
|
769 |
+
super().__init__(
|
770 |
+
message=self.message, response=self.response, body=None
|
771 |
+
) # Call the base class constructor with the parameters it needs
|
772 |
+
|
773 |
+
|
774 |
+
class MockException(openai.APIError):
|
775 |
+
# used for testing
|
776 |
+
def __init__(
|
777 |
+
self,
|
778 |
+
status_code: int,
|
779 |
+
message,
|
780 |
+
llm_provider,
|
781 |
+
model,
|
782 |
+
request: Optional[httpx.Request] = None,
|
783 |
+
litellm_debug_info: Optional[str] = None,
|
784 |
+
max_retries: Optional[int] = None,
|
785 |
+
num_retries: Optional[int] = None,
|
786 |
+
):
|
787 |
+
self.status_code = status_code
|
788 |
+
self.message = "litellm.MockException: {}".format(message)
|
789 |
+
self.llm_provider = llm_provider
|
790 |
+
self.model = model
|
791 |
+
self.litellm_debug_info = litellm_debug_info
|
792 |
+
self.max_retries = max_retries
|
793 |
+
self.num_retries = num_retries
|
794 |
+
if request is None:
|
795 |
+
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
796 |
+
super().__init__(self.message, request=request, body=None) # type: ignore
|
797 |
+
|
798 |
+
|
799 |
+
class LiteLLMUnknownProvider(BadRequestError):
|
800 |
+
def __init__(self, model: str, custom_llm_provider: Optional[str] = None):
|
801 |
+
self.message = LiteLLMCommonStrings.llm_provider_not_provided.value.format(
|
802 |
+
model=model, custom_llm_provider=custom_llm_provider
|
803 |
+
)
|
804 |
+
super().__init__(
|
805 |
+
self.message, model=model, llm_provider=custom_llm_provider, response=None
|
806 |
+
)
|
807 |
+
|
808 |
+
def __str__(self):
|
809 |
+
return self.message
|
litellm/experimental_mcp_client/Readme.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LiteLLM MCP Client
|
2 |
+
|
3 |
+
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
|
4 |
+
|
5 |
+
|
6 |
+
|
litellm/experimental_mcp_client/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .tools import call_openai_tool, load_mcp_tools
|
2 |
+
|
3 |
+
__all__ = ["load_mcp_tools", "call_openai_tool"]
|
litellm/experimental_mcp_client/client.py
ADDED
File without changes
|
litellm/experimental_mcp_client/tools.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict, List, Literal, Union
|
3 |
+
|
4 |
+
from mcp import ClientSession
|
5 |
+
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
6 |
+
from mcp.types import CallToolResult as MCPCallToolResult
|
7 |
+
from mcp.types import Tool as MCPTool
|
8 |
+
from openai.types.chat import ChatCompletionToolParam
|
9 |
+
from openai.types.shared_params.function_definition import FunctionDefinition
|
10 |
+
|
11 |
+
from litellm.types.utils import ChatCompletionMessageToolCall
|
12 |
+
|
13 |
+
|
14 |
+
########################################################
|
15 |
+
# List MCP Tool functions
|
16 |
+
########################################################
|
17 |
+
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
18 |
+
"""Convert an MCP tool to an OpenAI tool."""
|
19 |
+
return ChatCompletionToolParam(
|
20 |
+
type="function",
|
21 |
+
function=FunctionDefinition(
|
22 |
+
name=mcp_tool.name,
|
23 |
+
description=mcp_tool.description or "",
|
24 |
+
parameters=mcp_tool.inputSchema,
|
25 |
+
strict=False,
|
26 |
+
),
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
async def load_mcp_tools(
|
31 |
+
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
32 |
+
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
33 |
+
"""
|
34 |
+
Load all available MCP tools
|
35 |
+
|
36 |
+
Args:
|
37 |
+
session: The MCP session to use
|
38 |
+
format: The format to convert the tools to
|
39 |
+
By default, the tools are returned in MCP format.
|
40 |
+
|
41 |
+
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
|
42 |
+
"""
|
43 |
+
tools = await session.list_tools()
|
44 |
+
if format == "openai":
|
45 |
+
return [
|
46 |
+
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
47 |
+
]
|
48 |
+
return tools.tools
|
49 |
+
|
50 |
+
|
51 |
+
########################################################
|
52 |
+
# Call MCP Tool functions
|
53 |
+
########################################################
|
54 |
+
|
55 |
+
|
56 |
+
async def call_mcp_tool(
|
57 |
+
session: ClientSession,
|
58 |
+
call_tool_request_params: MCPCallToolRequestParams,
|
59 |
+
) -> MCPCallToolResult:
|
60 |
+
"""Call an MCP tool."""
|
61 |
+
tool_result = await session.call_tool(
|
62 |
+
name=call_tool_request_params.name,
|
63 |
+
arguments=call_tool_request_params.arguments,
|
64 |
+
)
|
65 |
+
return tool_result
|
66 |
+
|
67 |
+
|
68 |
+
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
69 |
+
"""Helper to safely get and parse function arguments."""
|
70 |
+
arguments = function.get("arguments", {})
|
71 |
+
if isinstance(arguments, str):
|
72 |
+
try:
|
73 |
+
arguments = json.loads(arguments)
|
74 |
+
except json.JSONDecodeError:
|
75 |
+
arguments = {}
|
76 |
+
return arguments if isinstance(arguments, dict) else {}
|
77 |
+
|
78 |
+
|
79 |
+
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
80 |
+
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
81 |
+
) -> MCPCallToolRequestParams:
|
82 |
+
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
83 |
+
function = openai_tool["function"]
|
84 |
+
return MCPCallToolRequestParams(
|
85 |
+
name=function["name"],
|
86 |
+
arguments=_get_function_arguments(function),
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
async def call_openai_tool(
|
91 |
+
session: ClientSession,
|
92 |
+
openai_tool: ChatCompletionMessageToolCall,
|
93 |
+
) -> MCPCallToolResult:
|
94 |
+
"""
|
95 |
+
Call an OpenAI tool using MCP client.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
session: The MCP session to use
|
99 |
+
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
|
100 |
+
Returns:
|
101 |
+
The result of the MCP tool call.
|
102 |
+
"""
|
103 |
+
mcp_tool_call_request_params = (
|
104 |
+
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
105 |
+
openai_tool=openai_tool,
|
106 |
+
)
|
107 |
+
)
|
108 |
+
return await call_mcp_tool(
|
109 |
+
session=session,
|
110 |
+
call_tool_request_params=mcp_tool_call_request_params,
|
111 |
+
)
|
litellm/files/main.py
ADDED
@@ -0,0 +1,891 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main File for Files API implementation
|
3 |
+
|
4 |
+
https://platform.openai.com/docs/api-reference/files
|
5 |
+
|
6 |
+
"""
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
import contextvars
|
10 |
+
import os
|
11 |
+
from functools import partial
|
12 |
+
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
|
16 |
+
import litellm
|
17 |
+
from litellm import get_secret_str
|
18 |
+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
19 |
+
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
|
20 |
+
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
21 |
+
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
22 |
+
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
|
23 |
+
from litellm.types.llms.openai import (
|
24 |
+
CreateFileRequest,
|
25 |
+
FileContentRequest,
|
26 |
+
FileTypes,
|
27 |
+
HttpxBinaryResponseContent,
|
28 |
+
OpenAIFileObject,
|
29 |
+
)
|
30 |
+
from litellm.types.router import *
|
31 |
+
from litellm.types.utils import LlmProviders
|
32 |
+
from litellm.utils import (
|
33 |
+
ProviderConfigManager,
|
34 |
+
client,
|
35 |
+
get_litellm_params,
|
36 |
+
supports_httpx_timeout,
|
37 |
+
)
|
38 |
+
|
39 |
+
base_llm_http_handler = BaseLLMHTTPHandler()
|
40 |
+
|
41 |
+
####### ENVIRONMENT VARIABLES ###################
|
42 |
+
openai_files_instance = OpenAIFilesAPI()
|
43 |
+
azure_files_instance = AzureOpenAIFilesAPI()
|
44 |
+
vertex_ai_files_instance = VertexAIFilesHandler()
|
45 |
+
#################################################
|
46 |
+
|
47 |
+
|
48 |
+
@client
|
49 |
+
async def acreate_file(
|
50 |
+
file: FileTypes,
|
51 |
+
purpose: Literal["assistants", "batch", "fine-tune"],
|
52 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
53 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
54 |
+
extra_body: Optional[Dict[str, str]] = None,
|
55 |
+
**kwargs,
|
56 |
+
) -> OpenAIFileObject:
|
57 |
+
"""
|
58 |
+
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
59 |
+
|
60 |
+
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
61 |
+
"""
|
62 |
+
try:
|
63 |
+
loop = asyncio.get_event_loop()
|
64 |
+
kwargs["acreate_file"] = True
|
65 |
+
|
66 |
+
call_args = {
|
67 |
+
"file": file,
|
68 |
+
"purpose": purpose,
|
69 |
+
"custom_llm_provider": custom_llm_provider,
|
70 |
+
"extra_headers": extra_headers,
|
71 |
+
"extra_body": extra_body,
|
72 |
+
**kwargs,
|
73 |
+
}
|
74 |
+
|
75 |
+
# Use a partial function to pass your keyword arguments
|
76 |
+
func = partial(create_file, **call_args)
|
77 |
+
|
78 |
+
# Add the context to the function
|
79 |
+
ctx = contextvars.copy_context()
|
80 |
+
func_with_context = partial(ctx.run, func)
|
81 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
82 |
+
if asyncio.iscoroutine(init_response):
|
83 |
+
response = await init_response
|
84 |
+
else:
|
85 |
+
response = init_response # type: ignore
|
86 |
+
|
87 |
+
return response
|
88 |
+
except Exception as e:
|
89 |
+
raise e
|
90 |
+
|
91 |
+
|
92 |
+
@client
|
93 |
+
def create_file(
|
94 |
+
file: FileTypes,
|
95 |
+
purpose: Literal["assistants", "batch", "fine-tune"],
|
96 |
+
custom_llm_provider: Optional[Literal["openai", "azure", "vertex_ai"]] = None,
|
97 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
98 |
+
extra_body: Optional[Dict[str, str]] = None,
|
99 |
+
**kwargs,
|
100 |
+
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
101 |
+
"""
|
102 |
+
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
103 |
+
|
104 |
+
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
105 |
+
|
106 |
+
Specify either provider_list or custom_llm_provider.
|
107 |
+
"""
|
108 |
+
try:
|
109 |
+
_is_async = kwargs.pop("acreate_file", False) is True
|
110 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
111 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
112 |
+
logging_obj = cast(
|
113 |
+
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
|
114 |
+
)
|
115 |
+
if logging_obj is None:
|
116 |
+
raise ValueError("logging_obj is required")
|
117 |
+
client = kwargs.get("client")
|
118 |
+
|
119 |
+
### TIMEOUT LOGIC ###
|
120 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
121 |
+
# set timeout for 10 minutes by default
|
122 |
+
|
123 |
+
if (
|
124 |
+
timeout is not None
|
125 |
+
and isinstance(timeout, httpx.Timeout)
|
126 |
+
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
127 |
+
):
|
128 |
+
read_timeout = timeout.read or 600
|
129 |
+
timeout = read_timeout # default 10 min timeout
|
130 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
131 |
+
timeout = float(timeout) # type: ignore
|
132 |
+
elif timeout is None:
|
133 |
+
timeout = 600.0
|
134 |
+
|
135 |
+
_create_file_request = CreateFileRequest(
|
136 |
+
file=file,
|
137 |
+
purpose=purpose,
|
138 |
+
extra_headers=extra_headers,
|
139 |
+
extra_body=extra_body,
|
140 |
+
)
|
141 |
+
|
142 |
+
provider_config = ProviderConfigManager.get_provider_files_config(
|
143 |
+
model="",
|
144 |
+
provider=LlmProviders(custom_llm_provider),
|
145 |
+
)
|
146 |
+
if provider_config is not None:
|
147 |
+
response = base_llm_http_handler.create_file(
|
148 |
+
provider_config=provider_config,
|
149 |
+
litellm_params=litellm_params_dict,
|
150 |
+
create_file_data=_create_file_request,
|
151 |
+
headers=extra_headers or {},
|
152 |
+
api_base=optional_params.api_base,
|
153 |
+
api_key=optional_params.api_key,
|
154 |
+
logging_obj=logging_obj,
|
155 |
+
_is_async=_is_async,
|
156 |
+
client=client
|
157 |
+
if client is not None
|
158 |
+
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
159 |
+
else None,
|
160 |
+
timeout=timeout,
|
161 |
+
)
|
162 |
+
elif custom_llm_provider == "openai":
|
163 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
164 |
+
api_base = (
|
165 |
+
optional_params.api_base
|
166 |
+
or litellm.api_base
|
167 |
+
or os.getenv("OPENAI_BASE_URL")
|
168 |
+
or os.getenv("OPENAI_API_BASE")
|
169 |
+
or "https://api.openai.com/v1"
|
170 |
+
)
|
171 |
+
organization = (
|
172 |
+
optional_params.organization
|
173 |
+
or litellm.organization
|
174 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
175 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
176 |
+
)
|
177 |
+
# set API KEY
|
178 |
+
api_key = (
|
179 |
+
optional_params.api_key
|
180 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
181 |
+
or litellm.openai_key
|
182 |
+
or os.getenv("OPENAI_API_KEY")
|
183 |
+
)
|
184 |
+
|
185 |
+
response = openai_files_instance.create_file(
|
186 |
+
_is_async=_is_async,
|
187 |
+
api_base=api_base,
|
188 |
+
api_key=api_key,
|
189 |
+
timeout=timeout,
|
190 |
+
max_retries=optional_params.max_retries,
|
191 |
+
organization=organization,
|
192 |
+
create_file_data=_create_file_request,
|
193 |
+
)
|
194 |
+
elif custom_llm_provider == "azure":
|
195 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
196 |
+
api_version = (
|
197 |
+
optional_params.api_version
|
198 |
+
or litellm.api_version
|
199 |
+
or get_secret_str("AZURE_API_VERSION")
|
200 |
+
) # type: ignore
|
201 |
+
|
202 |
+
api_key = (
|
203 |
+
optional_params.api_key
|
204 |
+
or litellm.api_key
|
205 |
+
or litellm.azure_key
|
206 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
207 |
+
or get_secret_str("AZURE_API_KEY")
|
208 |
+
) # type: ignore
|
209 |
+
|
210 |
+
extra_body = optional_params.get("extra_body", {})
|
211 |
+
if extra_body is not None:
|
212 |
+
extra_body.pop("azure_ad_token", None)
|
213 |
+
else:
|
214 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
215 |
+
|
216 |
+
response = azure_files_instance.create_file(
|
217 |
+
_is_async=_is_async,
|
218 |
+
api_base=api_base,
|
219 |
+
api_key=api_key,
|
220 |
+
api_version=api_version,
|
221 |
+
timeout=timeout,
|
222 |
+
max_retries=optional_params.max_retries,
|
223 |
+
create_file_data=_create_file_request,
|
224 |
+
litellm_params=litellm_params_dict,
|
225 |
+
)
|
226 |
+
elif custom_llm_provider == "vertex_ai":
|
227 |
+
api_base = optional_params.api_base or ""
|
228 |
+
vertex_ai_project = (
|
229 |
+
optional_params.vertex_project
|
230 |
+
or litellm.vertex_project
|
231 |
+
or get_secret_str("VERTEXAI_PROJECT")
|
232 |
+
)
|
233 |
+
vertex_ai_location = (
|
234 |
+
optional_params.vertex_location
|
235 |
+
or litellm.vertex_location
|
236 |
+
or get_secret_str("VERTEXAI_LOCATION")
|
237 |
+
)
|
238 |
+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
239 |
+
"VERTEXAI_CREDENTIALS"
|
240 |
+
)
|
241 |
+
|
242 |
+
response = vertex_ai_files_instance.create_file(
|
243 |
+
_is_async=_is_async,
|
244 |
+
api_base=api_base,
|
245 |
+
vertex_project=vertex_ai_project,
|
246 |
+
vertex_location=vertex_ai_location,
|
247 |
+
vertex_credentials=vertex_credentials,
|
248 |
+
timeout=timeout,
|
249 |
+
max_retries=optional_params.max_retries,
|
250 |
+
create_file_data=_create_file_request,
|
251 |
+
)
|
252 |
+
else:
|
253 |
+
raise litellm.exceptions.BadRequestError(
|
254 |
+
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai'] are supported.".format(
|
255 |
+
custom_llm_provider
|
256 |
+
),
|
257 |
+
model="n/a",
|
258 |
+
llm_provider=custom_llm_provider,
|
259 |
+
response=httpx.Response(
|
260 |
+
status_code=400,
|
261 |
+
content="Unsupported provider",
|
262 |
+
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
|
263 |
+
),
|
264 |
+
)
|
265 |
+
return response
|
266 |
+
except Exception as e:
|
267 |
+
raise e
|
268 |
+
|
269 |
+
|
270 |
+
async def afile_retrieve(
|
271 |
+
file_id: str,
|
272 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
273 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
274 |
+
extra_body: Optional[Dict[str, str]] = None,
|
275 |
+
**kwargs,
|
276 |
+
):
|
277 |
+
"""
|
278 |
+
Async: Get file contents
|
279 |
+
|
280 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
281 |
+
"""
|
282 |
+
try:
|
283 |
+
loop = asyncio.get_event_loop()
|
284 |
+
kwargs["is_async"] = True
|
285 |
+
|
286 |
+
# Use a partial function to pass your keyword arguments
|
287 |
+
func = partial(
|
288 |
+
file_retrieve,
|
289 |
+
file_id,
|
290 |
+
custom_llm_provider,
|
291 |
+
extra_headers,
|
292 |
+
extra_body,
|
293 |
+
**kwargs,
|
294 |
+
)
|
295 |
+
|
296 |
+
# Add the context to the function
|
297 |
+
ctx = contextvars.copy_context()
|
298 |
+
func_with_context = partial(ctx.run, func)
|
299 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
300 |
+
if asyncio.iscoroutine(init_response):
|
301 |
+
response = await init_response
|
302 |
+
else:
|
303 |
+
response = init_response
|
304 |
+
|
305 |
+
return response
|
306 |
+
except Exception as e:
|
307 |
+
raise e
|
308 |
+
|
309 |
+
|
310 |
+
def file_retrieve(
|
311 |
+
file_id: str,
|
312 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
313 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
314 |
+
extra_body: Optional[Dict[str, str]] = None,
|
315 |
+
**kwargs,
|
316 |
+
) -> FileObject:
|
317 |
+
"""
|
318 |
+
Returns the contents of the specified file.
|
319 |
+
|
320 |
+
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
321 |
+
"""
|
322 |
+
try:
|
323 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
324 |
+
### TIMEOUT LOGIC ###
|
325 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
326 |
+
# set timeout for 10 minutes by default
|
327 |
+
|
328 |
+
if (
|
329 |
+
timeout is not None
|
330 |
+
and isinstance(timeout, httpx.Timeout)
|
331 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
332 |
+
):
|
333 |
+
read_timeout = timeout.read or 600
|
334 |
+
timeout = read_timeout # default 10 min timeout
|
335 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
336 |
+
timeout = float(timeout) # type: ignore
|
337 |
+
elif timeout is None:
|
338 |
+
timeout = 600.0
|
339 |
+
|
340 |
+
_is_async = kwargs.pop("is_async", False) is True
|
341 |
+
|
342 |
+
if custom_llm_provider == "openai":
|
343 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
344 |
+
api_base = (
|
345 |
+
optional_params.api_base
|
346 |
+
or litellm.api_base
|
347 |
+
or os.getenv("OPENAI_BASE_URL")
|
348 |
+
or os.getenv("OPENAI_API_BASE")
|
349 |
+
or "https://api.openai.com/v1"
|
350 |
+
)
|
351 |
+
organization = (
|
352 |
+
optional_params.organization
|
353 |
+
or litellm.organization
|
354 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
355 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
356 |
+
)
|
357 |
+
# set API KEY
|
358 |
+
api_key = (
|
359 |
+
optional_params.api_key
|
360 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
361 |
+
or litellm.openai_key
|
362 |
+
or os.getenv("OPENAI_API_KEY")
|
363 |
+
)
|
364 |
+
|
365 |
+
response = openai_files_instance.retrieve_file(
|
366 |
+
file_id=file_id,
|
367 |
+
_is_async=_is_async,
|
368 |
+
api_base=api_base,
|
369 |
+
api_key=api_key,
|
370 |
+
timeout=timeout,
|
371 |
+
max_retries=optional_params.max_retries,
|
372 |
+
organization=organization,
|
373 |
+
)
|
374 |
+
elif custom_llm_provider == "azure":
|
375 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
376 |
+
api_version = (
|
377 |
+
optional_params.api_version
|
378 |
+
or litellm.api_version
|
379 |
+
or get_secret_str("AZURE_API_VERSION")
|
380 |
+
) # type: ignore
|
381 |
+
|
382 |
+
api_key = (
|
383 |
+
optional_params.api_key
|
384 |
+
or litellm.api_key
|
385 |
+
or litellm.azure_key
|
386 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
387 |
+
or get_secret_str("AZURE_API_KEY")
|
388 |
+
) # type: ignore
|
389 |
+
|
390 |
+
extra_body = optional_params.get("extra_body", {})
|
391 |
+
if extra_body is not None:
|
392 |
+
extra_body.pop("azure_ad_token", None)
|
393 |
+
else:
|
394 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
395 |
+
|
396 |
+
response = azure_files_instance.retrieve_file(
|
397 |
+
_is_async=_is_async,
|
398 |
+
api_base=api_base,
|
399 |
+
api_key=api_key,
|
400 |
+
api_version=api_version,
|
401 |
+
timeout=timeout,
|
402 |
+
max_retries=optional_params.max_retries,
|
403 |
+
file_id=file_id,
|
404 |
+
)
|
405 |
+
else:
|
406 |
+
raise litellm.exceptions.BadRequestError(
|
407 |
+
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai' and 'azure' are supported.".format(
|
408 |
+
custom_llm_provider
|
409 |
+
),
|
410 |
+
model="n/a",
|
411 |
+
llm_provider=custom_llm_provider,
|
412 |
+
response=httpx.Response(
|
413 |
+
status_code=400,
|
414 |
+
content="Unsupported provider",
|
415 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
416 |
+
),
|
417 |
+
)
|
418 |
+
return cast(FileObject, response)
|
419 |
+
except Exception as e:
|
420 |
+
raise e
|
421 |
+
|
422 |
+
|
423 |
+
# Delete file
|
424 |
+
async def afile_delete(
|
425 |
+
file_id: str,
|
426 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
427 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
428 |
+
extra_body: Optional[Dict[str, str]] = None,
|
429 |
+
**kwargs,
|
430 |
+
) -> Coroutine[Any, Any, FileObject]:
|
431 |
+
"""
|
432 |
+
Async: Delete file
|
433 |
+
|
434 |
+
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
435 |
+
"""
|
436 |
+
try:
|
437 |
+
loop = asyncio.get_event_loop()
|
438 |
+
kwargs["is_async"] = True
|
439 |
+
|
440 |
+
# Use a partial function to pass your keyword arguments
|
441 |
+
func = partial(
|
442 |
+
file_delete,
|
443 |
+
file_id,
|
444 |
+
custom_llm_provider,
|
445 |
+
extra_headers,
|
446 |
+
extra_body,
|
447 |
+
**kwargs,
|
448 |
+
)
|
449 |
+
|
450 |
+
# Add the context to the function
|
451 |
+
ctx = contextvars.copy_context()
|
452 |
+
func_with_context = partial(ctx.run, func)
|
453 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
454 |
+
if asyncio.iscoroutine(init_response):
|
455 |
+
response = await init_response
|
456 |
+
else:
|
457 |
+
response = init_response # type: ignore
|
458 |
+
|
459 |
+
return cast(FileDeleted, response) # type: ignore
|
460 |
+
except Exception as e:
|
461 |
+
raise e
|
462 |
+
|
463 |
+
|
464 |
+
def file_delete(
|
465 |
+
file_id: str,
|
466 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
467 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
468 |
+
extra_body: Optional[Dict[str, str]] = None,
|
469 |
+
**kwargs,
|
470 |
+
) -> FileDeleted:
|
471 |
+
"""
|
472 |
+
Delete file
|
473 |
+
|
474 |
+
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
475 |
+
"""
|
476 |
+
try:
|
477 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
478 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
479 |
+
### TIMEOUT LOGIC ###
|
480 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
481 |
+
# set timeout for 10 minutes by default
|
482 |
+
client = kwargs.get("client")
|
483 |
+
|
484 |
+
if (
|
485 |
+
timeout is not None
|
486 |
+
and isinstance(timeout, httpx.Timeout)
|
487 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
488 |
+
):
|
489 |
+
read_timeout = timeout.read or 600
|
490 |
+
timeout = read_timeout # default 10 min timeout
|
491 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
492 |
+
timeout = float(timeout) # type: ignore
|
493 |
+
elif timeout is None:
|
494 |
+
timeout = 600.0
|
495 |
+
_is_async = kwargs.pop("is_async", False) is True
|
496 |
+
if custom_llm_provider == "openai":
|
497 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
498 |
+
api_base = (
|
499 |
+
optional_params.api_base
|
500 |
+
or litellm.api_base
|
501 |
+
or os.getenv("OPENAI_BASE_URL")
|
502 |
+
or os.getenv("OPENAI_API_BASE")
|
503 |
+
or "https://api.openai.com/v1"
|
504 |
+
)
|
505 |
+
organization = (
|
506 |
+
optional_params.organization
|
507 |
+
or litellm.organization
|
508 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
509 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
510 |
+
)
|
511 |
+
# set API KEY
|
512 |
+
api_key = (
|
513 |
+
optional_params.api_key
|
514 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
515 |
+
or litellm.openai_key
|
516 |
+
or os.getenv("OPENAI_API_KEY")
|
517 |
+
)
|
518 |
+
response = openai_files_instance.delete_file(
|
519 |
+
file_id=file_id,
|
520 |
+
_is_async=_is_async,
|
521 |
+
api_base=api_base,
|
522 |
+
api_key=api_key,
|
523 |
+
timeout=timeout,
|
524 |
+
max_retries=optional_params.max_retries,
|
525 |
+
organization=organization,
|
526 |
+
)
|
527 |
+
elif custom_llm_provider == "azure":
|
528 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
529 |
+
api_version = (
|
530 |
+
optional_params.api_version
|
531 |
+
or litellm.api_version
|
532 |
+
or get_secret_str("AZURE_API_VERSION")
|
533 |
+
) # type: ignore
|
534 |
+
|
535 |
+
api_key = (
|
536 |
+
optional_params.api_key
|
537 |
+
or litellm.api_key
|
538 |
+
or litellm.azure_key
|
539 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
540 |
+
or get_secret_str("AZURE_API_KEY")
|
541 |
+
) # type: ignore
|
542 |
+
|
543 |
+
extra_body = optional_params.get("extra_body", {})
|
544 |
+
if extra_body is not None:
|
545 |
+
extra_body.pop("azure_ad_token", None)
|
546 |
+
else:
|
547 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
548 |
+
|
549 |
+
response = azure_files_instance.delete_file(
|
550 |
+
_is_async=_is_async,
|
551 |
+
api_base=api_base,
|
552 |
+
api_key=api_key,
|
553 |
+
api_version=api_version,
|
554 |
+
timeout=timeout,
|
555 |
+
max_retries=optional_params.max_retries,
|
556 |
+
file_id=file_id,
|
557 |
+
client=client,
|
558 |
+
litellm_params=litellm_params_dict,
|
559 |
+
)
|
560 |
+
else:
|
561 |
+
raise litellm.exceptions.BadRequestError(
|
562 |
+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
563 |
+
custom_llm_provider
|
564 |
+
),
|
565 |
+
model="n/a",
|
566 |
+
llm_provider=custom_llm_provider,
|
567 |
+
response=httpx.Response(
|
568 |
+
status_code=400,
|
569 |
+
content="Unsupported provider",
|
570 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
571 |
+
),
|
572 |
+
)
|
573 |
+
return cast(FileDeleted, response)
|
574 |
+
except Exception as e:
|
575 |
+
raise e
|
576 |
+
|
577 |
+
|
578 |
+
# List files
|
579 |
+
async def afile_list(
|
580 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
581 |
+
purpose: Optional[str] = None,
|
582 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
583 |
+
extra_body: Optional[Dict[str, str]] = None,
|
584 |
+
**kwargs,
|
585 |
+
):
|
586 |
+
"""
|
587 |
+
Async: List files
|
588 |
+
|
589 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
590 |
+
"""
|
591 |
+
try:
|
592 |
+
loop = asyncio.get_event_loop()
|
593 |
+
kwargs["is_async"] = True
|
594 |
+
|
595 |
+
# Use a partial function to pass your keyword arguments
|
596 |
+
func = partial(
|
597 |
+
file_list,
|
598 |
+
custom_llm_provider,
|
599 |
+
purpose,
|
600 |
+
extra_headers,
|
601 |
+
extra_body,
|
602 |
+
**kwargs,
|
603 |
+
)
|
604 |
+
|
605 |
+
# Add the context to the function
|
606 |
+
ctx = contextvars.copy_context()
|
607 |
+
func_with_context = partial(ctx.run, func)
|
608 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
609 |
+
if asyncio.iscoroutine(init_response):
|
610 |
+
response = await init_response
|
611 |
+
else:
|
612 |
+
response = init_response # type: ignore
|
613 |
+
|
614 |
+
return response
|
615 |
+
except Exception as e:
|
616 |
+
raise e
|
617 |
+
|
618 |
+
|
619 |
+
def file_list(
|
620 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
621 |
+
purpose: Optional[str] = None,
|
622 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
623 |
+
extra_body: Optional[Dict[str, str]] = None,
|
624 |
+
**kwargs,
|
625 |
+
):
|
626 |
+
"""
|
627 |
+
List files
|
628 |
+
|
629 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
630 |
+
"""
|
631 |
+
try:
|
632 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
633 |
+
### TIMEOUT LOGIC ###
|
634 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
635 |
+
# set timeout for 10 minutes by default
|
636 |
+
|
637 |
+
if (
|
638 |
+
timeout is not None
|
639 |
+
and isinstance(timeout, httpx.Timeout)
|
640 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
641 |
+
):
|
642 |
+
read_timeout = timeout.read or 600
|
643 |
+
timeout = read_timeout # default 10 min timeout
|
644 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
645 |
+
timeout = float(timeout) # type: ignore
|
646 |
+
elif timeout is None:
|
647 |
+
timeout = 600.0
|
648 |
+
|
649 |
+
_is_async = kwargs.pop("is_async", False) is True
|
650 |
+
if custom_llm_provider == "openai":
|
651 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
652 |
+
api_base = (
|
653 |
+
optional_params.api_base
|
654 |
+
or litellm.api_base
|
655 |
+
or os.getenv("OPENAI_BASE_URL")
|
656 |
+
or os.getenv("OPENAI_API_BASE")
|
657 |
+
or "https://api.openai.com/v1"
|
658 |
+
)
|
659 |
+
organization = (
|
660 |
+
optional_params.organization
|
661 |
+
or litellm.organization
|
662 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
663 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
664 |
+
)
|
665 |
+
# set API KEY
|
666 |
+
api_key = (
|
667 |
+
optional_params.api_key
|
668 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
669 |
+
or litellm.openai_key
|
670 |
+
or os.getenv("OPENAI_API_KEY")
|
671 |
+
)
|
672 |
+
|
673 |
+
response = openai_files_instance.list_files(
|
674 |
+
purpose=purpose,
|
675 |
+
_is_async=_is_async,
|
676 |
+
api_base=api_base,
|
677 |
+
api_key=api_key,
|
678 |
+
timeout=timeout,
|
679 |
+
max_retries=optional_params.max_retries,
|
680 |
+
organization=organization,
|
681 |
+
)
|
682 |
+
elif custom_llm_provider == "azure":
|
683 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
684 |
+
api_version = (
|
685 |
+
optional_params.api_version
|
686 |
+
or litellm.api_version
|
687 |
+
or get_secret_str("AZURE_API_VERSION")
|
688 |
+
) # type: ignore
|
689 |
+
|
690 |
+
api_key = (
|
691 |
+
optional_params.api_key
|
692 |
+
or litellm.api_key
|
693 |
+
or litellm.azure_key
|
694 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
695 |
+
or get_secret_str("AZURE_API_KEY")
|
696 |
+
) # type: ignore
|
697 |
+
|
698 |
+
extra_body = optional_params.get("extra_body", {})
|
699 |
+
if extra_body is not None:
|
700 |
+
extra_body.pop("azure_ad_token", None)
|
701 |
+
else:
|
702 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
703 |
+
|
704 |
+
response = azure_files_instance.list_files(
|
705 |
+
_is_async=_is_async,
|
706 |
+
api_base=api_base,
|
707 |
+
api_key=api_key,
|
708 |
+
api_version=api_version,
|
709 |
+
timeout=timeout,
|
710 |
+
max_retries=optional_params.max_retries,
|
711 |
+
purpose=purpose,
|
712 |
+
)
|
713 |
+
else:
|
714 |
+
raise litellm.exceptions.BadRequestError(
|
715 |
+
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai' and 'azure' are supported.".format(
|
716 |
+
custom_llm_provider
|
717 |
+
),
|
718 |
+
model="n/a",
|
719 |
+
llm_provider=custom_llm_provider,
|
720 |
+
response=httpx.Response(
|
721 |
+
status_code=400,
|
722 |
+
content="Unsupported provider",
|
723 |
+
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
|
724 |
+
),
|
725 |
+
)
|
726 |
+
return response
|
727 |
+
except Exception as e:
|
728 |
+
raise e
|
729 |
+
|
730 |
+
|
731 |
+
async def afile_content(
|
732 |
+
file_id: str,
|
733 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
734 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
735 |
+
extra_body: Optional[Dict[str, str]] = None,
|
736 |
+
**kwargs,
|
737 |
+
) -> HttpxBinaryResponseContent:
|
738 |
+
"""
|
739 |
+
Async: Get file contents
|
740 |
+
|
741 |
+
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
742 |
+
"""
|
743 |
+
try:
|
744 |
+
loop = asyncio.get_event_loop()
|
745 |
+
kwargs["afile_content"] = True
|
746 |
+
|
747 |
+
# Use a partial function to pass your keyword arguments
|
748 |
+
func = partial(
|
749 |
+
file_content,
|
750 |
+
file_id,
|
751 |
+
custom_llm_provider,
|
752 |
+
extra_headers,
|
753 |
+
extra_body,
|
754 |
+
**kwargs,
|
755 |
+
)
|
756 |
+
|
757 |
+
# Add the context to the function
|
758 |
+
ctx = contextvars.copy_context()
|
759 |
+
func_with_context = partial(ctx.run, func)
|
760 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
761 |
+
if asyncio.iscoroutine(init_response):
|
762 |
+
response = await init_response
|
763 |
+
else:
|
764 |
+
response = init_response # type: ignore
|
765 |
+
|
766 |
+
return response
|
767 |
+
except Exception as e:
|
768 |
+
raise e
|
769 |
+
|
770 |
+
|
771 |
+
def file_content(
|
772 |
+
file_id: str,
|
773 |
+
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
774 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
775 |
+
extra_body: Optional[Dict[str, str]] = None,
|
776 |
+
**kwargs,
|
777 |
+
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
|
778 |
+
"""
|
779 |
+
Returns the contents of the specified file.
|
780 |
+
|
781 |
+
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
782 |
+
"""
|
783 |
+
try:
|
784 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
785 |
+
litellm_params_dict = get_litellm_params(**kwargs)
|
786 |
+
### TIMEOUT LOGIC ###
|
787 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
788 |
+
client = kwargs.get("client")
|
789 |
+
# set timeout for 10 minutes by default
|
790 |
+
|
791 |
+
if (
|
792 |
+
timeout is not None
|
793 |
+
and isinstance(timeout, httpx.Timeout)
|
794 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
795 |
+
):
|
796 |
+
read_timeout = timeout.read or 600
|
797 |
+
timeout = read_timeout # default 10 min timeout
|
798 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
799 |
+
timeout = float(timeout) # type: ignore
|
800 |
+
elif timeout is None:
|
801 |
+
timeout = 600.0
|
802 |
+
|
803 |
+
_file_content_request = FileContentRequest(
|
804 |
+
file_id=file_id,
|
805 |
+
extra_headers=extra_headers,
|
806 |
+
extra_body=extra_body,
|
807 |
+
)
|
808 |
+
|
809 |
+
_is_async = kwargs.pop("afile_content", False) is True
|
810 |
+
|
811 |
+
if custom_llm_provider == "openai":
|
812 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
813 |
+
api_base = (
|
814 |
+
optional_params.api_base
|
815 |
+
or litellm.api_base
|
816 |
+
or os.getenv("OPENAI_BASE_URL")
|
817 |
+
or os.getenv("OPENAI_API_BASE")
|
818 |
+
or "https://api.openai.com/v1"
|
819 |
+
)
|
820 |
+
organization = (
|
821 |
+
optional_params.organization
|
822 |
+
or litellm.organization
|
823 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
824 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
825 |
+
)
|
826 |
+
# set API KEY
|
827 |
+
api_key = (
|
828 |
+
optional_params.api_key
|
829 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
830 |
+
or litellm.openai_key
|
831 |
+
or os.getenv("OPENAI_API_KEY")
|
832 |
+
)
|
833 |
+
|
834 |
+
response = openai_files_instance.file_content(
|
835 |
+
_is_async=_is_async,
|
836 |
+
file_content_request=_file_content_request,
|
837 |
+
api_base=api_base,
|
838 |
+
api_key=api_key,
|
839 |
+
timeout=timeout,
|
840 |
+
max_retries=optional_params.max_retries,
|
841 |
+
organization=organization,
|
842 |
+
)
|
843 |
+
elif custom_llm_provider == "azure":
|
844 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
845 |
+
api_version = (
|
846 |
+
optional_params.api_version
|
847 |
+
or litellm.api_version
|
848 |
+
or get_secret_str("AZURE_API_VERSION")
|
849 |
+
) # type: ignore
|
850 |
+
|
851 |
+
api_key = (
|
852 |
+
optional_params.api_key
|
853 |
+
or litellm.api_key
|
854 |
+
or litellm.azure_key
|
855 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
856 |
+
or get_secret_str("AZURE_API_KEY")
|
857 |
+
) # type: ignore
|
858 |
+
|
859 |
+
extra_body = optional_params.get("extra_body", {})
|
860 |
+
if extra_body is not None:
|
861 |
+
extra_body.pop("azure_ad_token", None)
|
862 |
+
else:
|
863 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
864 |
+
|
865 |
+
response = azure_files_instance.file_content(
|
866 |
+
_is_async=_is_async,
|
867 |
+
api_base=api_base,
|
868 |
+
api_key=api_key,
|
869 |
+
api_version=api_version,
|
870 |
+
timeout=timeout,
|
871 |
+
max_retries=optional_params.max_retries,
|
872 |
+
file_content_request=_file_content_request,
|
873 |
+
client=client,
|
874 |
+
litellm_params=litellm_params_dict,
|
875 |
+
)
|
876 |
+
else:
|
877 |
+
raise litellm.exceptions.BadRequestError(
|
878 |
+
message="LiteLLM doesn't support {} for 'custom_llm_provider'. Supported providers are 'openai', 'azure', 'vertex_ai'.".format(
|
879 |
+
custom_llm_provider
|
880 |
+
),
|
881 |
+
model="n/a",
|
882 |
+
llm_provider=custom_llm_provider,
|
883 |
+
response=httpx.Response(
|
884 |
+
status_code=400,
|
885 |
+
content="Unsupported provider",
|
886 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
887 |
+
),
|
888 |
+
)
|
889 |
+
return response
|
890 |
+
except Exception as e:
|
891 |
+
raise e
|
litellm/fine_tuning/main.py
ADDED
@@ -0,0 +1,761 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main File for Fine Tuning API implementation
|
3 |
+
|
4 |
+
https://platform.openai.com/docs/api-reference/fine-tuning
|
5 |
+
|
6 |
+
- fine_tuning.jobs.create()
|
7 |
+
- fine_tuning.jobs.list()
|
8 |
+
- client.fine_tuning.jobs.list_events()
|
9 |
+
"""
|
10 |
+
|
11 |
+
import asyncio
|
12 |
+
import contextvars
|
13 |
+
import os
|
14 |
+
from functools import partial
|
15 |
+
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
16 |
+
|
17 |
+
import httpx
|
18 |
+
|
19 |
+
import litellm
|
20 |
+
from litellm._logging import verbose_logger
|
21 |
+
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
|
22 |
+
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
23 |
+
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
|
24 |
+
from litellm.secret_managers.main import get_secret_str
|
25 |
+
from litellm.types.llms.openai import (
|
26 |
+
FineTuningJob,
|
27 |
+
FineTuningJobCreate,
|
28 |
+
Hyperparameters,
|
29 |
+
)
|
30 |
+
from litellm.types.router import *
|
31 |
+
from litellm.utils import client, supports_httpx_timeout
|
32 |
+
|
33 |
+
####### ENVIRONMENT VARIABLES ###################
|
34 |
+
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
35 |
+
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
36 |
+
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
|
37 |
+
#################################################
|
38 |
+
|
39 |
+
|
40 |
+
@client
|
41 |
+
async def acreate_fine_tuning_job(
|
42 |
+
model: str,
|
43 |
+
training_file: str,
|
44 |
+
hyperparameters: Optional[dict] = {},
|
45 |
+
suffix: Optional[str] = None,
|
46 |
+
validation_file: Optional[str] = None,
|
47 |
+
integrations: Optional[List[str]] = None,
|
48 |
+
seed: Optional[int] = None,
|
49 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
50 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
51 |
+
extra_body: Optional[Dict[str, str]] = None,
|
52 |
+
**kwargs,
|
53 |
+
) -> FineTuningJob:
|
54 |
+
"""
|
55 |
+
Async: Creates and executes a batch from an uploaded file of request
|
56 |
+
|
57 |
+
"""
|
58 |
+
verbose_logger.debug(
|
59 |
+
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
|
60 |
+
)
|
61 |
+
try:
|
62 |
+
loop = asyncio.get_event_loop()
|
63 |
+
kwargs["acreate_fine_tuning_job"] = True
|
64 |
+
|
65 |
+
# Use a partial function to pass your keyword arguments
|
66 |
+
func = partial(
|
67 |
+
create_fine_tuning_job,
|
68 |
+
model,
|
69 |
+
training_file,
|
70 |
+
hyperparameters,
|
71 |
+
suffix,
|
72 |
+
validation_file,
|
73 |
+
integrations,
|
74 |
+
seed,
|
75 |
+
custom_llm_provider,
|
76 |
+
extra_headers,
|
77 |
+
extra_body,
|
78 |
+
**kwargs,
|
79 |
+
)
|
80 |
+
|
81 |
+
# Add the context to the function
|
82 |
+
ctx = contextvars.copy_context()
|
83 |
+
func_with_context = partial(ctx.run, func)
|
84 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
85 |
+
if asyncio.iscoroutine(init_response):
|
86 |
+
response = await init_response
|
87 |
+
else:
|
88 |
+
response = init_response # type: ignore
|
89 |
+
return response
|
90 |
+
except Exception as e:
|
91 |
+
raise e
|
92 |
+
|
93 |
+
|
94 |
+
@client
|
95 |
+
def create_fine_tuning_job(
|
96 |
+
model: str,
|
97 |
+
training_file: str,
|
98 |
+
hyperparameters: Optional[dict] = {},
|
99 |
+
suffix: Optional[str] = None,
|
100 |
+
validation_file: Optional[str] = None,
|
101 |
+
integrations: Optional[List[str]] = None,
|
102 |
+
seed: Optional[int] = None,
|
103 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
104 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
105 |
+
extra_body: Optional[Dict[str, str]] = None,
|
106 |
+
**kwargs,
|
107 |
+
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
108 |
+
"""
|
109 |
+
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
|
110 |
+
|
111 |
+
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
112 |
+
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
116 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
117 |
+
|
118 |
+
# handle hyperparameters
|
119 |
+
hyperparameters = hyperparameters or {} # original hyperparameters
|
120 |
+
_oai_hyperparameters: Hyperparameters = Hyperparameters(
|
121 |
+
**hyperparameters
|
122 |
+
) # Typed Hyperparameters for OpenAI Spec
|
123 |
+
### TIMEOUT LOGIC ###
|
124 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
125 |
+
# set timeout for 10 minutes by default
|
126 |
+
|
127 |
+
if (
|
128 |
+
timeout is not None
|
129 |
+
and isinstance(timeout, httpx.Timeout)
|
130 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
131 |
+
):
|
132 |
+
read_timeout = timeout.read or 600
|
133 |
+
timeout = read_timeout # default 10 min timeout
|
134 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
135 |
+
timeout = float(timeout) # type: ignore
|
136 |
+
elif timeout is None:
|
137 |
+
timeout = 600.0
|
138 |
+
|
139 |
+
# OpenAI
|
140 |
+
if custom_llm_provider == "openai":
|
141 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
142 |
+
api_base = (
|
143 |
+
optional_params.api_base
|
144 |
+
or litellm.api_base
|
145 |
+
or os.getenv("OPENAI_BASE_URL")
|
146 |
+
or os.getenv("OPENAI_API_BASE")
|
147 |
+
or "https://api.openai.com/v1"
|
148 |
+
)
|
149 |
+
organization = (
|
150 |
+
optional_params.organization
|
151 |
+
or litellm.organization
|
152 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
153 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
154 |
+
)
|
155 |
+
# set API KEY
|
156 |
+
api_key = (
|
157 |
+
optional_params.api_key
|
158 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
159 |
+
or litellm.openai_key
|
160 |
+
or os.getenv("OPENAI_API_KEY")
|
161 |
+
)
|
162 |
+
|
163 |
+
create_fine_tuning_job_data = FineTuningJobCreate(
|
164 |
+
model=model,
|
165 |
+
training_file=training_file,
|
166 |
+
hyperparameters=_oai_hyperparameters,
|
167 |
+
suffix=suffix,
|
168 |
+
validation_file=validation_file,
|
169 |
+
integrations=integrations,
|
170 |
+
seed=seed,
|
171 |
+
)
|
172 |
+
|
173 |
+
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
174 |
+
exclude_none=True
|
175 |
+
)
|
176 |
+
|
177 |
+
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
|
178 |
+
api_base=api_base,
|
179 |
+
api_key=api_key,
|
180 |
+
api_version=optional_params.api_version,
|
181 |
+
organization=organization,
|
182 |
+
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
183 |
+
timeout=timeout,
|
184 |
+
max_retries=optional_params.max_retries,
|
185 |
+
_is_async=_is_async,
|
186 |
+
client=kwargs.get(
|
187 |
+
"client", None
|
188 |
+
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
|
189 |
+
)
|
190 |
+
# Azure OpenAI
|
191 |
+
elif custom_llm_provider == "azure":
|
192 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
193 |
+
|
194 |
+
api_version = (
|
195 |
+
optional_params.api_version
|
196 |
+
or litellm.api_version
|
197 |
+
or get_secret_str("AZURE_API_VERSION")
|
198 |
+
) # type: ignore
|
199 |
+
|
200 |
+
api_key = (
|
201 |
+
optional_params.api_key
|
202 |
+
or litellm.api_key
|
203 |
+
or litellm.azure_key
|
204 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
205 |
+
or get_secret_str("AZURE_API_KEY")
|
206 |
+
) # type: ignore
|
207 |
+
|
208 |
+
extra_body = optional_params.get("extra_body", {})
|
209 |
+
if extra_body is not None:
|
210 |
+
extra_body.pop("azure_ad_token", None)
|
211 |
+
else:
|
212 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
213 |
+
create_fine_tuning_job_data = FineTuningJobCreate(
|
214 |
+
model=model,
|
215 |
+
training_file=training_file,
|
216 |
+
hyperparameters=_oai_hyperparameters,
|
217 |
+
suffix=suffix,
|
218 |
+
validation_file=validation_file,
|
219 |
+
integrations=integrations,
|
220 |
+
seed=seed,
|
221 |
+
)
|
222 |
+
|
223 |
+
create_fine_tuning_job_data_dict = create_fine_tuning_job_data.model_dump(
|
224 |
+
exclude_none=True
|
225 |
+
)
|
226 |
+
|
227 |
+
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
|
228 |
+
api_base=api_base,
|
229 |
+
api_key=api_key,
|
230 |
+
api_version=api_version,
|
231 |
+
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
232 |
+
timeout=timeout,
|
233 |
+
max_retries=optional_params.max_retries,
|
234 |
+
_is_async=_is_async,
|
235 |
+
organization=optional_params.organization,
|
236 |
+
)
|
237 |
+
elif custom_llm_provider == "vertex_ai":
|
238 |
+
api_base = optional_params.api_base or ""
|
239 |
+
vertex_ai_project = (
|
240 |
+
optional_params.vertex_project
|
241 |
+
or litellm.vertex_project
|
242 |
+
or get_secret_str("VERTEXAI_PROJECT")
|
243 |
+
)
|
244 |
+
vertex_ai_location = (
|
245 |
+
optional_params.vertex_location
|
246 |
+
or litellm.vertex_location
|
247 |
+
or get_secret_str("VERTEXAI_LOCATION")
|
248 |
+
)
|
249 |
+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
250 |
+
"VERTEXAI_CREDENTIALS"
|
251 |
+
)
|
252 |
+
create_fine_tuning_job_data = FineTuningJobCreate(
|
253 |
+
model=model,
|
254 |
+
training_file=training_file,
|
255 |
+
hyperparameters=_oai_hyperparameters,
|
256 |
+
suffix=suffix,
|
257 |
+
validation_file=validation_file,
|
258 |
+
integrations=integrations,
|
259 |
+
seed=seed,
|
260 |
+
)
|
261 |
+
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
|
262 |
+
_is_async=_is_async,
|
263 |
+
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
264 |
+
vertex_credentials=vertex_credentials,
|
265 |
+
vertex_project=vertex_ai_project,
|
266 |
+
vertex_location=vertex_ai_location,
|
267 |
+
timeout=timeout,
|
268 |
+
api_base=api_base,
|
269 |
+
kwargs=kwargs,
|
270 |
+
original_hyperparameters=hyperparameters,
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
raise litellm.exceptions.BadRequestError(
|
274 |
+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
275 |
+
custom_llm_provider
|
276 |
+
),
|
277 |
+
model="n/a",
|
278 |
+
llm_provider=custom_llm_provider,
|
279 |
+
response=httpx.Response(
|
280 |
+
status_code=400,
|
281 |
+
content="Unsupported provider",
|
282 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
283 |
+
),
|
284 |
+
)
|
285 |
+
return response
|
286 |
+
except Exception as e:
|
287 |
+
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
|
288 |
+
raise e
|
289 |
+
|
290 |
+
|
291 |
+
async def acancel_fine_tuning_job(
|
292 |
+
fine_tuning_job_id: str,
|
293 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
294 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
295 |
+
extra_body: Optional[Dict[str, str]] = None,
|
296 |
+
**kwargs,
|
297 |
+
) -> FineTuningJob:
|
298 |
+
"""
|
299 |
+
Async: Immediately cancel a fine-tune job.
|
300 |
+
"""
|
301 |
+
try:
|
302 |
+
loop = asyncio.get_event_loop()
|
303 |
+
kwargs["acancel_fine_tuning_job"] = True
|
304 |
+
|
305 |
+
# Use a partial function to pass your keyword arguments
|
306 |
+
func = partial(
|
307 |
+
cancel_fine_tuning_job,
|
308 |
+
fine_tuning_job_id,
|
309 |
+
custom_llm_provider,
|
310 |
+
extra_headers,
|
311 |
+
extra_body,
|
312 |
+
**kwargs,
|
313 |
+
)
|
314 |
+
|
315 |
+
# Add the context to the function
|
316 |
+
ctx = contextvars.copy_context()
|
317 |
+
func_with_context = partial(ctx.run, func)
|
318 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
319 |
+
if asyncio.iscoroutine(init_response):
|
320 |
+
response = await init_response
|
321 |
+
else:
|
322 |
+
response = init_response # type: ignore
|
323 |
+
return response
|
324 |
+
except Exception as e:
|
325 |
+
raise e
|
326 |
+
|
327 |
+
|
328 |
+
def cancel_fine_tuning_job(
|
329 |
+
fine_tuning_job_id: str,
|
330 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
331 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
332 |
+
extra_body: Optional[Dict[str, str]] = None,
|
333 |
+
**kwargs,
|
334 |
+
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
335 |
+
"""
|
336 |
+
Immediately cancel a fine-tune job.
|
337 |
+
|
338 |
+
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
339 |
+
|
340 |
+
"""
|
341 |
+
try:
|
342 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
343 |
+
### TIMEOUT LOGIC ###
|
344 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
345 |
+
# set timeout for 10 minutes by default
|
346 |
+
|
347 |
+
if (
|
348 |
+
timeout is not None
|
349 |
+
and isinstance(timeout, httpx.Timeout)
|
350 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
351 |
+
):
|
352 |
+
read_timeout = timeout.read or 600
|
353 |
+
timeout = read_timeout # default 10 min timeout
|
354 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
355 |
+
timeout = float(timeout) # type: ignore
|
356 |
+
elif timeout is None:
|
357 |
+
timeout = 600.0
|
358 |
+
|
359 |
+
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
360 |
+
|
361 |
+
# OpenAI
|
362 |
+
if custom_llm_provider == "openai":
|
363 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
364 |
+
api_base = (
|
365 |
+
optional_params.api_base
|
366 |
+
or litellm.api_base
|
367 |
+
or os.getenv("OPENAI_BASE_URL")
|
368 |
+
or os.getenv("OPENAI_API_BASE")
|
369 |
+
or "https://api.openai.com/v1"
|
370 |
+
)
|
371 |
+
organization = (
|
372 |
+
optional_params.organization
|
373 |
+
or litellm.organization
|
374 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
375 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
376 |
+
)
|
377 |
+
# set API KEY
|
378 |
+
api_key = (
|
379 |
+
optional_params.api_key
|
380 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
381 |
+
or litellm.openai_key
|
382 |
+
or os.getenv("OPENAI_API_KEY")
|
383 |
+
)
|
384 |
+
|
385 |
+
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
386 |
+
api_base=api_base,
|
387 |
+
api_key=api_key,
|
388 |
+
api_version=optional_params.api_version,
|
389 |
+
organization=organization,
|
390 |
+
fine_tuning_job_id=fine_tuning_job_id,
|
391 |
+
timeout=timeout,
|
392 |
+
max_retries=optional_params.max_retries,
|
393 |
+
_is_async=_is_async,
|
394 |
+
client=kwargs.get("client", None),
|
395 |
+
)
|
396 |
+
# Azure OpenAI
|
397 |
+
elif custom_llm_provider == "azure":
|
398 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
399 |
+
|
400 |
+
api_version = (
|
401 |
+
optional_params.api_version
|
402 |
+
or litellm.api_version
|
403 |
+
or get_secret_str("AZURE_API_VERSION")
|
404 |
+
) # type: ignore
|
405 |
+
|
406 |
+
api_key = (
|
407 |
+
optional_params.api_key
|
408 |
+
or litellm.api_key
|
409 |
+
or litellm.azure_key
|
410 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
411 |
+
or get_secret_str("AZURE_API_KEY")
|
412 |
+
) # type: ignore
|
413 |
+
|
414 |
+
extra_body = optional_params.get("extra_body", {})
|
415 |
+
if extra_body is not None:
|
416 |
+
extra_body.pop("azure_ad_token", None)
|
417 |
+
else:
|
418 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
419 |
+
|
420 |
+
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
421 |
+
api_base=api_base,
|
422 |
+
api_key=api_key,
|
423 |
+
api_version=api_version,
|
424 |
+
fine_tuning_job_id=fine_tuning_job_id,
|
425 |
+
timeout=timeout,
|
426 |
+
max_retries=optional_params.max_retries,
|
427 |
+
_is_async=_is_async,
|
428 |
+
organization=optional_params.organization,
|
429 |
+
)
|
430 |
+
else:
|
431 |
+
raise litellm.exceptions.BadRequestError(
|
432 |
+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
433 |
+
custom_llm_provider
|
434 |
+
),
|
435 |
+
model="n/a",
|
436 |
+
llm_provider=custom_llm_provider,
|
437 |
+
response=httpx.Response(
|
438 |
+
status_code=400,
|
439 |
+
content="Unsupported provider",
|
440 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
441 |
+
),
|
442 |
+
)
|
443 |
+
return response
|
444 |
+
except Exception as e:
|
445 |
+
raise e
|
446 |
+
|
447 |
+
|
448 |
+
async def alist_fine_tuning_jobs(
|
449 |
+
after: Optional[str] = None,
|
450 |
+
limit: Optional[int] = None,
|
451 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
452 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
453 |
+
extra_body: Optional[Dict[str, str]] = None,
|
454 |
+
**kwargs,
|
455 |
+
):
|
456 |
+
"""
|
457 |
+
Async: List your organization's fine-tuning jobs
|
458 |
+
"""
|
459 |
+
try:
|
460 |
+
loop = asyncio.get_event_loop()
|
461 |
+
kwargs["alist_fine_tuning_jobs"] = True
|
462 |
+
|
463 |
+
# Use a partial function to pass your keyword arguments
|
464 |
+
func = partial(
|
465 |
+
list_fine_tuning_jobs,
|
466 |
+
after,
|
467 |
+
limit,
|
468 |
+
custom_llm_provider,
|
469 |
+
extra_headers,
|
470 |
+
extra_body,
|
471 |
+
**kwargs,
|
472 |
+
)
|
473 |
+
|
474 |
+
# Add the context to the function
|
475 |
+
ctx = contextvars.copy_context()
|
476 |
+
func_with_context = partial(ctx.run, func)
|
477 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
478 |
+
if asyncio.iscoroutine(init_response):
|
479 |
+
response = await init_response
|
480 |
+
else:
|
481 |
+
response = init_response # type: ignore
|
482 |
+
return response
|
483 |
+
except Exception as e:
|
484 |
+
raise e
|
485 |
+
|
486 |
+
|
487 |
+
def list_fine_tuning_jobs(
|
488 |
+
after: Optional[str] = None,
|
489 |
+
limit: Optional[int] = None,
|
490 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
491 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
492 |
+
extra_body: Optional[Dict[str, str]] = None,
|
493 |
+
**kwargs,
|
494 |
+
):
|
495 |
+
"""
|
496 |
+
List your organization's fine-tuning jobs
|
497 |
+
|
498 |
+
Params:
|
499 |
+
|
500 |
+
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
|
501 |
+
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
|
502 |
+
"""
|
503 |
+
try:
|
504 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
505 |
+
### TIMEOUT LOGIC ###
|
506 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
507 |
+
# set timeout for 10 minutes by default
|
508 |
+
|
509 |
+
if (
|
510 |
+
timeout is not None
|
511 |
+
and isinstance(timeout, httpx.Timeout)
|
512 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
513 |
+
):
|
514 |
+
read_timeout = timeout.read or 600
|
515 |
+
timeout = read_timeout # default 10 min timeout
|
516 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
517 |
+
timeout = float(timeout) # type: ignore
|
518 |
+
elif timeout is None:
|
519 |
+
timeout = 600.0
|
520 |
+
|
521 |
+
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
522 |
+
|
523 |
+
# OpenAI
|
524 |
+
if custom_llm_provider == "openai":
|
525 |
+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
526 |
+
api_base = (
|
527 |
+
optional_params.api_base
|
528 |
+
or litellm.api_base
|
529 |
+
or os.getenv("OPENAI_BASE_URL")
|
530 |
+
or os.getenv("OPENAI_API_BASE")
|
531 |
+
or "https://api.openai.com/v1"
|
532 |
+
)
|
533 |
+
organization = (
|
534 |
+
optional_params.organization
|
535 |
+
or litellm.organization
|
536 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
537 |
+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
538 |
+
)
|
539 |
+
# set API KEY
|
540 |
+
api_key = (
|
541 |
+
optional_params.api_key
|
542 |
+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
543 |
+
or litellm.openai_key
|
544 |
+
or os.getenv("OPENAI_API_KEY")
|
545 |
+
)
|
546 |
+
|
547 |
+
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
548 |
+
api_base=api_base,
|
549 |
+
api_key=api_key,
|
550 |
+
api_version=optional_params.api_version,
|
551 |
+
organization=organization,
|
552 |
+
after=after,
|
553 |
+
limit=limit,
|
554 |
+
timeout=timeout,
|
555 |
+
max_retries=optional_params.max_retries,
|
556 |
+
_is_async=_is_async,
|
557 |
+
client=kwargs.get("client", None),
|
558 |
+
)
|
559 |
+
# Azure OpenAI
|
560 |
+
elif custom_llm_provider == "azure":
|
561 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
562 |
+
|
563 |
+
api_version = (
|
564 |
+
optional_params.api_version
|
565 |
+
or litellm.api_version
|
566 |
+
or get_secret_str("AZURE_API_VERSION")
|
567 |
+
) # type: ignore
|
568 |
+
|
569 |
+
api_key = (
|
570 |
+
optional_params.api_key
|
571 |
+
or litellm.api_key
|
572 |
+
or litellm.azure_key
|
573 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
574 |
+
or get_secret_str("AZURE_API_KEY")
|
575 |
+
) # type: ignore
|
576 |
+
|
577 |
+
extra_body = optional_params.get("extra_body", {})
|
578 |
+
if extra_body is not None:
|
579 |
+
extra_body.pop("azure_ad_token", None)
|
580 |
+
else:
|
581 |
+
get_secret("AZURE_AD_TOKEN") # type: ignore
|
582 |
+
|
583 |
+
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
584 |
+
api_base=api_base,
|
585 |
+
api_key=api_key,
|
586 |
+
api_version=api_version,
|
587 |
+
after=after,
|
588 |
+
limit=limit,
|
589 |
+
timeout=timeout,
|
590 |
+
max_retries=optional_params.max_retries,
|
591 |
+
_is_async=_is_async,
|
592 |
+
organization=optional_params.organization,
|
593 |
+
)
|
594 |
+
else:
|
595 |
+
raise litellm.exceptions.BadRequestError(
|
596 |
+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
597 |
+
custom_llm_provider
|
598 |
+
),
|
599 |
+
model="n/a",
|
600 |
+
llm_provider=custom_llm_provider,
|
601 |
+
response=httpx.Response(
|
602 |
+
status_code=400,
|
603 |
+
content="Unsupported provider",
|
604 |
+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
605 |
+
),
|
606 |
+
)
|
607 |
+
return response
|
608 |
+
except Exception as e:
|
609 |
+
raise e
|
610 |
+
|
611 |
+
|
612 |
+
async def aretrieve_fine_tuning_job(
|
613 |
+
fine_tuning_job_id: str,
|
614 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
615 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
616 |
+
extra_body: Optional[Dict[str, str]] = None,
|
617 |
+
**kwargs,
|
618 |
+
) -> FineTuningJob:
|
619 |
+
"""
|
620 |
+
Async: Get info about a fine-tuning job.
|
621 |
+
"""
|
622 |
+
try:
|
623 |
+
loop = asyncio.get_event_loop()
|
624 |
+
kwargs["aretrieve_fine_tuning_job"] = True
|
625 |
+
|
626 |
+
# Use a partial function to pass your keyword arguments
|
627 |
+
func = partial(
|
628 |
+
retrieve_fine_tuning_job,
|
629 |
+
fine_tuning_job_id,
|
630 |
+
custom_llm_provider,
|
631 |
+
extra_headers,
|
632 |
+
extra_body,
|
633 |
+
**kwargs,
|
634 |
+
)
|
635 |
+
|
636 |
+
# Add the context to the function
|
637 |
+
ctx = contextvars.copy_context()
|
638 |
+
func_with_context = partial(ctx.run, func)
|
639 |
+
init_response = await loop.run_in_executor(None, func_with_context)
|
640 |
+
if asyncio.iscoroutine(init_response):
|
641 |
+
response = await init_response
|
642 |
+
else:
|
643 |
+
response = init_response # type: ignore
|
644 |
+
return response
|
645 |
+
except Exception as e:
|
646 |
+
raise e
|
647 |
+
|
648 |
+
|
649 |
+
def retrieve_fine_tuning_job(
|
650 |
+
fine_tuning_job_id: str,
|
651 |
+
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
652 |
+
extra_headers: Optional[Dict[str, str]] = None,
|
653 |
+
extra_body: Optional[Dict[str, str]] = None,
|
654 |
+
**kwargs,
|
655 |
+
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
656 |
+
"""
|
657 |
+
Get info about a fine-tuning job.
|
658 |
+
"""
|
659 |
+
try:
|
660 |
+
optional_params = GenericLiteLLMParams(**kwargs)
|
661 |
+
### TIMEOUT LOGIC ###
|
662 |
+
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
663 |
+
# set timeout for 10 minutes by default
|
664 |
+
|
665 |
+
if (
|
666 |
+
timeout is not None
|
667 |
+
and isinstance(timeout, httpx.Timeout)
|
668 |
+
and supports_httpx_timeout(custom_llm_provider) is False
|
669 |
+
):
|
670 |
+
read_timeout = timeout.read or 600
|
671 |
+
timeout = read_timeout # default 10 min timeout
|
672 |
+
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
673 |
+
timeout = float(timeout) # type: ignore
|
674 |
+
elif timeout is None:
|
675 |
+
timeout = 600.0
|
676 |
+
|
677 |
+
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
|
678 |
+
|
679 |
+
# OpenAI
|
680 |
+
if custom_llm_provider == "openai":
|
681 |
+
api_base = (
|
682 |
+
optional_params.api_base
|
683 |
+
or litellm.api_base
|
684 |
+
or os.getenv("OPENAI_BASE_URL")
|
685 |
+
or os.getenv("OPENAI_API_BASE")
|
686 |
+
or "https://api.openai.com/v1"
|
687 |
+
)
|
688 |
+
organization = (
|
689 |
+
optional_params.organization
|
690 |
+
or litellm.organization
|
691 |
+
or os.getenv("OPENAI_ORGANIZATION", None)
|
692 |
+
or None
|
693 |
+
)
|
694 |
+
api_key = (
|
695 |
+
optional_params.api_key
|
696 |
+
or litellm.api_key
|
697 |
+
or litellm.openai_key
|
698 |
+
or os.getenv("OPENAI_API_KEY")
|
699 |
+
)
|
700 |
+
|
701 |
+
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
702 |
+
api_base=api_base,
|
703 |
+
api_key=api_key,
|
704 |
+
api_version=optional_params.api_version,
|
705 |
+
organization=organization,
|
706 |
+
fine_tuning_job_id=fine_tuning_job_id,
|
707 |
+
timeout=timeout,
|
708 |
+
max_retries=optional_params.max_retries,
|
709 |
+
_is_async=_is_async,
|
710 |
+
client=kwargs.get("client", None),
|
711 |
+
)
|
712 |
+
# Azure OpenAI
|
713 |
+
elif custom_llm_provider == "azure":
|
714 |
+
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
715 |
+
|
716 |
+
api_version = (
|
717 |
+
optional_params.api_version
|
718 |
+
or litellm.api_version
|
719 |
+
or get_secret_str("AZURE_API_VERSION")
|
720 |
+
) # type: ignore
|
721 |
+
|
722 |
+
api_key = (
|
723 |
+
optional_params.api_key
|
724 |
+
or litellm.api_key
|
725 |
+
or litellm.azure_key
|
726 |
+
or get_secret_str("AZURE_OPENAI_API_KEY")
|
727 |
+
or get_secret_str("AZURE_API_KEY")
|
728 |
+
) # type: ignore
|
729 |
+
|
730 |
+
extra_body = optional_params.get("extra_body", {})
|
731 |
+
if extra_body is not None:
|
732 |
+
extra_body.pop("azure_ad_token", None)
|
733 |
+
else:
|
734 |
+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
735 |
+
|
736 |
+
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
737 |
+
api_base=api_base,
|
738 |
+
api_key=api_key,
|
739 |
+
api_version=api_version,
|
740 |
+
fine_tuning_job_id=fine_tuning_job_id,
|
741 |
+
timeout=timeout,
|
742 |
+
max_retries=optional_params.max_retries,
|
743 |
+
_is_async=_is_async,
|
744 |
+
organization=optional_params.organization,
|
745 |
+
)
|
746 |
+
else:
|
747 |
+
raise litellm.exceptions.BadRequestError(
|
748 |
+
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
|
749 |
+
custom_llm_provider
|
750 |
+
),
|
751 |
+
model="n/a",
|
752 |
+
llm_provider=custom_llm_provider,
|
753 |
+
response=httpx.Response(
|
754 |
+
status_code=400,
|
755 |
+
content="Unsupported provider",
|
756 |
+
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
|
757 |
+
),
|
758 |
+
)
|
759 |
+
return response
|
760 |
+
except Exception as e:
|
761 |
+
raise e
|
litellm/integrations/Readme.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Integrations
|
2 |
+
|
3 |
+
This folder contains logging integrations for litellm
|
4 |
+
|
5 |
+
eg. logging to Datadog, Langfuse, Prometheus, s3, GCS Bucket, etc.
|
litellm/integrations/SlackAlerting/Readme.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Slack Alerting on LiteLLM Gateway
|
2 |
+
|
3 |
+
This folder contains the Slack Alerting integration for LiteLLM Gateway.
|
4 |
+
|
5 |
+
## Folder Structure
|
6 |
+
|
7 |
+
- `slack_alerting.py`: This is the main file that handles sending different types of alerts
|
8 |
+
- `batching_handler.py`: Handles Batching + sending Httpx Post requests to slack. Slack alerts are sent every 10s or when events are greater than X events. Done to ensure litellm has good performance under high traffic
|
9 |
+
- `types.py`: This file contains the AlertType enum which is used to define the different types of alerts that can be sent to Slack.
|
10 |
+
- `utils.py`: This file contains common utils used specifically for slack alerting
|
11 |
+
|
12 |
+
## Further Reading
|
13 |
+
- [Doc setting up Alerting on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/alerting)
|
litellm/integrations/SlackAlerting/batching_handler.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Handles Batching + sending Httpx Post requests to slack
|
3 |
+
|
4 |
+
Slack alerts are sent every 10s or when events are greater than X events
|
5 |
+
|
6 |
+
see custom_batch_logger.py for more details / defaults
|
7 |
+
"""
|
8 |
+
|
9 |
+
from typing import TYPE_CHECKING, Any
|
10 |
+
|
11 |
+
from litellm._logging import verbose_proxy_logger
|
12 |
+
|
13 |
+
if TYPE_CHECKING:
|
14 |
+
from .slack_alerting import SlackAlerting as _SlackAlerting
|
15 |
+
|
16 |
+
SlackAlertingType = _SlackAlerting
|
17 |
+
else:
|
18 |
+
SlackAlertingType = Any
|
19 |
+
|
20 |
+
|
21 |
+
def squash_payloads(queue):
|
22 |
+
squashed = {}
|
23 |
+
if len(queue) == 0:
|
24 |
+
return squashed
|
25 |
+
if len(queue) == 1:
|
26 |
+
return {"key": {"item": queue[0], "count": 1}}
|
27 |
+
|
28 |
+
for item in queue:
|
29 |
+
url = item["url"]
|
30 |
+
alert_type = item["alert_type"]
|
31 |
+
_key = (url, alert_type)
|
32 |
+
|
33 |
+
if _key in squashed:
|
34 |
+
squashed[_key]["count"] += 1
|
35 |
+
# Merge the payloads
|
36 |
+
|
37 |
+
else:
|
38 |
+
squashed[_key] = {"item": item, "count": 1}
|
39 |
+
|
40 |
+
return squashed
|
41 |
+
|
42 |
+
|
43 |
+
def _print_alerting_payload_warning(
|
44 |
+
payload: dict, slackAlertingInstance: SlackAlertingType
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
Print the payload to the console when
|
48 |
+
slackAlertingInstance.alerting_args.log_to_console is True
|
49 |
+
|
50 |
+
Relevant issue: https://github.com/BerriAI/litellm/issues/7372
|
51 |
+
"""
|
52 |
+
if slackAlertingInstance.alerting_args.log_to_console is True:
|
53 |
+
verbose_proxy_logger.warning(payload)
|
54 |
+
|
55 |
+
|
56 |
+
async def send_to_webhook(slackAlertingInstance: SlackAlertingType, item, count):
|
57 |
+
"""
|
58 |
+
Send a single slack alert to the webhook
|
59 |
+
"""
|
60 |
+
import json
|
61 |
+
|
62 |
+
payload = item.get("payload", {})
|
63 |
+
try:
|
64 |
+
if count > 1:
|
65 |
+
payload["text"] = f"[Num Alerts: {count}]\n\n{payload['text']}"
|
66 |
+
|
67 |
+
response = await slackAlertingInstance.async_http_handler.post(
|
68 |
+
url=item["url"],
|
69 |
+
headers=item["headers"],
|
70 |
+
data=json.dumps(payload),
|
71 |
+
)
|
72 |
+
if response.status_code != 200:
|
73 |
+
verbose_proxy_logger.debug(
|
74 |
+
f"Error sending slack alert to url={item['url']}. Error={response.text}"
|
75 |
+
)
|
76 |
+
except Exception as e:
|
77 |
+
verbose_proxy_logger.debug(f"Error sending slack alert: {str(e)}")
|
78 |
+
finally:
|
79 |
+
_print_alerting_payload_warning(
|
80 |
+
payload, slackAlertingInstance=slackAlertingInstance
|
81 |
+
)
|
litellm/integrations/SlackAlerting/slack_alerting.py
ADDED
@@ -0,0 +1,1825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#### What this does ####
|
2 |
+
# Class for sending Slack Alerts #
|
3 |
+
import asyncio
|
4 |
+
import datetime
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
from datetime import timedelta
|
9 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
10 |
+
|
11 |
+
from openai import APIError
|
12 |
+
|
13 |
+
import litellm
|
14 |
+
import litellm.litellm_core_utils
|
15 |
+
import litellm.litellm_core_utils.litellm_logging
|
16 |
+
import litellm.types
|
17 |
+
from litellm._logging import verbose_logger, verbose_proxy_logger
|
18 |
+
from litellm.caching.caching import DualCache
|
19 |
+
from litellm.constants import HOURS_IN_A_DAY
|
20 |
+
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
21 |
+
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
22 |
+
from litellm.litellm_core_utils.exception_mapping_utils import (
|
23 |
+
_add_key_name_and_team_to_alert,
|
24 |
+
)
|
25 |
+
from litellm.llms.custom_httpx.http_handler import (
|
26 |
+
get_async_httpx_client,
|
27 |
+
httpxSpecialProvider,
|
28 |
+
)
|
29 |
+
from litellm.proxy._types import AlertType, CallInfo, VirtualKeyEvent, WebhookEvent
|
30 |
+
from litellm.types.integrations.slack_alerting import *
|
31 |
+
|
32 |
+
from ..email_templates.templates import *
|
33 |
+
from .batching_handler import send_to_webhook, squash_payloads
|
34 |
+
from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables
|
35 |
+
|
36 |
+
if TYPE_CHECKING:
|
37 |
+
from litellm.router import Router as _Router
|
38 |
+
|
39 |
+
Router = _Router
|
40 |
+
else:
|
41 |
+
Router = Any
|
42 |
+
|
43 |
+
|
44 |
+
class SlackAlerting(CustomBatchLogger):
|
45 |
+
"""
|
46 |
+
Class for sending Slack Alerts
|
47 |
+
"""
|
48 |
+
|
49 |
+
# Class variables or attributes
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
internal_usage_cache: Optional[DualCache] = None,
|
53 |
+
alerting_threshold: Optional[
|
54 |
+
float
|
55 |
+
] = None, # threshold for slow / hanging llm responses (in seconds)
|
56 |
+
alerting: Optional[List] = [],
|
57 |
+
alert_types: List[AlertType] = DEFAULT_ALERT_TYPES,
|
58 |
+
alert_to_webhook_url: Optional[
|
59 |
+
Dict[AlertType, Union[List[str], str]]
|
60 |
+
] = None, # if user wants to separate alerts to diff channels
|
61 |
+
alerting_args={},
|
62 |
+
default_webhook_url: Optional[str] = None,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
if alerting_threshold is None:
|
66 |
+
alerting_threshold = 300
|
67 |
+
self.alerting_threshold = alerting_threshold
|
68 |
+
self.alerting = alerting
|
69 |
+
self.alert_types = alert_types
|
70 |
+
self.internal_usage_cache = internal_usage_cache or DualCache()
|
71 |
+
self.async_http_handler = get_async_httpx_client(
|
72 |
+
llm_provider=httpxSpecialProvider.LoggingCallback
|
73 |
+
)
|
74 |
+
self.alert_to_webhook_url = process_slack_alerting_variables(
|
75 |
+
alert_to_webhook_url=alert_to_webhook_url
|
76 |
+
)
|
77 |
+
self.is_running = False
|
78 |
+
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
79 |
+
self.default_webhook_url = default_webhook_url
|
80 |
+
self.flush_lock = asyncio.Lock()
|
81 |
+
super().__init__(**kwargs, flush_lock=self.flush_lock)
|
82 |
+
|
83 |
+
def update_values(
|
84 |
+
self,
|
85 |
+
alerting: Optional[List] = None,
|
86 |
+
alerting_threshold: Optional[float] = None,
|
87 |
+
alert_types: Optional[List[AlertType]] = None,
|
88 |
+
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]] = None,
|
89 |
+
alerting_args: Optional[Dict] = None,
|
90 |
+
llm_router: Optional[Router] = None,
|
91 |
+
):
|
92 |
+
if alerting is not None:
|
93 |
+
self.alerting = alerting
|
94 |
+
asyncio.create_task(self.periodic_flush())
|
95 |
+
if alerting_threshold is not None:
|
96 |
+
self.alerting_threshold = alerting_threshold
|
97 |
+
if alert_types is not None:
|
98 |
+
self.alert_types = alert_types
|
99 |
+
if alerting_args is not None:
|
100 |
+
self.alerting_args = SlackAlertingArgs(**alerting_args)
|
101 |
+
if alert_to_webhook_url is not None:
|
102 |
+
# update the dict
|
103 |
+
if self.alert_to_webhook_url is None:
|
104 |
+
self.alert_to_webhook_url = process_slack_alerting_variables(
|
105 |
+
alert_to_webhook_url=alert_to_webhook_url
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
_new_values = (
|
109 |
+
process_slack_alerting_variables(
|
110 |
+
alert_to_webhook_url=alert_to_webhook_url
|
111 |
+
)
|
112 |
+
or {}
|
113 |
+
)
|
114 |
+
self.alert_to_webhook_url.update(_new_values)
|
115 |
+
if llm_router is not None:
|
116 |
+
self.llm_router = llm_router
|
117 |
+
|
118 |
+
async def deployment_in_cooldown(self):
|
119 |
+
pass
|
120 |
+
|
121 |
+
async def deployment_removed_from_cooldown(self):
|
122 |
+
pass
|
123 |
+
|
124 |
+
def _all_possible_alert_types(self):
|
125 |
+
# used by the UI to show all supported alert types
|
126 |
+
# Note: This is not the alerts the user has configured, instead it's all possible alert types a user can select
|
127 |
+
# return list of all values AlertType enum
|
128 |
+
return list(AlertType)
|
129 |
+
|
130 |
+
def _response_taking_too_long_callback_helper(
|
131 |
+
self,
|
132 |
+
kwargs, # kwargs to completion
|
133 |
+
start_time,
|
134 |
+
end_time, # start/end time
|
135 |
+
):
|
136 |
+
try:
|
137 |
+
time_difference = end_time - start_time
|
138 |
+
# Convert the timedelta to float (in seconds)
|
139 |
+
time_difference_float = time_difference.total_seconds()
|
140 |
+
litellm_params = kwargs.get("litellm_params", {})
|
141 |
+
model = kwargs.get("model", "")
|
142 |
+
api_base = litellm.get_api_base(model=model, optional_params=litellm_params)
|
143 |
+
messages = kwargs.get("messages", None)
|
144 |
+
# if messages does not exist fallback to "input"
|
145 |
+
if messages is None:
|
146 |
+
messages = kwargs.get("input", None)
|
147 |
+
|
148 |
+
# only use first 100 chars for alerting
|
149 |
+
_messages = str(messages)[:100]
|
150 |
+
|
151 |
+
return time_difference_float, model, api_base, _messages
|
152 |
+
except Exception as e:
|
153 |
+
raise e
|
154 |
+
|
155 |
+
def _get_deployment_latencies_to_alert(self, metadata=None):
|
156 |
+
if metadata is None:
|
157 |
+
return None
|
158 |
+
|
159 |
+
if "_latency_per_deployment" in metadata:
|
160 |
+
# Translate model_id to -> api_base
|
161 |
+
# _latency_per_deployment is a dictionary that looks like this:
|
162 |
+
"""
|
163 |
+
_latency_per_deployment: {
|
164 |
+
api_base: 0.01336697916666667
|
165 |
+
}
|
166 |
+
"""
|
167 |
+
_message_to_send = ""
|
168 |
+
_deployment_latencies = metadata["_latency_per_deployment"]
|
169 |
+
if len(_deployment_latencies) == 0:
|
170 |
+
return None
|
171 |
+
_deployment_latency_map: Optional[dict] = None
|
172 |
+
try:
|
173 |
+
# try sorting deployments by latency
|
174 |
+
_deployment_latencies = sorted(
|
175 |
+
_deployment_latencies.items(), key=lambda x: x[1]
|
176 |
+
)
|
177 |
+
_deployment_latency_map = dict(_deployment_latencies)
|
178 |
+
except Exception:
|
179 |
+
pass
|
180 |
+
|
181 |
+
if _deployment_latency_map is None:
|
182 |
+
return
|
183 |
+
|
184 |
+
for api_base, latency in _deployment_latency_map.items():
|
185 |
+
_message_to_send += f"\n{api_base}: {round(latency,2)}s"
|
186 |
+
_message_to_send = "```" + _message_to_send + "```"
|
187 |
+
return _message_to_send
|
188 |
+
|
189 |
+
async def response_taking_too_long_callback(
|
190 |
+
self,
|
191 |
+
kwargs, # kwargs to completion
|
192 |
+
completion_response, # response from completion
|
193 |
+
start_time,
|
194 |
+
end_time, # start/end time
|
195 |
+
):
|
196 |
+
if self.alerting is None or self.alert_types is None:
|
197 |
+
return
|
198 |
+
|
199 |
+
(
|
200 |
+
time_difference_float,
|
201 |
+
model,
|
202 |
+
api_base,
|
203 |
+
messages,
|
204 |
+
) = self._response_taking_too_long_callback_helper(
|
205 |
+
kwargs=kwargs,
|
206 |
+
start_time=start_time,
|
207 |
+
end_time=end_time,
|
208 |
+
)
|
209 |
+
if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions:
|
210 |
+
messages = "Message not logged. litellm.redact_messages_in_exceptions=True"
|
211 |
+
request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`"
|
212 |
+
slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`"
|
213 |
+
alerting_metadata: dict = {}
|
214 |
+
if time_difference_float > self.alerting_threshold:
|
215 |
+
# add deployment latencies to alert
|
216 |
+
if (
|
217 |
+
kwargs is not None
|
218 |
+
and "litellm_params" in kwargs
|
219 |
+
and "metadata" in kwargs["litellm_params"]
|
220 |
+
):
|
221 |
+
_metadata: dict = kwargs["litellm_params"]["metadata"]
|
222 |
+
request_info = _add_key_name_and_team_to_alert(
|
223 |
+
request_info=request_info, metadata=_metadata
|
224 |
+
)
|
225 |
+
|
226 |
+
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
227 |
+
metadata=_metadata
|
228 |
+
)
|
229 |
+
if _deployment_latency_map is not None:
|
230 |
+
request_info += (
|
231 |
+
f"\nAvailable Deployment Latencies\n{_deployment_latency_map}"
|
232 |
+
)
|
233 |
+
|
234 |
+
if "alerting_metadata" in _metadata:
|
235 |
+
alerting_metadata = _metadata["alerting_metadata"]
|
236 |
+
await self.send_alert(
|
237 |
+
message=slow_message + request_info,
|
238 |
+
level="Low",
|
239 |
+
alert_type=AlertType.llm_too_slow,
|
240 |
+
alerting_metadata=alerting_metadata,
|
241 |
+
)
|
242 |
+
|
243 |
+
async def async_update_daily_reports(
|
244 |
+
self, deployment_metrics: DeploymentMetrics
|
245 |
+
) -> int:
|
246 |
+
"""
|
247 |
+
Store the perf by deployment in cache
|
248 |
+
- Number of failed requests per deployment
|
249 |
+
- Latency / output tokens per deployment
|
250 |
+
|
251 |
+
'deployment_id:daily_metrics:failed_requests'
|
252 |
+
'deployment_id:daily_metrics:latency_per_output_token'
|
253 |
+
|
254 |
+
Returns
|
255 |
+
int - count of metrics set (1 - if just latency, 2 - if failed + latency)
|
256 |
+
"""
|
257 |
+
|
258 |
+
return_val = 0
|
259 |
+
try:
|
260 |
+
## FAILED REQUESTS ##
|
261 |
+
if deployment_metrics.failed_request:
|
262 |
+
await self.internal_usage_cache.async_increment_cache(
|
263 |
+
key="{}:{}".format(
|
264 |
+
deployment_metrics.id,
|
265 |
+
SlackAlertingCacheKeys.failed_requests_key.value,
|
266 |
+
),
|
267 |
+
value=1,
|
268 |
+
parent_otel_span=None, # no attached request, this is a background operation
|
269 |
+
)
|
270 |
+
|
271 |
+
return_val += 1
|
272 |
+
|
273 |
+
## LATENCY ##
|
274 |
+
if deployment_metrics.latency_per_output_token is not None:
|
275 |
+
await self.internal_usage_cache.async_increment_cache(
|
276 |
+
key="{}:{}".format(
|
277 |
+
deployment_metrics.id, SlackAlertingCacheKeys.latency_key.value
|
278 |
+
),
|
279 |
+
value=deployment_metrics.latency_per_output_token,
|
280 |
+
parent_otel_span=None, # no attached request, this is a background operation
|
281 |
+
)
|
282 |
+
|
283 |
+
return_val += 1
|
284 |
+
|
285 |
+
return return_val
|
286 |
+
except Exception:
|
287 |
+
return 0
|
288 |
+
|
289 |
+
async def send_daily_reports(self, router) -> bool: # noqa: PLR0915
|
290 |
+
"""
|
291 |
+
Send a daily report on:
|
292 |
+
- Top 5 deployments with most failed requests
|
293 |
+
- Top 5 slowest deployments (normalized by latency/output tokens)
|
294 |
+
|
295 |
+
Get the value from redis cache (if available) or in-memory and send it
|
296 |
+
|
297 |
+
Cleanup:
|
298 |
+
- reset values in cache -> prevent memory leak
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
True -> if successfuly sent
|
302 |
+
False -> if not sent
|
303 |
+
"""
|
304 |
+
|
305 |
+
ids = router.get_model_ids()
|
306 |
+
|
307 |
+
# get keys
|
308 |
+
failed_request_keys = [
|
309 |
+
"{}:{}".format(id, SlackAlertingCacheKeys.failed_requests_key.value)
|
310 |
+
for id in ids
|
311 |
+
]
|
312 |
+
latency_keys = [
|
313 |
+
"{}:{}".format(id, SlackAlertingCacheKeys.latency_key.value) for id in ids
|
314 |
+
]
|
315 |
+
|
316 |
+
combined_metrics_keys = failed_request_keys + latency_keys # reduce cache calls
|
317 |
+
|
318 |
+
combined_metrics_values = await self.internal_usage_cache.async_batch_get_cache(
|
319 |
+
keys=combined_metrics_keys
|
320 |
+
) # [1, 2, None, ..]
|
321 |
+
|
322 |
+
if combined_metrics_values is None:
|
323 |
+
return False
|
324 |
+
|
325 |
+
all_none = True
|
326 |
+
for val in combined_metrics_values:
|
327 |
+
if val is not None and val > 0:
|
328 |
+
all_none = False
|
329 |
+
break
|
330 |
+
|
331 |
+
if all_none:
|
332 |
+
return False
|
333 |
+
|
334 |
+
failed_request_values = combined_metrics_values[
|
335 |
+
: len(failed_request_keys)
|
336 |
+
] # # [1, 2, None, ..]
|
337 |
+
latency_values = combined_metrics_values[len(failed_request_keys) :]
|
338 |
+
|
339 |
+
# find top 5 failed
|
340 |
+
## Replace None values with a placeholder value (-1 in this case)
|
341 |
+
placeholder_value = 0
|
342 |
+
replaced_failed_values = [
|
343 |
+
value if value is not None else placeholder_value
|
344 |
+
for value in failed_request_values
|
345 |
+
]
|
346 |
+
|
347 |
+
## Get the indices of top 5 keys with the highest numerical values (ignoring None and 0 values)
|
348 |
+
top_5_failed = sorted(
|
349 |
+
range(len(replaced_failed_values)),
|
350 |
+
key=lambda i: replaced_failed_values[i],
|
351 |
+
reverse=True,
|
352 |
+
)[:5]
|
353 |
+
top_5_failed = [
|
354 |
+
index for index in top_5_failed if replaced_failed_values[index] > 0
|
355 |
+
]
|
356 |
+
|
357 |
+
# find top 5 slowest
|
358 |
+
# Replace None values with a placeholder value (-1 in this case)
|
359 |
+
placeholder_value = 0
|
360 |
+
replaced_slowest_values = [
|
361 |
+
value if value is not None else placeholder_value
|
362 |
+
for value in latency_values
|
363 |
+
]
|
364 |
+
|
365 |
+
# Get the indices of top 5 values with the highest numerical values (ignoring None and 0 values)
|
366 |
+
top_5_slowest = sorted(
|
367 |
+
range(len(replaced_slowest_values)),
|
368 |
+
key=lambda i: replaced_slowest_values[i],
|
369 |
+
reverse=True,
|
370 |
+
)[:5]
|
371 |
+
top_5_slowest = [
|
372 |
+
index for index in top_5_slowest if replaced_slowest_values[index] > 0
|
373 |
+
]
|
374 |
+
|
375 |
+
# format alert -> return the litellm model name + api base
|
376 |
+
message = f"\n\nTime: `{time.time()}`s\nHere are today's key metrics 📈: \n\n"
|
377 |
+
|
378 |
+
message += "\n\n*❗️ Top Deployments with Most Failed Requests:*\n\n"
|
379 |
+
if not top_5_failed:
|
380 |
+
message += "\tNone\n"
|
381 |
+
for i in range(len(top_5_failed)):
|
382 |
+
key = failed_request_keys[top_5_failed[i]].split(":")[0]
|
383 |
+
_deployment = router.get_model_info(key)
|
384 |
+
if isinstance(_deployment, dict):
|
385 |
+
deployment_name = _deployment["litellm_params"].get("model", "")
|
386 |
+
else:
|
387 |
+
return False
|
388 |
+
|
389 |
+
api_base = litellm.get_api_base(
|
390 |
+
model=deployment_name,
|
391 |
+
optional_params=(
|
392 |
+
_deployment["litellm_params"] if _deployment is not None else {}
|
393 |
+
),
|
394 |
+
)
|
395 |
+
if api_base is None:
|
396 |
+
api_base = ""
|
397 |
+
value = replaced_failed_values[top_5_failed[i]]
|
398 |
+
message += f"\t{i+1}. Deployment: `{deployment_name}`, Failed Requests: `{value}`, API Base: `{api_base}`\n"
|
399 |
+
|
400 |
+
message += "\n\n*😅 Top Slowest Deployments:*\n\n"
|
401 |
+
if not top_5_slowest:
|
402 |
+
message += "\tNone\n"
|
403 |
+
for i in range(len(top_5_slowest)):
|
404 |
+
key = latency_keys[top_5_slowest[i]].split(":")[0]
|
405 |
+
_deployment = router.get_model_info(key)
|
406 |
+
if _deployment is not None:
|
407 |
+
deployment_name = _deployment["litellm_params"].get("model", "")
|
408 |
+
else:
|
409 |
+
deployment_name = ""
|
410 |
+
api_base = litellm.get_api_base(
|
411 |
+
model=deployment_name,
|
412 |
+
optional_params=(
|
413 |
+
_deployment["litellm_params"] if _deployment is not None else {}
|
414 |
+
),
|
415 |
+
)
|
416 |
+
value = round(replaced_slowest_values[top_5_slowest[i]], 3)
|
417 |
+
message += f"\t{i+1}. Deployment: `{deployment_name}`, Latency per output token: `{value}s/token`, API Base: `{api_base}`\n\n"
|
418 |
+
|
419 |
+
# cache cleanup -> reset values to 0
|
420 |
+
latency_cache_keys = [(key, 0) for key in latency_keys]
|
421 |
+
failed_request_cache_keys = [(key, 0) for key in failed_request_keys]
|
422 |
+
combined_metrics_cache_keys = latency_cache_keys + failed_request_cache_keys
|
423 |
+
await self.internal_usage_cache.async_set_cache_pipeline(
|
424 |
+
cache_list=combined_metrics_cache_keys
|
425 |
+
)
|
426 |
+
|
427 |
+
message += f"\n\nNext Run is at: `{time.time() + self.alerting_args.daily_report_frequency}`s"
|
428 |
+
|
429 |
+
# send alert
|
430 |
+
await self.send_alert(
|
431 |
+
message=message,
|
432 |
+
level="Low",
|
433 |
+
alert_type=AlertType.daily_reports,
|
434 |
+
alerting_metadata={},
|
435 |
+
)
|
436 |
+
|
437 |
+
return True
|
438 |
+
|
439 |
+
async def response_taking_too_long(
|
440 |
+
self,
|
441 |
+
start_time: Optional[datetime.datetime] = None,
|
442 |
+
end_time: Optional[datetime.datetime] = None,
|
443 |
+
type: Literal["hanging_request", "slow_response"] = "hanging_request",
|
444 |
+
request_data: Optional[dict] = None,
|
445 |
+
):
|
446 |
+
if self.alerting is None or self.alert_types is None:
|
447 |
+
return
|
448 |
+
model: str = ""
|
449 |
+
if request_data is not None:
|
450 |
+
model = request_data.get("model", "")
|
451 |
+
messages = request_data.get("messages", None)
|
452 |
+
if messages is None:
|
453 |
+
# if messages does not exist fallback to "input"
|
454 |
+
messages = request_data.get("input", None)
|
455 |
+
|
456 |
+
# try casting messages to str and get the first 100 characters, else mark as None
|
457 |
+
try:
|
458 |
+
messages = str(messages)
|
459 |
+
messages = messages[:100]
|
460 |
+
except Exception:
|
461 |
+
messages = ""
|
462 |
+
|
463 |
+
if (
|
464 |
+
litellm.turn_off_message_logging
|
465 |
+
or litellm.redact_messages_in_exceptions
|
466 |
+
):
|
467 |
+
messages = (
|
468 |
+
"Message not logged. litellm.redact_messages_in_exceptions=True"
|
469 |
+
)
|
470 |
+
request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`"
|
471 |
+
else:
|
472 |
+
request_info = ""
|
473 |
+
|
474 |
+
if type == "hanging_request":
|
475 |
+
await asyncio.sleep(
|
476 |
+
self.alerting_threshold
|
477 |
+
) # Set it to 5 minutes - i'd imagine this might be different for streaming, non-streaming, non-completion (embedding + img) requests
|
478 |
+
alerting_metadata: dict = {}
|
479 |
+
if await self._request_is_completed(request_data=request_data) is True:
|
480 |
+
return
|
481 |
+
|
482 |
+
if request_data is not None:
|
483 |
+
if request_data.get("deployment", None) is not None and isinstance(
|
484 |
+
request_data["deployment"], dict
|
485 |
+
):
|
486 |
+
_api_base = litellm.get_api_base(
|
487 |
+
model=model,
|
488 |
+
optional_params=request_data["deployment"].get(
|
489 |
+
"litellm_params", {}
|
490 |
+
),
|
491 |
+
)
|
492 |
+
|
493 |
+
if _api_base is None:
|
494 |
+
_api_base = ""
|
495 |
+
|
496 |
+
request_info += f"\nAPI Base: {_api_base}"
|
497 |
+
elif request_data.get("metadata", None) is not None and isinstance(
|
498 |
+
request_data["metadata"], dict
|
499 |
+
):
|
500 |
+
# In hanging requests sometime it has not made it to the point where the deployment is passed to the `request_data``
|
501 |
+
# in that case we fallback to the api base set in the request metadata
|
502 |
+
_metadata: dict = request_data["metadata"]
|
503 |
+
_api_base = _metadata.get("api_base", "")
|
504 |
+
|
505 |
+
request_info = _add_key_name_and_team_to_alert(
|
506 |
+
request_info=request_info, metadata=_metadata
|
507 |
+
)
|
508 |
+
|
509 |
+
if _api_base is None:
|
510 |
+
_api_base = ""
|
511 |
+
|
512 |
+
if "alerting_metadata" in _metadata:
|
513 |
+
alerting_metadata = _metadata["alerting_metadata"]
|
514 |
+
request_info += f"\nAPI Base: `{_api_base}`"
|
515 |
+
# only alert hanging responses if they have not been marked as success
|
516 |
+
alerting_message = (
|
517 |
+
f"`Requests are hanging - {self.alerting_threshold}s+ request time`"
|
518 |
+
)
|
519 |
+
|
520 |
+
if "langfuse" in litellm.success_callback:
|
521 |
+
langfuse_url = await _add_langfuse_trace_id_to_alert(
|
522 |
+
request_data=request_data,
|
523 |
+
)
|
524 |
+
|
525 |
+
if langfuse_url is not None:
|
526 |
+
request_info += "\n🪢 Langfuse Trace: {}".format(langfuse_url)
|
527 |
+
|
528 |
+
# add deployment latencies to alert
|
529 |
+
_deployment_latency_map = self._get_deployment_latencies_to_alert(
|
530 |
+
metadata=request_data.get("metadata", {})
|
531 |
+
)
|
532 |
+
if _deployment_latency_map is not None:
|
533 |
+
request_info += f"\nDeployment Latencies\n{_deployment_latency_map}"
|
534 |
+
|
535 |
+
await self.send_alert(
|
536 |
+
message=alerting_message + request_info,
|
537 |
+
level="Medium",
|
538 |
+
alert_type=AlertType.llm_requests_hanging,
|
539 |
+
alerting_metadata=alerting_metadata,
|
540 |
+
)
|
541 |
+
|
542 |
+
async def failed_tracking_alert(self, error_message: str, failing_model: str):
|
543 |
+
"""
|
544 |
+
Raise alert when tracking failed for specific model
|
545 |
+
|
546 |
+
Args:
|
547 |
+
error_message (str): Error message
|
548 |
+
failing_model (str): Model that failed tracking
|
549 |
+
"""
|
550 |
+
if self.alerting is None or self.alert_types is None:
|
551 |
+
# do nothing if alerting is not switched on
|
552 |
+
return
|
553 |
+
if "failed_tracking_spend" not in self.alert_types:
|
554 |
+
return
|
555 |
+
|
556 |
+
_cache: DualCache = self.internal_usage_cache
|
557 |
+
message = "Failed Tracking Cost for " + error_message
|
558 |
+
_cache_key = "budget_alerts:failed_tracking:{}".format(failing_model)
|
559 |
+
result = await _cache.async_get_cache(key=_cache_key)
|
560 |
+
if result is None:
|
561 |
+
await self.send_alert(
|
562 |
+
message=message,
|
563 |
+
level="High",
|
564 |
+
alert_type=AlertType.failed_tracking_spend,
|
565 |
+
alerting_metadata={},
|
566 |
+
)
|
567 |
+
await _cache.async_set_cache(
|
568 |
+
key=_cache_key,
|
569 |
+
value="SENT",
|
570 |
+
ttl=self.alerting_args.budget_alert_ttl,
|
571 |
+
)
|
572 |
+
|
573 |
+
async def budget_alerts( # noqa: PLR0915
|
574 |
+
self,
|
575 |
+
type: Literal[
|
576 |
+
"token_budget",
|
577 |
+
"soft_budget",
|
578 |
+
"user_budget",
|
579 |
+
"team_budget",
|
580 |
+
"proxy_budget",
|
581 |
+
"projected_limit_exceeded",
|
582 |
+
],
|
583 |
+
user_info: CallInfo,
|
584 |
+
):
|
585 |
+
## PREVENTITIVE ALERTING ## - https://github.com/BerriAI/litellm/issues/2727
|
586 |
+
# - Alert once within 24hr period
|
587 |
+
# - Cache this information
|
588 |
+
# - Don't re-alert, if alert already sent
|
589 |
+
_cache: DualCache = self.internal_usage_cache
|
590 |
+
|
591 |
+
if self.alerting is None or self.alert_types is None:
|
592 |
+
# do nothing if alerting is not switched on
|
593 |
+
return
|
594 |
+
if "budget_alerts" not in self.alert_types:
|
595 |
+
return
|
596 |
+
_id: Optional[str] = "default_id" # used for caching
|
597 |
+
user_info_json = user_info.model_dump(exclude_none=True)
|
598 |
+
user_info_str = self._get_user_info_str(user_info)
|
599 |
+
event: Optional[
|
600 |
+
Literal[
|
601 |
+
"budget_crossed",
|
602 |
+
"threshold_crossed",
|
603 |
+
"projected_limit_exceeded",
|
604 |
+
"soft_budget_crossed",
|
605 |
+
]
|
606 |
+
] = None
|
607 |
+
event_group: Optional[
|
608 |
+
Literal["internal_user", "team", "key", "proxy", "customer"]
|
609 |
+
] = None
|
610 |
+
event_message: str = ""
|
611 |
+
webhook_event: Optional[WebhookEvent] = None
|
612 |
+
if type == "proxy_budget":
|
613 |
+
event_group = "proxy"
|
614 |
+
event_message += "Proxy Budget: "
|
615 |
+
elif type == "soft_budget":
|
616 |
+
event_group = "proxy"
|
617 |
+
event_message += "Soft Budget Crossed: "
|
618 |
+
elif type == "user_budget":
|
619 |
+
event_group = "internal_user"
|
620 |
+
event_message += "User Budget: "
|
621 |
+
_id = user_info.user_id or _id
|
622 |
+
elif type == "team_budget":
|
623 |
+
event_group = "team"
|
624 |
+
event_message += "Team Budget: "
|
625 |
+
_id = user_info.team_id or _id
|
626 |
+
elif type == "token_budget":
|
627 |
+
event_group = "key"
|
628 |
+
event_message += "Key Budget: "
|
629 |
+
_id = user_info.token
|
630 |
+
elif type == "projected_limit_exceeded":
|
631 |
+
event_group = "key"
|
632 |
+
event_message += "Key Budget: Projected Limit Exceeded"
|
633 |
+
event = "projected_limit_exceeded"
|
634 |
+
_id = user_info.token
|
635 |
+
|
636 |
+
# percent of max_budget left to spend
|
637 |
+
if user_info.max_budget is None and user_info.soft_budget is None:
|
638 |
+
return
|
639 |
+
percent_left: float = 0
|
640 |
+
if user_info.max_budget is not None:
|
641 |
+
if user_info.max_budget > 0:
|
642 |
+
percent_left = (
|
643 |
+
user_info.max_budget - user_info.spend
|
644 |
+
) / user_info.max_budget
|
645 |
+
|
646 |
+
# check if crossed budget
|
647 |
+
if user_info.max_budget is not None:
|
648 |
+
if user_info.spend >= user_info.max_budget:
|
649 |
+
event = "budget_crossed"
|
650 |
+
event_message += (
|
651 |
+
f"Budget Crossed\n Total Budget:`{user_info.max_budget}`"
|
652 |
+
)
|
653 |
+
elif percent_left <= SLACK_ALERTING_THRESHOLD_5_PERCENT:
|
654 |
+
event = "threshold_crossed"
|
655 |
+
event_message += "5% Threshold Crossed "
|
656 |
+
elif percent_left <= SLACK_ALERTING_THRESHOLD_15_PERCENT:
|
657 |
+
event = "threshold_crossed"
|
658 |
+
event_message += "15% Threshold Crossed"
|
659 |
+
elif user_info.soft_budget is not None:
|
660 |
+
if user_info.spend >= user_info.soft_budget:
|
661 |
+
event = "soft_budget_crossed"
|
662 |
+
if event is not None and event_group is not None:
|
663 |
+
_cache_key = "budget_alerts:{}:{}".format(event, _id)
|
664 |
+
result = await _cache.async_get_cache(key=_cache_key)
|
665 |
+
if result is None:
|
666 |
+
webhook_event = WebhookEvent(
|
667 |
+
event=event,
|
668 |
+
event_group=event_group,
|
669 |
+
event_message=event_message,
|
670 |
+
**user_info_json,
|
671 |
+
)
|
672 |
+
await self.send_alert(
|
673 |
+
message=event_message + "\n\n" + user_info_str,
|
674 |
+
level="High",
|
675 |
+
alert_type=AlertType.budget_alerts,
|
676 |
+
user_info=webhook_event,
|
677 |
+
alerting_metadata={},
|
678 |
+
)
|
679 |
+
await _cache.async_set_cache(
|
680 |
+
key=_cache_key,
|
681 |
+
value="SENT",
|
682 |
+
ttl=self.alerting_args.budget_alert_ttl,
|
683 |
+
)
|
684 |
+
|
685 |
+
return
|
686 |
+
return
|
687 |
+
|
688 |
+
def _get_user_info_str(self, user_info: CallInfo) -> str:
|
689 |
+
"""
|
690 |
+
Create a standard message for a budget alert
|
691 |
+
"""
|
692 |
+
_all_fields_as_dict = user_info.model_dump(exclude_none=True)
|
693 |
+
_all_fields_as_dict.pop("token")
|
694 |
+
msg = ""
|
695 |
+
for k, v in _all_fields_as_dict.items():
|
696 |
+
msg += f"*{k}:* `{v}`\n"
|
697 |
+
|
698 |
+
return msg
|
699 |
+
|
700 |
+
async def customer_spend_alert(
|
701 |
+
self,
|
702 |
+
token: Optional[str],
|
703 |
+
key_alias: Optional[str],
|
704 |
+
end_user_id: Optional[str],
|
705 |
+
response_cost: Optional[float],
|
706 |
+
max_budget: Optional[float],
|
707 |
+
):
|
708 |
+
if (
|
709 |
+
self.alerting is not None
|
710 |
+
and "webhook" in self.alerting
|
711 |
+
and end_user_id is not None
|
712 |
+
and token is not None
|
713 |
+
and response_cost is not None
|
714 |
+
):
|
715 |
+
# log customer spend
|
716 |
+
event = WebhookEvent(
|
717 |
+
spend=response_cost,
|
718 |
+
max_budget=max_budget,
|
719 |
+
token=token,
|
720 |
+
customer_id=end_user_id,
|
721 |
+
user_id=None,
|
722 |
+
team_id=None,
|
723 |
+
user_email=None,
|
724 |
+
key_alias=key_alias,
|
725 |
+
projected_exceeded_date=None,
|
726 |
+
projected_spend=None,
|
727 |
+
event="spend_tracked",
|
728 |
+
event_group="customer",
|
729 |
+
event_message="Customer spend tracked. Customer={}, spend={}".format(
|
730 |
+
end_user_id, response_cost
|
731 |
+
),
|
732 |
+
)
|
733 |
+
|
734 |
+
await self.send_webhook_alert(webhook_event=event)
|
735 |
+
|
736 |
+
def _count_outage_alerts(self, alerts: List[int]) -> str:
|
737 |
+
"""
|
738 |
+
Parameters:
|
739 |
+
- alerts: List[int] -> list of error codes (either 408 or 500+)
|
740 |
+
|
741 |
+
Returns:
|
742 |
+
- str -> formatted string. This is an alert message, giving a human-friendly description of the errors.
|
743 |
+
"""
|
744 |
+
error_breakdown = {"Timeout Errors": 0, "API Errors": 0, "Unknown Errors": 0}
|
745 |
+
for alert in alerts:
|
746 |
+
if alert == 408:
|
747 |
+
error_breakdown["Timeout Errors"] += 1
|
748 |
+
elif alert >= 500:
|
749 |
+
error_breakdown["API Errors"] += 1
|
750 |
+
else:
|
751 |
+
error_breakdown["Unknown Errors"] += 1
|
752 |
+
|
753 |
+
error_msg = ""
|
754 |
+
for key, value in error_breakdown.items():
|
755 |
+
if value > 0:
|
756 |
+
error_msg += "\n{}: {}\n".format(key, value)
|
757 |
+
|
758 |
+
return error_msg
|
759 |
+
|
760 |
+
def _outage_alert_msg_factory(
|
761 |
+
self,
|
762 |
+
alert_type: Literal["Major", "Minor"],
|
763 |
+
key: Literal["Model", "Region"],
|
764 |
+
key_val: str,
|
765 |
+
provider: str,
|
766 |
+
api_base: Optional[str],
|
767 |
+
outage_value: BaseOutageModel,
|
768 |
+
) -> str:
|
769 |
+
"""Format an alert message for slack"""
|
770 |
+
headers = {f"{key} Name": key_val, "Provider": provider}
|
771 |
+
if api_base is not None:
|
772 |
+
headers["API Base"] = api_base # type: ignore
|
773 |
+
|
774 |
+
headers_str = "\n"
|
775 |
+
for k, v in headers.items():
|
776 |
+
headers_str += f"*{k}:* `{v}`\n"
|
777 |
+
return f"""\n\n
|
778 |
+
*⚠️ {alert_type} Service Outage*
|
779 |
+
|
780 |
+
{headers_str}
|
781 |
+
|
782 |
+
*Errors:*
|
783 |
+
{self._count_outage_alerts(alerts=outage_value["alerts"])}
|
784 |
+
|
785 |
+
*Last Check:* `{round(time.time() - outage_value["last_updated_at"], 4)}s ago`\n\n
|
786 |
+
"""
|
787 |
+
|
788 |
+
async def region_outage_alerts(
|
789 |
+
self,
|
790 |
+
exception: APIError,
|
791 |
+
deployment_id: str,
|
792 |
+
) -> None:
|
793 |
+
"""
|
794 |
+
Send slack alert if specific provider region is having an outage.
|
795 |
+
|
796 |
+
Track for 408 (Timeout) and >=500 Error codes
|
797 |
+
"""
|
798 |
+
## CREATE (PROVIDER+REGION) ID ##
|
799 |
+
if self.llm_router is None:
|
800 |
+
return
|
801 |
+
|
802 |
+
deployment = self.llm_router.get_deployment(model_id=deployment_id)
|
803 |
+
|
804 |
+
if deployment is None:
|
805 |
+
return
|
806 |
+
|
807 |
+
model = deployment.litellm_params.model
|
808 |
+
### GET PROVIDER ###
|
809 |
+
provider = deployment.litellm_params.custom_llm_provider
|
810 |
+
if provider is None:
|
811 |
+
model, provider, _, _ = litellm.get_llm_provider(model=model)
|
812 |
+
|
813 |
+
### GET REGION ###
|
814 |
+
region_name = deployment.litellm_params.region_name
|
815 |
+
if region_name is None:
|
816 |
+
region_name = litellm.utils._get_model_region(
|
817 |
+
custom_llm_provider=provider, litellm_params=deployment.litellm_params
|
818 |
+
)
|
819 |
+
|
820 |
+
if region_name is None:
|
821 |
+
return
|
822 |
+
|
823 |
+
### UNIQUE CACHE KEY ###
|
824 |
+
cache_key = provider + region_name
|
825 |
+
|
826 |
+
outage_value: Optional[
|
827 |
+
ProviderRegionOutageModel
|
828 |
+
] = await self.internal_usage_cache.async_get_cache(key=cache_key)
|
829 |
+
|
830 |
+
if (
|
831 |
+
getattr(exception, "status_code", None) is None
|
832 |
+
or (
|
833 |
+
exception.status_code != 408 # type: ignore
|
834 |
+
and exception.status_code < 500 # type: ignore
|
835 |
+
)
|
836 |
+
or self.llm_router is None
|
837 |
+
):
|
838 |
+
return
|
839 |
+
|
840 |
+
if outage_value is None:
|
841 |
+
_deployment_set = set()
|
842 |
+
_deployment_set.add(deployment_id)
|
843 |
+
outage_value = ProviderRegionOutageModel(
|
844 |
+
provider_region_id=cache_key,
|
845 |
+
alerts=[exception.status_code], # type: ignore
|
846 |
+
minor_alert_sent=False,
|
847 |
+
major_alert_sent=False,
|
848 |
+
last_updated_at=time.time(),
|
849 |
+
deployment_ids=_deployment_set,
|
850 |
+
)
|
851 |
+
|
852 |
+
## add to cache ##
|
853 |
+
await self.internal_usage_cache.async_set_cache(
|
854 |
+
key=cache_key,
|
855 |
+
value=outage_value,
|
856 |
+
ttl=self.alerting_args.region_outage_alert_ttl,
|
857 |
+
)
|
858 |
+
return
|
859 |
+
|
860 |
+
if len(outage_value["alerts"]) < self.alerting_args.max_outage_alert_list_size:
|
861 |
+
outage_value["alerts"].append(exception.status_code) # type: ignore
|
862 |
+
else: # prevent memory leaks
|
863 |
+
pass
|
864 |
+
_deployment_set = outage_value["deployment_ids"]
|
865 |
+
_deployment_set.add(deployment_id)
|
866 |
+
outage_value["deployment_ids"] = _deployment_set
|
867 |
+
outage_value["last_updated_at"] = time.time()
|
868 |
+
|
869 |
+
## MINOR OUTAGE ALERT SENT ##
|
870 |
+
if (
|
871 |
+
outage_value["minor_alert_sent"] is False
|
872 |
+
and len(outage_value["alerts"])
|
873 |
+
>= self.alerting_args.minor_outage_alert_threshold
|
874 |
+
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
875 |
+
):
|
876 |
+
msg = self._outage_alert_msg_factory(
|
877 |
+
alert_type="Minor",
|
878 |
+
key="Region",
|
879 |
+
key_val=region_name,
|
880 |
+
api_base=None,
|
881 |
+
outage_value=outage_value,
|
882 |
+
provider=provider,
|
883 |
+
)
|
884 |
+
# send minor alert
|
885 |
+
await self.send_alert(
|
886 |
+
message=msg,
|
887 |
+
level="Medium",
|
888 |
+
alert_type=AlertType.outage_alerts,
|
889 |
+
alerting_metadata={},
|
890 |
+
)
|
891 |
+
# set to true
|
892 |
+
outage_value["minor_alert_sent"] = True
|
893 |
+
|
894 |
+
## MAJOR OUTAGE ALERT SENT ##
|
895 |
+
elif (
|
896 |
+
outage_value["major_alert_sent"] is False
|
897 |
+
and len(outage_value["alerts"])
|
898 |
+
>= self.alerting_args.major_outage_alert_threshold
|
899 |
+
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
900 |
+
):
|
901 |
+
msg = self._outage_alert_msg_factory(
|
902 |
+
alert_type="Major",
|
903 |
+
key="Region",
|
904 |
+
key_val=region_name,
|
905 |
+
api_base=None,
|
906 |
+
outage_value=outage_value,
|
907 |
+
provider=provider,
|
908 |
+
)
|
909 |
+
|
910 |
+
# send minor alert
|
911 |
+
await self.send_alert(
|
912 |
+
message=msg,
|
913 |
+
level="High",
|
914 |
+
alert_type=AlertType.outage_alerts,
|
915 |
+
alerting_metadata={},
|
916 |
+
)
|
917 |
+
# set to true
|
918 |
+
outage_value["major_alert_sent"] = True
|
919 |
+
|
920 |
+
## update cache ##
|
921 |
+
await self.internal_usage_cache.async_set_cache(
|
922 |
+
key=cache_key, value=outage_value
|
923 |
+
)
|
924 |
+
|
925 |
+
async def outage_alerts(
|
926 |
+
self,
|
927 |
+
exception: APIError,
|
928 |
+
deployment_id: str,
|
929 |
+
) -> None:
|
930 |
+
"""
|
931 |
+
Send slack alert if model is badly configured / having an outage (408, 401, 429, >=500).
|
932 |
+
|
933 |
+
key = model_id
|
934 |
+
|
935 |
+
value = {
|
936 |
+
- model_id
|
937 |
+
- threshold
|
938 |
+
- alerts []
|
939 |
+
}
|
940 |
+
|
941 |
+
ttl = 1hr
|
942 |
+
max_alerts_size = 10
|
943 |
+
"""
|
944 |
+
try:
|
945 |
+
outage_value: Optional[OutageModel] = await self.internal_usage_cache.async_get_cache(key=deployment_id) # type: ignore
|
946 |
+
if (
|
947 |
+
getattr(exception, "status_code", None) is None
|
948 |
+
or (
|
949 |
+
exception.status_code != 408 # type: ignore
|
950 |
+
and exception.status_code < 500 # type: ignore
|
951 |
+
)
|
952 |
+
or self.llm_router is None
|
953 |
+
):
|
954 |
+
return
|
955 |
+
|
956 |
+
### EXTRACT MODEL DETAILS ###
|
957 |
+
deployment = self.llm_router.get_deployment(model_id=deployment_id)
|
958 |
+
if deployment is None:
|
959 |
+
return
|
960 |
+
|
961 |
+
model = deployment.litellm_params.model
|
962 |
+
provider = deployment.litellm_params.custom_llm_provider
|
963 |
+
if provider is None:
|
964 |
+
try:
|
965 |
+
model, provider, _, _ = litellm.get_llm_provider(model=model)
|
966 |
+
except Exception:
|
967 |
+
provider = ""
|
968 |
+
api_base = litellm.get_api_base(
|
969 |
+
model=model, optional_params=deployment.litellm_params
|
970 |
+
)
|
971 |
+
|
972 |
+
if outage_value is None:
|
973 |
+
outage_value = OutageModel(
|
974 |
+
model_id=deployment_id,
|
975 |
+
alerts=[exception.status_code], # type: ignore
|
976 |
+
minor_alert_sent=False,
|
977 |
+
major_alert_sent=False,
|
978 |
+
last_updated_at=time.time(),
|
979 |
+
)
|
980 |
+
|
981 |
+
## add to cache ##
|
982 |
+
await self.internal_usage_cache.async_set_cache(
|
983 |
+
key=deployment_id,
|
984 |
+
value=outage_value,
|
985 |
+
ttl=self.alerting_args.outage_alert_ttl,
|
986 |
+
)
|
987 |
+
return
|
988 |
+
|
989 |
+
if (
|
990 |
+
len(outage_value["alerts"])
|
991 |
+
< self.alerting_args.max_outage_alert_list_size
|
992 |
+
):
|
993 |
+
outage_value["alerts"].append(exception.status_code) # type: ignore
|
994 |
+
else: # prevent memory leaks
|
995 |
+
pass
|
996 |
+
|
997 |
+
outage_value["last_updated_at"] = time.time()
|
998 |
+
|
999 |
+
## MINOR OUTAGE ALERT SENT ##
|
1000 |
+
if (
|
1001 |
+
outage_value["minor_alert_sent"] is False
|
1002 |
+
and len(outage_value["alerts"])
|
1003 |
+
>= self.alerting_args.minor_outage_alert_threshold
|
1004 |
+
):
|
1005 |
+
msg = self._outage_alert_msg_factory(
|
1006 |
+
alert_type="Minor",
|
1007 |
+
key="Model",
|
1008 |
+
key_val=model,
|
1009 |
+
api_base=api_base,
|
1010 |
+
outage_value=outage_value,
|
1011 |
+
provider=provider,
|
1012 |
+
)
|
1013 |
+
# send minor alert
|
1014 |
+
await self.send_alert(
|
1015 |
+
message=msg,
|
1016 |
+
level="Medium",
|
1017 |
+
alert_type=AlertType.outage_alerts,
|
1018 |
+
alerting_metadata={},
|
1019 |
+
)
|
1020 |
+
# set to true
|
1021 |
+
outage_value["minor_alert_sent"] = True
|
1022 |
+
elif (
|
1023 |
+
outage_value["major_alert_sent"] is False
|
1024 |
+
and len(outage_value["alerts"])
|
1025 |
+
>= self.alerting_args.major_outage_alert_threshold
|
1026 |
+
):
|
1027 |
+
msg = self._outage_alert_msg_factory(
|
1028 |
+
alert_type="Major",
|
1029 |
+
key="Model",
|
1030 |
+
key_val=model,
|
1031 |
+
api_base=api_base,
|
1032 |
+
outage_value=outage_value,
|
1033 |
+
provider=provider,
|
1034 |
+
)
|
1035 |
+
# send minor alert
|
1036 |
+
await self.send_alert(
|
1037 |
+
message=msg,
|
1038 |
+
level="High",
|
1039 |
+
alert_type=AlertType.outage_alerts,
|
1040 |
+
alerting_metadata={},
|
1041 |
+
)
|
1042 |
+
# set to true
|
1043 |
+
outage_value["major_alert_sent"] = True
|
1044 |
+
|
1045 |
+
## update cache ##
|
1046 |
+
await self.internal_usage_cache.async_set_cache(
|
1047 |
+
key=deployment_id, value=outage_value
|
1048 |
+
)
|
1049 |
+
except Exception:
|
1050 |
+
pass
|
1051 |
+
|
1052 |
+
async def model_added_alert(
|
1053 |
+
self, model_name: str, litellm_model_name: str, passed_model_info: Any
|
1054 |
+
):
|
1055 |
+
base_model_from_user = getattr(passed_model_info, "base_model", None)
|
1056 |
+
model_info = {}
|
1057 |
+
base_model = ""
|
1058 |
+
if base_model_from_user is not None:
|
1059 |
+
model_info = litellm.model_cost.get(base_model_from_user, {})
|
1060 |
+
base_model = f"Base Model: `{base_model_from_user}`\n"
|
1061 |
+
else:
|
1062 |
+
model_info = litellm.model_cost.get(litellm_model_name, {})
|
1063 |
+
model_info_str = ""
|
1064 |
+
for k, v in model_info.items():
|
1065 |
+
if k == "input_cost_per_token" or k == "output_cost_per_token":
|
1066 |
+
# when converting to string it should not be 1.63e-06
|
1067 |
+
v = "{:.8f}".format(v)
|
1068 |
+
|
1069 |
+
model_info_str += f"{k}: {v}\n"
|
1070 |
+
|
1071 |
+
message = f"""
|
1072 |
+
*🚅 New Model Added*
|
1073 |
+
Model Name: `{model_name}`
|
1074 |
+
{base_model}
|
1075 |
+
|
1076 |
+
Usage OpenAI Python SDK:
|
1077 |
+
```
|
1078 |
+
import openai
|
1079 |
+
client = openai.OpenAI(
|
1080 |
+
api_key="your_api_key",
|
1081 |
+
base_url={os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")}
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
response = client.chat.completions.create(
|
1085 |
+
model="{model_name}", # model to send to the proxy
|
1086 |
+
messages = [
|
1087 |
+
{{
|
1088 |
+
"role": "user",
|
1089 |
+
"content": "this is a test request, write a short poem"
|
1090 |
+
}}
|
1091 |
+
]
|
1092 |
+
)
|
1093 |
+
```
|
1094 |
+
|
1095 |
+
Model Info:
|
1096 |
+
```
|
1097 |
+
{model_info_str}
|
1098 |
+
```
|
1099 |
+
"""
|
1100 |
+
|
1101 |
+
alert_val = self.send_alert(
|
1102 |
+
message=message,
|
1103 |
+
level="Low",
|
1104 |
+
alert_type=AlertType.new_model_added,
|
1105 |
+
alerting_metadata={},
|
1106 |
+
)
|
1107 |
+
|
1108 |
+
if alert_val is not None and asyncio.iscoroutine(alert_val):
|
1109 |
+
await alert_val
|
1110 |
+
|
1111 |
+
async def model_removed_alert(self, model_name: str):
|
1112 |
+
pass
|
1113 |
+
|
1114 |
+
async def send_webhook_alert(self, webhook_event: WebhookEvent) -> bool:
|
1115 |
+
"""
|
1116 |
+
Sends structured alert to webhook, if set.
|
1117 |
+
|
1118 |
+
Currently only implemented for budget alerts
|
1119 |
+
|
1120 |
+
Returns -> True if sent, False if not.
|
1121 |
+
|
1122 |
+
Raises Exception
|
1123 |
+
- if WEBHOOK_URL is not set
|
1124 |
+
"""
|
1125 |
+
|
1126 |
+
webhook_url = os.getenv("WEBHOOK_URL", None)
|
1127 |
+
if webhook_url is None:
|
1128 |
+
raise Exception("Missing webhook_url from environment")
|
1129 |
+
|
1130 |
+
payload = webhook_event.model_dump_json()
|
1131 |
+
headers = {"Content-type": "application/json"}
|
1132 |
+
|
1133 |
+
response = await self.async_http_handler.post(
|
1134 |
+
url=webhook_url,
|
1135 |
+
headers=headers,
|
1136 |
+
data=payload,
|
1137 |
+
)
|
1138 |
+
if response.status_code == 200:
|
1139 |
+
return True
|
1140 |
+
else:
|
1141 |
+
print("Error sending webhook alert. Error=", response.text) # noqa
|
1142 |
+
|
1143 |
+
return False
|
1144 |
+
|
1145 |
+
async def _check_if_using_premium_email_feature(
|
1146 |
+
self,
|
1147 |
+
premium_user: bool,
|
1148 |
+
email_logo_url: Optional[str] = None,
|
1149 |
+
email_support_contact: Optional[str] = None,
|
1150 |
+
):
|
1151 |
+
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
1152 |
+
|
1153 |
+
if premium_user is not True:
|
1154 |
+
if email_logo_url is not None or email_support_contact is not None:
|
1155 |
+
raise ValueError(
|
1156 |
+
f"Trying to Customize Email Alerting\n {CommonProxyErrors.not_premium_user.value}"
|
1157 |
+
)
|
1158 |
+
return
|
1159 |
+
|
1160 |
+
async def send_key_created_or_user_invited_email(
|
1161 |
+
self, webhook_event: WebhookEvent
|
1162 |
+
) -> bool:
|
1163 |
+
try:
|
1164 |
+
from litellm.proxy.utils import send_email
|
1165 |
+
|
1166 |
+
if self.alerting is None or "email" not in self.alerting:
|
1167 |
+
# do nothing if user does not want email alerts
|
1168 |
+
verbose_proxy_logger.error(
|
1169 |
+
"Error sending email alert - 'email' not in self.alerting %s",
|
1170 |
+
self.alerting,
|
1171 |
+
)
|
1172 |
+
return False
|
1173 |
+
from litellm.proxy.proxy_server import premium_user, prisma_client
|
1174 |
+
|
1175 |
+
email_logo_url = os.getenv(
|
1176 |
+
"SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
|
1177 |
+
)
|
1178 |
+
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
|
1179 |
+
await self._check_if_using_premium_email_feature(
|
1180 |
+
premium_user, email_logo_url, email_support_contact
|
1181 |
+
)
|
1182 |
+
if email_logo_url is None:
|
1183 |
+
email_logo_url = LITELLM_LOGO_URL
|
1184 |
+
if email_support_contact is None:
|
1185 |
+
email_support_contact = LITELLM_SUPPORT_CONTACT
|
1186 |
+
|
1187 |
+
event_name = webhook_event.event_message
|
1188 |
+
recipient_email = webhook_event.user_email
|
1189 |
+
recipient_user_id = webhook_event.user_id
|
1190 |
+
if (
|
1191 |
+
recipient_email is None
|
1192 |
+
and recipient_user_id is not None
|
1193 |
+
and prisma_client is not None
|
1194 |
+
):
|
1195 |
+
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
1196 |
+
where={"user_id": recipient_user_id}
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
if user_row is not None:
|
1200 |
+
recipient_email = user_row.user_email
|
1201 |
+
|
1202 |
+
key_token = webhook_event.token
|
1203 |
+
key_budget = webhook_event.max_budget
|
1204 |
+
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
|
1205 |
+
|
1206 |
+
email_html_content = "Alert from LiteLLM Server"
|
1207 |
+
if recipient_email is None:
|
1208 |
+
verbose_proxy_logger.error(
|
1209 |
+
"Trying to send email alert to no recipient",
|
1210 |
+
extra=webhook_event.dict(),
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
if webhook_event.event == "key_created":
|
1214 |
+
email_html_content = KEY_CREATED_EMAIL_TEMPLATE.format(
|
1215 |
+
email_logo_url=email_logo_url,
|
1216 |
+
recipient_email=recipient_email,
|
1217 |
+
key_budget=key_budget,
|
1218 |
+
key_token=key_token,
|
1219 |
+
base_url=base_url,
|
1220 |
+
email_support_contact=email_support_contact,
|
1221 |
+
)
|
1222 |
+
elif webhook_event.event == "internal_user_created":
|
1223 |
+
# GET TEAM NAME
|
1224 |
+
team_id = webhook_event.team_id
|
1225 |
+
team_name = "Default Team"
|
1226 |
+
if team_id is not None and prisma_client is not None:
|
1227 |
+
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
1228 |
+
where={"team_id": team_id}
|
1229 |
+
)
|
1230 |
+
if team_row is not None:
|
1231 |
+
team_name = team_row.team_alias or "-"
|
1232 |
+
email_html_content = USER_INVITED_EMAIL_TEMPLATE.format(
|
1233 |
+
email_logo_url=email_logo_url,
|
1234 |
+
recipient_email=recipient_email,
|
1235 |
+
team_name=team_name,
|
1236 |
+
base_url=base_url,
|
1237 |
+
email_support_contact=email_support_contact,
|
1238 |
+
)
|
1239 |
+
else:
|
1240 |
+
verbose_proxy_logger.error(
|
1241 |
+
"Trying to send email alert on unknown webhook event",
|
1242 |
+
extra=webhook_event.model_dump(),
|
1243 |
+
)
|
1244 |
+
|
1245 |
+
webhook_event.model_dump_json()
|
1246 |
+
email_event = {
|
1247 |
+
"to": recipient_email,
|
1248 |
+
"subject": f"LiteLLM: {event_name}",
|
1249 |
+
"html": email_html_content,
|
1250 |
+
}
|
1251 |
+
|
1252 |
+
await send_email(
|
1253 |
+
receiver_email=email_event["to"],
|
1254 |
+
subject=email_event["subject"],
|
1255 |
+
html=email_event["html"],
|
1256 |
+
)
|
1257 |
+
|
1258 |
+
return True
|
1259 |
+
|
1260 |
+
except Exception as e:
|
1261 |
+
verbose_proxy_logger.error("Error sending email alert %s", str(e))
|
1262 |
+
return False
|
1263 |
+
|
1264 |
+
async def send_email_alert_using_smtp(
|
1265 |
+
self, webhook_event: WebhookEvent, alert_type: str
|
1266 |
+
) -> bool:
|
1267 |
+
"""
|
1268 |
+
Sends structured Email alert to an SMTP server
|
1269 |
+
|
1270 |
+
Currently only implemented for budget alerts
|
1271 |
+
|
1272 |
+
Returns -> True if sent, False if not.
|
1273 |
+
"""
|
1274 |
+
from litellm.proxy.proxy_server import premium_user
|
1275 |
+
from litellm.proxy.utils import send_email
|
1276 |
+
|
1277 |
+
email_logo_url = os.getenv(
|
1278 |
+
"SMTP_SENDER_LOGO", os.getenv("EMAIL_LOGO_URL", None)
|
1279 |
+
)
|
1280 |
+
email_support_contact = os.getenv("EMAIL_SUPPORT_CONTACT", None)
|
1281 |
+
await self._check_if_using_premium_email_feature(
|
1282 |
+
premium_user, email_logo_url, email_support_contact
|
1283 |
+
)
|
1284 |
+
|
1285 |
+
if email_logo_url is None:
|
1286 |
+
email_logo_url = LITELLM_LOGO_URL
|
1287 |
+
if email_support_contact is None:
|
1288 |
+
email_support_contact = LITELLM_SUPPORT_CONTACT
|
1289 |
+
|
1290 |
+
event_name = webhook_event.event_message
|
1291 |
+
recipient_email = webhook_event.user_email
|
1292 |
+
user_name = webhook_event.user_id
|
1293 |
+
max_budget = webhook_event.max_budget
|
1294 |
+
email_html_content = "Alert from LiteLLM Server"
|
1295 |
+
if recipient_email is None:
|
1296 |
+
verbose_proxy_logger.error(
|
1297 |
+
"Trying to send email alert to no recipient", extra=webhook_event.dict()
|
1298 |
+
)
|
1299 |
+
|
1300 |
+
if webhook_event.event == "budget_crossed":
|
1301 |
+
email_html_content = f"""
|
1302 |
+
<img src="{email_logo_url}" alt="LiteLLM Logo" width="150" height="50" />
|
1303 |
+
|
1304 |
+
<p> Hi {user_name}, <br/>
|
1305 |
+
|
1306 |
+
Your LLM API usage this month has reached your account's <b> monthly budget of ${max_budget} </b> <br /> <br />
|
1307 |
+
|
1308 |
+
API requests will be rejected until either (a) you increase your monthly budget or (b) your monthly usage resets at the beginning of the next calendar month. <br /> <br />
|
1309 |
+
|
1310 |
+
If you have any questions, please send an email to {email_support_contact} <br /> <br />
|
1311 |
+
|
1312 |
+
Best, <br />
|
1313 |
+
The LiteLLM team <br />
|
1314 |
+
"""
|
1315 |
+
|
1316 |
+
webhook_event.model_dump_json()
|
1317 |
+
email_event = {
|
1318 |
+
"to": recipient_email,
|
1319 |
+
"subject": f"LiteLLM: {event_name}",
|
1320 |
+
"html": email_html_content,
|
1321 |
+
}
|
1322 |
+
|
1323 |
+
await send_email(
|
1324 |
+
receiver_email=email_event["to"],
|
1325 |
+
subject=email_event["subject"],
|
1326 |
+
html=email_event["html"],
|
1327 |
+
)
|
1328 |
+
if webhook_event.event_group == "team":
|
1329 |
+
from litellm.integrations.email_alerting import send_team_budget_alert
|
1330 |
+
|
1331 |
+
await send_team_budget_alert(webhook_event=webhook_event)
|
1332 |
+
|
1333 |
+
return False
|
1334 |
+
|
1335 |
+
async def send_alert(
|
1336 |
+
self,
|
1337 |
+
message: str,
|
1338 |
+
level: Literal["Low", "Medium", "High"],
|
1339 |
+
alert_type: AlertType,
|
1340 |
+
alerting_metadata: dict,
|
1341 |
+
user_info: Optional[WebhookEvent] = None,
|
1342 |
+
**kwargs,
|
1343 |
+
):
|
1344 |
+
"""
|
1345 |
+
Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298
|
1346 |
+
|
1347 |
+
- Responses taking too long
|
1348 |
+
- Requests are hanging
|
1349 |
+
- Calls are failing
|
1350 |
+
- DB Read/Writes are failing
|
1351 |
+
- Proxy Close to max budget
|
1352 |
+
- Key Close to max budget
|
1353 |
+
|
1354 |
+
Parameters:
|
1355 |
+
level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'.
|
1356 |
+
message: str - what is the alert about
|
1357 |
+
"""
|
1358 |
+
if self.alerting is None:
|
1359 |
+
return
|
1360 |
+
|
1361 |
+
if (
|
1362 |
+
"webhook" in self.alerting
|
1363 |
+
and alert_type == "budget_alerts"
|
1364 |
+
and user_info is not None
|
1365 |
+
):
|
1366 |
+
await self.send_webhook_alert(webhook_event=user_info)
|
1367 |
+
|
1368 |
+
if (
|
1369 |
+
"email" in self.alerting
|
1370 |
+
and alert_type == "budget_alerts"
|
1371 |
+
and user_info is not None
|
1372 |
+
):
|
1373 |
+
# only send budget alerts over Email
|
1374 |
+
await self.send_email_alert_using_smtp(
|
1375 |
+
webhook_event=user_info, alert_type=alert_type
|
1376 |
+
)
|
1377 |
+
|
1378 |
+
if "slack" not in self.alerting:
|
1379 |
+
return
|
1380 |
+
if alert_type not in self.alert_types:
|
1381 |
+
return
|
1382 |
+
|
1383 |
+
from datetime import datetime
|
1384 |
+
|
1385 |
+
# Get the current timestamp
|
1386 |
+
current_time = datetime.now().strftime("%H:%M:%S")
|
1387 |
+
_proxy_base_url = os.getenv("PROXY_BASE_URL", None)
|
1388 |
+
if alert_type == "daily_reports" or alert_type == "new_model_added":
|
1389 |
+
formatted_message = message
|
1390 |
+
else:
|
1391 |
+
formatted_message = (
|
1392 |
+
f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}"
|
1393 |
+
)
|
1394 |
+
|
1395 |
+
if kwargs:
|
1396 |
+
for key, value in kwargs.items():
|
1397 |
+
formatted_message += f"\n\n{key}: `{value}`\n\n"
|
1398 |
+
if alerting_metadata:
|
1399 |
+
for key, value in alerting_metadata.items():
|
1400 |
+
formatted_message += f"\n\n*Alerting Metadata*: \n{key}: `{value}`\n\n"
|
1401 |
+
if _proxy_base_url is not None:
|
1402 |
+
formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`"
|
1403 |
+
|
1404 |
+
# check if we find the slack webhook url in self.alert_to_webhook_url
|
1405 |
+
if (
|
1406 |
+
self.alert_to_webhook_url is not None
|
1407 |
+
and alert_type in self.alert_to_webhook_url
|
1408 |
+
):
|
1409 |
+
slack_webhook_url: Optional[
|
1410 |
+
Union[str, List[str]]
|
1411 |
+
] = self.alert_to_webhook_url[alert_type]
|
1412 |
+
elif self.default_webhook_url is not None:
|
1413 |
+
slack_webhook_url = self.default_webhook_url
|
1414 |
+
else:
|
1415 |
+
slack_webhook_url = os.getenv("SLACK_WEBHOOK_URL", None)
|
1416 |
+
|
1417 |
+
if slack_webhook_url is None:
|
1418 |
+
raise ValueError("Missing SLACK_WEBHOOK_URL from environment")
|
1419 |
+
payload = {"text": formatted_message}
|
1420 |
+
headers = {"Content-type": "application/json"}
|
1421 |
+
|
1422 |
+
if isinstance(slack_webhook_url, list):
|
1423 |
+
for url in slack_webhook_url:
|
1424 |
+
self.log_queue.append(
|
1425 |
+
{
|
1426 |
+
"url": url,
|
1427 |
+
"headers": headers,
|
1428 |
+
"payload": payload,
|
1429 |
+
"alert_type": alert_type,
|
1430 |
+
}
|
1431 |
+
)
|
1432 |
+
else:
|
1433 |
+
self.log_queue.append(
|
1434 |
+
{
|
1435 |
+
"url": slack_webhook_url,
|
1436 |
+
"headers": headers,
|
1437 |
+
"payload": payload,
|
1438 |
+
"alert_type": alert_type,
|
1439 |
+
}
|
1440 |
+
)
|
1441 |
+
|
1442 |
+
if len(self.log_queue) >= self.batch_size:
|
1443 |
+
await self.flush_queue()
|
1444 |
+
|
1445 |
+
async def async_send_batch(self):
|
1446 |
+
if not self.log_queue:
|
1447 |
+
return
|
1448 |
+
|
1449 |
+
squashed_queue = squash_payloads(self.log_queue)
|
1450 |
+
tasks = [
|
1451 |
+
send_to_webhook(
|
1452 |
+
slackAlertingInstance=self, item=item["item"], count=item["count"]
|
1453 |
+
)
|
1454 |
+
for item in squashed_queue.values()
|
1455 |
+
]
|
1456 |
+
await asyncio.gather(*tasks)
|
1457 |
+
self.log_queue.clear()
|
1458 |
+
|
1459 |
+
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
1460 |
+
"""Log deployment latency"""
|
1461 |
+
try:
|
1462 |
+
if "daily_reports" in self.alert_types:
|
1463 |
+
litellm_params = kwargs.get("litellm_params", {}) or {}
|
1464 |
+
model_info = litellm_params.get("model_info", {}) or {}
|
1465 |
+
model_id = model_info.get("id", "") or ""
|
1466 |
+
response_s: timedelta = end_time - start_time
|
1467 |
+
|
1468 |
+
final_value = response_s
|
1469 |
+
|
1470 |
+
if isinstance(response_obj, litellm.ModelResponse) and (
|
1471 |
+
hasattr(response_obj, "usage")
|
1472 |
+
and response_obj.usage is not None # type: ignore
|
1473 |
+
and hasattr(response_obj.usage, "completion_tokens") # type: ignore
|
1474 |
+
):
|
1475 |
+
completion_tokens = response_obj.usage.completion_tokens # type: ignore
|
1476 |
+
if completion_tokens is not None and completion_tokens > 0:
|
1477 |
+
final_value = float(
|
1478 |
+
response_s.total_seconds() / completion_tokens
|
1479 |
+
)
|
1480 |
+
if isinstance(final_value, timedelta):
|
1481 |
+
final_value = final_value.total_seconds()
|
1482 |
+
|
1483 |
+
await self.async_update_daily_reports(
|
1484 |
+
DeploymentMetrics(
|
1485 |
+
id=model_id,
|
1486 |
+
failed_request=False,
|
1487 |
+
latency_per_output_token=final_value,
|
1488 |
+
updated_at=litellm.utils.get_utc_datetime(),
|
1489 |
+
)
|
1490 |
+
)
|
1491 |
+
except Exception as e:
|
1492 |
+
verbose_proxy_logger.error(
|
1493 |
+
f"[Non-Blocking Error] Slack Alerting: Got error in logging LLM deployment latency: {str(e)}"
|
1494 |
+
)
|
1495 |
+
pass
|
1496 |
+
|
1497 |
+
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
1498 |
+
"""Log failure + deployment latency"""
|
1499 |
+
_litellm_params = kwargs.get("litellm_params", {})
|
1500 |
+
_model_info = _litellm_params.get("model_info", {}) or {}
|
1501 |
+
model_id = _model_info.get("id", "")
|
1502 |
+
try:
|
1503 |
+
if "daily_reports" in self.alert_types:
|
1504 |
+
try:
|
1505 |
+
await self.async_update_daily_reports(
|
1506 |
+
DeploymentMetrics(
|
1507 |
+
id=model_id,
|
1508 |
+
failed_request=True,
|
1509 |
+
latency_per_output_token=None,
|
1510 |
+
updated_at=litellm.utils.get_utc_datetime(),
|
1511 |
+
)
|
1512 |
+
)
|
1513 |
+
except Exception as e:
|
1514 |
+
verbose_logger.debug(f"Exception raises -{str(e)}")
|
1515 |
+
|
1516 |
+
if isinstance(kwargs.get("exception", ""), APIError):
|
1517 |
+
if "outage_alerts" in self.alert_types:
|
1518 |
+
await self.outage_alerts(
|
1519 |
+
exception=kwargs["exception"],
|
1520 |
+
deployment_id=model_id,
|
1521 |
+
)
|
1522 |
+
|
1523 |
+
if "region_outage_alerts" in self.alert_types:
|
1524 |
+
await self.region_outage_alerts(
|
1525 |
+
exception=kwargs["exception"], deployment_id=model_id
|
1526 |
+
)
|
1527 |
+
except Exception:
|
1528 |
+
pass
|
1529 |
+
|
1530 |
+
async def _run_scheduler_helper(self, llm_router) -> bool:
|
1531 |
+
"""
|
1532 |
+
Returns:
|
1533 |
+
- True -> report sent
|
1534 |
+
- False -> report not sent
|
1535 |
+
"""
|
1536 |
+
report_sent_bool = False
|
1537 |
+
|
1538 |
+
report_sent = await self.internal_usage_cache.async_get_cache(
|
1539 |
+
key=SlackAlertingCacheKeys.report_sent_key.value,
|
1540 |
+
parent_otel_span=None,
|
1541 |
+
) # None | float
|
1542 |
+
|
1543 |
+
current_time = time.time()
|
1544 |
+
|
1545 |
+
if report_sent is None:
|
1546 |
+
await self.internal_usage_cache.async_set_cache(
|
1547 |
+
key=SlackAlertingCacheKeys.report_sent_key.value,
|
1548 |
+
value=current_time,
|
1549 |
+
)
|
1550 |
+
elif isinstance(report_sent, float):
|
1551 |
+
# Check if current time - interval >= time last sent
|
1552 |
+
interval_seconds = self.alerting_args.daily_report_frequency
|
1553 |
+
|
1554 |
+
if current_time - report_sent >= interval_seconds:
|
1555 |
+
# Sneak in the reporting logic here
|
1556 |
+
await self.send_daily_reports(router=llm_router)
|
1557 |
+
# Also, don't forget to update the report_sent time after sending the report!
|
1558 |
+
await self.internal_usage_cache.async_set_cache(
|
1559 |
+
key=SlackAlertingCacheKeys.report_sent_key.value,
|
1560 |
+
value=current_time,
|
1561 |
+
)
|
1562 |
+
report_sent_bool = True
|
1563 |
+
|
1564 |
+
return report_sent_bool
|
1565 |
+
|
1566 |
+
async def _run_scheduled_daily_report(self, llm_router: Optional[Any] = None):
|
1567 |
+
"""
|
1568 |
+
If 'daily_reports' enabled
|
1569 |
+
|
1570 |
+
Ping redis cache every 5 minutes to check if we should send the report
|
1571 |
+
|
1572 |
+
If yes -> call send_daily_report()
|
1573 |
+
"""
|
1574 |
+
if llm_router is None or self.alert_types is None:
|
1575 |
+
return
|
1576 |
+
|
1577 |
+
if "daily_reports" in self.alert_types:
|
1578 |
+
while True:
|
1579 |
+
await self._run_scheduler_helper(llm_router=llm_router)
|
1580 |
+
interval = random.randint(
|
1581 |
+
self.alerting_args.report_check_interval - 3,
|
1582 |
+
self.alerting_args.report_check_interval + 3,
|
1583 |
+
) # shuffle to prevent collisions
|
1584 |
+
await asyncio.sleep(interval)
|
1585 |
+
return
|
1586 |
+
|
1587 |
+
async def send_weekly_spend_report(
|
1588 |
+
self,
|
1589 |
+
time_range: str = "7d",
|
1590 |
+
):
|
1591 |
+
"""
|
1592 |
+
Send a spend report for a configurable time range.
|
1593 |
+
|
1594 |
+
Args:
|
1595 |
+
time_range: A string specifying the time range for the report, e.g., "1d", "7d", "30d"
|
1596 |
+
"""
|
1597 |
+
if self.alerting is None or "spend_reports" not in self.alert_types:
|
1598 |
+
return
|
1599 |
+
|
1600 |
+
try:
|
1601 |
+
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
1602 |
+
_get_spend_report_for_time_range,
|
1603 |
+
)
|
1604 |
+
|
1605 |
+
# Parse the time range
|
1606 |
+
days = int(time_range[:-1])
|
1607 |
+
if time_range[-1].lower() != "d":
|
1608 |
+
raise ValueError("Time range must be specified in days, e.g., '7d'")
|
1609 |
+
|
1610 |
+
todays_date = datetime.datetime.now().date()
|
1611 |
+
start_date = todays_date - datetime.timedelta(days=days)
|
1612 |
+
|
1613 |
+
_event_cache_key = f"weekly_spend_report_sent_{start_date.strftime('%Y-%m-%d')}_{todays_date.strftime('%Y-%m-%d')}"
|
1614 |
+
if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
|
1615 |
+
return
|
1616 |
+
|
1617 |
+
_resp = await _get_spend_report_for_time_range(
|
1618 |
+
start_date=start_date.strftime("%Y-%m-%d"),
|
1619 |
+
end_date=todays_date.strftime("%Y-%m-%d"),
|
1620 |
+
)
|
1621 |
+
if _resp is None or _resp == ([], []):
|
1622 |
+
return
|
1623 |
+
|
1624 |
+
spend_per_team, spend_per_tag = _resp
|
1625 |
+
|
1626 |
+
_spend_message = f"*💸 Spend Report for `{start_date.strftime('%m-%d-%Y')} - {todays_date.strftime('%m-%d-%Y')}` ({days} days)*\n"
|
1627 |
+
|
1628 |
+
if spend_per_team is not None:
|
1629 |
+
_spend_message += "\n*Team Spend Report:*\n"
|
1630 |
+
for spend in spend_per_team:
|
1631 |
+
_team_spend = round(float(spend["total_spend"]), 4)
|
1632 |
+
_spend_message += (
|
1633 |
+
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
1634 |
+
)
|
1635 |
+
|
1636 |
+
if spend_per_tag is not None:
|
1637 |
+
_spend_message += "\n*Tag Spend Report:*\n"
|
1638 |
+
for spend in spend_per_tag:
|
1639 |
+
_tag_spend = round(float(spend["total_spend"]), 4)
|
1640 |
+
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
1641 |
+
|
1642 |
+
await self.send_alert(
|
1643 |
+
message=_spend_message,
|
1644 |
+
level="Low",
|
1645 |
+
alert_type=AlertType.spend_reports,
|
1646 |
+
alerting_metadata={},
|
1647 |
+
)
|
1648 |
+
|
1649 |
+
await self.internal_usage_cache.async_set_cache(
|
1650 |
+
key=_event_cache_key,
|
1651 |
+
value="SENT",
|
1652 |
+
ttl=duration_in_seconds(time_range),
|
1653 |
+
)
|
1654 |
+
|
1655 |
+
except ValueError as ve:
|
1656 |
+
verbose_proxy_logger.error(f"Invalid time range format: {ve}")
|
1657 |
+
except Exception as e:
|
1658 |
+
verbose_proxy_logger.error(f"Error sending spend report: {e}")
|
1659 |
+
|
1660 |
+
async def send_monthly_spend_report(self):
|
1661 |
+
""" """
|
1662 |
+
try:
|
1663 |
+
from calendar import monthrange
|
1664 |
+
|
1665 |
+
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
1666 |
+
_get_spend_report_for_time_range,
|
1667 |
+
)
|
1668 |
+
|
1669 |
+
todays_date = datetime.datetime.now().date()
|
1670 |
+
first_day_of_month = todays_date.replace(day=1)
|
1671 |
+
_, last_day_of_month = monthrange(todays_date.year, todays_date.month)
|
1672 |
+
last_day_of_month = first_day_of_month + datetime.timedelta(
|
1673 |
+
days=last_day_of_month - 1
|
1674 |
+
)
|
1675 |
+
|
1676 |
+
_event_cache_key = f"monthly_spend_report_sent_{first_day_of_month.strftime('%Y-%m-%d')}_{last_day_of_month.strftime('%Y-%m-%d')}"
|
1677 |
+
if await self.internal_usage_cache.async_get_cache(key=_event_cache_key):
|
1678 |
+
return
|
1679 |
+
|
1680 |
+
_resp = await _get_spend_report_for_time_range(
|
1681 |
+
start_date=first_day_of_month.strftime("%Y-%m-%d"),
|
1682 |
+
end_date=last_day_of_month.strftime("%Y-%m-%d"),
|
1683 |
+
)
|
1684 |
+
|
1685 |
+
if _resp is None or _resp == ([], []):
|
1686 |
+
return
|
1687 |
+
|
1688 |
+
monthly_spend_per_team, monthly_spend_per_tag = _resp
|
1689 |
+
|
1690 |
+
_spend_message = f"*💸 Monthly Spend Report for `{first_day_of_month.strftime('%m-%d-%Y')} - {last_day_of_month.strftime('%m-%d-%Y')}` *\n"
|
1691 |
+
|
1692 |
+
if monthly_spend_per_team is not None:
|
1693 |
+
_spend_message += "\n*Team Spend Report:*\n"
|
1694 |
+
for spend in monthly_spend_per_team:
|
1695 |
+
_team_spend = spend["total_spend"]
|
1696 |
+
_team_spend = float(_team_spend)
|
1697 |
+
# round to 4 decimal places
|
1698 |
+
_team_spend = round(_team_spend, 4)
|
1699 |
+
_spend_message += (
|
1700 |
+
f"Team: `{spend['team_alias']}` | Spend: `${_team_spend}`\n"
|
1701 |
+
)
|
1702 |
+
|
1703 |
+
if monthly_spend_per_tag is not None:
|
1704 |
+
_spend_message += "\n*Tag Spend Report:*\n"
|
1705 |
+
for spend in monthly_spend_per_tag:
|
1706 |
+
_tag_spend = spend["total_spend"]
|
1707 |
+
_tag_spend = float(_tag_spend)
|
1708 |
+
# round to 4 decimal places
|
1709 |
+
_tag_spend = round(_tag_spend, 4)
|
1710 |
+
_spend_message += f"Tag: `{spend['individual_request_tag']}` | Spend: `${_tag_spend}`\n"
|
1711 |
+
|
1712 |
+
await self.send_alert(
|
1713 |
+
message=_spend_message,
|
1714 |
+
level="Low",
|
1715 |
+
alert_type=AlertType.spend_reports,
|
1716 |
+
alerting_metadata={},
|
1717 |
+
)
|
1718 |
+
|
1719 |
+
await self.internal_usage_cache.async_set_cache(
|
1720 |
+
key=_event_cache_key,
|
1721 |
+
value="SENT",
|
1722 |
+
ttl=(30 * HOURS_IN_A_DAY * 60 * 60), # 1 month
|
1723 |
+
)
|
1724 |
+
|
1725 |
+
except Exception as e:
|
1726 |
+
verbose_proxy_logger.exception("Error sending weekly spend report %s", e)
|
1727 |
+
|
1728 |
+
async def send_fallback_stats_from_prometheus(self):
|
1729 |
+
"""
|
1730 |
+
Helper to send fallback statistics from prometheus server -> to slack
|
1731 |
+
|
1732 |
+
This runs once per day and sends an overview of all the fallback statistics
|
1733 |
+
"""
|
1734 |
+
try:
|
1735 |
+
from litellm.integrations.prometheus_helpers.prometheus_api import (
|
1736 |
+
get_fallback_metric_from_prometheus,
|
1737 |
+
)
|
1738 |
+
|
1739 |
+
# call prometheuslogger.
|
1740 |
+
falllback_success_info_prometheus = (
|
1741 |
+
await get_fallback_metric_from_prometheus()
|
1742 |
+
)
|
1743 |
+
|
1744 |
+
fallback_message = (
|
1745 |
+
f"*Fallback Statistics:*\n{falllback_success_info_prometheus}"
|
1746 |
+
)
|
1747 |
+
|
1748 |
+
await self.send_alert(
|
1749 |
+
message=fallback_message,
|
1750 |
+
level="Low",
|
1751 |
+
alert_type=AlertType.fallback_reports,
|
1752 |
+
alerting_metadata={},
|
1753 |
+
)
|
1754 |
+
|
1755 |
+
except Exception as e:
|
1756 |
+
verbose_proxy_logger.error("Error sending weekly spend report %s", e)
|
1757 |
+
|
1758 |
+
pass
|
1759 |
+
|
1760 |
+
async def send_virtual_key_event_slack(
|
1761 |
+
self,
|
1762 |
+
key_event: VirtualKeyEvent,
|
1763 |
+
alert_type: AlertType,
|
1764 |
+
event_name: str,
|
1765 |
+
):
|
1766 |
+
"""
|
1767 |
+
Handles sending Virtual Key related alerts
|
1768 |
+
|
1769 |
+
Example:
|
1770 |
+
- New Virtual Key Created
|
1771 |
+
- Internal User Updated
|
1772 |
+
- Team Created, Updated, Deleted
|
1773 |
+
"""
|
1774 |
+
try:
|
1775 |
+
message = f"`{event_name}`\n"
|
1776 |
+
|
1777 |
+
key_event_dict = key_event.model_dump()
|
1778 |
+
|
1779 |
+
# Add Created by information first
|
1780 |
+
message += "*Action Done by:*\n"
|
1781 |
+
for key, value in key_event_dict.items():
|
1782 |
+
if "created_by" in key:
|
1783 |
+
message += f"{key}: `{value}`\n"
|
1784 |
+
|
1785 |
+
# Add args sent to function in the alert
|
1786 |
+
message += "\n*Arguments passed:*\n"
|
1787 |
+
request_kwargs = key_event.request_kwargs
|
1788 |
+
for key, value in request_kwargs.items():
|
1789 |
+
if key == "user_api_key_dict":
|
1790 |
+
continue
|
1791 |
+
message += f"{key}: `{value}`\n"
|
1792 |
+
|
1793 |
+
await self.send_alert(
|
1794 |
+
message=message,
|
1795 |
+
level="High",
|
1796 |
+
alert_type=alert_type,
|
1797 |
+
alerting_metadata={},
|
1798 |
+
)
|
1799 |
+
|
1800 |
+
except Exception as e:
|
1801 |
+
verbose_proxy_logger.error(
|
1802 |
+
"Error sending send_virtual_key_event_slack %s", e
|
1803 |
+
)
|
1804 |
+
|
1805 |
+
return
|
1806 |
+
|
1807 |
+
async def _request_is_completed(self, request_data: Optional[dict]) -> bool:
|
1808 |
+
"""
|
1809 |
+
Returns True if the request is completed - either as a success or failure
|
1810 |
+
"""
|
1811 |
+
if request_data is None:
|
1812 |
+
return False
|
1813 |
+
|
1814 |
+
if (
|
1815 |
+
request_data.get("litellm_status", "") != "success"
|
1816 |
+
and request_data.get("litellm_status", "") != "fail"
|
1817 |
+
):
|
1818 |
+
## CHECK IF CACHE IS UPDATED
|
1819 |
+
litellm_call_id = request_data.get("litellm_call_id", "")
|
1820 |
+
status: Optional[str] = await self.internal_usage_cache.async_get_cache(
|
1821 |
+
key="request_status:{}".format(litellm_call_id), local_only=True
|
1822 |
+
)
|
1823 |
+
if status is not None and (status == "success" or status == "fail"):
|
1824 |
+
return True
|
1825 |
+
return False
|
litellm/integrations/SlackAlerting/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utils used for slack alerting
|
3 |
+
"""
|
4 |
+
|
5 |
+
import asyncio
|
6 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
7 |
+
|
8 |
+
from litellm.proxy._types import AlertType
|
9 |
+
from litellm.secret_managers.main import get_secret
|
10 |
+
|
11 |
+
if TYPE_CHECKING:
|
12 |
+
from litellm.litellm_core_utils.litellm_logging import Logging as _Logging
|
13 |
+
|
14 |
+
Logging = _Logging
|
15 |
+
else:
|
16 |
+
Logging = Any
|
17 |
+
|
18 |
+
|
19 |
+
def process_slack_alerting_variables(
|
20 |
+
alert_to_webhook_url: Optional[Dict[AlertType, Union[List[str], str]]]
|
21 |
+
) -> Optional[Dict[AlertType, Union[List[str], str]]]:
|
22 |
+
"""
|
23 |
+
process alert_to_webhook_url
|
24 |
+
- check if any urls are set as os.environ/SLACK_WEBHOOK_URL_1 read env var and set the correct value
|
25 |
+
"""
|
26 |
+
if alert_to_webhook_url is None:
|
27 |
+
return None
|
28 |
+
|
29 |
+
for alert_type, webhook_urls in alert_to_webhook_url.items():
|
30 |
+
if isinstance(webhook_urls, list):
|
31 |
+
_webhook_values: List[str] = []
|
32 |
+
for webhook_url in webhook_urls:
|
33 |
+
if "os.environ/" in webhook_url:
|
34 |
+
_env_value = get_secret(secret_name=webhook_url)
|
35 |
+
if not isinstance(_env_value, str):
|
36 |
+
raise ValueError(
|
37 |
+
f"Invalid webhook url value for: {webhook_url}. Got type={type(_env_value)}"
|
38 |
+
)
|
39 |
+
_webhook_values.append(_env_value)
|
40 |
+
else:
|
41 |
+
_webhook_values.append(webhook_url)
|
42 |
+
|
43 |
+
alert_to_webhook_url[alert_type] = _webhook_values
|
44 |
+
else:
|
45 |
+
_webhook_value_str: str = webhook_urls
|
46 |
+
if "os.environ/" in webhook_urls:
|
47 |
+
_env_value = get_secret(secret_name=webhook_urls)
|
48 |
+
if not isinstance(_env_value, str):
|
49 |
+
raise ValueError(
|
50 |
+
f"Invalid webhook url value for: {webhook_urls}. Got type={type(_env_value)}"
|
51 |
+
)
|
52 |
+
_webhook_value_str = _env_value
|
53 |
+
else:
|
54 |
+
_webhook_value_str = webhook_urls
|
55 |
+
|
56 |
+
alert_to_webhook_url[alert_type] = _webhook_value_str
|
57 |
+
|
58 |
+
return alert_to_webhook_url
|
59 |
+
|
60 |
+
|
61 |
+
async def _add_langfuse_trace_id_to_alert(
|
62 |
+
request_data: Optional[dict] = None,
|
63 |
+
) -> Optional[str]:
|
64 |
+
"""
|
65 |
+
Returns langfuse trace url
|
66 |
+
|
67 |
+
- check:
|
68 |
+
-> existing_trace_id
|
69 |
+
-> trace_id
|
70 |
+
-> litellm_call_id
|
71 |
+
"""
|
72 |
+
# do nothing for now
|
73 |
+
if (
|
74 |
+
request_data is not None
|
75 |
+
and request_data.get("litellm_logging_obj", None) is not None
|
76 |
+
):
|
77 |
+
trace_id: Optional[str] = None
|
78 |
+
litellm_logging_obj: Logging = request_data["litellm_logging_obj"]
|
79 |
+
|
80 |
+
for _ in range(3):
|
81 |
+
trace_id = litellm_logging_obj._get_trace_id(service_name="langfuse")
|
82 |
+
if trace_id is not None:
|
83 |
+
break
|
84 |
+
await asyncio.sleep(3) # wait 3s before retrying for trace id
|
85 |
+
|
86 |
+
_langfuse_object = litellm_logging_obj._get_callback_object(
|
87 |
+
service_name="langfuse"
|
88 |
+
)
|
89 |
+
if _langfuse_object is not None:
|
90 |
+
base_url = _langfuse_object.Langfuse.base_url
|
91 |
+
return f"{base_url}/trace/{trace_id}"
|
92 |
+
return None
|
litellm/integrations/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|
litellm/integrations/_types/open_inference.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
|
4 |
+
class SpanAttributes:
|
5 |
+
OUTPUT_VALUE = "output.value"
|
6 |
+
OUTPUT_MIME_TYPE = "output.mime_type"
|
7 |
+
"""
|
8 |
+
The type of output.value. If unspecified, the type is plain text by default.
|
9 |
+
If type is JSON, the value is a string representing a JSON object.
|
10 |
+
"""
|
11 |
+
INPUT_VALUE = "input.value"
|
12 |
+
INPUT_MIME_TYPE = "input.mime_type"
|
13 |
+
"""
|
14 |
+
The type of input.value. If unspecified, the type is plain text by default.
|
15 |
+
If type is JSON, the value is a string representing a JSON object.
|
16 |
+
"""
|
17 |
+
|
18 |
+
EMBEDDING_EMBEDDINGS = "embedding.embeddings"
|
19 |
+
"""
|
20 |
+
A list of objects containing embedding data, including the vector and represented piece of text.
|
21 |
+
"""
|
22 |
+
EMBEDDING_MODEL_NAME = "embedding.model_name"
|
23 |
+
"""
|
24 |
+
The name of the embedding model.
|
25 |
+
"""
|
26 |
+
|
27 |
+
LLM_FUNCTION_CALL = "llm.function_call"
|
28 |
+
"""
|
29 |
+
For models and APIs that support function calling. Records attributes such as the function
|
30 |
+
name and arguments to the called function.
|
31 |
+
"""
|
32 |
+
LLM_INVOCATION_PARAMETERS = "llm.invocation_parameters"
|
33 |
+
"""
|
34 |
+
Invocation parameters passed to the LLM or API, such as the model name, temperature, etc.
|
35 |
+
"""
|
36 |
+
LLM_INPUT_MESSAGES = "llm.input_messages"
|
37 |
+
"""
|
38 |
+
Messages provided to a chat API.
|
39 |
+
"""
|
40 |
+
LLM_OUTPUT_MESSAGES = "llm.output_messages"
|
41 |
+
"""
|
42 |
+
Messages received from a chat API.
|
43 |
+
"""
|
44 |
+
LLM_MODEL_NAME = "llm.model_name"
|
45 |
+
"""
|
46 |
+
The name of the model being used.
|
47 |
+
"""
|
48 |
+
LLM_PROVIDER = "llm.provider"
|
49 |
+
"""
|
50 |
+
The provider of the model, such as OpenAI, Azure, Google, etc.
|
51 |
+
"""
|
52 |
+
LLM_SYSTEM = "llm.system"
|
53 |
+
"""
|
54 |
+
The AI product as identified by the client or server
|
55 |
+
"""
|
56 |
+
LLM_PROMPTS = "llm.prompts"
|
57 |
+
"""
|
58 |
+
Prompts provided to a completions API.
|
59 |
+
"""
|
60 |
+
LLM_PROMPT_TEMPLATE = "llm.prompt_template.template"
|
61 |
+
"""
|
62 |
+
The prompt template as a Python f-string.
|
63 |
+
"""
|
64 |
+
LLM_PROMPT_TEMPLATE_VARIABLES = "llm.prompt_template.variables"
|
65 |
+
"""
|
66 |
+
A list of input variables to the prompt template.
|
67 |
+
"""
|
68 |
+
LLM_PROMPT_TEMPLATE_VERSION = "llm.prompt_template.version"
|
69 |
+
"""
|
70 |
+
The version of the prompt template being used.
|
71 |
+
"""
|
72 |
+
LLM_TOKEN_COUNT_PROMPT = "llm.token_count.prompt"
|
73 |
+
"""
|
74 |
+
Number of tokens in the prompt.
|
75 |
+
"""
|
76 |
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_WRITE = "llm.token_count.prompt_details.cache_write"
|
77 |
+
"""
|
78 |
+
Number of tokens in the prompt that were written to cache.
|
79 |
+
"""
|
80 |
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ = "llm.token_count.prompt_details.cache_read"
|
81 |
+
"""
|
82 |
+
Number of tokens in the prompt that were read from cache.
|
83 |
+
"""
|
84 |
+
LLM_TOKEN_COUNT_PROMPT_DETAILS_AUDIO = "llm.token_count.prompt_details.audio"
|
85 |
+
"""
|
86 |
+
The number of audio input tokens presented in the prompt
|
87 |
+
"""
|
88 |
+
LLM_TOKEN_COUNT_COMPLETION = "llm.token_count.completion"
|
89 |
+
"""
|
90 |
+
Number of tokens in the completion.
|
91 |
+
"""
|
92 |
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING = "llm.token_count.completion_details.reasoning"
|
93 |
+
"""
|
94 |
+
Number of tokens used for reasoning steps in the completion.
|
95 |
+
"""
|
96 |
+
LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO = "llm.token_count.completion_details.audio"
|
97 |
+
"""
|
98 |
+
The number of audio input tokens generated by the model
|
99 |
+
"""
|
100 |
+
LLM_TOKEN_COUNT_TOTAL = "llm.token_count.total"
|
101 |
+
"""
|
102 |
+
Total number of tokens, including both prompt and completion.
|
103 |
+
"""
|
104 |
+
|
105 |
+
LLM_TOOLS = "llm.tools"
|
106 |
+
"""
|
107 |
+
List of tools that are advertised to the LLM to be able to call
|
108 |
+
"""
|
109 |
+
|
110 |
+
TOOL_NAME = "tool.name"
|
111 |
+
"""
|
112 |
+
Name of the tool being used.
|
113 |
+
"""
|
114 |
+
TOOL_DESCRIPTION = "tool.description"
|
115 |
+
"""
|
116 |
+
Description of the tool's purpose, typically used to select the tool.
|
117 |
+
"""
|
118 |
+
TOOL_PARAMETERS = "tool.parameters"
|
119 |
+
"""
|
120 |
+
Parameters of the tool represented a dictionary JSON string, e.g.
|
121 |
+
see https://platform.openai.com/docs/guides/gpt/function-calling
|
122 |
+
"""
|
123 |
+
|
124 |
+
RETRIEVAL_DOCUMENTS = "retrieval.documents"
|
125 |
+
|
126 |
+
METADATA = "metadata"
|
127 |
+
"""
|
128 |
+
Metadata attributes are used to store user-defined key-value pairs.
|
129 |
+
For example, LangChain uses metadata to store user-defined attributes for a chain.
|
130 |
+
"""
|
131 |
+
|
132 |
+
TAG_TAGS = "tag.tags"
|
133 |
+
"""
|
134 |
+
Custom categorical tags for the span.
|
135 |
+
"""
|
136 |
+
|
137 |
+
OPENINFERENCE_SPAN_KIND = "openinference.span.kind"
|
138 |
+
|
139 |
+
SESSION_ID = "session.id"
|
140 |
+
"""
|
141 |
+
The id of the session
|
142 |
+
"""
|
143 |
+
USER_ID = "user.id"
|
144 |
+
"""
|
145 |
+
The id of the user
|
146 |
+
"""
|
147 |
+
|
148 |
+
PROMPT_VENDOR = "prompt.vendor"
|
149 |
+
"""
|
150 |
+
The vendor or origin of the prompt, e.g. a prompt library, a specialized service, etc.
|
151 |
+
"""
|
152 |
+
PROMPT_ID = "prompt.id"
|
153 |
+
"""
|
154 |
+
A vendor-specific id used to locate the prompt.
|
155 |
+
"""
|
156 |
+
PROMPT_URL = "prompt.url"
|
157 |
+
"""
|
158 |
+
A vendor-specific url used to locate the prompt.
|
159 |
+
"""
|
160 |
+
|
161 |
+
|
162 |
+
class MessageAttributes:
|
163 |
+
"""
|
164 |
+
Attributes for a message sent to or from an LLM
|
165 |
+
"""
|
166 |
+
|
167 |
+
MESSAGE_ROLE = "message.role"
|
168 |
+
"""
|
169 |
+
The role of the message, such as "user", "agent", "function".
|
170 |
+
"""
|
171 |
+
MESSAGE_CONTENT = "message.content"
|
172 |
+
"""
|
173 |
+
The content of the message to or from the llm, must be a string.
|
174 |
+
"""
|
175 |
+
MESSAGE_CONTENTS = "message.contents"
|
176 |
+
"""
|
177 |
+
The message contents to the llm, it is an array of
|
178 |
+
`message_content` prefixed attributes.
|
179 |
+
"""
|
180 |
+
MESSAGE_NAME = "message.name"
|
181 |
+
"""
|
182 |
+
The name of the message, often used to identify the function
|
183 |
+
that was used to generate the message.
|
184 |
+
"""
|
185 |
+
MESSAGE_TOOL_CALLS = "message.tool_calls"
|
186 |
+
"""
|
187 |
+
The tool calls generated by the model, such as function calls.
|
188 |
+
"""
|
189 |
+
MESSAGE_FUNCTION_CALL_NAME = "message.function_call_name"
|
190 |
+
"""
|
191 |
+
The function name that is a part of the message list.
|
192 |
+
This is populated for role 'function' or 'agent' as a mechanism to identify
|
193 |
+
the function that was called during the execution of a tool.
|
194 |
+
"""
|
195 |
+
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = "message.function_call_arguments_json"
|
196 |
+
"""
|
197 |
+
The JSON string representing the arguments passed to the function
|
198 |
+
during a function call.
|
199 |
+
"""
|
200 |
+
MESSAGE_TOOL_CALL_ID = "message.tool_call_id"
|
201 |
+
"""
|
202 |
+
The id of the tool call.
|
203 |
+
"""
|
204 |
+
|
205 |
+
|
206 |
+
class MessageContentAttributes:
|
207 |
+
"""
|
208 |
+
Attributes for the contents of user messages sent to an LLM.
|
209 |
+
"""
|
210 |
+
|
211 |
+
MESSAGE_CONTENT_TYPE = "message_content.type"
|
212 |
+
"""
|
213 |
+
The type of the content, such as "text" or "image".
|
214 |
+
"""
|
215 |
+
MESSAGE_CONTENT_TEXT = "message_content.text"
|
216 |
+
"""
|
217 |
+
The text content of the message, if the type is "text".
|
218 |
+
"""
|
219 |
+
MESSAGE_CONTENT_IMAGE = "message_content.image"
|
220 |
+
"""
|
221 |
+
The image content of the message, if the type is "image".
|
222 |
+
An image can be made available to the model by passing a link to
|
223 |
+
the image or by passing the base64 encoded image directly in the
|
224 |
+
request.
|
225 |
+
"""
|
226 |
+
|
227 |
+
|
228 |
+
class ImageAttributes:
|
229 |
+
"""
|
230 |
+
Attributes for images
|
231 |
+
"""
|
232 |
+
|
233 |
+
IMAGE_URL = "image.url"
|
234 |
+
"""
|
235 |
+
An http or base64 image url
|
236 |
+
"""
|
237 |
+
|
238 |
+
|
239 |
+
class AudioAttributes:
|
240 |
+
"""
|
241 |
+
Attributes for audio
|
242 |
+
"""
|
243 |
+
|
244 |
+
AUDIO_URL = "audio.url"
|
245 |
+
"""
|
246 |
+
The url to an audio file
|
247 |
+
"""
|
248 |
+
AUDIO_MIME_TYPE = "audio.mime_type"
|
249 |
+
"""
|
250 |
+
The mime type of the audio file
|
251 |
+
"""
|
252 |
+
AUDIO_TRANSCRIPT = "audio.transcript"
|
253 |
+
"""
|
254 |
+
The transcript of the audio file
|
255 |
+
"""
|
256 |
+
|
257 |
+
|
258 |
+
class DocumentAttributes:
|
259 |
+
"""
|
260 |
+
Attributes for a document.
|
261 |
+
"""
|
262 |
+
|
263 |
+
DOCUMENT_ID = "document.id"
|
264 |
+
"""
|
265 |
+
The id of the document.
|
266 |
+
"""
|
267 |
+
DOCUMENT_SCORE = "document.score"
|
268 |
+
"""
|
269 |
+
The score of the document
|
270 |
+
"""
|
271 |
+
DOCUMENT_CONTENT = "document.content"
|
272 |
+
"""
|
273 |
+
The content of the document.
|
274 |
+
"""
|
275 |
+
DOCUMENT_METADATA = "document.metadata"
|
276 |
+
"""
|
277 |
+
The metadata of the document represented as a dictionary
|
278 |
+
JSON string, e.g. `"{ 'title': 'foo' }"`
|
279 |
+
"""
|
280 |
+
|
281 |
+
|
282 |
+
class RerankerAttributes:
|
283 |
+
"""
|
284 |
+
Attributes for a reranker
|
285 |
+
"""
|
286 |
+
|
287 |
+
RERANKER_INPUT_DOCUMENTS = "reranker.input_documents"
|
288 |
+
"""
|
289 |
+
List of documents as input to the reranker
|
290 |
+
"""
|
291 |
+
RERANKER_OUTPUT_DOCUMENTS = "reranker.output_documents"
|
292 |
+
"""
|
293 |
+
List of documents as output from the reranker
|
294 |
+
"""
|
295 |
+
RERANKER_QUERY = "reranker.query"
|
296 |
+
"""
|
297 |
+
Query string for the reranker
|
298 |
+
"""
|
299 |
+
RERANKER_MODEL_NAME = "reranker.model_name"
|
300 |
+
"""
|
301 |
+
Model name of the reranker
|
302 |
+
"""
|
303 |
+
RERANKER_TOP_K = "reranker.top_k"
|
304 |
+
"""
|
305 |
+
Top K parameter of the reranker
|
306 |
+
"""
|
307 |
+
|
308 |
+
|
309 |
+
class EmbeddingAttributes:
|
310 |
+
"""
|
311 |
+
Attributes for an embedding
|
312 |
+
"""
|
313 |
+
|
314 |
+
EMBEDDING_TEXT = "embedding.text"
|
315 |
+
"""
|
316 |
+
The text represented by the embedding.
|
317 |
+
"""
|
318 |
+
EMBEDDING_VECTOR = "embedding.vector"
|
319 |
+
"""
|
320 |
+
The embedding vector.
|
321 |
+
"""
|
322 |
+
|
323 |
+
|
324 |
+
class ToolCallAttributes:
|
325 |
+
"""
|
326 |
+
Attributes for a tool call
|
327 |
+
"""
|
328 |
+
|
329 |
+
TOOL_CALL_ID = "tool_call.id"
|
330 |
+
"""
|
331 |
+
The id of the tool call.
|
332 |
+
"""
|
333 |
+
TOOL_CALL_FUNCTION_NAME = "tool_call.function.name"
|
334 |
+
"""
|
335 |
+
The name of function that is being called during a tool call.
|
336 |
+
"""
|
337 |
+
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = "tool_call.function.arguments"
|
338 |
+
"""
|
339 |
+
The JSON string representing the arguments passed to the function
|
340 |
+
during a tool call.
|
341 |
+
"""
|
342 |
+
|
343 |
+
|
344 |
+
class ToolAttributes:
|
345 |
+
"""
|
346 |
+
Attributes for a tools
|
347 |
+
"""
|
348 |
+
|
349 |
+
TOOL_JSON_SCHEMA = "tool.json_schema"
|
350 |
+
"""
|
351 |
+
The json schema of a tool input, It is RECOMMENDED that this be in the
|
352 |
+
OpenAI tool calling format: https://platform.openai.com/docs/assistants/tools
|
353 |
+
"""
|
354 |
+
|
355 |
+
|
356 |
+
class OpenInferenceSpanKindValues(Enum):
|
357 |
+
TOOL = "TOOL"
|
358 |
+
CHAIN = "CHAIN"
|
359 |
+
LLM = "LLM"
|
360 |
+
RETRIEVER = "RETRIEVER"
|
361 |
+
EMBEDDING = "EMBEDDING"
|
362 |
+
AGENT = "AGENT"
|
363 |
+
RERANKER = "RERANKER"
|
364 |
+
UNKNOWN = "UNKNOWN"
|
365 |
+
GUARDRAIL = "GUARDRAIL"
|
366 |
+
EVALUATOR = "EVALUATOR"
|
367 |
+
|
368 |
+
|
369 |
+
class OpenInferenceMimeTypeValues(Enum):
|
370 |
+
TEXT = "text/plain"
|
371 |
+
JSON = "application/json"
|
372 |
+
|
373 |
+
|
374 |
+
class OpenInferenceLLMSystemValues(Enum):
|
375 |
+
OPENAI = "openai"
|
376 |
+
ANTHROPIC = "anthropic"
|
377 |
+
COHERE = "cohere"
|
378 |
+
MISTRALAI = "mistralai"
|
379 |
+
VERTEXAI = "vertexai"
|
380 |
+
|
381 |
+
|
382 |
+
class OpenInferenceLLMProviderValues(Enum):
|
383 |
+
OPENAI = "openai"
|
384 |
+
ANTHROPIC = "anthropic"
|
385 |
+
COHERE = "cohere"
|
386 |
+
MISTRALAI = "mistralai"
|
387 |
+
GOOGLE = "google"
|
388 |
+
AZURE = "azure"
|
389 |
+
AWS = "aws"
|
litellm/integrations/additional_logging_utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Base class for Additional Logging Utils for CustomLoggers
|
3 |
+
|
4 |
+
- Health Check for the logging util
|
5 |
+
- Get Request / Response Payload for the logging util
|
6 |
+
"""
|
7 |
+
|
8 |
+
from abc import ABC, abstractmethod
|
9 |
+
from datetime import datetime
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
|
13 |
+
|
14 |
+
|
15 |
+
class AdditionalLoggingUtils(ABC):
|
16 |
+
def __init__(self):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
async def async_health_check(self) -> IntegrationHealthCheckStatus:
|
21 |
+
"""
|
22 |
+
Check if the service is healthy
|
23 |
+
"""
|
24 |
+
pass
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
async def get_request_response_payload(
|
28 |
+
self,
|
29 |
+
request_id: str,
|
30 |
+
start_time_utc: Optional[datetime],
|
31 |
+
end_time_utc: Optional[datetime],
|
32 |
+
) -> Optional[dict]:
|
33 |
+
"""
|
34 |
+
Get the request and response payload for a given `request_id`
|
35 |
+
"""
|
36 |
+
return None
|
litellm/integrations/agentops/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .agentops import AgentOps
|
2 |
+
|
3 |
+
__all__ = ["AgentOps"]
|