File size: 3,126 Bytes
fd4b932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image, ImageDraw
from scipy.ndimage import center_of_mass, label, sum as area


def nms_on_area(x, s):  # x is a binary image, s is a structuring element
    labels, num_labels = label(x, structure=s)  # find connected components
    if num_labels > 1:
        indexes = np.arange(1, num_labels + 1)
        areas = area(x, labels, indexes)  # compute area for each connected components
        
        biggest = max(zip(areas, indexes))[1]  # get index of largest component
        x[labels != biggest] = 0  # discard other components

    return x


def compute_metrics(p, thr=None, nms=False):
    p = p.squeeze()

    if thr:
        p = p > thr
        if nms:  # perform non-maximum suppression: keep only largest area
            s = np.ones((3, 3))  # connectivity structure
            p = nms_on_area(p, s)

    center = center_of_mass(p)
    area = p.sum()
    return center, area


def visualizable(x, y, alpha=(.5, .5), thr=0):
    xx = np.tile(x, (3,))  # Gray -> RGB: repeat channels 3 times
    yy = (y, ) + (np.zeros_like(x),) * (3 - y.shape[-1])
    yy = np.concatenate(yy, axis=-1)  # add a zero channels to pad to RGB
    mask = yy.max(axis=-1, keepdims=True) > thr  # blend only where a prediction is present
    # mask = mask[:, :, None]
    return np.where(mask, alpha[0] * xx + alpha[1] * yy, xx)


def draw_predictions(image, predictions, thr=None):
    x = image.convert('RGBA')

    maps, tags = predictions
    maps = maps[0] if maps.ndim == 4 else maps
    eye, blink = tags.squeeze()
    alpha = maps.max(axis=-1, keepdims=True)
    alpha = alpha > thr if thr is not None else alpha

    n_pad = 3 - maps.shape[-1]
    zero_channels = np.zeros(image.size + (n_pad,))
    y = np.concatenate((maps, zero_channels, alpha), axis=-1)  # add pad and masked alpha channel
    y = (y * 255).astype(np.uint8)
    y = Image.fromarray(y).convert('RGBA')

    preview = Image.alpha_composite(x, y)
    draw = ImageDraw.Draw(preview)
    draw.text((5, 5), 'E: {: >3.1%}  B:{: >3.1%}'.format(eye, blink), fill=(0, 0, 255))
    # draw.text((5, image.height - 5), ''.format(blink), fill=(255, 0, 0))

    return preview


def visualize(x, y, out=None, thr=0, n_cols=4, width=20):
    n_rows = len(x) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(width, width * n_rows // n_cols))
    y_masks, y_tags = y

    axes = axes.flatten() if isinstance(axes, np.ndarray) else (axes,)
    
    for xi, yi_mask, yi_tags, ax in zip(x, y_masks, y_tags, axes):
        i = visualizable(xi, yi_mask, thr=thr)
        ax.imshow(i, cmap=plt.cm.gray)
        ax.grid(False)
        if len(yi_tags) == 2:
            title = 'E: {:.1%} - B: {:.1%}'
        elif len(yi_tags) == 4:
            title = 'pE: {:.1%} - pB: {:.1%}\ntE: {:.1%} - tB: {:.1%}'

        ax.text(x=0.5, y=-0.02, s=title.format(*yi_tags), transform=ax.transAxes,
                ha='center', va='top',
                fontsize=width * 4 / 5, fontfamily='monospace')
        ax.set_axis_off()

    if out:
        plt.savefig(out, bbox_inches='tight')
        plt.close()