{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "# Copyright 2025 Bytedance Ltd. and/or its affiliates.\n", "# SPDX-License-Identifier: Apache-2.0" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "from copy import deepcopy\n", "from typing import (\n", " Any,\n", " AsyncIterable,\n", " Callable,\n", " Dict,\n", " Generator,\n", " List,\n", " NamedTuple,\n", " Optional,\n", " Tuple,\n", " Union,\n", ")\n", "import requests\n", "from io import BytesIO\n", "\n", "from PIL import Image\n", "import torch\n", "from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights\n", "\n", "from data.transforms import ImageTransform\n", "from data.data_utils import pil_img2rgb, add_special_tokens\n", "from modeling.bagel import (\n", " BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel\n", ")\n", "from modeling.qwen2 import Qwen2Tokenizer\n", "from modeling.bagel.qwen2_navit import NaiveCache\n", "from modeling.autoencoder import load_ae\n", "from safetensors.torch import load_file" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Initialization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_path = \"/path/to/BAGEL-7B-MoT/weights\" # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT\n", "\n", "# LLM config preparing\n", "llm_config = Qwen2Config.from_json_file(os.path.join(model_path, \"llm_config.json\"))\n", "llm_config.qk_norm = True\n", "llm_config.tie_word_embeddings = False\n", "llm_config.layer_module = \"Qwen2MoTDecoderLayer\"\n", "\n", "# ViT config preparing\n", "vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, \"vit_config.json\"))\n", "vit_config.rope = False\n", "vit_config.num_hidden_layers = vit_config.num_hidden_layers - 1\n", "\n", "# VAE loading\n", "vae_model, vae_config = load_ae(local_path=os.path.join(model_path, \"ae.safetensors\"))\n", "\n", "# Bagel config preparing\n", "config = BagelConfig(\n", " visual_gen=True,\n", " visual_und=True,\n", " llm_config=llm_config, \n", " vit_config=vit_config,\n", " vae_config=vae_config,\n", " vit_max_num_patch_per_side=70,\n", " connector_act='gelu_pytorch_tanh',\n", " latent_patch_size=2,\n", " max_latent_size=64,\n", ")\n", "\n", "with init_empty_weights():\n", " language_model = Qwen2ForCausalLM(llm_config)\n", " vit_model = SiglipVisionModel(vit_config)\n", " model = Bagel(language_model, vit_model, config)\n", " model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)\n", "\n", "# Tokenizer Preparing\n", "tokenizer = Qwen2Tokenizer.from_pretrained(model_path)\n", "tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)\n", "\n", "# Image Transform Preparing\n", "vae_transform = ImageTransform(1024, 512, 16)\n", "vit_transform = ImageTransform(980, 224, 14)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Loading and Multi GPU Infernece Preparing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "max_mem_per_gpu = \"40GiB\" # Modify it according to your GPU setting. On an A100, 80 GiB is sufficient to load on a single GPU.\n", "\n", "device_map = infer_auto_device_map(\n", " model,\n", " max_memory={i: max_mem_per_gpu for i in range(torch.cuda.device_count())},\n", " no_split_module_classes=[\"Bagel\", \"Qwen2MoTDecoderLayer\"],\n", ")\n", "print(device_map)\n", "\n", "same_device_modules = [\n", " 'language_model.model.embed_tokens',\n", " 'time_embedder',\n", " 'latent_pos_embed',\n", " 'vae2llm',\n", " 'llm2vae',\n", " 'connector',\n", " 'vit_pos_embed'\n", "]\n", "\n", "if torch.cuda.device_count() == 1:\n", " first_device = device_map.get(same_device_modules[0], \"cuda:0\")\n", " for k in same_device_modules:\n", " if k in device_map:\n", " device_map[k] = first_device\n", " else:\n", " device_map[k] = \"cuda:0\"\n", "else:\n", " first_device = device_map.get(same_device_modules[0])\n", " for k in same_device_modules:\n", " if k in device_map:\n", " device_map[k] = first_device\n", "\n", "# Thanks @onion-liu: https://github.com/ByteDance-Seed/Bagel/pull/8\n", "model = load_checkpoint_and_dispatch(\n", " model,\n", " checkpoint=os.path.join(model_path, \"ema.safetensors\"),\n", " device_map=device_map,\n", " offload_buffers=True,\n", " dtype=torch.bfloat16,\n", " force_hooks=True,\n", " offload_folder=\"/tmp/offload\"\n", ")\n", "\n", "model = model.eval()\n", "print('Model loaded')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inferencer Preparing " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "from inferencer import InterleaveInferencer\n", "\n", "inferencer = InterleaveInferencer(\n", " model=model, \n", " vae_model=vae_model, \n", " tokenizer=tokenizer, \n", " vae_transform=vae_transform, \n", " vit_transform=vit_transform, \n", " new_token_ids=new_token_ids\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "import random\n", "import numpy as np\n", "\n", "seed = 42\n", "random.seed(seed)\n", "np.random.seed(seed)\n", "torch.manual_seed(seed)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**About Inference Hyperparameters:**\n", "- **`cfg_text_scale`:** Controls how strongly the model follows the text prompt. `1.0` disables text guidance. Typical range: `4.0–8.0`.\n", "- **`cfg_image_scale`:** Controls how much the model preserves input image details. `1.0` disables image guidance. Typical range: `1.0–2.0`.\n", "- **`cfg_interval`:** Fraction of denoising steps where CFG is applied. Later steps can skip CFG to reduce computation. Typical: `[0.4, 1.0]`.\n", "- **`timestep_shift`:** Shifts the distribution of denoising steps. Higher values allocate more steps at the start (affects layout); lower values allocate more at the end (improves details).\n", "- **`num_timesteps`:** Total denoising steps. Typical: `50`.\n", "- **`cfg_renorm_min`:** Minimum value for CFG-Renorm. `1.0` disables renorm. Typical: `0`.\n", "- **`cfg_renorm_type`:** CFG-Renorm method: \n", " - `global`: Normalize over all tokens and channels (default for T2I).\n", " - `channel`: Normalize across channels for each token.\n", " - `text_channel`: Like `channel`, but only applies to text condition (good for editing, may cause blur).\n", "- **If edited images appear blurry, try `global` CFG-Renorm, decrease `cfg_renorm_min` or decrease `cfg_scale`.**\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image Generation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_hyper=dict(\n", " cfg_text_scale=4.0,\n", " cfg_img_scale=1.0,\n", " cfg_interval=[0.4, 1.0],\n", " timestep_shift=3.0,\n", " num_timesteps=50,\n", " cfg_renorm_min=1.0,\n", " cfg_renorm_type=\"global\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "prompt = \"A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.\"\n", "\n", "print(prompt)\n", "print('-' * 10)\n", "output_dict = inferencer(text=prompt, **inference_hyper)\n", "display(output_dict['image'])" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Image Generation with Think" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_hyper=dict(\n", " max_think_token_n=1000,\n", " do_sample=False,\n", " # text_temperature=0.3,\n", " cfg_text_scale=4.0,\n", " cfg_img_scale=1.0,\n", " cfg_interval=[0.4, 1.0],\n", " timestep_shift=3.0,\n", " num_timesteps=50,\n", " cfg_renorm_min=1.0,\n", " cfg_renorm_type=\"global\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "prompt = 'a car made of small cars'\n", "\n", "print(prompt)\n", "print('-' * 10)\n", "output_dict = inferencer(text=prompt, think=True, **inference_hyper)\n", "print(output_dict['text'])\n", "display(output_dict['image'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Editing" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_hyper=dict(\n", " cfg_text_scale=4.0,\n", " cfg_img_scale=2.0,\n", " cfg_interval=[0.0, 1.0],\n", " timestep_shift=3.0,\n", " num_timesteps=50,\n", " cfg_renorm_min=1.0,\n", " cfg_renorm_type=\"text_channel\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "image = Image.open('test_images/women.jpg')\n", "prompt = 'She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.'\n", "\n", "display(image)\n", "print(prompt)\n", "print('-'*10)\n", "output_dict = inferencer(image=image, text=prompt, **inference_hyper)\n", "display(output_dict['image'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Edit with Think" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_hyper=dict(\n", " max_think_token_n=1000,\n", " do_sample=False,\n", " # text_temperature=0.3,\n", " cfg_text_scale=4.0,\n", " cfg_img_scale=2.0,\n", " cfg_interval=[0.0, 1.0],\n", " timestep_shift=3.0,\n", " num_timesteps=50,\n", " cfg_renorm_min=0.0,\n", " cfg_renorm_type=\"text_channel\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "image = Image.open('test_images/octupusy.jpg')\n", "prompt = 'Could you display the sculpture that takes after this design?'\n", "\n", "display(image)\n", "print('-'*10)\n", "output_dict = inferencer(image=image, text=prompt, think=True, **inference_hyper)\n", "print(output_dict['text'])\n", "display(output_dict['image'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "inference_hyper=dict(\n", " max_think_token_n=1000,\n", " do_sample=False,\n", " # text_temperature=0.3,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "image = Image.open('test_images/meme.jpg')\n", "prompt = \"Can someone explain what’s funny about this meme??\"\n", "\n", "display(image)\n", "print(prompt)\n", "print('-'*10)\n", "output_dict = inferencer(image=image, text=prompt, understanding_output=True, **inference_hyper)\n", "print(output_dict['text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "fileId": "1bfaa82d-51b0-4c13-9e4c-295ba28bcd8a", "filePath": "/mnt/bn/seed-aws-va/chaorui/code/cdt-hf/notebooks/chat.ipynb", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 4 }