Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoConfig | |
def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str): | |
hf_token = hf_token.strip() | |
try: | |
cfg = AutoConfig.from_pretrained( | |
name, | |
trust_remote_code=True, | |
token=hf_token or None, | |
) | |
except Exception as e: | |
raise gr.Error(e) | |
use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3")) | |
if hasattr(cfg, "text_config"): | |
cfg = cfg.text_config | |
num_layers = cfg.num_hidden_layers | |
model_config = [ | |
["num_layers", num_layers], | |
["max_ctx_len", cfg.max_position_embeddings], | |
] | |
if ctx_len > cfg.max_position_embeddings: | |
gr.Warning( | |
"Requested context length is larger than the max value supported by the model" | |
) | |
# TODO: show attention type, show calculation | |
if use_mla: | |
kv_lora_rank = cfg.kv_lora_rank | |
qk_rope_head_dim = cfg.qk_rope_head_dim | |
nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim) | |
model_config.append(["kv_lora_rank", kv_lora_rank]) | |
model_config.append(["qk_rope_head_dim", qk_rope_head_dim]) | |
else: | |
num_kv_heads = cfg.num_key_value_heads | |
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) | |
nelems_per_token = num_layers * num_kv_heads * head_dim * 2 | |
model_config.append(["num_kv_heads", num_kv_heads]) | |
model_config.append(["head_dim", head_dim]) | |
if dtype == "fp16/bf16": | |
nbytes_per_elem = 2 | |
elif dtype == "fp8": | |
nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling | |
kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9 | |
return kv_cache_size, model_config | |
DESCRIPTION = ( | |
"NOTE:\n" | |
" - For gated repos, you will need to provide your HF token in the box below. You can " | |
"generate a new one at https://huggingface.co/settings/tokens. The token won't be stored " | |
"(you can check `app.py`).\n" | |
" - We don't take into account KV cache savings from sliding window attention (most " | |
"serving frameworks don't optimize for this anyway?)\n" | |
" - For Multi-head Latent Attention (MLA) used in DeepSeek-V2/V3, we calculate the " | |
"compressed KV cache as intended by MLA. This might not be supported on certain framework" | |
"+hardware combinations e.g. llama.cpp, MLX, which will fallback to Multi-head Attention " | |
"(MHA)." | |
) | |
demo = gr.Interface( | |
description=DESCRIPTION, | |
fn=calculate, | |
inputs=[ | |
gr.Textbox(label="model_id", value="Qwen/QwQ-32B"), | |
gr.Number(label="Context length", value=128_000), | |
gr.Number(label="No. of users", value=1), | |
gr.Dropdown(label="KV cache dtype", choices=["fp16/bf16", "fp8"]), | |
gr.Textbox(label="HF token"), | |
], | |
outputs=[ | |
gr.Number(label="KV cache size (GB)", precision=2), | |
gr.Dataframe( | |
label="Model config", headers=["Key", "Value"], datatype=["str", "int"] | |
), | |
], | |
) | |
demo.launch() | |