import gradio as gr
from functools import partial
import torch
import spaces
import DDCM_blind_face_image_restoration
import latent_DDCM_CCFG
import latent_DDCM_compression
from latent_models import load_model
import os
# import transformers
# transformers.utils.move_cache()
if os.getenv("SPACES_ZERO_GPU") == "true":
os.environ["SPACES_ZERO_GPU"] = "1"
avail_models = {'512x512': load_model('stabilityai/stable-diffusion-2-1-base', 1000, float16=True, device=torch.device("cpu"), compile=False)[0],
'768x768': load_model('stabilityai/stable-diffusion-2-1', 1000, float16=True, device=torch.device("cpu"), compile=False)[0]
}
compression_func = partial(latent_DDCM_compression.main, avail_models=avail_models)
def get_t_and_k_from_file_name(file_name):
T = int(file_name.split('T')[1].split('-')[0])
K = int(file_name.split('K')[1].split('-')[0])
model_type = file_name.split('M')[1].split('-')[0]
return T, K, model_type
def ccfg(text_input, T, K, ccfg_scale, model_type, compressed_file_in=None):
return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), model_type, compressed_file_in,
avail_models=avail_models)
# return latent_DDCM_CCFG.main(text_input, T, K, min(ccfg_scale, K), compressed_file_in)
@spaces.GPU
def decompress_given_bitstream(bitstream, method):
if bitstream is None:
gr.Error("Please provide a bit-stream file when performing decompression")
file_name = bitstream.name
T, K, model_type = get_t_and_k_from_file_name(file_name)
if method == 'compression':
return compression_func(None, T, K, model_type, bitstream)
elif method == 'blind':
return DDCM_blind_face_image_restoration.inference(None, T, K, 'NIQE', 1, True, bitstream)
elif method == 'ccfg':
return ccfg(None, T, K, -1, model_type, bitstream)
else:
raise NotImplementedError()
def validate_K(K):
if (K & (K - 1)) != 0:
gr.Warning("For efficient bit usage, K should be a power of 2.")
method_to_func = {
'compression': partial(decompress_given_bitstream, method='compression'),
'blind': partial(decompress_given_bitstream, method='blind'),
'ccfg': partial(decompress_given_bitstream, method='ccfg'),
}
title = "
Compressed Image Generation with Denoising Diffusion Codebook Models
"
intro = """
* Equal contribution
Technion - Israel Institute of Technology
Denoising Diffusion Codebook Models (DDCM) is a novel (and simple) generative approach based on any Denoising Diffusion Model (DDM), that is able to produce high-quality image samples along with their losslessly compressed bit-stream representations.
DDCM can easily be utilized for perceptual image compression, as well as for solving a variety of compressed conditional generation tasks such as text-conditional image generation and image restoration, where each generated sample is accompanied by a compressed bit-stream.
The tabs below correspond to demos of different practical applications. Open each tab to see the application's specific instructions.
Note: The demos below rely on relatively old pre-trained diffusion models such as Stable Diffusion 2.1, simply for the purpose of demonstrating the capabilities of DDCM. Feel free to implement our DDCM-based methods using newer diffusion models to further improve performance.
"""
article = r"""
If you find our work useful, please â our GitHub repository. Thanks!
ð **Citation**
```bibtex
@article{ohayon2025compressedimagegenerationdenoising,
title={Compressed Image Generation with Denoising Diffusion Codebook Models},
author={Guy Ohayon and Hila Manor and Tomer Michaeli and Michael Elad},
year={2025},
eprint={2502.01189},
journal={arXiv},
primaryClass={eess.IV},
url={https://arxiv.org/abs/2502.01189},
}
```
ð **License**
This project is released under the MIT license.
ð§ **Contact**
If you have any questions, please feel free to contact us at guyoep@gmail.com (Guy Ohayon) and hila.manor@campus.technion.ac.il (Hila Manor).
"""
custom_css = """
.tabs button {
font-size: 21px !important;
font-weight: bold !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.HTML(title)
gr.HTML(intro)
# gr.Markdown("# Compressed Image Generation with Denoising Diffusion Codebook Models")
with gr.Tab("Image Compression"):
gr.Markdown(
"- To change the bit rate, modify the number of diffusion timesteps (T) and/or the codebook sizes (K).")
gr.Markdown("- The input image will be center-cropped and resized to the specified size (512x512 or 768x768).")
# gr.Markdown("#### Notes:")
# gr.Markdown('* Since our methods relies on Stable Diffusion, we resize the input image to 512512 pixels')
with gr.Row():
with gr.Column(scale=2):
input_image = gr.Image(label="Input image", scale=2, image_mode='RGB', type='pil')
with gr.Group():
with gr.Row():
T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=2)
K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048, scale=3)
with gr.Row():
model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
compress = gr.Button("Compress image")
with gr.Column(scale=3):
decompressed_image = gr.Image(label="Decompressed image", scale=2)
compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)
compress.click(validate_K, inputs=[K]).then(compression_func, inputs=[input_image, T, K, model_type],
outputs=[decompressed_image, compressed_file_out])
gr.Examples([
["examples/compression/1.jpg", 1000, 256, '512x512'],
["examples/compression/2.jpg", 1000, 256, '512x512'],
["examples/compression/4.jpg", 1000, 256, '512x512'],
["examples/compression/7.jpg", 1000, 256, '512x512'],
["examples/compression/8.jpg", 1000, 256, '512x512'],
["examples/compression/13.jpg", 1000, 256, '512x512'],
["examples/compression/15.jpg", 1000, 256, '512x512'],
["examples/compression/17.jpg", 1000, 256, '512x512'],
["examples/compression/18.jpg", 1000, 256, '512x512'],
["examples/compression/19.jpg", 1000, 256, '512x512'],
["examples/compression/21.jpg", 1000, 256, '512x512'],
["examples/compression/22.jpg", 1000, 256, '512x512'],
["examples/compression/23.jpg", 1000, 256, '512x512'],
],
inputs=[input_image, T, K, model_type],
outputs=[decompressed_image, compressed_file_out],
fn=compression_func,
cache_examples='lazy')
gr.Markdown("### Decompress a previously generated bit-stream")
with gr.Row():
with gr.Column(scale=2):
bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
decompress = gr.Button("Decompress image")
with gr.Column(scale=3):
decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
decompress.click(method_to_func['compression'], inputs=bitstream, outputs=decompressed_image)
with gr.Tab("Real-World Face Image Restoration"):
gr.Markdown( # "Restore any degraded face image. "
"Please mark if your input face image is already aligned. "
"If not, we will try to automatically detect, crop and align the faces, and raise an error if no faces are found. Expect better results if your input image is already aligned.")
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_image = gr.Image(label="Input image", scale=2, type='filepath')
aligned = gr.Checkbox(label='Input face image is aligned')
with gr.Group():
with gr.Row():
T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000)
K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=8192, value=2048)
iqa_metric = gr.Radio(['NIQE', 'TOPIQ', 'CLIP-IQA'], label='Perceptual quality measure to optimize',
value='NIQE')
iqa_coef = gr.Number(
label="Perception-distortion tradeoff coefficient (Îŧ)",
info="Higher -> better perceptual quality",
# label="Coefficient controlling the perception-distortion tradeoff (higher means better perceptual quality)",
minimum=0, maximum=1, value=1)
restore = gr.Button("Restore and compress")
with gr.Column(scale=3):
decompressed_image = gr.Gallery(label="Restored faces gallery", type="numpy", show_label=True,
format="png")
compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0, file_count='multiple')
restore.click(validate_K, inputs=[K]).then(DDCM_blind_face_image_restoration.inference,
inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
outputs=[decompressed_image, compressed_file_out])
gr.Examples([
["examples/bfr/00000055.png", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/00000085.png", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/00000113.png", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/00000137.png", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/wider/0034.jpg", 1000, 4096, 'NIQE', 1, True],
["examples/bfr/webphoto/00042_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/lfw/Ana_Palacio_0001_00.jpg", 1000, 4096, 'TOPIQ', 0.1, True],
["examples/bfr/01.png", 1000, 4096, 'NIQE', 0.1, False],
["examples/bfr/03.jpg", 1000, 4096, 'TOPIQ', 0.1, False],
],
inputs=[input_image, T, K, iqa_metric, iqa_coef, aligned],
outputs=[decompressed_image, compressed_file_out],
fn=DDCM_blind_face_image_restoration.inference,
cache_examples='lazy')
gr.Markdown("### Decompress a previously generated bit-stream")
with gr.Row():
with gr.Column(scale=2):
bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
decompress = gr.Button("Decompress image")
with gr.Column(scale=3):
decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
decompress.click(method_to_func['blind'], inputs=bitstream, outputs=decompressed_image)
with gr.Tab("Compressed Text-to-Image Generation"):
gr.Markdown(
"This application demonstrates the capabilities of our new *compressed* classifier-free guidance method, which *does not require the input condition for decompression*."
" \n" # newline
"Each image is generated along with its compressed bit-stream representation, and the input condition is implicitly encoded in the bit-stream.")
# gr.Markdown("### Generate an image and its compressed bit-stream given an input text prompt")
# gr.Markdown("#### Notes:")
# gr.Markdown("* The size of the generated image is 512x512")
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
text_input = gr.Textbox(label="Input text prompt", scale=1, value="An image of a dog")
with gr.Row():
T = gr.Number(label="Diffusion timesteps (T)", minimum=50, maximum=1000, value=1000, scale=1)
K = gr.Number(label="Size of each codebook (K)", minimum=2, maximum=256, value=128, scale=1)
K_tilde = gr.Number(label=r"Sub-sampled codebooks' sizes (KĖ)", scale=1,
info="Behaves like a guidance scale", minimum=2, maximum=256, value=32)
model_type = gr.Radio(["768x768", "512x512"], label="Image size", value="512x512")
button = gr.Button("Generate and compress")
with gr.Column(scale=3):
decompressed_image = gr.Image(label="Generated image", scale=2)
compressed_file_out = gr.File(label="Compressed bit-stream (output)", scale=0)
button.click(validate_K, inputs=[K]).then(ccfg, inputs=[text_input, T, K, K_tilde, model_type],
outputs=[decompressed_image, compressed_file_out])
gr.Examples([
["An image of a dog", 1000, 64, 4, '512x512'],
["Rainbow over the mountains", 1000, 64, 4, '512x512'],
["A cat playing soccer", 1000, 64, 4, '512x512'],
],
inputs=[text_input, T, K, K_tilde, model_type],
outputs=[decompressed_image, compressed_file_out],
fn=ccfg,
cache_examples='lazy')
gr.Markdown("### Decompress a previously generated bit-stream")
with gr.Row():
with gr.Column(scale=2):
bitstream = gr.File(label="Compressed bit-stream (input)", scale=0)
button = gr.Button("Decompress")
with gr.Column(scale=3):
decompressed_image = gr.Image(label="Decompressed image (from uploaded bit-stream)", scale=2)
button.click(method_to_func['ccfg'], inputs=bitstream, outputs=decompressed_image)
gr.Markdown(article)
demo.queue()
demo.launch(state_session_capacity=500)