Spaces:
Running
on
Zero
Running
on
Zero
lzyhha
commited on
Commit
·
40fb840
1
Parent(s):
6dd0ec6
package
Browse files- requirements.txt +5 -7
- visualcloze.py +17 -16
requirements.txt
CHANGED
@@ -1,19 +1,17 @@
|
|
1 |
-
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
torch==2.1.0
|
|
|
3 |
torchvision==0.16.0
|
4 |
-
numpy
|
5 |
diffusers==0.32.1
|
6 |
-
accelerate==1.
|
7 |
-
transformers==4.
|
8 |
huggingface-hub==0.25.0
|
9 |
tensorboard
|
10 |
gradio
|
11 |
-
torchdiffeq
|
12 |
click
|
13 |
-
torchvision
|
14 |
opencv-python
|
15 |
scikit-image
|
16 |
-
numba
|
17 |
scipy
|
18 |
tqdm
|
19 |
einops
|
|
|
|
|
1 |
torch==2.1.0
|
2 |
+
torchdiffeq==0.2.5
|
3 |
torchvision==0.16.0
|
4 |
+
numpy==1.26.3
|
5 |
diffusers==0.32.1
|
6 |
+
accelerate==1.2.1
|
7 |
+
transformers==4.47.1
|
8 |
huggingface-hub==0.25.0
|
9 |
tensorboard
|
10 |
gradio
|
|
|
11 |
click
|
|
|
12 |
opencv-python
|
13 |
scikit-image
|
14 |
+
numba==0.60.0
|
15 |
scipy
|
16 |
tqdm
|
17 |
einops
|
visualcloze.py
CHANGED
@@ -12,6 +12,7 @@ from flux.util import load_clip, load_t5, load_flow_model
|
|
12 |
from transport import Sampler, create_transport
|
13 |
from imgproc import to_rgb_if_rgba
|
14 |
|
|
|
15 |
def center_crop(image, target_size):
|
16 |
width, height = image.size
|
17 |
new_width, new_height = target_size
|
@@ -90,26 +91,26 @@ class VisualClozeModel:
|
|
90 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
91 |
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
|
92 |
|
93 |
-
#
|
94 |
-
|
95 |
-
|
96 |
|
97 |
-
#
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
|
102 |
-
#
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
|
107 |
-
|
108 |
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
|
114 |
# Initialize sampler
|
115 |
transport = create_transport(
|
|
|
12 |
from transport import Sampler, create_transport
|
13 |
from imgproc import to_rgb_if_rgba
|
14 |
|
15 |
+
|
16 |
def center_crop(image, target_size):
|
17 |
width, height = image.size
|
18 |
new_width, new_height = target_size
|
|
|
91 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
92 |
self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
|
93 |
|
94 |
+
# Initialize model
|
95 |
+
print("Initializing model...")
|
96 |
+
self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
|
97 |
|
98 |
+
# Initialize VAE
|
99 |
+
print("Initializing VAE...")
|
100 |
+
self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
|
101 |
+
self.ae.requires_grad_(False)
|
102 |
|
103 |
+
# Initialize text encoders
|
104 |
+
print("Initializing text encoders...")
|
105 |
+
self.t5 = load_t5(self.device, max_length=self.max_length)
|
106 |
+
self.clip = load_clip(self.device)
|
107 |
|
108 |
+
self.model.eval().to(self.device, dtype=self.dtype)
|
109 |
|
110 |
+
# Load model weights
|
111 |
+
ckpt = torch.load(model_path)
|
112 |
+
self.model.load_state_dict(ckpt, strict=False)
|
113 |
+
del ckpt
|
114 |
|
115 |
# Initialize sampler
|
116 |
transport = create_transport(
|