package MachineLearning; import java.util.*; /** * Klasa drzewa decyzyjnego. Zbudeje drzewo na podstawie przykładów uczących * @author Angello * */ public class DecisionTree { private TermostatLeaf m_rootNode = null; private Object[][] m_learingTable = null; // zbiór uczący /** * Zawiera w tablicy nazwy atrybutów konieczne do rozszyfrowania, które Testy jeszcze zostały */ private String[] m_testNameList = null; /** * Indeks pod jakim można w zbiorze uczącym znaleźć przydzieloną kategorię */ private Integer m_categoryIndex; private AttrChangeCollection m_attrChangeCollection; /** * Próg entropii używany w agregacji atrybutów porządkowych */ private double m_entropyTreshhold = 0.091; /** * Domyślny konstruktor */ public DecisionTree() { m_attrChangeCollection = new AttrChangeCollection(); } /** * Buduje drzewo decyzyjne * @param learningTable Zbiór uczący * @param attributesNames Nazwy kolejnych atrybutów */ public void initialize(Object[][] learningTable, String[] attributesNames, boolean c45) { ArrayList tresholds; ArrayList S; m_learingTable = learningTable; // Przypisanie zbioru uczącego m_categoryIndex = learningTable[0].length - 1; ArrayList P = new ArrayList(); for (int i = 0; i < m_learingTable.length; i++) // Zamiana zbioru na indeksy P.add(i); if(c45) { for (int i = 0; i < m_categoryIndex; i++) // Próba zagregowania każdego atrybutu { tresholds = discreteDown(P, i); m_attrChangeCollection.Add(tresholds, i); } m_attrChangeCollection.Change(m_learingTable); } m_testNameList = attributesNames; S = new ArrayList(); for (int i = 0; i < m_testNameList.length; i++) // Testy tożsame z nowymi atrybutami S.add(i); Integer d = chooseCategory(P); m_rootNode = buildTree(P, d, S); ShowedNode showedNode = m_rootNode.showTree(); WriteFile sw = new WriteFile(); sw.open("Graph_file.txt"); sw.writeLine("digraph G {"); //zapisuje do pliku tekst showedNode.getGraphvizName(sw); sw.close(); sw.writeLine("}"); //zapisuje do pliku nastepny tekst sw.close(); // zamyka plik m_attrChangeCollection.WriteAggregateValues(); } /** * Funcja zwraca nowo stworzony węzęł lub liść(jeśli nie ma sensu dla węzła) * @param P Zbiór uczący * @param d Domyślna etykieta * @param S Dostępne testy * @return Liść lub węzeł */ private TermostatLeaf buildTree(ArrayList P, Integer d, ArrayList S) { if (P.size() == 0) // kryterium Stopu return new TermostatLeaf(d, P); // Zostaje kategoria domyślna TermostatLeaf tempLeaf = onlyOneCategory(P); // Jedyna kategoria if (tempLeaf != null) // kryterium Stopu return tempLeaf; if (S.size() == 0) // kryterium Stopu return new TermostatLeaf(chooseCategory(P), P); // Najczęściej pojawiająca się kategoria Integer testIndex = chooseTest(P, S); NominalTest t = new NominalTest(m_testNameList[S.get(testIndex)], S.get(testIndex)); // Określamy który atrybut jest dla tego testu ArrayList STemp = new ArrayList(S); STemp.remove(testIndex); // Wyrzucenie już użytego testu for (Integer index : P) { t.AddResult((Integer)m_learingTable[index][t.AttributeIndex]); } d = chooseCategory(P); TermostatNode node = new TermostatNode(t, d, P); for (Integer r : t.Results()) { ArrayList PTemp = new ArrayList(); // Utworzenie nowej listy tylko z tych danych które prowadzą do 'r' for (Integer indexP : P) // Dla każdego przykładu etykietowanego { if ((Integer)m_learingTable[indexP][t.AttributeIndex] == r) // Czy atrybut przykładu zgadza się z naszym rezulatatem 'r' { PTemp.add(indexP); // Dodanie do P które przekażemy do niższego węzła } } node.addNode(r, buildTree(PTemp, d, STemp)); // Najpierw zostanie zbudowany nowy węzeł dopiero później nie zostanie nigdzie dodany } return node; } /** * Funkcja zakwalifikuje podany ciąg atrybutów do jednej kategorii * @param attributes Zbiór atrybutów * @return Kategoria */ public Integer findCategory(Object[] attributes) { m_attrChangeCollection.Change(attributes); return m_rootNode.getCategory(attributes); } /** * Wybiera najbardziej liczną kategorię w podanym zbiorze uczącym * @param P Zbiór uczący * @return Kategoria */ private Integer chooseCategory(ArrayList P) { HashMap categoryCounter = new HashMap(); int category = 0; for (Integer var : P) { category = (Integer)m_learingTable[var][m_categoryIndex]; // Kategoria tych uczących które jeszcze zostały w indexList if (categoryCounter.containsKey(category)) { categoryCounter.put(category, categoryCounter.get(category) + 1); } else { categoryCounter.put(category, 1); } } Integer biggerAmount = 0; for (Integer key : categoryCounter.keySet()) { if (biggerAmount < categoryCounter.get(key)) { category = key; } } return category; } /** * Wybiera najlepszy test na podstawie najmniejszej entropii * @param P Zbiór uczący * @param S Zbiór testów * @return Wybrany test */ private Integer chooseTest(ArrayList P, ArrayList S) { if (S.size() > 0) { NominalTest test = null; double minEntropy = Double.POSITIVE_INFINITY; Integer bestIndex = 0; for (Integer i = 0; i < S.size(); i++) { test = new NominalTest(m_testNameList[S.get(i)], S.get(i)); // Określamy który atrybut jest dla tego testu double testEntropy = countTestEntropy(P, test); if (testEntropy < minEntropy) { minEntropy = testEntropy; bestIndex = i; } } return bestIndex; } else { throw new RuntimeException("Brak sprawdzania czy testy sie nie skończyły"); } } /** * Sprawdza czy w zbiorze uczącym jest tylko jedna kategoria * @param P Zbiór uczący * @return Liść lub null w przypadku braku pożadanego węzła */ private TermostatLeaf onlyOneCategory(ArrayList P) { if (P.size() > 0) { boolean repeated = false; Integer category = (Integer)m_learingTable[P.get(0)][m_categoryIndex]; for (Integer i = 1; i < P.size(); i++) { if (category != (Integer)m_learingTable[P.get(i)][m_categoryIndex]) // Jeśli któryś jest inny niż pierwszy repeated = true; } if (!repeated) // Jeśli się nie powtórzył { return new TermostatLeaf(category, P); // Zwraca jedyną kategorię } return null; } else { throw new RuntimeException("Brak sprawdzania czy zbiór uczący nie jest pusty"); } } /** * Wylicza entropię danego przykładu trenującego dla danego testu * @param P Zbiór uczący * @param test Test * @return Entropia */ private double countTestEntropy(ArrayList P, NominalTest test) { double allSum = 0; Integer attrIndex = test.AttributeIndex; ArrayList result = new ArrayList(); ArrayList temp = new ArrayList(); for (Integer indexP : P) if (!result.contains((Integer)m_learingTable[indexP][attrIndex])) result.add((Integer)m_learingTable[indexP][attrIndex]); // Pozbieranie możliwych rezultatów for (Integer r : result) // Dla kolejnych rezultatów testu { temp.clear(); for (Integer indexP : P) { if ((Integer)m_learingTable[indexP][attrIndex] == r) // Czy przykład należy do tego rezultatu { temp.add(indexP); } } allSum += countEntropy(temp) * temp.size() / P.size();; // Entropia Et(P) czyli suma po r |Ptr|/|P| * Etr(P) } return allSum; } /** * Rekurencyjnie wywoływana funkcja dyskretyzacji-zstępującej * @param P Zbiór uczący * @param attrIndex Indeks opisujący atrybut * @return Lista punktów dzielących wartości */ private ArrayList discreteDown(ArrayList P, Integer attrIndex) { if (onlyOneValue(P,attrIndex)) // Kryterium Stopu. Jeśli została tylko jedna wartość atrybutu w zbiorze uczącym return new ArrayList(); Integer treshold; ArrayList P0 = new ArrayList(P); ArrayList P1 = new ArrayList(); treshold = chooseTreshold(P, attrIndex); // Wybranie progu splitList(P, attrIndex, treshold, P0, P1); double information = countEntropy(P); double entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size(); if (information - entropy <= m_entropyTreshhold) // Kryterium Stopu bazujące na wzrastaniu informacji return new ArrayList(); ArrayList fi0 = discreteDown(P0, attrIndex); // Dzielenie zbioru < treshold ArrayList fi1 = discreteDown(P1, attrIndex); // Dzielenie zbioru >= treshold ArrayList ret = new ArrayList(fi0); ret.add(treshold); ret.addAll(fi1); return ret; } /** * Wylicza entropię * @param top * @param bottom * @return */ private double countEntropy(Integer top, Integer bottom) { double temp = (double)top / (double)bottom; return - temp * Math.log(temp); } /** * Oblicza entropię Etr(P) czyli sumuje entropie po wszystkiech kategoriach pojęcia docelowego * @param P Zbiór uczący * @return */ private double countEntropy(ArrayList P) { Integer key; double sum = 0; HashMap dict = new HashMap(); // Kategoria -> ilość wystąpień Kategorii for (Integer indexP : P) { key = (Integer)m_learingTable[indexP][m_categoryIndex]; dict.put(key,dict.containsKey(key) ? dict.get(key) + 1 : 1) ; // Zliczanie jak liczne są poszczególne kategorie } sum = 0; for (Integer count : dict.values()) { sum += countEntropy(count, P.size()); // Wyliczanie entropii Etr(P) } return sum; } /** * Wybiera próg dzielący wartości * @param P Zbiór uczący * @param attrIndex Indeks opisujący atrybut * @return Próg */ private Integer chooseTreshold(ArrayList P, Integer attrIndex) { if (P.size() > 0) { double entropy; double minEntropy = Double.POSITIVE_INFINITY; Integer bestTreshold = 0; ArrayList P0; ArrayList P1; ArrayList tresholdList = new ArrayList(); for (Integer indexP : P) { Integer tempa = (Integer)m_learingTable[indexP][attrIndex]; if (!tresholdList.contains(tempa)) tresholdList.add(tempa); // Wybranie możliwości podziału } Collections.sort(tresholdList); for (Integer treshold : tresholdList) // Dla kolejnych wartości rozdzielających (progowych) { P0 = new ArrayList(P); P1 = new ArrayList(); splitList(P, attrIndex, treshold, P0, P1); entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size(); if (entropy < minEntropy) { minEntropy = entropy; bestTreshold = treshold; } } return bestTreshold; } else { throw new RuntimeException("P jest pusty, nie można wyznaczyć progu"); } } /** * Dzieli zbiór przykładów na <= od progu oraz > od progu * @param P Zbiór przykładów * @param attrIndex Indeks opisujący atrybut * @param treshold Próg podziału * @param P0 Wartości poniżej progu * @param P1 >Wartości powyżej progu */ private void splitList(ArrayList P, Integer attrIndex, Integer treshold, ArrayList P0, ArrayList P1) { for (int i = 0; i < P0.size(); i++) // Jedziemy do końca listy która nam się skraca, może inny warunek { if ((Integer)m_learingTable[P0.get(i)][attrIndex] > treshold) // Jeśli większy niż treshold to: { P1.add(P0.get(i)); // Dodajemy do większego od treshold P0.remove(i--); // Zdejmujemy z mniejszego lub równego od treshold } } } /** * Sprawdza czy tylko jedna wartość występuje w danym zbiorze * @param P Zbiór przykładów * @param attrIndex Indeks opisujący atrybut * @return */ private Boolean onlyOneValue(ArrayList P, Integer attrIndex) { Integer value = (Integer)m_learingTable[P.get(0)][attrIndex]; for (Integer indexP : P) { if ((Integer)m_learingTable[indexP][attrIndex] != value) return false; } return true; } }