gabehubner commited on
Commit
d291162
·
verified ·
1 Parent(s): 23311f1

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +2 -82
  2. config.json +43 -0
  3. diffusion_pytorch_model.safetensors +3 -0
README.md CHANGED
@@ -1,83 +1,3 @@
1
- ---
2
- tags:
3
- - pytorch
4
- - vae
5
- - diffusion
6
- - image-generation
7
- - cc3m
8
- license: mit
9
- datasets:
10
- - pixparse/cc3m-wds
11
- library_name: diffusers
12
- model-index:
13
- - name: vae-256px-8z
14
- results:
15
- - task:
16
- type: image-generation
17
- dataset:
18
- type: conceptual-captions
19
- name: Conceptual Captions
20
- metrics:
21
- - type: Frechet Inception Distance (FID)
22
- value: 9.43
23
- - type: Learned Perceptual Image Patch Similarity (LPIPS)
24
- value: 0.163
25
- - type: ID-similarity
26
- value: 0.0010186772755879851
27
- source:
28
- name: Conceptual Captions GitHub
29
- url: https://github.com/google-research-datasets/conceptual-captions
30
- ---
31
 
32
- # UNet-Style VAE for 256x256 Image Reconstruction
33
-
34
- This model is a UNet-style Variational Autoencoder (VAE) trained on the [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) dataset for high-quality image reconstruction and generation. It integrates adversarial, perceptual, and identity-preserving loss terms to improve semantic and visual fidelity.
35
-
36
- ## Architecture
37
-
38
- - **Encoder/Decoder**: Multi-scale UNet architecture
39
- - **Latent Space**: 8-channel latent bottleneck with reparameterization (mu, logvar)
40
- - **Losses**:
41
- - L1 reconstruction loss
42
- - KL divergence with annealing
43
- - LPIPS perceptual loss (VGG backbone)
44
- - Identity loss via MoCo-v2 embeddings
45
- - Adversarial loss via Patch Discriminator w/ Spectral Norm
46
-
47
- $$
48
- \mathcal{L}_{total} = \mathcal{L}_{recon} + \mathcal{L}_{PIPS} + 0.5 * \mathcal{L}_{GAN} + 0.1 *\mathcal{L}_{ID} + 10^{-6} *\mathcal{L}_{KL}
49
- $$
50
-
51
- ## Reconstructions
52
-
53
- | Input | Output |
54
- |-------|--------|
55
- | ![input](./input_grid.png) | ![output](./recon_grid.png) |
56
-
57
- ## Training Config
58
-
59
- | Hyperparameter | Value |
60
- |-----------------------|----------------------------|
61
- | Dataset | CC3M (850k images) |
62
- | Image Resolution | 256 x 256 |
63
- | Batch Size | 16 |
64
- | Optimizer | AdamW |
65
- | Learning Rate | 5e-5 |
66
- | Precision | bf16 (mixed precision) |
67
- | Total Steps | 210,000 |
68
- | GAN Start Step | 50,000 |
69
- | KL Annealing | Yes (10% of training) |
70
- | Augmentations | Crop, flip, jitter, blur, rotation |
71
-
72
- Trained using a cosine learning rate schedule with gradient clipping and automatic mixed precision (`torch.cuda.amp`)
73
-
74
- ## Usage Example
75
-
76
- ```python
77
- import torch
78
- from diffusers import AutoencoderKL
79
- vae = AutoencoderKL.from_pretrained("gabehubner/vae-256px-8z")
80
- vae.eval()
81
-
82
- input_tensor = torch.randn(1, 3, 256, 256) # Replace with your actual input
83
- with torch.no_grad():
 
1
+ # VAE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ A UNet-style VAE trained on CC3M with adversarial and perceptual losses.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "VAEWrapper",
3
+ "_diffusers_version": "0.33.1",
4
+ "act_fn": "silu",
5
+ "attention_resolutions": [
6
+ 32
7
+ ],
8
+ "block_out_channels": [
9
+ 64
10
+ ],
11
+ "channel_multipliers": [
12
+ 1,
13
+ 2,
14
+ 2,
15
+ 4,
16
+ 4
17
+ ],
18
+ "double_z": true,
19
+ "down_block_types": [
20
+ "DownEncoderBlock2D"
21
+ ],
22
+ "force_upcast": true,
23
+ "hidden_channels": 128,
24
+ "image_size": 256,
25
+ "in_channels": 3,
26
+ "latent_channels": 4,
27
+ "latents_mean": null,
28
+ "latents_std": null,
29
+ "layers_per_block": 1,
30
+ "mid_block_add_attention": true,
31
+ "norm_num_groups": 32,
32
+ "num_res_blocks": 3,
33
+ "out_channels": 3,
34
+ "sample_size": 32,
35
+ "scaling_factor": 0.18215,
36
+ "shift_factor": null,
37
+ "up_block_types": [
38
+ "UpDecoderBlock2D"
39
+ ],
40
+ "use_post_quant_conv": true,
41
+ "use_quant_conv": true,
42
+ "z_channels": 8
43
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b84b814c16e6e5bcb8e0300200a960c1199e27ccae73ad1215d48695eac74f82
3
+ size 322169676