kwilliamson commited on
Commit
e65d65f
·
1 Parent(s): f744cd1

Initial commit

Browse files
Files changed (12) hide show
  1. .gitignore +10 -0
  2. .idea/.gitignore +8 -0
  3. .python-version +1 -0
  4. README.md +0 -14
  5. app.py +9 -0
  6. artifacts/img.png +0 -0
  7. models.py +80 -0
  8. pyproject.toml +7 -0
  9. requirements.txt +221 -0
  10. space.yaml +2 -0
  11. ui.py +57 -0
  12. uv.lock +7 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md CHANGED
@@ -1,14 +0,0 @@
1
- ---
2
- title: ColorMasking
3
- emoji: 🌖
4
- colorFrom: purple
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.14.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: A very simple demonstration of color masking & diffusion
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from models import LineartGenerator
2
+ from ui import SketchToImageApp
3
+
4
+
5
+ if __name__ == "__main__":
6
+ lineart_generator = LineartGenerator()
7
+ app = SketchToImageApp(lineart_generator).create_interface()
8
+ app.launch()
9
+
artifacts/img.png ADDED
models.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ from controlnet_aux import LineartDetector
5
+ from diffusers import (
6
+ ControlNetModel,
7
+ StableDiffusionControlNetPipeline,
8
+ UniPCMultistepScheduler
9
+ )
10
+
11
+
12
+ class LineartGenerator:
13
+
14
+ def __init__(self, device: str = None, seed: int = 0):
15
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.dtype = torch.float16 if self.device == "cuda" else torch.float32
17
+ self.seed = seed
18
+ self._initialize_models()
19
+
20
+ def _initialize_models(self):
21
+ self.lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
22
+ checkpoint = "ControlNet-1-1-preview/control_v11p_sd15_lineart"
23
+ self.controlnet = ControlNetModel.from_pretrained(checkpoint)
24
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
25
+ "runwayml/stable-diffusion-v1-5",
26
+ controlnet=self.controlnet
27
+ )
28
+ self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
29
+ self.pipe.to(self.device)
30
+
31
+ @staticmethod
32
+ def load_lineart_image(uploaded_file, size: tuple = (512, 512)) -> np.ndarray:
33
+ if not uploaded_file:
34
+ return None
35
+
36
+ with open(uploaded_file.name, "rb") as file_obj:
37
+ image = Image.open(file_obj).convert("L")
38
+ image = image.resize(size)
39
+ return np.array(image, dtype=np.uint8)
40
+
41
+ @staticmethod
42
+ def merge_lineart_and_brush(brush_canvas, uploaded_file, size: tuple = (512, 512)) -> Image.Image:
43
+ if brush_canvas is None or uploaded_file is None:
44
+ return None
45
+
46
+ # Reload and process the original lineart image
47
+ with open(uploaded_file.name, "rb") as file_obj:
48
+ lineart_image = Image.open(file_obj).convert("L")
49
+ lineart_image = lineart_image.resize(size)
50
+
51
+ lineart_rgba = lineart_image.convert("RGBA")
52
+ processed_pixels = []
53
+ for pixel in lineart_rgba.getdata():
54
+ if pixel[0] > 240:
55
+ processed_pixels.append((255, 255, 255, 0))
56
+ else:
57
+ processed_pixels.append((0, 0, 0, 255))
58
+ lineart_rgba.putdata(processed_pixels)
59
+
60
+ brush_layer = Image.fromarray(brush_canvas["composite"]).convert("RGBA")
61
+
62
+ combined = Image.alpha_composite(brush_layer, lineart_rgba)
63
+ return combined
64
+
65
+ def generate_image(self, annotated_lineart: Image.Image,
66
+ prompt: str = "",
67
+ num_inference_steps: int = 30) -> Image.Image:
68
+ if annotated_lineart is None:
69
+ raise ValueError("No annotated lineart provided!")
70
+
71
+ annotated_lineart = annotated_lineart.resize((512, 512))
72
+ refined_lineart = self.lineart_detector(annotated_lineart).convert("RGBA")
73
+ generator = torch.manual_seed(self.seed)
74
+
75
+ return self.pipe(
76
+ prompt=prompt,
77
+ image=refined_lineart,
78
+ num_inference_steps=num_inference_steps,
79
+ generator=generator
80
+ ).images[0]
pyproject.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "color-mapping"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = []
requirements.txt ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aif360==0.6.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.7.0
6
+ anyio==4.3.0
7
+ appnope==0.1.4
8
+ argon2-cffi==23.1.0
9
+ argon2-cffi-bindings==21.2.0
10
+ arrow==1.3.0
11
+ association-metrics==0.0.1
12
+ asttokens==2.4.1
13
+ async-lru==2.0.4
14
+ attrs==23.2.0
15
+ Babel==2.15.0
16
+ beautifulsoup4==4.12.3
17
+ bleach==6.1.0
18
+ blinker==1.8.2
19
+ boto3==1.34.135
20
+ botocore==1.34.135
21
+ cachetools==5.3.3
22
+ certifi==2024.2.2
23
+ cffi==1.16.0
24
+ charset-normalizer==3.3.2
25
+ click==8.1.7
26
+ comm==0.2.2
27
+ contourpy==1.2.1
28
+ cssutils==2.11.0
29
+ cycler==0.12.1
30
+ dataclasses-json==0.6.6
31
+ dataframe-image==0.2.3
32
+ debugpy==1.8.1
33
+ decorator==5.1.1
34
+ defusedxml==0.7.1
35
+ diskcache==5.6.3
36
+ dnspython==2.6.1
37
+ email_validator==2.1.1
38
+ executing==2.0.1
39
+ faiss-cpu==1.8.0
40
+ fastapi==0.111.0
41
+ fastapi-cli==0.0.4
42
+ fastjsonschema==2.19.1
43
+ filelock==3.14.0
44
+ fonttools==4.52.4
45
+ fqdn==1.5.1
46
+ frozenlist==1.4.1
47
+ fsspec==2024.5.0
48
+ gensim==4.3.2
49
+ gitdb==4.0.11
50
+ GitPython==3.1.43
51
+ greenlet==3.0.3
52
+ h11==0.14.0
53
+ html2image==2.0.4.3
54
+ httpcore==1.0.5
55
+ httptools==0.6.1
56
+ httpx==0.27.0
57
+ huggingface-hub==0.23.2
58
+ idna==3.7
59
+ imbalanced-learn==0.12.3
60
+ imblearn==0.0
61
+ iniconfig==2.0.0
62
+ ipykernel==6.29.4
63
+ ipython==8.24.0
64
+ ipywidgets==8.1.2
65
+ isoduration==20.11.0
66
+ jedi==0.19.1
67
+ Jinja2==3.1.4
68
+ jmespath==1.0.1
69
+ joblib==1.4.2
70
+ json5==0.9.25
71
+ jsonpatch==1.33
72
+ jsonpointer==2.4
73
+ jsonschema==4.22.0
74
+ jsonschema-specifications==2023.12.1
75
+ jupyter==1.0.0
76
+ jupyter-console==6.6.3
77
+ jupyter-events==0.10.0
78
+ jupyter-lsp==2.2.5
79
+ jupyter_client==8.6.1
80
+ jupyter_core==5.7.2
81
+ jupyter_server==2.14.0
82
+ jupyter_server_terminals==0.5.3
83
+ jupyterlab==4.1.8
84
+ jupyterlab_pygments==0.3.0
85
+ jupyterlab_server==2.27.1
86
+ jupyterlab_widgets==3.0.10
87
+ kiwisolver==1.4.5
88
+ langchain==0.2.1
89
+ langchain-community==0.2.1
90
+ langchain-core==0.2.3
91
+ langchain-text-splitters==0.2.0
92
+ langchainhub==0.1.17
93
+ langsmith==0.1.67
94
+ Levenshtein==0.25.1
95
+ llama_cpp_python==0.2.76
96
+ lxml==5.2.2
97
+ markdown-it-py==3.0.0
98
+ MarkupSafe==2.1.5
99
+ marshmallow==3.21.2
100
+ matplotlib==3.9.0
101
+ matplotlib-inline==0.1.7
102
+ mdurl==0.1.2
103
+ mistune==3.0.2
104
+ mpmath==1.3.0
105
+ multidict==6.0.5
106
+ mypy-extensions==1.0.0
107
+ nbclient==0.10.0
108
+ nbconvert==7.16.4
109
+ nbformat==5.10.4
110
+ nest-asyncio==1.6.0
111
+ networkx==3.3
112
+ notebook==7.1.3
113
+ notebook_shim==0.2.4
114
+ numpy==1.26.4
115
+ orjson==3.10.3
116
+ overrides==7.7.0
117
+ packaging==23.2
118
+ pandas==2.2.2
119
+ pandocfilters==1.5.1
120
+ parso==0.8.4
121
+ patsy==0.5.6
122
+ pexpect==4.9.0
123
+ pika==1.3.2
124
+ pillow==10.3.0
125
+ platformdirs==4.2.1
126
+ plotly==5.22.0
127
+ pluggy==1.5.0
128
+ polars==0.20.31
129
+ prometheus_client==0.20.0
130
+ prompt-toolkit==3.0.43
131
+ protobuf==4.25.3
132
+ psutil==5.9.8
133
+ psycopg2==2.9.10
134
+ ptyprocess==0.7.0
135
+ pure-eval==0.2.2
136
+ pyaml==24.4.0
137
+ pyarrow==16.1.0
138
+ pycparser==2.22
139
+ pydantic==2.10.3
140
+ pydantic_core==2.27.1
141
+ pydash==8.0.1
142
+ pydeck==0.9.1
143
+ Pygments==2.18.0
144
+ pykalman==0.9.7
145
+ pyparsing==3.1.2
146
+ pytest==8.2.1
147
+ pytest-mock==3.14.0
148
+ python-dateutil==2.9.0.post0
149
+ python-dotenv==1.0.1
150
+ python-json-logger==2.0.7
151
+ python-Levenshtein==0.25.1
152
+ python-multipart==0.0.9
153
+ pytz==2024.1
154
+ PyWavelets==1.6.0
155
+ PyYAML==6.0.1
156
+ pyzmq==26.0.3
157
+ qtconsole==5.5.2
158
+ QtPy==2.4.1
159
+ rapidfuzz==3.9.0
160
+ redis==5.2.1
161
+ referencing==0.35.1
162
+ regex==2024.5.15
163
+ requests==2.31.0
164
+ rfc3339-validator==0.1.4
165
+ rfc3986-validator==0.1.1
166
+ rich==13.7.1
167
+ rpds-py==0.18.1
168
+ s3transfer==0.10.2
169
+ safetensors==0.4.3
170
+ scikit-learn==1.4.2
171
+ scikit-optimize==0.10.2
172
+ scipy==1.13.0
173
+ seaborn==0.13.2
174
+ Send2Trash==1.8.3
175
+ sentence-transformers==3.0.0
176
+ shellingham==1.5.4
177
+ simpy==4.1.1
178
+ six==1.16.0
179
+ smart-open==7.0.4
180
+ smmap==5.0.1
181
+ sniffio==1.3.1
182
+ soupsieve==2.5
183
+ SQLAlchemy==2.0.36
184
+ stack-data==0.6.3
185
+ starlette==0.37.2
186
+ statsmodels==0.14.2
187
+ streamlit==1.35.0
188
+ sympy==1.12.1
189
+ tenacity==8.3.0
190
+ terminado==0.18.1
191
+ threadpoolctl==3.5.0
192
+ tinycss2==1.3.0
193
+ tokenizers==0.19.1
194
+ toml==0.10.2
195
+ toolz==0.12.1
196
+ torch==2.2.2
197
+ tornado==6.4
198
+ tqdm==4.66.4
199
+ traitlets==5.14.3
200
+ transformers==4.41.2
201
+ typer==0.12.3
202
+ types-python-dateutil==2.9.0.20240316
203
+ types-requests==2.32.0.20240602
204
+ typing-inspect==0.9.0
205
+ typing_extensions==4.12.2
206
+ tzdata==2024.1
207
+ ujson==5.10.0
208
+ uri-template==1.3.0
209
+ urllib3==2.2.1
210
+ uvicorn==0.30.0
211
+ uvloop==0.19.0
212
+ watchfiles==0.22.0
213
+ wcwidth==0.2.13
214
+ webcolors==1.13
215
+ webencodings==0.5.1
216
+ websocket-client==1.8.0
217
+ websockets==12.0
218
+ widgetsnbextension==4.0.10
219
+ wrapt==1.16.0
220
+ yarl==1.9.4
221
+ controlnet_aux
space.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ sdk: gradio
2
+ accelerator: gpu
ui.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from models import LineartGenerator
4
+
5
+
6
+ class SketchToImageApp:
7
+ """
8
+ An application that combines a user’s annotated sketch with original lineart,
9
+ and then generates an image using a ControlNet-based pipeline.
10
+ """
11
+
12
+ def __init__(self, lineart_generator: LineartGenerator):
13
+ self.lineart_generator = lineart_generator
14
+
15
+ def generate_image(self, brush_canvas: dict, uploaded_file) -> Image.Image:
16
+ merged_lineart = self.lineart_generator.merge_lineart_and_brush(brush_canvas, uploaded_file)
17
+ return self.lineart_generator.generate_image(merged_lineart, num_inference_steps=30)
18
+
19
+ def create_interface(self) -> gr.Blocks:
20
+ with gr.Blocks() as app:
21
+ gr.Markdown(
22
+ "# Lineart & Color Mask With Controlnet\n"
23
+ "Brush strokes will be applied behind the processed lineart so that the "
24
+ "black lines always remain visible."
25
+ )
26
+
27
+ lineart_file_input = gr.File(
28
+ label="Upload Lineart Sketch",
29
+ file_types=["image"],
30
+ file_count="single"
31
+ )
32
+
33
+ with gr.Row():
34
+ brush_canvas_input = gr.Sketchpad(
35
+ label="Annotate Your Lineart",
36
+ type="numpy",
37
+ brush=gr.Brush(),
38
+ width=512,
39
+ height=512
40
+ )
41
+
42
+ lineart_file_input.change(
43
+ fn=self.lineart_generator.load_lineart_image,
44
+ inputs=lineart_file_input,
45
+ outputs=brush_canvas_input
46
+ )
47
+
48
+ generate_button = gr.Button("Generate")
49
+ output_image = gr.Image(label="Generated Image")
50
+
51
+ generate_button.click(
52
+ fn=self.generate_image,
53
+ inputs=[brush_canvas_input, lineart_file_input],
54
+ outputs=output_image
55
+ )
56
+
57
+ return app
uv.lock ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ version = 1
2
+ requires-python = ">=3.11"
3
+
4
+ [[package]]
5
+ name = "color-mapping"
6
+ version = "0.1.0"
7
+ source = { virtual = "." }