Spaces:
Sleeping
Sleeping
import wandb | |
import torch | |
import numpy as np | |
from monai.visualize import blend_images | |
class WandBModel: | |
""" | |
Enable WandB features to the model using multiple inheritance | |
""" | |
def __init__(self, *args, **kwargs): | |
# the following attributes should be initialized by class `BaseSegmentationModel` | |
self.visual_pairs = None | |
self.train_loss = None | |
self.val_loss = None | |
self.metric_meter = None | |
self.name = None | |
# the following attributes should be initialized by the child class | |
self.val_table = None | |
def volume2videos(self, time_dim=3, tag=''): | |
""" | |
Convert 3D volumes to video in favor of WandB logging | |
Args: | |
time_dim: the spatial dimension to be converted as the time dimension, default is the axial axis (dim 3) | |
tag: extra information for logging | |
""" | |
videos = [] | |
for image_pair in self.visual_pairs: | |
try: | |
pair_name = getattr(self, image_pair['name']) | |
image = getattr(self, image_pair['image']) | |
mask = getattr(self, image_pair['mask']) | |
vis_type = image_pair['type'] | |
except: | |
continue | |
for i in range(image.shape[0]): # deallocate the batch dim | |
image2save = image[i, ...] | |
mask2save = mask[i, ...] | |
item_name = pair_name[i] | |
# detach the tensor, format [C, H, W, D] | |
image_numpy = image2save.detach() | |
mask_numpy = mask2save.detach() | |
if mask_numpy.shape[0] > 1: | |
mask_numpy = torch.argmax(mask_numpy, dim=0, keepdim=True) | |
# (C, H, W, D), torch.Tensor on device | |
pair_blend = blend_images(image_numpy, mask_numpy, alpha=0.5) * 255 | |
# permute the axes to (time, channel, height, width) | |
spatial_dim = list(range(1, len(pair_blend.shape[1:]) + 1)) | |
spatial_dim.remove(time_dim) | |
pair_blend = pair_blend.permute([time_dim, 0] + spatial_dim).cpu().numpy().astype(np.uint8) | |
# record in the wandb.Video class | |
video = wandb.Video(pair_blend, fps=8, caption='{}_{}{}'.format(item_name, vis_type, tag)) | |
videos.append(video) | |
return videos | |
def log_scaler(self, key, value, step=None): | |
""" | |
Log manually defined scaler data | |
""" | |
wandb.log({key: np.round(value, decimals=4)}, step=step) | |
def log_train_loss(self, step=None): | |
""" | |
Log train loss | |
""" | |
data_dict = self.train_loss.pop_data(True) | |
for key, value in data_dict.items(): | |
wandb.log({'train/{}'.format(key): value}, step=step) | |
def log_val_loss(self, step=None): | |
""" | |
Log val loss | |
""" | |
data_dict = self.val_loss.pop_data(True) | |
for key, value in data_dict.items(): | |
wandb.log({'val/{}'.format(key): value}, step=step) | |
def log_metrics(self, step=None): | |
""" | |
Log validation metrics as wandb.Table | |
""" | |
df = self.metric_meter.to_df() | |
wandb.log({'val/metrics': wandb.Table(dataframe=df)}, step=step) | |
def log_vis(self, key, step=None, time_dim=3, tag=''): | |
""" | |
Log training intermediate visualizations | |
""" | |
videos = self.volume2videos(time_dim, tag) | |
wandb.log({key: videos}, step=step) | |
def update_val_visualization(self, time_dim=3, tag=''): | |
""" | |
Update the validation visualization to buffer, called every step of evaluation | |
""" | |
videos = self.volume2videos(time_dim, tag) | |
self.val_table.add_data(self.name, *videos) | |
def log_val_visualization(self, step=None): | |
""" | |
Log validation visualization | |
""" | |
wandb.log({'val/visualization': self.val_table}, step=step) | |
# re-initialize the table for next logging | |
del self.val_table | |
self.val_table = wandb.Table(columns=['ID'] + [pair['type'] for pair in self.visual_pairs]) | |