File size: 1,453 Bytes
d902dc8
 
 
e313015
 
d902dc8
 
 
 
 
 
 
 
 
 
9adfffc
d902dc8
 
 
 
 
 
 
 
 
 
 
 
 
9adfffc
 
 
1fc69ed
9adfffc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
import torch
from safetensors.torch import load_file as load_safetensor
from diffusers import AutoencoderKL


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load(tokenizer_path = "tokenizer", text_encoder_path = "text_encoder"):
    """ loads the clip model and tokenizer. returns: tuple of clip_model, tokenizer"""
    safetensor_fp16 = f"./{text_encoder_path}/model.fp16.safetensors"  # or use model.safetensors
    config_path = f"./{text_encoder_path}/config.json"
    
    # Load tokenizer
    clip_tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
    
    # Load CLIPTextModelWithProjection from the config file and safetensor
    clip_model = CLIPTextModelWithProjection.from_pretrained(
        text_encoder_path, 
        config=config_path, 
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )
    
    # Load safetensor weights
    state_dict = load_safetensor(safetensor_fp16)
    clip_model.load_state_dict(state_dict)
    clip_model = clip_model.to(device)
    
    return clip_model, clip_tokenizer

def load_vae(vae_path='vae'):
    return AutoencoderKL.from_pretrained(vae_path)

# Example function for processing prompts
def encode_prompt(prompt,tokenizer,clip_model):
    inputs = tokenizer(prompt, return_tensors="pt")
    return clip_model(**inputs).last_hidden_state