MrRokot commited on
Commit
89c47c3
·
verified ·
1 Parent(s): bb7bc30

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +176 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+
5
+ from transformers import (
6
+ AutoImageProcessor,
7
+ AutoModelForImageClassification,
8
+ )
9
+
10
+ import gradio as gr
11
+ import spaces # ZERO GPU
12
+
13
+
14
+ MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
15
+ MODEL_NAME = MODEL_NAMES[0]
16
+
17
+ model = AutoModelForImageClassification.from_pretrained(
18
+ MODEL_NAME,
19
+ )
20
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
21
+ processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
22
+
23
+
24
+ def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
25
+ return (
26
+ [f"1{noun}"]
27
+ + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)]
28
+ + [f"{maximum+1}+{noun}s"]
29
+ )
30
+
31
+
32
+ PEOPLE_TAGS = (
33
+ _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"]
34
+ )
35
+ RATING_MAP = {
36
+ "general": "safe",
37
+ "sensitive": "sensitive",
38
+ "questionable": "nsfw",
39
+ "explicit": "explicit, nsfw",
40
+ }
41
+
42
+ DESCRIPTION_MD = """
43
+ # WD Tagger with 🤗 transformers
44
+ Currently supports the following model(s):
45
+ - [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf)
46
+
47
+ """.strip()
48
+
49
+
50
+ def postprocess_results(
51
+ results: dict[str, float], general_threshold: float, character_threshold: float
52
+ ):
53
+ results = {
54
+ k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
55
+ }
56
+
57
+ rating = {}
58
+ character = {}
59
+ general = {}
60
+
61
+ for k, v in results.items():
62
+ if k.startswith("rating:"):
63
+ rating[k.replace("rating:", "")] = v
64
+ continue
65
+ elif k.startswith("character:"):
66
+ character[k.replace("character:", "")] = v
67
+ continue
68
+
69
+ general[k] = v
70
+
71
+ character = {k: v for k, v in character.items() if v >= character_threshold}
72
+ general = {k: v for k, v in general.items() if v >= general_threshold}
73
+
74
+ return rating, character, general
75
+
76
+
77
+ def animagine_prompt(rating: list[str], character: list[str], general: list[str]):
78
+ people_tags: list[str] = []
79
+ other_tags: list[str] = []
80
+ rating_tag = RATING_MAP[rating[0]]
81
+
82
+ for tag in general:
83
+ if tag in PEOPLE_TAGS:
84
+ people_tags.append(tag)
85
+ else:
86
+ other_tags.append(tag)
87
+
88
+ all_tags = people_tags + character + other_tags + [rating_tag]
89
+
90
+ return ", ".join(all_tags)
91
+
92
+
93
+ @spaces.GPU(enable_queue=True)
94
+ def predict_tags(
95
+ image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8
96
+ ):
97
+ inputs = processor.preprocess(image, return_tensors="pt")
98
+
99
+ outputs = model(**inputs.to(model.device, model.dtype))
100
+ logits = torch.sigmoid(outputs.logits[0]) # take the first logits
101
+
102
+ # get probabilities
103
+ results = {
104
+ model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
105
+ }
106
+
107
+ # rating, character, general
108
+ rating, character, general = postprocess_results(
109
+ results, general_threshold, character_threshold
110
+ )
111
+
112
+ prompt = animagine_prompt(
113
+ list(rating.keys()), list(character.keys()), list(general.keys())
114
+ )
115
+
116
+ return rating, character, general, prompt
117
+
118
+
119
+ def demo():
120
+ with gr.Blocks() as ui:
121
+ gr.Markdown(DESCRIPTION_MD)
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ input_image = gr.Image(label="Input image", type="pil")
126
+
127
+ with gr.Group():
128
+ general_threshold = gr.Slider(
129
+ label="Threshold",
130
+ minimum=0.0,
131
+ maximum=1.0,
132
+ value=0.3,
133
+ step=0.01,
134
+ interactive=True,
135
+ )
136
+ character_threshold = gr.Slider(
137
+ label="Character threshold",
138
+ minimum=0.0,
139
+ maximum=1.0,
140
+ value=0.8,
141
+ step=0.01,
142
+ interactive=True,
143
+ )
144
+
145
+ _model_radio = gr.Dropdown(
146
+ choices=MODEL_NAMES,
147
+ label="Model",
148
+ value=MODEL_NAMES[0],
149
+ interactive=True,
150
+ )
151
+
152
+ start_btn = gr.Button(value="Start", variant="primary")
153
+
154
+ with gr.Column():
155
+ prompt_text = gr.Text(label="Prompt")
156
+
157
+ rating_tags_label = gr.Label(label="Rating tags")
158
+ character_tags_label = gr.Label(label="Character tags")
159
+ general_tags_label = gr.Label(label="General tags")
160
+
161
+ start_btn.click(
162
+ predict_tags,
163
+ inputs=[input_image, general_threshold, character_threshold],
164
+ outputs=[
165
+ rating_tags_label,
166
+ character_tags_label,
167
+ general_tags_label,
168
+ prompt_text,
169
+ ],
170
+ )
171
+
172
+ return ui
173
+
174
+
175
+ if __name__ == "__main__":
176
+ demo().queue().launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ accelerate
4
+ transformers==4.38.2
5
+ spaces