Mbi2Spi / base /base_wandb_model.py
hsiangyualex's picture
Upload 64 files
f97a499 verified
raw
history blame
4.12 kB
import wandb
import torch
import numpy as np
from monai.visualize import blend_images
class WandBModel:
"""
Enable WandB features to the model using multiple inheritance
"""
def __init__(self, *args, **kwargs):
# the following attributes should be initialized by class `BaseSegmentationModel`
self.visual_pairs = None
self.train_loss = None
self.val_loss = None
self.metric_meter = None
self.name = None
# the following attributes should be initialized by the child class
self.val_table = None
def volume2videos(self, time_dim=3, tag=''):
"""
Convert 3D volumes to video in favor of WandB logging
Args:
time_dim: the spatial dimension to be converted as the time dimension, default is the axial axis (dim 3)
tag: extra information for logging
"""
videos = []
for image_pair in self.visual_pairs:
try:
pair_name = getattr(self, image_pair['name'])
image = getattr(self, image_pair['image'])
mask = getattr(self, image_pair['mask'])
vis_type = image_pair['type']
except:
continue
for i in range(image.shape[0]): # deallocate the batch dim
image2save = image[i, ...]
mask2save = mask[i, ...]
item_name = pair_name[i]
# detach the tensor, format [C, H, W, D]
image_numpy = image2save.detach()
mask_numpy = mask2save.detach()
if mask_numpy.shape[0] > 1:
mask_numpy = torch.argmax(mask_numpy, dim=0, keepdim=True)
# (C, H, W, D), torch.Tensor on device
pair_blend = blend_images(image_numpy, mask_numpy, alpha=0.5) * 255
# permute the axes to (time, channel, height, width)
spatial_dim = list(range(1, len(pair_blend.shape[1:]) + 1))
spatial_dim.remove(time_dim)
pair_blend = pair_blend.permute([time_dim, 0] + spatial_dim).cpu().numpy().astype(np.uint8)
# record in the wandb.Video class
video = wandb.Video(pair_blend, fps=8, caption='{}_{}{}'.format(item_name, vis_type, tag))
videos.append(video)
return videos
def log_scaler(self, key, value, step=None):
"""
Log manually defined scaler data
"""
wandb.log({key: np.round(value, decimals=4)}, step=step)
def log_train_loss(self, step=None):
"""
Log train loss
"""
data_dict = self.train_loss.pop_data(True)
for key, value in data_dict.items():
wandb.log({'train/{}'.format(key): value}, step=step)
def log_val_loss(self, step=None):
"""
Log val loss
"""
data_dict = self.val_loss.pop_data(True)
for key, value in data_dict.items():
wandb.log({'val/{}'.format(key): value}, step=step)
def log_metrics(self, step=None):
"""
Log validation metrics as wandb.Table
"""
df = self.metric_meter.to_df()
wandb.log({'val/metrics': wandb.Table(dataframe=df)}, step=step)
def log_vis(self, key, step=None, time_dim=3, tag=''):
"""
Log training intermediate visualizations
"""
videos = self.volume2videos(time_dim, tag)
wandb.log({key: videos}, step=step)
def update_val_visualization(self, time_dim=3, tag=''):
"""
Update the validation visualization to buffer, called every step of evaluation
"""
videos = self.volume2videos(time_dim, tag)
self.val_table.add_data(self.name, *videos)
def log_val_visualization(self, step=None):
"""
Log validation visualization
"""
wandb.log({'val/visualization': self.val_table}, step=step)
# re-initialize the table for next logging
del self.val_table
self.val_table = wandb.Table(columns=['ID'] + [pair['type'] for pair in self.visual_pairs])