package MachineLearning;
import java.util.*;
/**
* Klasa drzewa decyzyjnego. Zbudeje drzewo na podstawie przykładów uczących
* @author Angello
*
*/
public class DecisionTree
{
private TermostatLeaf m_rootNode = null;
private Object[][] m_learingTable = null; // zbiór uczący
/**
* Zawiera w tablicy nazwy atrybutów konieczne do rozszyfrowania, które Testy jeszcze zostały
*/
private String[] m_testNameList = null;
/**
* Indeks pod jakim można w zbiorze uczącym znaleźć przydzieloną kategorię
*/
private Integer m_categoryIndex;
private AttrChangeCollection m_attrChangeCollection;
/**
* Próg entropii używany w agregacji atrybutów porządkowych
*/
private double m_entropyTreshhold = 0.091;
/**
* Domyślny konstruktor
*/
public DecisionTree()
{
m_attrChangeCollection = new AttrChangeCollection();
}
/**
* Buduje drzewo decyzyjne
* @param learningTable Zbiór uczący
* @param attributesNames Nazwy kolejnych atrybutów
*/
public void initialize(Object[][] learningTable, String[] attributesNames, boolean c45)
{
ArrayList<Integer> tresholds;
ArrayList<Integer> S;
m_learingTable = learningTable; // Przypisanie zbioru uczącego
m_categoryIndex = learningTable[0].length - 1;
ArrayList<Integer> P = new ArrayList<Integer>();
for (int i = 0; i < m_learingTable.length; i++) // Zamiana zbioru na indeksy
P.add(i);
if(c45)
{
for (int i = 0; i < m_categoryIndex; i++) // Próba zagregowania każdego atrybutu
{
tresholds = discreteDown(P, i);
m_attrChangeCollection.Add(tresholds, i);
}
m_attrChangeCollection.Change(m_learingTable);
}
m_testNameList = attributesNames;
S = new ArrayList<Integer>();
for (int i = 0; i < m_testNameList.length; i++) // Testy tożsame z nowymi atrybutami
S.add(i);
Integer d = chooseCategory(P);
m_rootNode = buildTree(P, d, S);
ShowedNode showedNode = m_rootNode.showTree();
WriteFile sw = new WriteFile();
sw.open("Graph_file.txt");
sw.writeLine("digraph G {"); //zapisuje do pliku tekst
showedNode.getGraphvizName(sw);
sw.close();
sw.writeLine("}"); //zapisuje do pliku nastepny tekst
sw.close(); // zamyka plik
m_attrChangeCollection.WriteAggregateValues();
}
/**
* Funcja zwraca nowo stworzony węzęł lub liść(jeśli nie ma sensu dla węzła)
* @param P Zbiór uczący
* @param d Domyślna etykieta
* @param S Dostępne testy
* @return Liść lub węzeł
*/
private TermostatLeaf buildTree(ArrayList<Integer> P, Integer d, ArrayList<Integer> S)
{
if (P.size() == 0) // kryterium Stopu
return new TermostatLeaf(d, P); // Zostaje kategoria domyślna
TermostatLeaf tempLeaf = onlyOneCategory(P); // Jedyna kategoria
if (tempLeaf != null) // kryterium Stopu
return tempLeaf;
if (S.size() == 0) // kryterium Stopu
return new TermostatLeaf(chooseCategory(P), P); // Najczęściej pojawiająca się kategoria
Integer testIndex = chooseTest(P, S);
NominalTest t = new NominalTest(m_testNameList[S.get(testIndex)], S.get(testIndex)); // Określamy który atrybut jest dla tego testu
ArrayList<Integer> STemp = new ArrayList<Integer>(S);
STemp.remove(testIndex); // Wyrzucenie już użytego testu
for (Integer index : P)
{
t.AddResult((Integer)m_learingTable[index][t.AttributeIndex]);
}
d = chooseCategory(P);
TermostatNode node = new TermostatNode(t, d, P);
for (Integer r : t.Results())
{
ArrayList<Integer> PTemp = new ArrayList<Integer>(); // Utworzenie nowej listy tylko z tych danych które prowadzą do 'r'
for (Integer indexP : P) // Dla każdego przykładu etykietowanego
{
if ((Integer)m_learingTable[indexP][t.AttributeIndex] == r) // Czy atrybut przykładu zgadza się z naszym rezulatatem 'r'
{
PTemp.add(indexP); // Dodanie do P które przekażemy do niższego węzła
}
}
node.addNode(r, buildTree(PTemp, d, STemp)); // Najpierw zostanie zbudowany nowy węzeł dopiero później nie zostanie nigdzie dodany
}
return node;
}
/**
* Funkcja zakwalifikuje podany ciąg atrybutów do jednej kategorii
* @param attributes Zbiór atrybutów
* @return Kategoria
*/
public Integer findCategory(Object[] attributes)
{
m_attrChangeCollection.Change(attributes);
return m_rootNode.getCategory(attributes);
}
/**
* Wybiera najbardziej liczną kategorię w podanym zbiorze uczącym
* @param P Zbiór uczący
* @return Kategoria
*/
private Integer chooseCategory(ArrayList<Integer> P)
{
HashMap<Integer, Integer> categoryCounter = new HashMap<Integer, Integer>();
int category = 0;
for (Integer var : P)
{
category = (Integer)m_learingTable[var][m_categoryIndex]; // Kategoria tych uczących które jeszcze zostały w indexList
if (categoryCounter.containsKey(category))
{
categoryCounter.put(category, categoryCounter.get(category) + 1);
}
else
{
categoryCounter.put(category, 1);
}
}
Integer biggerAmount = 0;
for (Integer key : categoryCounter.keySet())
{
if (biggerAmount < categoryCounter.get(key))
{
category = key;
}
}
return category;
}
/**
* Wybiera najlepszy test na podstawie najmniejszej entropii
* @param P Zbiór uczący
* @param S Zbiór testów
* @return Wybrany test
*/
private Integer chooseTest(ArrayList<Integer> P, ArrayList<Integer> S)
{
if (S.size() > 0)
{
NominalTest test = null;
double minEntropy = Double.POSITIVE_INFINITY;
Integer bestIndex = 0;
for (Integer i = 0; i < S.size(); i++)
{
test = new NominalTest(m_testNameList[S.get(i)], S.get(i)); // Określamy który atrybut jest dla tego testu
double testEntropy = countTestEntropy(P, test);
if (testEntropy < minEntropy)
{
minEntropy = testEntropy;
bestIndex = i;
}
}
return bestIndex;
}
else
{
throw new RuntimeException("Brak sprawdzania czy testy sie nie skończyły");
}
}
/**
* Sprawdza czy w zbiorze uczącym jest tylko jedna kategoria
* @param P Zbiór uczący
* @return Liść lub null w przypadku braku pożadanego węzła
*/
private TermostatLeaf onlyOneCategory(ArrayList<Integer> P)
{
if (P.size() > 0)
{
boolean repeated = false;
Integer category = (Integer)m_learingTable[P.get(0)][m_categoryIndex];
for (Integer i = 1; i < P.size(); i++)
{
if (category != (Integer)m_learingTable[P.get(i)][m_categoryIndex]) // Jeśli któryś jest inny niż pierwszy
repeated = true;
}
if (!repeated) // Jeśli się nie powtórzył
{
return new TermostatLeaf(category, P); // Zwraca jedyną kategorię
}
return null;
}
else
{
throw new RuntimeException("Brak sprawdzania czy zbiór uczący nie jest pusty");
}
}
/**
* Wylicza entropię danego przykładu trenującego dla danego testu
* @param P Zbiór uczący
* @param test Test
* @return Entropia
*/
private double countTestEntropy(ArrayList<Integer> P, NominalTest test)
{
double allSum = 0;
Integer attrIndex = test.AttributeIndex;
ArrayList<Integer> result = new ArrayList<Integer>();
ArrayList<Integer> temp = new ArrayList<Integer>();
for (Integer indexP : P)
if (!result.contains((Integer)m_learingTable[indexP][attrIndex]))
result.add((Integer)m_learingTable[indexP][attrIndex]); // Pozbieranie możliwych rezultatów
for (Integer r : result) // Dla kolejnych rezultatów testu
{
temp.clear();
for (Integer indexP : P)
{
if ((Integer)m_learingTable[indexP][attrIndex] == r) // Czy przykład należy do tego rezultatu
{
temp.add(indexP);
}
}
allSum += countEntropy(temp) * temp.size() / P.size();; // Entropia Et(P) czyli suma po r |Ptr|/|P| * Etr(P)
}
return allSum;
}
/**
* Rekurencyjnie wywoływana funkcja dyskretyzacji-zstępującej
* @param P Zbiór uczący
* @param attrIndex Indeks opisujący atrybut
* @return Lista punktów dzielących wartości
*/
private ArrayList<Integer> discreteDown(ArrayList<Integer> P, Integer attrIndex)
{
if (onlyOneValue(P,attrIndex)) // Kryterium Stopu. Jeśli została tylko jedna wartość atrybutu w zbiorze uczącym
return new ArrayList<Integer>();
Integer treshold;
ArrayList<Integer> P0 = new ArrayList<Integer>(P);
ArrayList<Integer> P1 = new ArrayList<Integer>();
treshold = chooseTreshold(P, attrIndex); // Wybranie progu
splitList(P, attrIndex, treshold, P0, P1);
double information = countEntropy(P);
double entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size();
if (information - entropy <= m_entropyTreshhold) // Kryterium Stopu bazujące na wzrastaniu informacji
return new ArrayList<Integer>();
ArrayList<Integer> fi0 = discreteDown(P0, attrIndex); // Dzielenie zbioru < treshold
ArrayList<Integer> fi1 = discreteDown(P1, attrIndex); // Dzielenie zbioru >= treshold
ArrayList<Integer> ret = new ArrayList<Integer>(fi0);
ret.add(treshold);
ret.addAll(fi1);
return ret;
}
/**
* Wylicza entropię
* @param top
* @param bottom
* @return
*/
private double countEntropy(Integer top, Integer bottom)
{
double temp = (double)top / (double)bottom;
return - temp * Math.log(temp);
}
/**
* Oblicza entropię Etr(P) czyli sumuje entropie po wszystkiech kategoriach pojęcia docelowego
* @param P Zbiór uczący
* @return
*/
private double countEntropy(ArrayList<Integer> P)
{
Integer key;
double sum = 0;
HashMap<Integer, Integer> dict = new HashMap<Integer,Integer>(); // Kategoria -> ilość wystąpień Kategorii
for (Integer indexP : P)
{
key = (Integer)m_learingTable[indexP][m_categoryIndex];
dict.put(key,dict.containsKey(key) ? dict.get(key) + 1 : 1) ; // Zliczanie jak liczne są poszczególne kategorie
}
sum = 0;
for (Integer count : dict.values())
{
sum += countEntropy(count, P.size()); // Wyliczanie entropii Etr(P)
}
return sum;
}
/**
* Wybiera próg dzielący wartości
* @param P Zbiór uczący
* @param attrIndex Indeks opisujący atrybut
* @return Próg
*/
private Integer chooseTreshold(ArrayList<Integer> P, Integer attrIndex)
{
if (P.size() > 0)
{
double entropy;
double minEntropy = Double.POSITIVE_INFINITY;
Integer bestTreshold = 0;
ArrayList<Integer> P0;
ArrayList<Integer> P1;
ArrayList<Integer> tresholdList = new ArrayList<Integer>();
for (Integer indexP : P)
{
Integer tempa = (Integer)m_learingTable[indexP][attrIndex];
if (!tresholdList.contains(tempa))
tresholdList.add(tempa); // Wybranie możliwości podziału
}
Collections.sort(tresholdList);
for (Integer treshold : tresholdList) // Dla kolejnych wartości rozdzielających (progowych)
{
P0 = new ArrayList<Integer>(P);
P1 = new ArrayList<Integer>();
splitList(P, attrIndex, treshold, P0, P1);
entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size();
if (entropy < minEntropy)
{
minEntropy = entropy;
bestTreshold = treshold;
}
}
return bestTreshold;
}
else
{
throw new RuntimeException("P jest pusty, nie można wyznaczyć progu");
}
}
/**
* Dzieli zbiór przykładów na <= od progu oraz > od progu
* @param P Zbiór przykładów
* @param attrIndex Indeks opisujący atrybut
* @param treshold Próg podziału
* @param P0 Wartości poniżej progu
* @param P1 >Wartości powyżej progu
*/
private void splitList(ArrayList<Integer> P, Integer attrIndex, Integer treshold, ArrayList<Integer> P0, ArrayList<Integer> P1)
{
for (int i = 0; i < P0.size(); i++) // Jedziemy do końca listy która nam się skraca, może inny warunek
{
if ((Integer)m_learingTable[P0.get(i)][attrIndex] > treshold) // Jeśli większy niż treshold to:
{
P1.add(P0.get(i)); // Dodajemy do większego od treshold
P0.remove(i--); // Zdejmujemy z mniejszego lub równego od treshold
}
}
}
/**
* Sprawdza czy tylko jedna wartość występuje w danym zbiorze
* @param P Zbiór przykładów
* @param attrIndex Indeks opisujący atrybut
* @return
*/
private Boolean onlyOneValue(ArrayList<Integer> P, Integer attrIndex)
{
Integer value = (Integer)m_learingTable[P.get(0)][attrIndex];
for (Integer indexP : P)
{
if ((Integer)m_learingTable[indexP][attrIndex] != value)
return false;
}
return true;
}
}