Spaces:
Sleeping
Sleeping
Yihe Tang
commited on
Commit
·
17b2682
1
Parent(s):
5396c1d
Initial deployment with model and dependencies
Browse files- app.py +53 -4
- checkpoints/gemini/checkpoint.pth +3 -0
- checkpoints/gemini/config.yaml +10 -0
- requirements.txt +10 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/model.cpython-310.pyc +0 -0
- src/__pycache__/network.cpython-310.pyc +0 -0
- src/model.py +92 -0
- src/network.py +271 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/argument_utils.cpython-310.pyc +0 -0
- src/utils/__pycache__/img_utils.cpython-310.pyc +0 -0
- src/utils/argument_utils.py +128 -0
- src/utils/img_utils.py +136 -0
- test_img.png +0 -0
app.py
CHANGED
@@ -1,7 +1,56 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
demo.launch()
|
|
|
1 |
+
"""
|
2 |
+
Interface for HuggingFace deployment
|
3 |
+
"""
|
4 |
+
|
5 |
import gradio as gr
|
6 |
+
import numpy as np
|
7 |
+
from src.model import AffordanceModel
|
8 |
+
from src.utils.argument_utils import get_yaml_config
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
print("Loading config...")
|
12 |
+
config = get_yaml_config("checkpoints/gemini/config.yaml")
|
13 |
+
print("Building model...")
|
14 |
+
model = AffordanceModel(config)
|
15 |
+
print("Model built successfully!")
|
16 |
+
|
17 |
+
def predict(image, text):
|
18 |
+
"""
|
19 |
+
Gradio inference function
|
20 |
+
Args:
|
21 |
+
image: PIL Image (Gradio's default image input format)
|
22 |
+
text: str
|
23 |
+
Returns:
|
24 |
+
visualization of the heatmap
|
25 |
+
"""
|
26 |
+
# Convert PIL image to numpy array
|
27 |
+
image = np.array(image)
|
28 |
+
|
29 |
+
# Run model inference
|
30 |
+
heatmap = model.inference(image, text) # Returns (H, W) array
|
31 |
+
|
32 |
+
# Visualize heatmap (convert to RGB for display)
|
33 |
+
# Scale to 0-255 and apply colormap
|
34 |
+
heatmap_vis = (heatmap * 255).astype(np.uint8)
|
35 |
+
heatmap_colored = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
|
36 |
+
heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
|
37 |
+
|
38 |
+
return heatmap_colored
|
39 |
|
40 |
+
# Create Gradio interface
|
41 |
+
demo = gr.Interface(
|
42 |
+
fn=predict,
|
43 |
+
inputs=[
|
44 |
+
gr.Image(type="pil", label="Input Image"), # Accepts uploaded images
|
45 |
+
gr.Textbox(label="Text Query", placeholder="Enter text description...")
|
46 |
+
],
|
47 |
+
outputs=gr.Image(label="Affordance Heatmap"),
|
48 |
+
title="Affordance Detection",
|
49 |
+
description="Upload an image and provide a text query to detect affordances.",
|
50 |
+
examples=[
|
51 |
+
["test.png", "rim"] # Add your test image and query
|
52 |
+
]
|
53 |
+
)
|
54 |
|
55 |
+
if __name__ == "__main__":
|
56 |
+
demo.launch()
|
checkpoints/gemini/checkpoint.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ffb327d0b8aa9610430e851c740620c38b09c4ffcabca889d5eacec63627f08a
|
3 |
+
size 38003927
|
checkpoints/gemini/config.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
in_channels: 1024
|
3 |
+
filters: [256, 64, 1]
|
4 |
+
kernel_sizes: [3, 3, 1]
|
5 |
+
strides: [1, 1, 1]
|
6 |
+
norm: None
|
7 |
+
activation: lrelu
|
8 |
+
lang_emb_dim: 1024
|
9 |
+
film_mode: zero
|
10 |
+
checkpoint_path: checkpoints/gemini/checkpoint.pth
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy
|
4 |
+
opencv-python-headless # using headless version for server deployment
|
5 |
+
Pillow # for PIL
|
6 |
+
openai # for OpenAI API
|
7 |
+
gradio
|
8 |
+
pyyaml # for yaml
|
9 |
+
matplotlib
|
10 |
+
IPython
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
src/__pycache__/model.cpython-310.pyc
ADDED
Binary file (2.79 kB). View file
|
|
src/__pycache__/network.cpython-310.pyc
ADDED
Binary file (6.93 kB). View file
|
|
src/model.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Affordance model definition
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import os.path as osp
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
from .network import Conv2DFiLMNet
|
16 |
+
from .utils.img_utils import *
|
17 |
+
from .utils.argument_utils import get_yaml_config
|
18 |
+
|
19 |
+
use_cuda = True
|
20 |
+
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
|
21 |
+
|
22 |
+
def build_network(config):
|
23 |
+
"""
|
24 |
+
Build model
|
25 |
+
"""
|
26 |
+
net = Conv2DFiLMNet(**config)
|
27 |
+
|
28 |
+
net.build()
|
29 |
+
net = net.to(device)
|
30 |
+
|
31 |
+
return net
|
32 |
+
|
33 |
+
class AffordanceModel:
|
34 |
+
|
35 |
+
def __init__(self, config):
|
36 |
+
|
37 |
+
print("============ Building network and loading checkpoint =============")
|
38 |
+
# build network
|
39 |
+
self.model = build_network(config.model)
|
40 |
+
|
41 |
+
# load checkpoint
|
42 |
+
self.load_checkpoint(config.checkpoint_path)
|
43 |
+
|
44 |
+
print("============ Building DINO model =============")
|
45 |
+
torch_path = osp.join(osp.dirname(osp.dirname(__file__)), "data/torch_home")
|
46 |
+
self.dinov2 = load_pretrained_dino(torch_path=torch_path)
|
47 |
+
|
48 |
+
def load_checkpoint(self, checkpoint_path):
|
49 |
+
"""
|
50 |
+
Load checkpoint
|
51 |
+
"""
|
52 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
53 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
54 |
+
self.model.eval()
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def inference(self, img, text, keep_orig_size=True, temp=1.0):
|
58 |
+
"""
|
59 |
+
Inference model output on query image and text.
|
60 |
+
img: np.ndarray, (H, W, C)
|
61 |
+
text: str
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
out: np array, (H, W)
|
65 |
+
"""
|
66 |
+
self.model.eval()
|
67 |
+
|
68 |
+
# prepare input for model
|
69 |
+
img = rescale_img(img, max_size=672) # rescale image in case it is too large
|
70 |
+
processed_img, lang_emb = preprocess_data(img, text, self.model._lang_emb_dim)
|
71 |
+
img = torch.stack(processed_img, dim=0).to(device)
|
72 |
+
img_feat = get_dino_features(self.dinov2, img, repeat_to_orig_size=keep_orig_size).permute(0, 3, 1, 2)
|
73 |
+
|
74 |
+
lang_emb = lang_emb.to(device)
|
75 |
+
|
76 |
+
# forward pass
|
77 |
+
out = self.model(img_feat, lang_emb) # (1, 1, h, w)
|
78 |
+
out = out.squeeze()
|
79 |
+
|
80 |
+
# post-process output
|
81 |
+
out = torch.sigmoid(out)
|
82 |
+
out = out.detach().cpu().numpy()
|
83 |
+
return out
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
config_path = osp.join(osp.dirname(osp.dirname(__file__)), "checkpoints/gemini/config.yaml")
|
88 |
+
config = get_yaml_config(config_path)
|
89 |
+
|
90 |
+
affordance_model = AffordanceModel(config)
|
91 |
+
|
92 |
+
from IPython import embed; embed(); exit(0)
|
src/network.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Defines basic network building blocks and network architecture
|
3 |
+
Some code adapted from PerAct: https://github.com/peract/peract
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from typing import List
|
10 |
+
|
11 |
+
########################################
|
12 |
+
### layers
|
13 |
+
########################################
|
14 |
+
LRELU_SLOPE = 0.02
|
15 |
+
|
16 |
+
def act_layer(act):
|
17 |
+
if act == 'relu':
|
18 |
+
return nn.ReLU()
|
19 |
+
elif act == 'lrelu':
|
20 |
+
return nn.LeakyReLU(LRELU_SLOPE)
|
21 |
+
elif act == 'elu':
|
22 |
+
return nn.ELU()
|
23 |
+
elif act == 'tanh':
|
24 |
+
return nn.Tanh()
|
25 |
+
elif act == 'prelu':
|
26 |
+
return nn.PReLU()
|
27 |
+
else:
|
28 |
+
raise ValueError('%s not recognized.' % act)
|
29 |
+
|
30 |
+
def norm_layer2d(norm, channels):
|
31 |
+
if norm == 'batch':
|
32 |
+
return nn.BatchNorm2d(channels)
|
33 |
+
elif norm == 'instance':
|
34 |
+
return nn.InstanceNorm2d(channels, affine=True)
|
35 |
+
elif norm == 'layer':
|
36 |
+
return nn.GroupNorm(1, channels, affine=True)
|
37 |
+
elif norm == 'group':
|
38 |
+
return nn.GroupNorm(4, channels, affine=True)
|
39 |
+
else:
|
40 |
+
raise ValueError('%s not recognized.' % norm)
|
41 |
+
|
42 |
+
|
43 |
+
########################################
|
44 |
+
### network blocks
|
45 |
+
########################################
|
46 |
+
|
47 |
+
class FiLMBlockRand(nn.Module):
|
48 |
+
"""
|
49 |
+
FiLM block with random init gamma and beta.
|
50 |
+
x = gamma * x + beta
|
51 |
+
Adapted from PerAct (and original FiLM paper)
|
52 |
+
"""
|
53 |
+
def __init__(self, lang_emb_dim, num_channels):
|
54 |
+
super(FiLMBlockRand, self).__init__()
|
55 |
+
|
56 |
+
self.fc_gamma = nn.Linear(lang_emb_dim, num_channels)
|
57 |
+
self.fc_beta = nn.Linear(lang_emb_dim, num_channels)
|
58 |
+
|
59 |
+
def forward(self, x, lang_emb):
|
60 |
+
gamma = self.fc_gamma(lang_emb)
|
61 |
+
beta = self.fc_beta(lang_emb)
|
62 |
+
|
63 |
+
beta = beta.view(x.size(0), x.size(1), 1, 1)
|
64 |
+
gamma = gamma.view(x.size(0), x.size(1), 1, 1)
|
65 |
+
|
66 |
+
x = gamma * x + beta
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class FiLMBlockZero(nn.Module):
|
72 |
+
"""
|
73 |
+
FiLM block with zero init gamma and beta.
|
74 |
+
x = (1 + gamma) * x + beta
|
75 |
+
Adapted from RT-1 https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py
|
76 |
+
"""
|
77 |
+
def __init__(self, lang_emb_dim, num_channels):
|
78 |
+
super(FiLMBlockZero, self).__init__()
|
79 |
+
|
80 |
+
self.fc_gamma = nn.Linear(lang_emb_dim, num_channels)
|
81 |
+
self.fc_beta = nn.Linear(lang_emb_dim, num_channels)
|
82 |
+
|
83 |
+
nn.init.zeros_(self.fc_gamma.weight)
|
84 |
+
nn.init.zeros_(self.fc_gamma.bias)
|
85 |
+
nn.init.zeros_(self.fc_beta.weight)
|
86 |
+
nn.init.zeros_(self.fc_beta.bias)
|
87 |
+
|
88 |
+
def forward(self, x, lang_emb):
|
89 |
+
gamma = self.fc_gamma(lang_emb)
|
90 |
+
beta = self.fc_beta(lang_emb)
|
91 |
+
|
92 |
+
beta = beta.view(x.size(0), x.size(1), 1, 1)
|
93 |
+
gamma = gamma.view(x.size(0), x.size(1), 1, 1)
|
94 |
+
|
95 |
+
x = (1 + gamma) * x + beta
|
96 |
+
|
97 |
+
return x
|
98 |
+
|
99 |
+
|
100 |
+
class Conv2DBlock(nn.Module):
|
101 |
+
|
102 |
+
def __init__(self, in_channels, out_channels, kernel_sizes, strides,
|
103 |
+
norm=None, activation=None, padding_mode='replicate'):
|
104 |
+
super(Conv2DBlock, self).__init__()
|
105 |
+
padding = kernel_sizes // 2 if isinstance(kernel_sizes, int) else (
|
106 |
+
kernel_sizes[0] // 2, kernel_sizes[1] // 2)
|
107 |
+
self.conv2d = nn.Conv2d(
|
108 |
+
in_channels, out_channels, kernel_sizes, strides, padding=padding,
|
109 |
+
padding_mode=padding_mode)
|
110 |
+
|
111 |
+
if activation is None:
|
112 |
+
nn.init.xavier_uniform_(self.conv2d.weight,
|
113 |
+
gain=nn.init.calculate_gain('linear'))
|
114 |
+
nn.init.zeros_(self.conv2d.bias)
|
115 |
+
elif activation == 'tanh':
|
116 |
+
nn.init.xavier_uniform_(self.conv2d.weight,
|
117 |
+
gain=nn.init.calculate_gain('tanh'))
|
118 |
+
nn.init.zeros_(self.conv2d.bias)
|
119 |
+
elif activation == 'lrelu':
|
120 |
+
nn.init.kaiming_uniform_(self.conv2d.weight, a=LRELU_SLOPE,
|
121 |
+
nonlinearity='leaky_relu')
|
122 |
+
nn.init.zeros_(self.conv2d.bias)
|
123 |
+
elif activation == 'relu':
|
124 |
+
nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity='relu')
|
125 |
+
nn.init.zeros_(self.conv2d.bias)
|
126 |
+
else:
|
127 |
+
raise ValueError()
|
128 |
+
|
129 |
+
self.activation = activation
|
130 |
+
self.norm = norm
|
131 |
+
if norm is not None:
|
132 |
+
self.norm = norm_layer2d(norm, out_channels)
|
133 |
+
if activation is not None:
|
134 |
+
self.activation = act_layer(activation)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.conv2d(x)
|
138 |
+
x = self.norm(x) if self.norm is not None else x
|
139 |
+
x = self.activation(x) if self.activation is not None else x
|
140 |
+
return x
|
141 |
+
|
142 |
+
class Conv2DFiLMBlock(Conv2DBlock):
|
143 |
+
|
144 |
+
def __init__(self, in_channels, out_channels, kernel_sizes, strides,
|
145 |
+
lang_emb_dim,
|
146 |
+
norm=None, activation=None, padding_mode='replicate',
|
147 |
+
film_mode='rand', film_place='after'
|
148 |
+
):
|
149 |
+
super(Conv2DFiLMBlock, self).__init__(
|
150 |
+
in_channels, out_channels, kernel_sizes, strides, norm, activation,
|
151 |
+
padding_mode)
|
152 |
+
|
153 |
+
|
154 |
+
self.film_place = film_place
|
155 |
+
if film_place == 'after':
|
156 |
+
film_channels = out_channels
|
157 |
+
elif film_place == 'before':
|
158 |
+
film_channels = in_channels
|
159 |
+
else:
|
160 |
+
raise ValueError(f"film_place {film_place} not recognized")
|
161 |
+
|
162 |
+
if film_mode == 'rand':
|
163 |
+
self.film = FiLMBlockRand(lang_emb_dim, film_channels)
|
164 |
+
elif film_mode == 'zero':
|
165 |
+
self.film = FiLMBlockZero(lang_emb_dim, film_channels)
|
166 |
+
else:
|
167 |
+
raise ValueError(f"film_mode {film_mode} not recognized")
|
168 |
+
|
169 |
+
def forward(self, x, lang_emb):
|
170 |
+
if self.film_place == 'before':
|
171 |
+
x = self.film(x, lang_emb)
|
172 |
+
x = self.conv2d(x)
|
173 |
+
x = self.norm(x) if self.norm is not None else x
|
174 |
+
x = self.activation(x) if self.activation is not None else x
|
175 |
+
|
176 |
+
elif self.film_place == 'after':
|
177 |
+
x = self.conv2d(x) # (B, C, H, W)
|
178 |
+
x = self.norm(x) if self.norm is not None else x
|
179 |
+
x = self.film(x, lang_emb) # lang_emb: (B, lang_emb_dim), output: (B, C, H, W)
|
180 |
+
x = self.activation(x) if self.activation is not None else x
|
181 |
+
|
182 |
+
else:
|
183 |
+
raise ValueError(f"film_place {self.film_place} not recognized")
|
184 |
+
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
##############################################
|
189 |
+
#### Network
|
190 |
+
##############################################
|
191 |
+
|
192 |
+
class Conv2DFiLMNet(nn.Module):
|
193 |
+
|
194 |
+
def __init__(self,
|
195 |
+
in_channels: int,
|
196 |
+
filters: List[int], # num of output channels for each Conv2D layer
|
197 |
+
kernel_sizes: List[int],
|
198 |
+
strides: List[int],
|
199 |
+
norm: str = None,
|
200 |
+
activation: str = 'relu',
|
201 |
+
|
202 |
+
lang_emb_dim: int = 256,
|
203 |
+
film_mode: str = 'zero',
|
204 |
+
film_place: str = 'after'
|
205 |
+
):
|
206 |
+
super(Conv2DFiLMNet, self).__init__()
|
207 |
+
|
208 |
+
self._in_channels = in_channels
|
209 |
+
self._filters = filters
|
210 |
+
self._kernel_sizes = kernel_sizes
|
211 |
+
self._strides = strides
|
212 |
+
self._norm = norm
|
213 |
+
self._activation = activation
|
214 |
+
|
215 |
+
self._lang_emb_dim = lang_emb_dim
|
216 |
+
self._film_mode = film_mode
|
217 |
+
self._film_place = film_place
|
218 |
+
|
219 |
+
def build(self):
|
220 |
+
self.conv_blocks = nn.ModuleList()
|
221 |
+
for i in range(len(self._filters)):
|
222 |
+
in_channels = self._in_channels if i == 0 else self._filters[i-1]
|
223 |
+
out_channels = self._filters[i]
|
224 |
+
kernel_sizes = self._kernel_sizes[i]
|
225 |
+
strides = self._strides[i]
|
226 |
+
norm = self._norm
|
227 |
+
activation = self._activation if i < len(self._filters) - 1 else None # no activation for the last layer
|
228 |
+
conv_block = Conv2DFiLMBlock(
|
229 |
+
in_channels, out_channels, kernel_sizes, strides,
|
230 |
+
self._lang_emb_dim,
|
231 |
+
norm=norm, activation=activation,
|
232 |
+
film_mode=self._film_mode,
|
233 |
+
film_place=self._film_place
|
234 |
+
)
|
235 |
+
self.conv_blocks.append(conv_block)
|
236 |
+
|
237 |
+
def forward(self, x, lang_emb):
|
238 |
+
"""
|
239 |
+
Args:
|
240 |
+
x: (B, C, H, W)
|
241 |
+
lang_emb: (B, lang_emb_dim)
|
242 |
+
"""
|
243 |
+
for conv_block in self.conv_blocks:
|
244 |
+
x = conv_block(x, lang_emb)
|
245 |
+
return x
|
246 |
+
|
247 |
+
|
248 |
+
if __name__ == "__main__":
|
249 |
+
|
250 |
+
use_cuda = False
|
251 |
+
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
|
252 |
+
# from IPython import embed; embed(); exit(0)
|
253 |
+
|
254 |
+
# Test Conv2DFiLMNet
|
255 |
+
in_channels = 1024
|
256 |
+
filters = [256, 64, 1]
|
257 |
+
kernel_sizes = [3, 3, 1]
|
258 |
+
strides = [1, 1, 1]
|
259 |
+
norm = None
|
260 |
+
activation = 'lrelu'
|
261 |
+
lang_emb_dim = 1536
|
262 |
+
film_mode = 'zero'
|
263 |
+
|
264 |
+
net = Conv2DFiLMNet(
|
265 |
+
in_channels, filters, kernel_sizes, strides, norm, activation,
|
266 |
+
lang_emb_dim, film_mode
|
267 |
+
)
|
268 |
+
|
269 |
+
net.build()
|
270 |
+
|
271 |
+
from IPython import embed; embed(); exit(0)
|
src/utils/__init__.py
ADDED
File without changes
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (160 Bytes). View file
|
|
src/utils/__pycache__/argument_utils.cpython-310.pyc
ADDED
Binary file (4.17 kB). View file
|
|
src/utils/__pycache__/img_utils.cpython-310.pyc
ADDED
Binary file (4.16 kB). View file
|
|
src/utils/argument_utils.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""load and save YAML config file. Originally from VoxPoser"""
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import yaml
|
5 |
+
import json
|
6 |
+
|
7 |
+
class ConfigDict(dict):
|
8 |
+
def __init__(self, config):
|
9 |
+
"""recursively build config"""
|
10 |
+
# self.config = config
|
11 |
+
for key, value in config.items():
|
12 |
+
if isinstance(value, str) and value.lower() == 'none':
|
13 |
+
value = None
|
14 |
+
if isinstance(value, dict):
|
15 |
+
self[key] = ConfigDict(value)
|
16 |
+
else:
|
17 |
+
self[key] = value
|
18 |
+
|
19 |
+
def __getattr__(self, key):
|
20 |
+
if key in self:
|
21 |
+
return self[key]
|
22 |
+
elif key == 'config':
|
23 |
+
return self
|
24 |
+
else:
|
25 |
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
26 |
+
def __setattr__(self, key, value):
|
27 |
+
self[key] = value
|
28 |
+
def __delattr__(self, key):
|
29 |
+
del self[key]
|
30 |
+
def __getstate__(self):
|
31 |
+
return self.config
|
32 |
+
def __setstate__(self, state):
|
33 |
+
self.config = state
|
34 |
+
self.__init__(state)
|
35 |
+
|
36 |
+
def update(self, config):
|
37 |
+
"""update with another dict"""
|
38 |
+
if isinstance(config, ConfigDict):
|
39 |
+
config = config.convert_to_dict()
|
40 |
+
for key, value in config.items():
|
41 |
+
if isinstance(value, dict):
|
42 |
+
self[key].update(value)
|
43 |
+
else:
|
44 |
+
self[key] = value
|
45 |
+
|
46 |
+
def convert_to_dict(self):
|
47 |
+
"""convert to dict"""
|
48 |
+
config = {}
|
49 |
+
for key, value in self.items():
|
50 |
+
if isinstance(value, ConfigDict):
|
51 |
+
config[key] = value.convert_to_dict()
|
52 |
+
else:
|
53 |
+
config[key] = value
|
54 |
+
return config
|
55 |
+
|
56 |
+
def load_yaml_config(config_path):
|
57 |
+
with open(config_path, 'r') as f:
|
58 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
59 |
+
return config
|
60 |
+
|
61 |
+
def get_yaml_config(config_path=None):
|
62 |
+
assert config_path and os.path.exists(config_path), f'config file does not exist ({config_path})'
|
63 |
+
config = load_yaml_config(config_path)
|
64 |
+
config = ConfigDict(config)
|
65 |
+
return config
|
66 |
+
|
67 |
+
def eval_str_to_lst(query_str):
|
68 |
+
"""
|
69 |
+
Parse a string in format [a, b, c] to a list
|
70 |
+
"""
|
71 |
+
query_str = query_str.replace('[', '').replace(']', '')
|
72 |
+
query_lst = query_str.split(',')
|
73 |
+
query_lst = [q.strip() for q in query_lst]
|
74 |
+
return query_lst
|
75 |
+
|
76 |
+
def get_command_line_args(argv, to_config_dict=True):
|
77 |
+
"""
|
78 |
+
Utility function to parse all command line arguments and return them as a dictionary.
|
79 |
+
If argument is in format 'key1.key2=value', it will be parsed as a nested dictionary.
|
80 |
+
"""
|
81 |
+
args_dict = {}
|
82 |
+
for arg in argv[1:]: # Skip the first argument (script name)
|
83 |
+
if '=' in arg:
|
84 |
+
key, value = arg.split('=', 1)
|
85 |
+
key = key.lstrip('-') # Remove leading dashes
|
86 |
+
|
87 |
+
# Try to convert value to int, float, bool, or leave as string
|
88 |
+
try:
|
89 |
+
value = json.loads(value)
|
90 |
+
except json.JSONDecodeError:
|
91 |
+
pass
|
92 |
+
|
93 |
+
if isinstance(value, str) and '[' in value and ']' in value:
|
94 |
+
value = eval_str_to_lst(value)
|
95 |
+
|
96 |
+
# Check for hierarchy in the key (indicated by '.')
|
97 |
+
if '.' in key:
|
98 |
+
sub_keys = key.split('.')
|
99 |
+
current_dict = args_dict
|
100 |
+
# Iterate through sub_keys to create nested dictionaries
|
101 |
+
for sub_key in sub_keys[:-1]:
|
102 |
+
if sub_key not in current_dict:
|
103 |
+
current_dict[sub_key] = {}
|
104 |
+
current_dict = current_dict[sub_key]
|
105 |
+
current_dict[sub_keys[-1]] = value
|
106 |
+
else:
|
107 |
+
args_dict[key] = value
|
108 |
+
|
109 |
+
if to_config_dict:
|
110 |
+
args_dict = ConfigDict(args_dict)
|
111 |
+
return args_dict
|
112 |
+
|
113 |
+
def save_config(config, config_path):
|
114 |
+
if not isinstance(config, dict):
|
115 |
+
raise ValueError("config must be a dictionary")
|
116 |
+
if type(config) != dict:
|
117 |
+
print("Converting config to dict")
|
118 |
+
config = config.convert_to_dict()
|
119 |
+
with open(config_path, 'w') as f:
|
120 |
+
yaml.dump(config, f, default_flow_style=False)
|
121 |
+
|
122 |
+
# def main():
|
123 |
+
# config = get_yaml_config(config_path='./configs/sim_env/empty_scene_fetch.yaml')
|
124 |
+
# from IPython import embed; embed()
|
125 |
+
|
126 |
+
if __name__ == '__main__':
|
127 |
+
from IPython import embed; embed(); exit(0)
|
128 |
+
# main()
|
src/utils/img_utils.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Image related util functions
|
3 |
+
"""
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision.transforms as T
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from openai import OpenAI
|
15 |
+
|
16 |
+
##############################################
|
17 |
+
### Image processing
|
18 |
+
##############################################
|
19 |
+
|
20 |
+
def rescale_img(img, max_size=448):
|
21 |
+
"""
|
22 |
+
Rescale image such that largest dimension is max_size
|
23 |
+
img: np.ndarray, (H, W, C)
|
24 |
+
"""
|
25 |
+
h, w = img.shape[:2]
|
26 |
+
if max(h, w) <= max_size:
|
27 |
+
return img
|
28 |
+
|
29 |
+
scale = max_size / max(h, w)
|
30 |
+
new_h, new_w = int(h*scale), int(w*scale)
|
31 |
+
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
##############################################
|
36 |
+
### DINO features
|
37 |
+
##############################################
|
38 |
+
|
39 |
+
def load_pretrained_dino(model_type='dinov2_vitl14', device=None):
|
40 |
+
if device is None:
|
41 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
+
|
43 |
+
dinov2 = torch.hub.load('facebookresearch/dinov2', model_type).eval()
|
44 |
+
dinov2 = dinov2.to(device)
|
45 |
+
return dinov2
|
46 |
+
|
47 |
+
def get_dino_features(dinov2, imgs, repeat_to_orig_size=False):
|
48 |
+
"""
|
49 |
+
Get features from DINO model
|
50 |
+
::param dinov2:: DINO model
|
51 |
+
::param imgs:: tensor of shape (bs, C, H, W)
|
52 |
+
"""
|
53 |
+
bs, C, H, W = imgs.shape
|
54 |
+
patch_h = H // 14
|
55 |
+
patch_w = W // 14
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
features_dict = dinov2.forward_features(imgs)
|
59 |
+
features = features_dict['x_norm_patchtokens']
|
60 |
+
features = features.reshape(bs, patch_h, patch_w, -1)
|
61 |
+
|
62 |
+
if not repeat_to_orig_size:
|
63 |
+
return features # (bs, patch_h, patch_w, n_features)
|
64 |
+
else:
|
65 |
+
# repeat on batched dims to original size
|
66 |
+
ratio = H // (patch_h*2)
|
67 |
+
features = F.interpolate(features.permute(0, 3, 1, 2), scale_factor=ratio, mode='bilinear').permute(0, 2, 3, 1)
|
68 |
+
return features
|
69 |
+
|
70 |
+
def transform_imgs_for_dino(imgs, blur=True):
|
71 |
+
"""
|
72 |
+
Transform image before passing to DINO model
|
73 |
+
::param imgs:: np.array of shape (H, W, C) or (bs, H, W, C)
|
74 |
+
::param blur:: bool, whether to apply Gaussian blur before resizing
|
75 |
+
|
76 |
+
::return:: list of transformed images
|
77 |
+
"""
|
78 |
+
# handles both single image and batch of images
|
79 |
+
if len(imgs.shape) == 3:
|
80 |
+
H, W, C = imgs.shape
|
81 |
+
imgs = imgs[None, ...]
|
82 |
+
bs = 1
|
83 |
+
else:
|
84 |
+
bs, H, W, C = imgs.shape
|
85 |
+
|
86 |
+
H *= 2
|
87 |
+
W *= 2
|
88 |
+
|
89 |
+
patch_h = H // 14
|
90 |
+
patch_w = W // 14
|
91 |
+
|
92 |
+
if blur:
|
93 |
+
transform_lst = [T.GaussianBlur(9, sigma=(1.0, 2.0))]
|
94 |
+
else:
|
95 |
+
transform_lst = []
|
96 |
+
transform_lst += [
|
97 |
+
T.Resize((patch_h * 14, patch_w * 14)),
|
98 |
+
T.CenterCrop((patch_h * 14, patch_w * 14)),
|
99 |
+
T.ToTensor(),
|
100 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
101 |
+
]
|
102 |
+
|
103 |
+
transform = T.Compose(transform_lst)
|
104 |
+
|
105 |
+
transformed_imgs = []
|
106 |
+
for i in range(bs):
|
107 |
+
temp = imgs[i].copy()
|
108 |
+
if temp.max() <= 1.1: # handle images with values in [0, 1]
|
109 |
+
temp = (temp * 255)
|
110 |
+
temp = temp.astype(np.uint8).clip(0, 255)
|
111 |
+
transformed_imgs.append(transform(Image.fromarray(temp)))
|
112 |
+
|
113 |
+
return transformed_imgs
|
114 |
+
|
115 |
+
|
116 |
+
##############################################
|
117 |
+
### Preparing data for model
|
118 |
+
##############################################
|
119 |
+
|
120 |
+
def get_text_embedding(text, model="text-embedding-3-large", dim=1024):
|
121 |
+
"""
|
122 |
+
Get text embedding with specified dimension
|
123 |
+
"""
|
124 |
+
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
125 |
+
embedding = np.array(client.embeddings.create(input=[text], model=model, dimensions=dim).data[0].embedding)
|
126 |
+
return embedding
|
127 |
+
|
128 |
+
def preprocess_data(img, text, lang_emb_dim=1024):
|
129 |
+
if len(img.shape) == 3 and img.shape[2] == 4:
|
130 |
+
img = img[:, :, :3]
|
131 |
+
elif len(img.shape) == 2:
|
132 |
+
img = np.stack([img] * 3, axis=-1)
|
133 |
+
processed_img = transform_imgs_for_dino(img)
|
134 |
+
lang_emb = torch.tensor(get_text_embedding(text, dim=lang_emb_dim)).to(torch.float32)
|
135 |
+
|
136 |
+
return processed_img, lang_emb
|
test_img.png
ADDED
![]() |