lzyhha commited on
Commit
40fb840
·
1 Parent(s): 6dd0ec6
Files changed (2) hide show
  1. requirements.txt +5 -7
  2. visualcloze.py +17 -16
requirements.txt CHANGED
@@ -1,19 +1,17 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu124
2
  torch==2.1.0
 
3
  torchvision==0.16.0
4
- numpy<2
5
  diffusers==0.32.1
6
- accelerate==1.1.1
7
- transformers==4.46.2
8
  huggingface-hub==0.25.0
9
  tensorboard
10
  gradio
11
- torchdiffeq
12
  click
13
- torchvision
14
  opencv-python
15
  scikit-image
16
- numba
17
  scipy
18
  tqdm
19
  einops
 
 
1
  torch==2.1.0
2
+ torchdiffeq==0.2.5
3
  torchvision==0.16.0
4
+ numpy==1.26.3
5
  diffusers==0.32.1
6
+ accelerate==1.2.1
7
+ transformers==4.47.1
8
  huggingface-hub==0.25.0
9
  tensorboard
10
  gradio
 
11
  click
 
12
  opencv-python
13
  scikit-image
14
+ numba==0.60.0
15
  scipy
16
  tqdm
17
  einops
visualcloze.py CHANGED
@@ -12,6 +12,7 @@ from flux.util import load_clip, load_t5, load_flow_model
12
  from transport import Sampler, create_transport
13
  from imgproc import to_rgb_if_rgba
14
 
 
15
  def center_crop(image, target_size):
16
  width, height = image.size
17
  new_width, new_height = target_size
@@ -90,26 +91,26 @@ class VisualClozeModel:
90
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
91
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
92
 
93
- # # Initialize model
94
- # print("Initializing model...")
95
- # self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
96
 
97
- # # Initialize VAE
98
- # print("Initializing VAE...")
99
- # self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
100
- # self.ae.requires_grad_(False)
101
 
102
- # # Initialize text encoders
103
- # print("Initializing text encoders...")
104
- # self.t5 = load_t5(self.device, max_length=self.max_length)
105
- # self.clip = load_clip(self.device)
106
 
107
- # self.model.eval().to(self.device, dtype=self.dtype)
108
 
109
- # # Load model weights
110
- # ckpt = torch.load(model_path)
111
- # self.model.load_state_dict(ckpt, strict=False)
112
- # del ckpt
113
 
114
  # Initialize sampler
115
  transport = create_transport(
 
12
  from transport import Sampler, create_transport
13
  from imgproc import to_rgb_if_rgba
14
 
15
+
16
  def center_crop(image, target_size):
17
  width, height = image.size
18
  new_width, new_height = target_size
 
91
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  self.dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.precision]
93
 
94
+ # Initialize model
95
+ print("Initializing model...")
96
+ self.model = load_flow_model(model_name, device=self.device, lora_rank=self.lora_rank)
97
 
98
+ # Initialize VAE
99
+ print("Initializing VAE...")
100
+ self.ae = AutoencoderKL.from_pretrained(f"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=self.dtype).to(self.device)
101
+ self.ae.requires_grad_(False)
102
 
103
+ # Initialize text encoders
104
+ print("Initializing text encoders...")
105
+ self.t5 = load_t5(self.device, max_length=self.max_length)
106
+ self.clip = load_clip(self.device)
107
 
108
+ self.model.eval().to(self.device, dtype=self.dtype)
109
 
110
+ # Load model weights
111
+ ckpt = torch.load(model_path)
112
+ self.model.load_state_dict(ckpt, strict=False)
113
+ del ckpt
114
 
115
  # Initialize sampler
116
  transport = create_transport(