-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfusionmatStats.m
57 lines (49 loc) · 2.16 KB
/
confusionmatStats.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
function stats = confusionmatStats(group,grouphat)
% Confusion Matrix
% Predicted Classes
% p' n'
% ___|_____|_____|
% Actual p | TP | FN |
% Classes n | FP | TN |
% TP: True Positive, FN: False Negative, FP: False Positive, TN:True Negative
% stats.accuracy = (TP + TN)/(TP + FP + FN + TN) ;
% stats.precision = TP / (TP + FP) % for each class label
% stats.sensitivity = TP / (TP + FN) % for each class label
% stats.specificity = TN / (FP + TN) % for each class label
% stats.recall = sensitivity % for each class label
% stats.Fscore = 2*TP /(2*TP + FP + FN) % for each class label
field1 = 'confusionMat';
if nargin < 2
value1 = group;
else
[value1,gorder] = confusionmat(group,grouphat);
end
numOfClasses = size(value1,1);
totalSamples = sum(sum(value1));
[TP,TN,FP,FN,accuracy,sensitivity,specificity,precision,f_score] = deal(zeros(numOfClasses,1));
for class = 1:numOfClasses
TP(class) = value1(class,class);
tempMat = value1;
tempMat(:,class) = []; % remove column
tempMat(class,:) = []; % remove row
TN(class) = sum(sum(tempMat));
FP(class) = sum(value1(:,class))-TP(class);
FN(class) = sum(value1(class,:))-TP(class);
end
for class = 1:numOfClasses
accuracy(class) = (TP(class) + TN(class)) / totalSamples;
sensitivity(class) = TP(class) / (TP(class) + FN(class));
specificity(class) = TN(class) / (FP(class) + TN(class));
precision(class) = TP(class) / (TP(class) + FP(class));
f_score(class) = 2*TP(class)/(2*TP(class) + FP(class) + FN(class));
end
field2 = 'accuracy'; value2 = accuracy;
field3 = 'sensitivity'; value3 = sensitivity;
field4 = 'specificity'; value4 = specificity;
field5 = 'precision'; value5 = precision;
field6 = 'recall'; value6 = sensitivity;
field7 = 'Fscore'; value7 = f_score;
stats = struct(field1,value1,field2,value2,field3,value3,field4,value4,field5,value5,field6,value6,field7,value7);
if exist('gorder','var')
stats = struct(field1,value1,field2,value2,field3,value3,field4,value4,field5,value5,field6,value6,field7,value7,'groupOrder',gorder);
end