Bagel-7B-Demo / app.py
KingNish's picture
Update app.py
09d9d95 verified
raw
history blame
21.5 kB
import spaces
import gradio as gr
import numpy as np
import os
import torch
import random
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
from PIL import Image
from data.data_utils import add_special_tokens, pil_img2rgb
from data.transforms import ImageTransform
from inferencer import InterleaveInferencer
from modeling.autoencoder import load_ae
from modeling.bagel.qwen2_navit import NaiveCache
from modeling.bagel import (
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
SiglipVisionConfig, SiglipVisionModel
)
from modeling.qwen2 import Qwen2Tokenizer
from huggingface_hub import snapshot_download
save_dir = "./model"
repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
cache_dir = save_dir + "/cache"
snapshot_download(cache_dir=cache_dir,
local_dir=save_dir,
repo_id=repo_id,
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
)
# Model Initialization
model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
vit_config.rope = False
vit_config.num_hidden_layers -= 1
vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
config = BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
vit_max_num_patch_per_side=70,
connector_act='gelu_pytorch_tanh',
latent_patch_size=2,
max_latent_size=64,
)
with init_empty_weights():
language_model = Qwen2ForCausalLM(llm_config)
vit_model = SiglipVisionModel(vit_config)
model = Bagel(language_model, vit_model, config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
vae_transform = ImageTransform(1024, 512, 16)
vit_transform = ImageTransform(980, 224, 14)
# Model Loading and Multi GPU Infernece Preparing
device_map = infer_auto_device_map(
model,
max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
)
same_device_modules = [
'language_model.model.embed_tokens',
'time_embedder',
'latent_pos_embed',
'vae2llm',
'llm2vae',
'connector',
'vit_pos_embed'
]
if torch.cuda.device_count() == 1:
first_device = device_map.get(same_device_modules[0], "cuda:0")
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
else:
device_map[k] = "cuda:0"
else:
first_device = device_map.get(same_device_modules[0])
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
model = load_checkpoint_and_dispatch(
model,
checkpoint=os.path.join(model_path, "ema.safetensors"),
device_map=device_map,
offload_buffers=True,
dtype=torch.bfloat16,
force_hooks=True,
).eval()
# Inferencer Preparing
inferencer = InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids,
)
def set_seed(seed):
"""Set random seeds for reproducibility"""
if seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return seed
# Text to Image function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
timestep_shift=3.0, num_timesteps=50,
cfg_renorm_min=1.0, cfg_renorm_type="global",
max_think_token_n=1024, do_sample=False, text_temperature=0.3,
seed=0, image_ratio="1:1"):
# Set seed for reproducibility
set_seed(seed)
if image_ratio == "1:1":
image_shapes = (1024, 1024)
elif image_ratio == "4:3":
image_shapes = (768, 1024)
elif image_ratio == "3:4":
image_shapes = (1024, 768)
elif image_ratio == "16:9":
image_shapes = (576, 1024)
elif image_ratio == "9:16":
image_shapes = (1024, 576)
# Set hyperparameters
inference_hyper = dict(
max_think_token_n=max_think_token_n if show_thinking else 1024,
do_sample=do_sample if show_thinking else False,
temperature=text_temperature if show_thinking else 0.3,
cfg_text_scale=cfg_text_scale,
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
image_shapes=image_shapes,
)
result = {"text": "", "image": None}
# Call inferencer with or without think parameter based on user choice
for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
print(type(i))
if type(i) == str:
result["text"] += i
else:
result["image"] = i
yield result["image"], result.get("text", None)
# Image Understanding function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
do_sample=False, text_temperature=0.3, max_new_tokens=512):
if image is None:
return "Please upload an image."
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = pil_img2rgb(image)
# Set hyperparameters
inference_hyper = dict(
do_sample=do_sample,
temperature=text_temperature,
max_think_token_n=max_new_tokens, # Set max_length
)
result = {"text": "", "image": None}
# Use show_thinking parameter to control thinking process
for i in inferencer(image=image, text=prompt, think=show_thinking,
understanding_output=True, **inference_hyper):
if type(i) == str:
result["text"] += i
else:
result["image"] = i
yield result["text"]
# Image Editing function with thinking option and hyperparameters
@spaces.GPU(duration=90)
def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
cfg_img_scale=2.0, cfg_interval=0.0,
timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
cfg_renorm_type="text_channel", max_think_token_n=1024,
do_sample=False, text_temperature=0.3, seed=0):
# Set seed for reproducibility
set_seed(seed)
if image is None:
return "Please upload an image.", ""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = pil_img2rgb(image)
# Set hyperparameters
inference_hyper = dict(
max_think_token_n=max_think_token_n if show_thinking else 1024,
do_sample=do_sample if show_thinking else False,
temperature=text_temperature if show_thinking else 0.3,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
)
# Include thinking parameter based on user choice
result = {"text": "", "image": None}
for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
if type(i) == str:
result["text"] += i
else:
result["image"] = i
yield result["image"], result.get("text", "")
# Helper function to load example images
def load_example_image(image_path):
try:
return Image.open(image_path)
except Exception as e:
print(f"Error loading example image: {e}")
return None
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("""
<div>
<img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
</div>
""")
with gr.Tab("📝 Text to Image"):
txt_input = gr.Textbox(
label="Prompt",
value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
)
with gr.Row():
show_thinking = gr.Checkbox(label="Thinking", value=False)
# Add hyperparameter controls in an accordion
with gr.Accordion("Inference Hyperparameters", open=False):
# 参数一排两个布局
with gr.Group():
with gr.Row():
seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
label="Seed", info="0 for random seed, positive for reproducible results")
image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
value="1:1", label="Image Ratio",
info="The longer size is fixed to 1024")
with gr.Row():
cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
with gr.Row():
cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
value="global", label="CFG Renorm Type",
info="If the genrated image is blurry, use 'global'")
cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
with gr.Row():
num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
label="Timesteps", info="Total denoising steps")
timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
label="Timestep Shift", info="Higher values for layout, lower for details")
# Thinking parameters in a single row
thinking_params = gr.Group(visible=False)
with thinking_params:
with gr.Row():
do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
label="Max Think Tokens", info="Maximum number of tokens for thinking")
text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
label="Temperature", info="Controls randomness in text generation")
thinking_output = gr.Textbox(label="Thinking Process", visible=False)
img_output = gr.Image(label="Generated Image")
gen_btn = gr.Button("Generate")
# Dynamically show/hide thinking process box and parameters
def update_thinking_visibility(show):
return gr.update(visible=show), gr.update(visible=show)
show_thinking.change(
fn=update_thinking_visibility,
inputs=[show_thinking],
outputs=[thinking_output, thinking_params]
)
gen_btn.click(
fn=text_to_image,
inputs=[
txt_input, show_thinking, cfg_text_scale,
cfg_interval, timestep_shift,
num_timesteps, cfg_renorm_min, cfg_renorm_type,
max_think_token_n, do_sample, text_temperature, seed, image_ratio
],
outputs=[img_output, thinking_output]
)
with gr.Tab("🖌️ Image Edit"):
with gr.Row():
with gr.Column(scale=1):
edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
edit_prompt = gr.Textbox(
label="Prompt",
value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
)
with gr.Column(scale=1):
edit_image_output = gr.Image(label="Result")
edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
with gr.Row():
edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
# Add hyperparameter controls in an accordion
with gr.Accordion("Inference Hyperparameters", open=False):
with gr.Group():
with gr.Row():
edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
label="Seed", info="0 for random seed, positive for reproducible results")
edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
with gr.Row():
edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
label="CFG Image Scale", info="Controls how much the model preserves input image details")
edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
with gr.Row():
edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
value="text_channel", label="CFG Renorm Type",
info="If the genrated image is blurry, use 'global")
edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
with gr.Row():
edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
label="Timesteps", info="Total denoising steps")
edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
label="Timestep Shift", info="Higher values for layout, lower for details")
# Thinking parameters in a single row
edit_thinking_params = gr.Group(visible=False)
with edit_thinking_params:
with gr.Row():
edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
label="Max Think Tokens", info="Maximum number of tokens for thinking")
edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
label="Temperature", info="Controls randomness in text generation")
edit_btn = gr.Button("Submit")
# Dynamically show/hide thinking process box for editing
def update_edit_thinking_visibility(show):
return gr.update(visible=show), gr.update(visible=show)
edit_show_thinking.change(
fn=update_edit_thinking_visibility,
inputs=[edit_show_thinking],
outputs=[edit_thinking_output, edit_thinking_params]
)
edit_btn.click(
fn=edit_image,
inputs=[
edit_image_input, edit_prompt, edit_show_thinking,
edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
edit_timestep_shift, edit_num_timesteps,
edit_cfg_renorm_min, edit_cfg_renorm_type,
edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
],
outputs=[edit_image_output, edit_thinking_output]
)
with gr.Tab("🖼️ Image Understanding"):
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
understand_prompt = gr.Textbox(
label="Prompt",
value="Can someone explain what's funny about this meme??"
)
with gr.Column(scale=1):
txt_output = gr.Textbox(label="Result", lines=20)
with gr.Row():
understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
# Add hyperparameter controls in an accordion
with gr.Accordion("Inference Hyperparameters", open=False):
with gr.Row():
understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
img_understand_btn = gr.Button("Submit")
img_understand_btn.click(
fn=image_understanding,
inputs=[
img_input, understand_prompt, understand_show_thinking,
understand_do_sample, understand_text_temperature, understand_max_new_tokens
],
outputs=txt_output
)
gr.Markdown("""
<div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
<a href="https://bagel-ai.org/">
<img
src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
alt="BAGEL Website"
/>
</a>
<a href="https://arxiv.org/abs/2505.14683">
<img
src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
alt="BAGEL Paper on arXiv"
/>
</a>
<a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
<img
src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
alt="BAGEL on Hugging Face"
/>
</a>
<a href="https://demo.bagel-ai.org/">
<img
src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
alt="BAGEL Demo"
/>
</a>
<a href="https://discord.gg/Z836xxzy">
<img
src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
alt="BAGEL Discord"
/>
</a>
<a href="mailto:[email protected]">
<img
src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
alt="BAGEL Email"
/>
</a>
</div>
""")
demo.launch()