Yihe Tang commited on
Commit
17b2682
·
1 Parent(s): 5396c1d

Initial deployment with model and dependencies

Browse files
app.py CHANGED
@@ -1,7 +1,56 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ Interface for HuggingFace deployment
3
+ """
4
+
5
  import gradio as gr
6
+ import numpy as np
7
+ from src.model import AffordanceModel
8
+ from src.utils.argument_utils import get_yaml_config
9
+ import cv2
10
+
11
+ print("Loading config...")
12
+ config = get_yaml_config("checkpoints/gemini/config.yaml")
13
+ print("Building model...")
14
+ model = AffordanceModel(config)
15
+ print("Model built successfully!")
16
+
17
+ def predict(image, text):
18
+ """
19
+ Gradio inference function
20
+ Args:
21
+ image: PIL Image (Gradio's default image input format)
22
+ text: str
23
+ Returns:
24
+ visualization of the heatmap
25
+ """
26
+ # Convert PIL image to numpy array
27
+ image = np.array(image)
28
+
29
+ # Run model inference
30
+ heatmap = model.inference(image, text) # Returns (H, W) array
31
+
32
+ # Visualize heatmap (convert to RGB for display)
33
+ # Scale to 0-255 and apply colormap
34
+ heatmap_vis = (heatmap * 255).astype(np.uint8)
35
+ heatmap_colored = cv2.applyColorMap(heatmap_vis, cv2.COLORMAP_JET)
36
+ heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
37
+
38
+ return heatmap_colored
39
 
40
+ # Create Gradio interface
41
+ demo = gr.Interface(
42
+ fn=predict,
43
+ inputs=[
44
+ gr.Image(type="pil", label="Input Image"), # Accepts uploaded images
45
+ gr.Textbox(label="Text Query", placeholder="Enter text description...")
46
+ ],
47
+ outputs=gr.Image(label="Affordance Heatmap"),
48
+ title="Affordance Detection",
49
+ description="Upload an image and provide a text query to detect affordances.",
50
+ examples=[
51
+ ["test.png", "rim"] # Add your test image and query
52
+ ]
53
+ )
54
 
55
+ if __name__ == "__main__":
56
+ demo.launch()
checkpoints/gemini/checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffb327d0b8aa9610430e851c740620c38b09c4ffcabca889d5eacec63627f08a
3
+ size 38003927
checkpoints/gemini/config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ in_channels: 1024
3
+ filters: [256, 64, 1]
4
+ kernel_sizes: [3, 3, 1]
5
+ strides: [1, 1, 1]
6
+ norm: None
7
+ activation: lrelu
8
+ lang_emb_dim: 1024
9
+ film_mode: zero
10
+ checkpoint_path: checkpoints/gemini/checkpoint.pth
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python-headless # using headless version for server deployment
5
+ Pillow # for PIL
6
+ openai # for OpenAI API
7
+ gradio
8
+ pyyaml # for yaml
9
+ matplotlib
10
+ IPython
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
src/__pycache__/network.cpython-310.pyc ADDED
Binary file (6.93 kB). View file
 
src/model.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Affordance model definition
3
+ """
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ import os
8
+ import sys
9
+ import os.path as osp
10
+
11
+ import numpy as np
12
+ import cv2
13
+ import matplotlib.pyplot as plt
14
+
15
+ from .network import Conv2DFiLMNet
16
+ from .utils.img_utils import *
17
+ from .utils.argument_utils import get_yaml_config
18
+
19
+ use_cuda = True
20
+ device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
21
+
22
+ def build_network(config):
23
+ """
24
+ Build model
25
+ """
26
+ net = Conv2DFiLMNet(**config)
27
+
28
+ net.build()
29
+ net = net.to(device)
30
+
31
+ return net
32
+
33
+ class AffordanceModel:
34
+
35
+ def __init__(self, config):
36
+
37
+ print("============ Building network and loading checkpoint =============")
38
+ # build network
39
+ self.model = build_network(config.model)
40
+
41
+ # load checkpoint
42
+ self.load_checkpoint(config.checkpoint_path)
43
+
44
+ print("============ Building DINO model =============")
45
+ torch_path = osp.join(osp.dirname(osp.dirname(__file__)), "data/torch_home")
46
+ self.dinov2 = load_pretrained_dino(torch_path=torch_path)
47
+
48
+ def load_checkpoint(self, checkpoint_path):
49
+ """
50
+ Load checkpoint
51
+ """
52
+ checkpoint = torch.load(checkpoint_path, map_location=device)
53
+ self.model.load_state_dict(checkpoint['model_state_dict'])
54
+ self.model.eval()
55
+
56
+ @torch.no_grad()
57
+ def inference(self, img, text, keep_orig_size=True, temp=1.0):
58
+ """
59
+ Inference model output on query image and text.
60
+ img: np.ndarray, (H, W, C)
61
+ text: str
62
+
63
+ Returns:
64
+ out: np array, (H, W)
65
+ """
66
+ self.model.eval()
67
+
68
+ # prepare input for model
69
+ img = rescale_img(img, max_size=672) # rescale image in case it is too large
70
+ processed_img, lang_emb = preprocess_data(img, text, self.model._lang_emb_dim)
71
+ img = torch.stack(processed_img, dim=0).to(device)
72
+ img_feat = get_dino_features(self.dinov2, img, repeat_to_orig_size=keep_orig_size).permute(0, 3, 1, 2)
73
+
74
+ lang_emb = lang_emb.to(device)
75
+
76
+ # forward pass
77
+ out = self.model(img_feat, lang_emb) # (1, 1, h, w)
78
+ out = out.squeeze()
79
+
80
+ # post-process output
81
+ out = torch.sigmoid(out)
82
+ out = out.detach().cpu().numpy()
83
+ return out
84
+
85
+
86
+ if __name__ == "__main__":
87
+ config_path = osp.join(osp.dirname(osp.dirname(__file__)), "checkpoints/gemini/config.yaml")
88
+ config = get_yaml_config(config_path)
89
+
90
+ affordance_model = AffordanceModel(config)
91
+
92
+ from IPython import embed; embed(); exit(0)
src/network.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines basic network building blocks and network architecture
3
+ Some code adapted from PerAct: https://github.com/peract/peract
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from typing import List
10
+
11
+ ########################################
12
+ ### layers
13
+ ########################################
14
+ LRELU_SLOPE = 0.02
15
+
16
+ def act_layer(act):
17
+ if act == 'relu':
18
+ return nn.ReLU()
19
+ elif act == 'lrelu':
20
+ return nn.LeakyReLU(LRELU_SLOPE)
21
+ elif act == 'elu':
22
+ return nn.ELU()
23
+ elif act == 'tanh':
24
+ return nn.Tanh()
25
+ elif act == 'prelu':
26
+ return nn.PReLU()
27
+ else:
28
+ raise ValueError('%s not recognized.' % act)
29
+
30
+ def norm_layer2d(norm, channels):
31
+ if norm == 'batch':
32
+ return nn.BatchNorm2d(channels)
33
+ elif norm == 'instance':
34
+ return nn.InstanceNorm2d(channels, affine=True)
35
+ elif norm == 'layer':
36
+ return nn.GroupNorm(1, channels, affine=True)
37
+ elif norm == 'group':
38
+ return nn.GroupNorm(4, channels, affine=True)
39
+ else:
40
+ raise ValueError('%s not recognized.' % norm)
41
+
42
+
43
+ ########################################
44
+ ### network blocks
45
+ ########################################
46
+
47
+ class FiLMBlockRand(nn.Module):
48
+ """
49
+ FiLM block with random init gamma and beta.
50
+ x = gamma * x + beta
51
+ Adapted from PerAct (and original FiLM paper)
52
+ """
53
+ def __init__(self, lang_emb_dim, num_channels):
54
+ super(FiLMBlockRand, self).__init__()
55
+
56
+ self.fc_gamma = nn.Linear(lang_emb_dim, num_channels)
57
+ self.fc_beta = nn.Linear(lang_emb_dim, num_channels)
58
+
59
+ def forward(self, x, lang_emb):
60
+ gamma = self.fc_gamma(lang_emb)
61
+ beta = self.fc_beta(lang_emb)
62
+
63
+ beta = beta.view(x.size(0), x.size(1), 1, 1)
64
+ gamma = gamma.view(x.size(0), x.size(1), 1, 1)
65
+
66
+ x = gamma * x + beta
67
+
68
+ return x
69
+
70
+
71
+ class FiLMBlockZero(nn.Module):
72
+ """
73
+ FiLM block with zero init gamma and beta.
74
+ x = (1 + gamma) * x + beta
75
+ Adapted from RT-1 https://github.com/google-research/robotics_transformer/blob/master/film_efficientnet/film_conditioning_layer.py
76
+ """
77
+ def __init__(self, lang_emb_dim, num_channels):
78
+ super(FiLMBlockZero, self).__init__()
79
+
80
+ self.fc_gamma = nn.Linear(lang_emb_dim, num_channels)
81
+ self.fc_beta = nn.Linear(lang_emb_dim, num_channels)
82
+
83
+ nn.init.zeros_(self.fc_gamma.weight)
84
+ nn.init.zeros_(self.fc_gamma.bias)
85
+ nn.init.zeros_(self.fc_beta.weight)
86
+ nn.init.zeros_(self.fc_beta.bias)
87
+
88
+ def forward(self, x, lang_emb):
89
+ gamma = self.fc_gamma(lang_emb)
90
+ beta = self.fc_beta(lang_emb)
91
+
92
+ beta = beta.view(x.size(0), x.size(1), 1, 1)
93
+ gamma = gamma.view(x.size(0), x.size(1), 1, 1)
94
+
95
+ x = (1 + gamma) * x + beta
96
+
97
+ return x
98
+
99
+
100
+ class Conv2DBlock(nn.Module):
101
+
102
+ def __init__(self, in_channels, out_channels, kernel_sizes, strides,
103
+ norm=None, activation=None, padding_mode='replicate'):
104
+ super(Conv2DBlock, self).__init__()
105
+ padding = kernel_sizes // 2 if isinstance(kernel_sizes, int) else (
106
+ kernel_sizes[0] // 2, kernel_sizes[1] // 2)
107
+ self.conv2d = nn.Conv2d(
108
+ in_channels, out_channels, kernel_sizes, strides, padding=padding,
109
+ padding_mode=padding_mode)
110
+
111
+ if activation is None:
112
+ nn.init.xavier_uniform_(self.conv2d.weight,
113
+ gain=nn.init.calculate_gain('linear'))
114
+ nn.init.zeros_(self.conv2d.bias)
115
+ elif activation == 'tanh':
116
+ nn.init.xavier_uniform_(self.conv2d.weight,
117
+ gain=nn.init.calculate_gain('tanh'))
118
+ nn.init.zeros_(self.conv2d.bias)
119
+ elif activation == 'lrelu':
120
+ nn.init.kaiming_uniform_(self.conv2d.weight, a=LRELU_SLOPE,
121
+ nonlinearity='leaky_relu')
122
+ nn.init.zeros_(self.conv2d.bias)
123
+ elif activation == 'relu':
124
+ nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity='relu')
125
+ nn.init.zeros_(self.conv2d.bias)
126
+ else:
127
+ raise ValueError()
128
+
129
+ self.activation = activation
130
+ self.norm = norm
131
+ if norm is not None:
132
+ self.norm = norm_layer2d(norm, out_channels)
133
+ if activation is not None:
134
+ self.activation = act_layer(activation)
135
+
136
+ def forward(self, x):
137
+ x = self.conv2d(x)
138
+ x = self.norm(x) if self.norm is not None else x
139
+ x = self.activation(x) if self.activation is not None else x
140
+ return x
141
+
142
+ class Conv2DFiLMBlock(Conv2DBlock):
143
+
144
+ def __init__(self, in_channels, out_channels, kernel_sizes, strides,
145
+ lang_emb_dim,
146
+ norm=None, activation=None, padding_mode='replicate',
147
+ film_mode='rand', film_place='after'
148
+ ):
149
+ super(Conv2DFiLMBlock, self).__init__(
150
+ in_channels, out_channels, kernel_sizes, strides, norm, activation,
151
+ padding_mode)
152
+
153
+
154
+ self.film_place = film_place
155
+ if film_place == 'after':
156
+ film_channels = out_channels
157
+ elif film_place == 'before':
158
+ film_channels = in_channels
159
+ else:
160
+ raise ValueError(f"film_place {film_place} not recognized")
161
+
162
+ if film_mode == 'rand':
163
+ self.film = FiLMBlockRand(lang_emb_dim, film_channels)
164
+ elif film_mode == 'zero':
165
+ self.film = FiLMBlockZero(lang_emb_dim, film_channels)
166
+ else:
167
+ raise ValueError(f"film_mode {film_mode} not recognized")
168
+
169
+ def forward(self, x, lang_emb):
170
+ if self.film_place == 'before':
171
+ x = self.film(x, lang_emb)
172
+ x = self.conv2d(x)
173
+ x = self.norm(x) if self.norm is not None else x
174
+ x = self.activation(x) if self.activation is not None else x
175
+
176
+ elif self.film_place == 'after':
177
+ x = self.conv2d(x) # (B, C, H, W)
178
+ x = self.norm(x) if self.norm is not None else x
179
+ x = self.film(x, lang_emb) # lang_emb: (B, lang_emb_dim), output: (B, C, H, W)
180
+ x = self.activation(x) if self.activation is not None else x
181
+
182
+ else:
183
+ raise ValueError(f"film_place {self.film_place} not recognized")
184
+
185
+ return x
186
+
187
+
188
+ ##############################################
189
+ #### Network
190
+ ##############################################
191
+
192
+ class Conv2DFiLMNet(nn.Module):
193
+
194
+ def __init__(self,
195
+ in_channels: int,
196
+ filters: List[int], # num of output channels for each Conv2D layer
197
+ kernel_sizes: List[int],
198
+ strides: List[int],
199
+ norm: str = None,
200
+ activation: str = 'relu',
201
+
202
+ lang_emb_dim: int = 256,
203
+ film_mode: str = 'zero',
204
+ film_place: str = 'after'
205
+ ):
206
+ super(Conv2DFiLMNet, self).__init__()
207
+
208
+ self._in_channels = in_channels
209
+ self._filters = filters
210
+ self._kernel_sizes = kernel_sizes
211
+ self._strides = strides
212
+ self._norm = norm
213
+ self._activation = activation
214
+
215
+ self._lang_emb_dim = lang_emb_dim
216
+ self._film_mode = film_mode
217
+ self._film_place = film_place
218
+
219
+ def build(self):
220
+ self.conv_blocks = nn.ModuleList()
221
+ for i in range(len(self._filters)):
222
+ in_channels = self._in_channels if i == 0 else self._filters[i-1]
223
+ out_channels = self._filters[i]
224
+ kernel_sizes = self._kernel_sizes[i]
225
+ strides = self._strides[i]
226
+ norm = self._norm
227
+ activation = self._activation if i < len(self._filters) - 1 else None # no activation for the last layer
228
+ conv_block = Conv2DFiLMBlock(
229
+ in_channels, out_channels, kernel_sizes, strides,
230
+ self._lang_emb_dim,
231
+ norm=norm, activation=activation,
232
+ film_mode=self._film_mode,
233
+ film_place=self._film_place
234
+ )
235
+ self.conv_blocks.append(conv_block)
236
+
237
+ def forward(self, x, lang_emb):
238
+ """
239
+ Args:
240
+ x: (B, C, H, W)
241
+ lang_emb: (B, lang_emb_dim)
242
+ """
243
+ for conv_block in self.conv_blocks:
244
+ x = conv_block(x, lang_emb)
245
+ return x
246
+
247
+
248
+ if __name__ == "__main__":
249
+
250
+ use_cuda = False
251
+ device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")
252
+ # from IPython import embed; embed(); exit(0)
253
+
254
+ # Test Conv2DFiLMNet
255
+ in_channels = 1024
256
+ filters = [256, 64, 1]
257
+ kernel_sizes = [3, 3, 1]
258
+ strides = [1, 1, 1]
259
+ norm = None
260
+ activation = 'lrelu'
261
+ lang_emb_dim = 1536
262
+ film_mode = 'zero'
263
+
264
+ net = Conv2DFiLMNet(
265
+ in_channels, filters, kernel_sizes, strides, norm, activation,
266
+ lang_emb_dim, film_mode
267
+ )
268
+
269
+ net.build()
270
+
271
+ from IPython import embed; embed(); exit(0)
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (160 Bytes). View file
 
