Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from PIL import Image | |
import math | |
import io | |
import base64 | |
import subprocess | |
import os | |
from concept_attention import ConceptAttentionFluxPipeline | |
IMG_SIZE = 210 | |
COLUMNS = 5 | |
def update_default_concepts(prompt): | |
default_concepts = { | |
"A dog by a tree": ["dog", "grass", "tree", "background"], | |
"A man on the beach": ["man", "dirt", "ocean", "sky"], | |
"A hot air balloon": ["balloon", "sky", "water", "tree"] | |
} | |
return gr.update(value=default_concepts.get(prompt, [])) | |
pipeline = ConceptAttentionFluxPipeline(model_name="flux-schnell")# , offload_model=True) # , device="cuda:0") # , offload_model=True) | |
def convert_pil_to_bytes(img): | |
img = img.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
return img_str | |
def encode_image(image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples): | |
try: | |
if not prompt: | |
prompt = "" | |
prompt = prompt.strip() | |
if len(concepts) == 0: | |
raise gr.Error("Please enter at least 1 concept", duration=10) | |
if len(concepts) > 9: | |
raise gr.Error("Please enter at most 9 concepts", duration=10) | |
image = image.convert("RGB") | |
pipeline_output = pipeline.encode_image( | |
image=image, | |
prompt=prompt, | |
concepts=concepts, | |
width=1024, | |
height=1024, | |
seed=seed, | |
num_samples=num_samples, | |
noise_timestep=noise_timestep, | |
num_steps=4, | |
layer_indices=list(range(layer_start_index, 19)), | |
softmax=True if len(concepts) > 1 else False | |
) | |
output_image = pipeline_output.image | |
output_space_heatmaps = pipeline_output.concept_heatmaps | |
output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] | |
output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] | |
cross_attention_heatmaps = pipeline_output.cross_attention_maps | |
cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] | |
cross_attention_maps_and_labels = [] | |
prompt_tokens = prompt.split() | |
for concept_index in range(len(concepts)): | |
concept = concepts[concept_index] | |
if concept in prompt_tokens: | |
cross_attention_maps_and_labels.append( | |
(cross_attention_heatmaps[concept_index], concept) | |
) | |
else: | |
# Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism | |
empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42)) | |
cross_attention_maps_and_labels.append( | |
(empty_image, concept) | |
) | |
return output_image, \ | |
gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ | |
gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) | |
except gr.Error as e: | |
return None, gr.update(value=[], columns=1) # , gr.update(value=[], columns=1) | |
def generate_image(prompt, concepts, seed, layer_start_index, timestep_start_index): | |
try: | |
if not prompt: | |
raise gr.Error("Please enter a prompt", duration=10) | |
if not prompt.strip(): | |
raise gr.Error("Please enter a prompt", duration=10) | |
prompt = prompt.strip() | |
if len(concepts) == 0: | |
raise gr.Error("Please enter at least 1 concept", duration=10) | |
if len(concepts) > 9: | |
raise gr.Error("Please enter at most 9 concepts", duration=10) | |
pipeline_output = pipeline.generate_image( | |
prompt=prompt, | |
concepts=concepts, | |
width=1024, | |
height=1024, | |
seed=seed, | |
timesteps=list(range(timestep_start_index, 4)), | |
num_inference_steps=4, | |
layer_indices=list(range(layer_start_index, 19)), | |
softmax=True if len(concepts) > 1 else False | |
) | |
output_image = pipeline_output.image | |
output_space_heatmaps = pipeline_output.concept_heatmaps | |
output_space_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in output_space_heatmaps] | |
output_space_maps_and_labels = [(output_space_heatmaps[concept_index], concepts[concept_index]) for concept_index in range(len(concepts))] | |
cross_attention_heatmaps = pipeline_output.cross_attention_maps | |
cross_attention_heatmaps = [heatmap.resize((IMG_SIZE, IMG_SIZE), resample=Image.NEAREST) for heatmap in cross_attention_heatmaps] | |
cross_attention_maps_and_labels = [] | |
prompt_tokens = prompt.split() | |
for concept_index in range(len(concepts)): | |
concept = concepts[concept_index] | |
if concept in prompt_tokens: | |
cross_attention_maps_and_labels.append( | |
(cross_attention_heatmaps[concept_index], concept) | |
) | |
else: | |
# Exclude this concept because it is only generated due to ConceptAttention's causal attention mechanism | |
empty_image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (39, 39, 42)) | |
cross_attention_maps_and_labels.append( | |
(empty_image, concept) | |
) | |
return output_image, \ | |
gr.update(value=output_space_maps_and_labels, columns=len(output_space_maps_and_labels)), \ | |
gr.update(value=cross_attention_maps_and_labels, columns=len(cross_attention_maps_and_labels)) | |
except gr.Error as e: | |
return None, gr.update(value=[], columns=1), gr.update(value=[], columns=1) | |
with gr.Blocks( | |
css=""" | |
.container { | |
max-width: 1300px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.application { | |
max-width: 1200px; | |
} | |
.generated-image { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
height: 100%; /* Ensures full height */ | |
} | |
.input { | |
height: 47px; | |
} | |
.input-column-label {} | |
.gallery { | |
height: 220px; | |
} | |
.run-button-column { | |
width: 100px !important; | |
} | |
.gallery-container { | |
scrollbar-width: thin; | |
scrollbar-color: grey black; | |
} | |
@media (min-width: 1280px) { | |
.svg-container { | |
min-width: 250px; | |
display: flex; | |
flex-direction: column; | |
padding-top: 340px; | |
} | |
.callout { | |
width: 250px; | |
} | |
.input-row { | |
height: 100px; | |
} | |
.input-column { | |
flex-direction: column; | |
gap: 0px; | |
height: 100%; | |
} | |
} | |
@media (max-width: 1280px) { | |
.svg-container { | |
display: none !important; | |
} | |
.callout { | |
display: none; | |
} | |
} | |
/* | |
@media (max-width: 1024px) { | |
.svg-container { | |
display: none !important; | |
display: flex; | |
flex-direction: column; | |
} | |
.callout { | |
display: none; | |
} | |
} | |
*/ | |
.header { | |
display: flex; | |
flex-direction: column; | |
} | |
#title { | |
font-size: 4.4em; | |
color: #F3B13E; | |
text-align: center; | |
margin: 5px; | |
} | |
#subtitle { | |
font-size: 3.0em; | |
color: #FAE2BA; | |
text-align: center; | |
margin: 5px; | |
} | |
#abstract { | |
text-align: center; | |
font-size: 2.0em; | |
color:rgb(219, 219, 219); | |
margin: 5px; | |
margin-top: 10px; | |
} | |
#links { | |
text-align: center; | |
font-size: 2.0em; | |
margin: 5px; | |
} | |
#links a { | |
color: #93B7E9; | |
text-decoration: none; | |
} | |
.caption-label { | |
font-size: 1.15em; | |
} | |
.gallery label { | |
font-size: 1.15em; | |
} | |
""" | |
) as demo: | |
# with gr.Column(elem_classes="container"): | |
with gr.Row(elem_classes="container", scale=8): | |
with gr.Column(elem_classes="application-content", scale=10): | |
with gr.Row(scale=3, elem_classes="header"): | |
gr.HTML(""" | |
<h1 id='title'> ConceptAttention </h1> | |
<h1 id='subtitle'> Visualize Any Concepts in Your Generated Images </h1> | |
<h1 id='abstract'> Interpret diffusion models with precise, high-quality heatmaps. </h1> | |
<h1 id='links'> <a href='https://arxiv.org/abs/2502.04320'> Paper </a> | <a href='https://github.com/helblazer811/ConceptAttention'> Code </a> </h1> | |
""") | |
with gr.Tab(label="Generate Image"): | |
with gr.Row(elem_classes="input-row", scale=2): | |
with gr.Column(scale=4, elem_classes="input-column", min_width=250): | |
gr.HTML( | |
"Write a Prompt", | |
elem_classes="input-column-label" | |
) | |
prompt = gr.Dropdown( | |
["A dog by a tree", "A man on the beach", "A hot air balloon"], | |
container=False, | |
allow_custom_value=True, | |
elem_classes="input" | |
) | |
with gr.Column(scale=7, elem_classes="input-column"): | |
gr.HTML( | |
"Select or Write Concepts", | |
elem_classes="input-column-label" | |
) | |
concepts = gr.Dropdown( | |
["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"], | |
value=["dog", "grass", "tree", "background"], | |
multiselect=True, | |
label="Concepts", | |
container=False, | |
allow_custom_value=True, | |
# scale=4, | |
elem_classes="input", | |
max_choices=5 | |
) | |
with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"): | |
gr.HTML( | |
"​", | |
elem_classes="input-column-label" | |
) | |
submit_btn = gr.Button( | |
"Run", | |
elem_classes="input" | |
) | |
with gr.Row(elem_classes="gallery-container", scale=8): | |
with gr.Column(scale=1, min_width=250): | |
generated_image = gr.Image( | |
elem_classes="generated-image", | |
show_label=False, | |
) | |
with gr.Column(scale=4): | |
concept_attention_gallery = gr.Gallery( | |
label="Concept Attention (Ours)", | |
show_label=True, | |
# columns=3, | |
rows=1, | |
object_fit="contain", | |
height="200px", | |
elem_classes="gallery", | |
elem_id="concept-attention-gallery", | |
# scale=4 | |
) | |
cross_attention_gallery = gr.Gallery( | |
label="Cross Attention", | |
show_label=True, | |
# columns=3, | |
rows=1, | |
object_fit="contain", | |
height="200px", | |
elem_classes="gallery", | |
# scale=4 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10) | |
timestep_start_index = gr.Slider(minimum=0, maximum=4, step=1, label="Timestep Start Index", value=2) | |
submit_btn.click( | |
fn=generate_image, | |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], | |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery] | |
) | |
prompt.change(update_default_concepts, inputs=[prompt], outputs=[concepts]) | |
# Automatically process the first example on launch | |
demo.load( | |
generate_image, | |
inputs=[prompt, concepts, seed, layer_start_index, timestep_start_index], | |
outputs=[generated_image, concept_attention_gallery, cross_attention_gallery] | |
) | |
with gr.Tab(label="Explain a Real Image"): | |
with gr.Row(elem_classes="input-row", scale=2): | |
with gr.Column(scale=4, elem_classes="input-column", min_width=250): | |
gr.HTML( | |
"Write a Prompt (Optional)", | |
elem_classes="input-column-label" | |
) | |
# prompt = gr.Dropdown( | |
# ["A dog by a tree", "A man on the beach", "A hot air balloon"], | |
# container=False, | |
# allow_custom_value=True, | |
# elem_classes="input" | |
# ) | |
prompt = gr.Textbox( | |
placeholder="Write a prompt (Optional)", | |
container=False, | |
elem_classes="input" | |
) | |
with gr.Column(scale=7, elem_classes="input-column"): | |
gr.HTML( | |
"Select or Write Concepts", | |
elem_classes="input-column-label" | |
) | |
concepts = gr.Dropdown( | |
["dog", "grass", "tree", "dragon", "sky", "rock", "cloud", "balloon", "water", "background"], | |
value=["dog", "grass", "tree", "background"], | |
multiselect=True, | |
label="Concepts", | |
container=False, | |
allow_custom_value=True, | |
# scale=4, | |
elem_classes="input", | |
max_choices=5 | |
) | |
with gr.Column(scale=1, min_width=100, elem_classes="input-column run-button-column"): | |
gr.HTML( | |
"​", | |
elem_classes="input-column-label" | |
) | |
submit_btn = gr.Button( | |
"Run", | |
elem_classes="input" | |
) | |
with gr.Row(elem_classes="gallery-container", scale=8, equal_height=True): | |
with gr.Column(scale=1, min_width=250): | |
input_image = gr.Image( | |
elem_classes="generated-image", | |
show_label=False, | |
interactive=True, | |
type="pil", | |
image_mode="RGB", | |
scale=1 | |
) | |
with gr.Column(scale=2): | |
concept_attention_gallery = gr.Gallery( | |
label="Concept Attention (Ours)", | |
show_label=True, | |
# columns=3, | |
rows=1, | |
object_fit="contain", | |
height="200px", | |
elem_classes="gallery", | |
elem_id="concept-attention-gallery", | |
# scale=4 | |
) | |
cross_attention_gallery = gr.Gallery( | |
label="Cross Attention", | |
show_label=True, | |
# columns=3, | |
rows=1, | |
object_fit="contain", | |
height="200px", | |
elem_classes="gallery", | |
# scale=4 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
num_samples = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Samples", value=4) | |
layer_start_index = gr.Slider(minimum=0, maximum=18, step=1, label="Layer Start Index", value=10) | |
noise_timestep = gr.Slider(minimum=0, maximum=4, step=1, label="Noise Timestep", value=2) | |
submit_btn.click( | |
fn=encode_image, | |
inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples], | |
outputs=[input_image, concept_attention_gallery, cross_attention_gallery] | |
) | |
# # Automatically process the first example on launch | |
# demo.load( | |
# encode_image, | |
# inputs=[input_image, prompt, concepts, seed, layer_start_index, noise_timestep, num_samples], | |
# outputs=[input_image, concept_attention_gallery, cross_attention_gallery] | |
# ) | |
with gr.Column(scale=2, min_width=200, elem_classes="svg-column"): | |
with gr.Row(scale=8): | |
gr.HTML("<div></div>") | |
with gr.Row(scale=4, elem_classes="svg-container"): | |
concept_attention_callout_svg = gr.HTML( | |
"<img src='/gradio_api/file=ConceptAttentionCallout.svg' class='callout'/>", | |
# container=False, | |
) | |
cross_attention_callout_svg = gr.HTML( | |
"<img src='/gradio_api/file=CrossAttentionCallout.svg' class='callout'/>", | |
# container=False, | |
) | |
with gr.Row(scale=4): | |
gr.HTML("<div></div>") | |
if __name__ == "__main__": | |
if os.path.exists("/data-nvme/zerogpu-offload"): | |
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
demo.launch( | |
allowed_paths=["."] | |
) | |
# share=True, | |
# server_name="0.0.0.0", | |
# inbrowser=True, | |
# # share=False, | |
# server_port=6754, | |
# quiet=True, | |
# max_threads=1 | |
# ) | |