import gradio as gr from transformers import AutoConfig def calculate(name: str, ctx_len: int, num_users: int, dtype: str, hf_token: str): hf_token = hf_token.strip() try: cfg = AutoConfig.from_pretrained( name, trust_remote_code=True, token=hf_token or None, ) except Exception as e: raise gr.Error(e) use_mla = cfg.architectures[0].startswith(("DeepseekV2", "DeepseekV3")) if hasattr(cfg, "text_config"): cfg = cfg.text_config num_layers = cfg.num_hidden_layers model_config = [ ["num_layers", num_layers], ["max_ctx_len", cfg.max_position_embeddings], ] if ctx_len > cfg.max_position_embeddings: gr.Warning( "Requested context length is larger than the max value supported by the model" ) # TODO: show attention type, show calculation if use_mla: kv_lora_rank = cfg.kv_lora_rank qk_rope_head_dim = cfg.qk_rope_head_dim nelems_per_token = num_layers * (kv_lora_rank + qk_rope_head_dim) model_config.append(["kv_lora_rank", kv_lora_rank]) model_config.append(["qk_rope_head_dim", qk_rope_head_dim]) else: num_kv_heads = cfg.num_key_value_heads head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) nelems_per_token = num_layers * num_kv_heads * head_dim * 2 model_config.append(["num_kv_heads", num_kv_heads]) model_config.append(["head_dim", head_dim]) if dtype == "fp16/bf16": nbytes_per_elem = 2 elif dtype == "fp8": nbytes_per_elem = 1 + 2 / cfg.hidden_size # assume per-token scaling kv_cache_size = nelems_per_token * ctx_len * num_users * nbytes_per_elem / 1e9 return kv_cache_size, model_config DESCRIPTION = ( "NOTE:\n" " - For gated repos, you will need to provide your HF token in the box below. You can " "generate a new one at https://huggingface.co/settings/tokens. The token won't be stored " "(you can check `app.py`).\n" " - We don't take into account KV cache savings from sliding window attention (most " "serving frameworks don't optimize for this anyway?)\n" " - For Multi-head Latent Attention (MLA) used in DeepSeek-V2/V3, we calculate the " "compressed KV cache as intended by MLA. This might not be supported on certain framework" "+hardware combinations e.g. llama.cpp, MLX, which will fallback to Multi-head Attention " "(MHA)." ) demo = gr.Interface( description=DESCRIPTION, fn=calculate, inputs=[ gr.Textbox(label="model_id", value="Qwen/QwQ-32B"), gr.Number(label="Context length", value=128_000), gr.Number(label="No. of users", value=1), gr.Dropdown(label="KV cache dtype", choices=["fp16/bf16", "fp8"]), gr.Textbox(label="HF token"), ], outputs=[ gr.Number(label="KV cache size (GB)", precision=2), gr.Dataframe( label="Model config", headers=["Key", "Value"], datatype=["str", "int"] ), ], ) demo.launch()