src/utils/__pycache__/argument_utils.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
src/utils/__pycache__/img_utils.cpython-310.pyc ADDED
Binary file (4.16 kB). View file
 
src/utils/argument_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """load and save YAML config file. Originally from VoxPoser"""
2
+ import os
3
+ import sys
4
+ import yaml
5
+ import json
6
+
7
+ class ConfigDict(dict):
8
+ def __init__(self, config):
9
+ """recursively build config"""
10
+ # self.config = config
11
+ for key, value in config.items():
12
+ if isinstance(value, str) and value.lower() == 'none':
13
+ value = None
14
+ if isinstance(value, dict):
15
+ self[key] = ConfigDict(value)
16
+ else:
17
+ self[key] = value
18
+
19
+ def __getattr__(self, key):
20
+ if key in self:
21
+ return self[key]
22
+ elif key == 'config':
23
+ return self
24
+ else:
25
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
26
+ def __setattr__(self, key, value):
27
+ self[key] = value
28
+ def __delattr__(self, key):
29
+ del self[key]
30
+ def __getstate__(self):
31
+ return self.config
32
+ def __setstate__(self, state):
33
+ self.config = state
34
+ self.__init__(state)
35
+
36
+ def update(self, config):
37
+ """update with another dict"""
38
+ if isinstance(config, ConfigDict):
39
+ config = config.convert_to_dict()
40
+ for key, value in config.items():
41
+ if isinstance(value, dict):
42
+ self[key].update(value)
43
+ else:
44
+ self[key] = value
45
+
46
+ def convert_to_dict(self):
47
+ """convert to dict"""
48
+ config = {}
49
+ for key, value in self.items():
50
+ if isinstance(value, ConfigDict):
51
+ config[key] = value.convert_to_dict()
52
+ else:
53
+ config[key] = value
54
+ return config
55
+
56
+ def load_yaml_config(config_path):
57
+ with open(config_path, 'r') as f:
58
+ config = yaml.load(f, Loader=yaml.FullLoader)
59
+ return config
60
+
61
+ def get_yaml_config(config_path=None):
62
+ assert config_path and os.path.exists(config_path), f'config file does not exist ({config_path})'
63
+ config = load_yaml_config(config_path)
64
+ config = ConfigDict(config)
65
+ return config
66
+
67
+ def eval_str_to_lst(query_str):
68
+ """
69
+ Parse a string in format [a, b, c] to a list
70
+ """
71
+ query_str = query_str.replace('[', '').replace(']', '')
72
+ query_lst = query_str.split(',')
73
+ query_lst = [q.strip() for q in query_lst]
74
+ return query_lst
75
+
76
+ def get_command_line_args(argv, to_config_dict=True):
77
+ """
78
+ Utility function to parse all command line arguments and return them as a dictionary.
79
+ If argument is in format 'key1.key2=value', it will be parsed as a nested dictionary.
80
+ """
81
+ args_dict = {}
82
+ for arg in argv[1:]: # Skip the first argument (script name)
83
+ if '=' in arg:
84
+ key, value = arg.split('=', 1)
85
+ key = key.lstrip('-') # Remove leading dashes
86
+
87
+ # Try to convert value to int, float, bool, or leave as string
88
+ try:
89
+ value = json.loads(value)
90
+ except json.JSONDecodeError:
91
+ pass
92
+
93
+ if isinstance(value, str) and '[' in value and ']' in value:
94
+ value = eval_str_to_lst(value)
95
+
96
+ # Check for hierarchy in the key (indicated by '.')
97
+ if '.' in key:
98
+ sub_keys = key.split('.')
99
+ current_dict = args_dict
100
+ # Iterate through sub_keys to create nested dictionaries
101
+ for sub_key in sub_keys[:-1]:
102
+ if sub_key not in current_dict:
103
+ current_dict[sub_key] = {}
104
+ current_dict = current_dict[sub_key]
105
+ current_dict[sub_keys[-1]] = value
106
+ else:
107
+ args_dict[key] = value
108
+
109
+ if to_config_dict:
110
+ args_dict = ConfigDict(args_dict)
111
+ return args_dict
112
+
113
+ def save_config(config, config_path):
114
+ if not isinstance(config, dict):
115
+ raise ValueError("config must be a dictionary")
116
+ if type(config) != dict:
117
+ print("Converting config to dict")
118
+ config = config.convert_to_dict()
119
+ with open(config_path, 'w') as f:
120
+ yaml.dump(config, f, default_flow_style=False)
121
+
122
+ # def main():
123
+ # config = get_yaml_config(config_path='./configs/sim_env/empty_scene_fetch.yaml')
124
+ # from IPython import embed; embed()
125
+
126
+ if __name__ == '__main__':
127
+ from IPython import embed; embed(); exit(0)
128
+ # main()
src/utils/img_utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image related util functions
3
+ """
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import os
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+ from PIL import Image
13
+
14
+ from openai import OpenAI
15
+
16
+ ##############################################
17
+ ### Image processing
18
+ ##############################################
19
+
20
+ def rescale_img(img, max_size=448):
21
+ """
22
+ Rescale image such that largest dimension is max_size
23
+ img: np.ndarray, (H, W, C)
24
+ """
25
+ h, w = img.shape[:2]
26
+ if max(h, w) <= max_size:
27
+ return img
28
+
29
+ scale = max_size / max(h, w)
30
+ new_h, new_w = int(h*scale), int(w*scale)
31
+ img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
32
+ return img
33
+
34
+
35
+ ##############################################
36
+ ### DINO features
37
+ ##############################################
38
+
39
+ def load_pretrained_dino(model_type='dinov2_vitl14', device=None):
40
+ if device is None:
41
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
+
43
+ dinov2 = torch.hub.load('facebookresearch/dinov2', model_type).eval()
44
+ dinov2 = dinov2.to(device)
45
+ return dinov2
46
+
47
+ def get_dino_features(dinov2, imgs, repeat_to_orig_size=False):
48
+ """
49
+ Get features from DINO model
50
+ ::param dinov2:: DINO model
51
+ ::param imgs:: tensor of shape (bs, C, H, W)
52
+ """
53
+ bs, C, H, W = imgs.shape
54
+ patch_h = H // 14
55
+ patch_w = W // 14
56
+
57
+ with torch.no_grad():
58
+ features_dict = dinov2.forward_features(imgs)
59
+ features = features_dict['x_norm_patchtokens']
60
+ features = features.reshape(bs, patch_h, patch_w, -1)
61
+
62
+ if not repeat_to_orig_size:
63
+ return features # (bs, patch_h, patch_w, n_features)
64
+ else:
65
+ # repeat on batched dims to original size
66
+ ratio = H // (patch_h*2)
67
+ features = F.interpolate(features.permute(0, 3, 1, 2), scale_factor=ratio, mode='bilinear').permute(0, 2, 3, 1)
68
+ return features
69
+
70
+ def transform_imgs_for_dino(imgs, blur=True):
71
+ """
72
+ Transform image before passing to DINO model
73
+ ::param imgs:: np.array of shape (H, W, C) or (bs, H, W, C)
74
+ ::param blur:: bool, whether to apply Gaussian blur before resizing
75
+
76
+ ::return:: list of transformed images
77
+ """
78
+ # handles both single image and batch of images
79
+ if len(imgs.shape) == 3:
80
+ H, W, C = imgs.shape
81
+ imgs = imgs[None, ...]
82
+ bs = 1
83
+ else:
84
+ bs, H, W, C = imgs.shape
85
+
86
+ H *= 2
87
+ W *= 2
88
+
89
+ patch_h = H // 14
90
+ patch_w = W // 14
91
+
92
+ if blur:
93
+ transform_lst = [T.GaussianBlur(9, sigma=(1.0, 2.0))]
94
+ else:
95
+ transform_lst = []
96
+ transform_lst += [
97
+ T.Resize((patch_h * 14, patch_w * 14)),
98
+ T.CenterCrop((patch_h * 14, patch_w * 14)),
99
+ T.ToTensor(),
100
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
101
+ ]
102
+
103
+ transform = T.Compose(transform_lst)
104
+
105
+ transformed_imgs = []
106
+ for i in range(bs):
107
+ temp = imgs[i].copy()
108
+ if temp.max() <= 1.1: # handle images with values in [0, 1]
109
+ temp = (temp * 255)
110
+ temp = temp.astype(np.uint8).clip(0, 255)
111
+ transformed_imgs.append(transform(Image.fromarray(temp)))
112
+
113
+ return transformed_imgs
114
+
115
+
116
+ ##############################################
117
+ ### Preparing data for model
118
+ ##############################################
119
+
120
+ def get_text_embedding(text, model="text-embedding-3-large", dim=1024):
121
+ """
122
+ Get text embedding with specified dimension
123
+ """
124
+ client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
125
+ embedding = np.array(client.embeddings.create(input=[text], model=model, dimensions=dim).data[0].embedding)
126
+ return embedding
127
+
128
+ def preprocess_data(img, text, lang_emb_dim=1024):
129
+ if len(img.shape) == 3 and img.shape[2] == 4:
130
+ img = img[:, :, :3]
131
+ elif len(img.shape) == 2:
132
+ img = np.stack([img] * 3, axis=-1)
133
+ processed_img = transform_imgs_for_dino(img)
134
+ lang_emb = torch.tensor(get_text_embedding(text, dim=lang_emb_dim)).to(torch.float32)
135
+
136
+ return processed_img, lang_emb
test_img.png ADDED