% Solution to Exercise 18.6
% prunetree( Tree, PrunedTree): PrunedTree is optimally pruned Tree
% with respect to estimated classification error using Laplace estimate
% Assume trees are binary:
% Tree = leaf( Node, ClassFrequencyList), or
% Tree = tree( Root, LeftSubtree, RightSubtree)
prunetree( Tree, PrunedTree) :-
prune( Tree, PrunedTree, Error, FrequencyList).
% prune( Tree, PrunedTree, Error, FrequencyList):
% PrunedTree is optimally pruned Tree with classification Error,
% FrequencyList is the list of frequencies of classes at root of Tree
prune( leaf( Node, FreqList), leaf( Node, FreqList), Error, FreqList) :-
static_error( FreqList, Error).
prune( tree( Root, Left, Right), PrunedT, Error, FreqList) :-
prune( Left, Left1, LeftError, LeftFreq),
prune( Right, Right1, RightError, RightFreq),
sumlists( LeftFreq, RightFreq, FreqList), % Add corresponding elements
static_error( FreqList, StaticErr),
sum( LeftFreq, N1),
sum( RightFreq, N2),
BackedErr is ( N1*LeftError + N2*RightError) / ( N1 + N2),
decide( StaticErr, BackedErr, Root, FreqList, Left1, Right1, Error, PrunedT).
% Decide to prune or not:
decide( StatErr, BackErr, Root, FreqL, _, _, StatErr, leaf( Root, FreqL)) :-
StatErr =< BackErr, !. % Static error smaller: prune subtrees
% Otherwise do not prune:
decide( _, BackErr, Root, _, Left, Right, BackErr, tree( Root, Left, Right)).
% static_error( ClassFrequencyList, Error): estimated class. error
static_error( FreqList, Error) :- % Use Laplace estimate
max( FreqList, Max), % Maximum number in FreqList
sum( FreqList, All), % Sum of numbers in FreqList
number_of_classes( NumClasses),
Error is ( All - Max + NumClasses - 1) / ( All + NumClasses).
sum( [], 0).
sum( [Number | Numbers], Sum) :-
sum( Numbers, Sum1),
Sum is Sum1 + Number.
max( [X], X).
max( [X,Y | List], Max) :-
X > Y, !, max( [X | List], Max)
;
max( [Y | List], Max).
sumlists( [], [], []).
sumlists( [X1 | L1], [X2 | L2], [X3 | L3]) :-
X3 is X1 + X2,
sumlists( L1, L2, L3).
% A tree
tree1( tree( a, % Root
tree( b, leaf( e, [3,2]), leaf( f, [1,0])), % Left subtree
tree( c, tree( d, leaf( g, [1,1]), leaf( h,[0,1])), leaf(i,[1,0])))).
number_of_classes( 2).
% Test query: ?- tree1( Tree), prunetree( Tree, PrunedTree).