|
|
|
|
|
'''
|
|
@license: (C) Copyright 2021, Hey.
|
|
@author: Hey
|
|
@email: sanyuan.**@**.com
|
|
@tel: 137****6540
|
|
@datetime: 2022/11/26 21:05
|
|
@project: LucaOne
|
|
@file: metrics.py
|
|
@desc: metrics for binary classification or multi-class classification
|
|
'''
|
|
import csv
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
plt.rcParams.update({'font.size': 18})
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, \
|
|
average_precision_score, confusion_matrix, mean_absolute_error, mean_squared_error, r2_score
|
|
|
|
|
|
def topk_accuracy_score(targets, probs, k=3):
|
|
'''
|
|
topk accuracy
|
|
:param targets:
|
|
:param probs:
|
|
:param k:
|
|
:return:
|
|
'''
|
|
|
|
max_k_preds = probs.argsort(axis=1)[:, -k:][:, ::-1]
|
|
a_real = np.resize(targets, (targets.shape[0], 1))
|
|
|
|
match_array = np.logical_or.reduce(max_k_preds == a_real, axis=1)
|
|
topk_acc_score = match_array.sum() / match_array.shape[0]
|
|
return topk_acc_score
|
|
|
|
|
|
def multi_class_acc(targets, probs, threshold=0.5):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
preds = np.argmax(probs, axis=1)
|
|
return accuracy_score(targets, preds)
|
|
|
|
|
|
def multi_class_precision(targets, probs, average='macro'):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
preds = np.argmax(probs, axis=1)
|
|
return precision_score(targets, preds, average=average)
|
|
|
|
|
|
def multi_class_recall(targets, probs, average='macro'):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
preds = np.argmax(probs, axis=1)
|
|
return recall_score(targets, preds, average=average)
|
|
|
|
|
|
def multi_class_f1(targets, probs, average='macro'):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
preds = np.argmax(probs, axis=1)
|
|
return f1_score(targets, preds, average=average)
|
|
|
|
|
|
def multi_class_roc_auc(targets, probs, average='macro'):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
return roc_auc_score(targets, probs, average=average, multi_class='ovr')
|
|
|
|
|
|
def multi_class_pr_auc(targets, probs, average='macro'):
|
|
if targets.ndim == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
z = probs.shape[1]
|
|
new_targets = np.eye(z)[targets]
|
|
pr_auc = average_precision_score(new_targets, probs, average=average)
|
|
return pr_auc
|
|
|
|
|
|
def metrics_multi_class(targets, probs, average="macro"):
|
|
'''
|
|
metrics of multi-class classification
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param probs: 2d-array probability (n_samples, m_classes)
|
|
:return:
|
|
'''
|
|
if targets.ndim == 2 and targets.shape[1] > 1:
|
|
targets = np.argmax(targets, axis=1)
|
|
elif targets.ndim == 2 and targets.shape[1] == 1:
|
|
targets = np.squeeze(targets, axis=1)
|
|
|
|
preds = np.argmax(probs, axis=1)
|
|
acc = accuracy_score(targets, preds)
|
|
prec = precision_score(targets, preds, average=average)
|
|
recall = recall_score(targets, preds, average=average)
|
|
f1 = f1_score(targets, preds, average=average)
|
|
result = {
|
|
"acc": round(float(acc), 6),
|
|
"prec": round(float(prec), 6),
|
|
"recall": round(float(recall), 6),
|
|
"f1": round(float(f1), 6)
|
|
}
|
|
result.update({
|
|
"top2_acc": round(float(topk_accuracy_score(targets, probs, k=2)), 6),
|
|
"top3_acc": round(float(topk_accuracy_score(targets, probs, k=3)), 6),
|
|
"top5_acc": round(float(topk_accuracy_score(targets, probs, k=5)), 6),
|
|
"top10_acc": round(float(topk_accuracy_score(targets, probs, k=10)), 6)
|
|
})
|
|
try:
|
|
roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
|
|
result.update({
|
|
"roc_auc": round(float(roc_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
z = probs.shape[1]
|
|
new_targets = np.eye(z)[targets]
|
|
pr_auc = average_precision_score(new_targets, probs, average=average)
|
|
result.update({
|
|
"pr_auc": round(float(pr_auc), 6),
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
return result
|
|
|
|
|
|
def metrics_multi_class_for_pred(targets, preds, probs=None, average="macro", savepath=None):
|
|
'''
|
|
metrcis for multi-class classification
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param preds: 1d-array class index (n_samples, )
|
|
:return:
|
|
'''
|
|
if targets.ndim == 2 and targets.shape[1] > 1:
|
|
targets = np.argmax(targets, axis=1)
|
|
elif targets.ndim == 2 and targets.shape[1] == 1:
|
|
targets = np.squeeze(targets, axis=1)
|
|
|
|
acc = accuracy_score(targets, preds)
|
|
prec = precision_score(targets, preds, average=average)
|
|
recall = recall_score(targets, preds, average=average)
|
|
f1 = f1_score(y_true=targets, y_pred=preds, average=average)
|
|
result = {
|
|
"acc": round(float(acc), 6),
|
|
"prec": round(float(prec), 6),
|
|
"recall": round(float(recall), 6),
|
|
"f1": round(float(f1), 6)
|
|
}
|
|
try:
|
|
roc_auc = roc_auc_score(targets, probs, average=average, multi_class='ovr')
|
|
result.update({
|
|
"roc_auc": round(float(roc_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
z = probs.shape[1]
|
|
new_targets = np.eye(z)[targets]
|
|
pr_auc = average_precision_score(new_targets, probs, average=average)
|
|
result.update({
|
|
"pr_auc": round(float(pr_auc), 6),
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
return result
|
|
|
|
|
|
def metrics_regression(targets, preds):
|
|
'''
|
|
metrcis for regression
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param preds: 1d-array class index (n_samples, )
|
|
:return:
|
|
'''
|
|
mae = mean_absolute_error(targets, preds)
|
|
mse = mean_squared_error(targets, preds)
|
|
r2 = r2_score(targets, preds)
|
|
return {
|
|
"mae": round(float(mae), 6),
|
|
"mse": round(float(mse), 6),
|
|
"r2": round(float(r2), 6)
|
|
}
|
|
|
|
|
|
def transform(targets, probs, threshold):
|
|
'''
|
|
metrics of binary classification
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param probs: 1d-array larger class probability (n_samples, )
|
|
:param threshold: 0-1 prob threshokd
|
|
:return:
|
|
'''
|
|
if targets.ndim == 2:
|
|
if targets.shape[1] == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
else:
|
|
targets = targets.flatten()
|
|
if probs.ndim == 2:
|
|
if probs.shape[1] == 2:
|
|
preds = np.argmax(probs, axis=1)
|
|
probs = probs[:, 1].flatten()
|
|
else:
|
|
preds = (probs >= threshold).astype(int).flatten()
|
|
probs = probs.flatten()
|
|
else:
|
|
preds = (probs >= threshold).astype(int)
|
|
return targets, probs, preds
|
|
|
|
|
|
def binary_acc(targets, probs, threshold=0.5):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return accuracy_score(targets, preds)
|
|
|
|
|
|
def binary_precision(targets, probs, threshold=0.5, average='binary'):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return precision_score(targets, preds, average=average)
|
|
|
|
|
|
def binary_recall(targets, probs, threshold=0.5, average='binary'):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return recall_score(targets, preds, average=average)
|
|
|
|
|
|
def binary_f1(targets, probs, threshold=0.5, average='binary'):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return f1_score(targets, preds, average=average)
|
|
|
|
|
|
def binary_roc_auc(targets, probs, threshold=0.5, average='macro'):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return roc_auc_score(targets, probs, average=average)
|
|
|
|
|
|
def binary_pr_auc(targets, probs, threshold=0.5, average='macro'):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
return average_precision_score(targets, probs, average=average)
|
|
|
|
|
|
def binary_confusion_matrix(targets, probs, threshold=0.5, savepath=None):
|
|
targets, probs, preds = transform(targets, probs, threshold)
|
|
cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
|
|
plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
|
|
tn, fp, fn, tp = cm_obj.ravel()
|
|
cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
|
|
return cm
|
|
|
|
|
|
def metrics_binary(targets, probs, threshold=0.5, average="binary", savepath=None):
|
|
'''
|
|
metrics for binary classification
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param probs: 1d-array larger class probability (n_samples, )
|
|
:param threshold: 0-1 prob threshold
|
|
:return:
|
|
'''
|
|
if targets.ndim == 2:
|
|
if targets.shape[1] == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
else:
|
|
targets = targets.flatten()
|
|
if probs.ndim == 2:
|
|
if probs.shape[1] == 2:
|
|
preds = np.argmax(probs, axis=1)
|
|
probs = probs[:, 1].flatten()
|
|
else:
|
|
preds = (probs >= threshold).astype(int).flatten()
|
|
probs = probs.flatten()
|
|
else:
|
|
preds = (probs >= threshold).astype(int)
|
|
acc = accuracy_score(targets, preds)
|
|
prec = precision_score(targets, preds, average=average)
|
|
recall = recall_score(targets, preds, average=average)
|
|
f1 = f1_score(targets, preds, average=average)
|
|
result = {
|
|
"acc": round(float(acc), 6),
|
|
"prec": round(float(prec), 6),
|
|
"recall": round(float(recall), 6),
|
|
"f1": round(float(f1), 6)
|
|
}
|
|
try:
|
|
roc_auc = roc_auc_score(targets, probs, average="macro")
|
|
result.update({
|
|
"roc_auc": round(float(roc_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
pr_auc = average_precision_score(targets, probs, average="macro")
|
|
result.update({
|
|
"pr_auc": round(float(pr_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
|
|
plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
|
|
tn, fp, fn, tp = cm_obj.ravel()
|
|
cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
|
|
result.update({
|
|
"confusion_matrix": cm
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
|
|
try:
|
|
tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
|
|
mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
|
|
result.update({
|
|
"mcc": round(mcc, 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
return result
|
|
|
|
|
|
def metrics_binary_for_pred(targets, preds, probs=None, average="binary", savepath=None):
|
|
'''
|
|
metrics for binary classification
|
|
:param targets: 1d-array class index (n_samples, )
|
|
:param preds: 1d-array larger class index (n_samples, )
|
|
:return:
|
|
'''
|
|
if targets.ndim == 2:
|
|
if targets.shape[1] == 2:
|
|
targets = np.argmax(targets, axis=1)
|
|
else:
|
|
targets = targets.flatten()
|
|
if preds.ndim == 2:
|
|
if preds.shape[1] == 2:
|
|
preds = np.argmax(preds, axis=1)
|
|
else:
|
|
preds = preds.flatten()
|
|
cm_obj = confusion_matrix(targets, preds, labels=[0, 1])
|
|
plot_confusion_matrix_for_binary_class(targets, preds, cm=cm_obj, savepath=savepath)
|
|
tn, fp, fn, tp = cm_obj.ravel()
|
|
cm = {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
|
|
if len(np.unique(targets)) > 1:
|
|
acc = accuracy_score(targets, preds)
|
|
prec = precision_score(targets, preds, average=average)
|
|
recall = recall_score(targets, preds, average=average)
|
|
f1 = f1_score(y_true=targets, y_pred=preds, average=average)
|
|
result = {
|
|
"acc": round(float(acc), 6),
|
|
"prec": round(float(prec), 6),
|
|
"recall": round(float(recall), 6),
|
|
"f1": round(float(f1), 6)
|
|
}
|
|
else:
|
|
|
|
result = {
|
|
"acc": round(float((cm["tp"] + cm["tn"]) / (cm["tp"] + cm["tn"] + cm["fp"] + cm["fn"])), 6),
|
|
"prec": round(float(cm["tp"]/(cm["tp"] + cm["fp"]) if cm["tp"] + cm["fp"] > 0 else 1.0), 6),
|
|
"recall": round(float(cm["tp"]/(cm["tp"] + cm["fn"]) if cm["tp"] + cm["fn"] > 0 else 1.0), 6),
|
|
}
|
|
result["f1"] = 2 * result["prec"] * result["recall"] / (result["prec"] + result["recall"])
|
|
|
|
try:
|
|
pr_auc = average_precision_score(targets, probs, average="macro")
|
|
result.update({
|
|
"pr_auc": round(float(pr_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
roc_auc = roc_auc_score(targets, probs, average="macro")
|
|
result.update({
|
|
"roc_auc": round(float(roc_auc), 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
try:
|
|
tn, fp, fn, tp = cm["tn"], cm["fp"], cm["fn"], cm["tp"]
|
|
mcc = (tn*tp - fp*fn) / (((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn)) ** 0.5)
|
|
result.update({
|
|
"mcc": round(mcc, 6)
|
|
})
|
|
except Exception as e:
|
|
pass
|
|
result.update({
|
|
"confusion_matrix": cm
|
|
})
|
|
return result
|
|
|
|
|
|
def write_error_samples_multi_class(filepath, samples, input_indexs, input_id_2_names, output_id_2_name, targets, probs,
|
|
use_other_diags=False, use_other_operas=False, use_checkin_department=False):
|
|
'''
|
|
write the bad cases of multi-class classification
|
|
:param filepath:
|
|
:param samples:
|
|
:param input_indexs:
|
|
:param input_id_2_names:
|
|
:param output_id_2_name:
|
|
:param targets:
|
|
:param probs:
|
|
:param use_other_diags:
|
|
:param use_other_operas:
|
|
:param use_checkin_department:
|
|
:return:
|
|
'''
|
|
targets = np.argmax(targets, axis=1)
|
|
preds = np.argmax(probs, axis=1)
|
|
with open(filepath, "w") as fp:
|
|
writer = csv.writer(fp)
|
|
writer.writerow(["score", "y_true", "y_pred", "inputs"])
|
|
for i in range(len(targets)):
|
|
target = targets[i]
|
|
pred = preds[i]
|
|
score = 1
|
|
if target != pred:
|
|
score = 0
|
|
if output_id_2_name:
|
|
target_label = output_id_2_name[target]
|
|
pred_label = output_id_2_name[pred]
|
|
else:
|
|
target_label = target
|
|
pred_label = pred
|
|
sample = samples[i]
|
|
if input_id_2_names:
|
|
new_sample = []
|
|
for idx, input_index in enumerate(input_indexs):
|
|
if input_index == 3 and not use_checkin_department:
|
|
input_index = 12
|
|
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
|
|
if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
|
|
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
|
|
else:
|
|
new_sample = sample
|
|
row = [score, target_label, pred_label, new_sample]
|
|
writer.writerow(row)
|
|
|
|
|
|
def write_error_samples_binary(filepath, samples, input_indexs, input_id_2_names, targets, probs, threshold=0.5,
|
|
use_other_diags=False, use_other_operas=False, use_checkin_department=False):
|
|
'''
|
|
write bad cases of binary classification
|
|
:param filepath:
|
|
:param samples:
|
|
:param input_indexs:
|
|
:param input_id_2_names:
|
|
:param targets:
|
|
:param probs:
|
|
:param threshold:
|
|
:param use_other_diags:
|
|
:param use_other_operas:
|
|
:param use_checkin_department:
|
|
:return:
|
|
'''
|
|
with open(filepath, "w") as fp:
|
|
writer = csv.writer(fp)
|
|
writer.writerow(["score", "y_true", "y_pred", "inputs"])
|
|
for i in range(len(targets)):
|
|
target = targets[i][0]
|
|
if target != 1:
|
|
target = 1
|
|
prob = probs[i][0]
|
|
if prob >= threshold:
|
|
pred = 1
|
|
else:
|
|
pred = 0
|
|
score = 1
|
|
if target != pred:
|
|
score = 0
|
|
target_label = "True" if target == 1 else "False"
|
|
pred_label = "True" if target == 1 else "False"
|
|
sample = samples[i]
|
|
if input_id_2_names:
|
|
new_sample = []
|
|
for idx, input_index in enumerate(input_indexs):
|
|
if input_index == 3 and not use_checkin_department:
|
|
input_index = 12
|
|
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index]])
|
|
if (input_index == 6 and use_other_diags) or (input_index == 8 and use_other_operas) or (input_index == 10 and use_other_diags):
|
|
new_sample.append([input_id_2_names[idx][v] for v in sample[input_index + 1]])
|
|
else:
|
|
new_sample = sample
|
|
row = [score, target_label, pred_label, new_sample]
|
|
writer.writerow(row)
|
|
|
|
|
|
def plot_confusion_matrix_for_binary_class(targets, preds, cm=None, savepath=None):
|
|
'''
|
|
:param targets: ground truth
|
|
:param preds: prediction probs
|
|
:param cm: confusion matrix
|
|
:param savepath: confusion matrix picture savepth
|
|
'''
|
|
|
|
plt.figure(figsize=(40, 20), dpi=100)
|
|
if cm is None:
|
|
cm = confusion_matrix(targets, preds, labels=[0, 1])
|
|
|
|
plt.matshow(cm, cmap=plt.cm.Oranges)
|
|
plt.colorbar()
|
|
|
|
for x in range(len(cm)):
|
|
for y in range(len(cm)):
|
|
plt.annotate(cm[x, y], xy=(y, x), verticalalignment='center', horizontalalignment='center')
|
|
plt.ylabel('True')
|
|
plt.xlabel('Prediction')
|
|
if savepath:
|
|
plt.savefig(savepath, dpi=100)
|
|
else:
|
|
plt.show()
|
|
plt.close("all")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
'''multi_class'''
|
|
targets = np.array([0, 1, 2, 1, 3])
|
|
probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.35, 0]])
|
|
print(metrics_multi_class(targets, probs))
|
|
|
|
targets = np.array([0, 1, 2, 3, 3])
|
|
probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
|
|
print(metrics_multi_class(targets, probs))
|
|
targets = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1]])
|
|
probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.25, 0.25, 0.4], [0.1, 0.25, 0.25, 0.4]])
|
|
print(metrics_multi_class(targets, probs))
|
|
|
|
'''binary'''
|
|
targets = np.array([0, 0, 1, 1])
|
|
probs = np.array([[0.1], [0.1], [0.1], [0.9]])
|
|
print(metrics_binary(targets, probs))
|
|
|
|
targets = np.array([[0], [0], [1], [1]])
|
|
probs = np.array([[0.1], [0.1], [0.1], [0.9]])
|
|
print(metrics_binary(targets, probs))
|
|
|
|
targets = np.array([0, 0, 1, 1])
|
|
probs = np.array([[0.1, 0.1, 0.1, 0.9]])
|
|
print(metrics_binary(targets, probs))
|
|
|
|
targets = np.array([0, 0, 1, 1])
|
|
probs = np.array([0.1, 0.1, 0.1, 0.9])
|
|
print(metrics_binary(targets, probs))
|
|
|
|
targets = np.array([0, 1, 2, 1, 3])
|
|
probs = np.array([[0.9, 0.05, 0.05, 0], [0.5, 0.45, 0.05, 0], [0.4, 0.05, 0.55, 0], [0.1, 0.55, 0.25, 0.1], [0.4, 0.25, 0.25, 0.1]])
|
|
z = probs.shape[1]
|
|
|
|
print(np.eye(z))
|
|
new_targets = np.eye(z)[targets]
|
|
print(new_targets)
|
|
|
|
|