package MachineLearning;
 
import java.util.*;
 
/**
 * Klasa drzewa decyzyjnego. Zbudeje drzewo na podstawie przykładów uczących
 * @author Angello
 *
 */
public class DecisionTree
{
	private TermostatLeaf m_rootNode = null;
	private Object[][] m_learingTable = null;				// zbiór uczący
 
	/**
	 * Zawiera w tablicy nazwy atrybutów konieczne do rozszyfrowania, które Testy jeszcze zostały
	 */
	private String[] m_testNameList = null;
 
	/**
	 * Indeks pod jakim można w zbiorze uczącym znaleźć przydzieloną kategorię
	 */
	private Integer m_categoryIndex;
 
	private AttrChangeCollection m_attrChangeCollection;
 
	/**
	 * Próg entropii używany w agregacji atrybutów porządkowych
	 */
	private double m_entropyTreshhold = 0.091;
 
	/**
	 * Domyślny konstruktor
	 */
	public DecisionTree()
	{
		m_attrChangeCollection = new AttrChangeCollection();
	}
 
	/**
	 * Buduje drzewo decyzyjne 
	 * @param learningTable Zbiór uczący
	 * @param attributesNames Nazwy kolejnych atrybutów
	 */
	public void initialize(Object[][] learningTable, String[] attributesNames, boolean c45)
	{
		ArrayList<Integer> tresholds;
		ArrayList<Integer> S;
		m_learingTable = learningTable;					// Przypisanie zbioru uczącego
		m_categoryIndex = learningTable[0].length - 1;
		ArrayList<Integer> P = new ArrayList<Integer>();
		for (int i = 0; i < m_learingTable.length; i++)	// Zamiana zbioru na indeksy
			P.add(i);
 
		if(c45)
		{
			for (int i = 0; i < m_categoryIndex; i++)		// Próba zagregowania każdego atrybutu
			{
				tresholds = discreteDown(P, i);
				m_attrChangeCollection.Add(tresholds, i);
			}
			m_attrChangeCollection.Change(m_learingTable); 
		}
 
		m_testNameList = attributesNames;
		S = new ArrayList<Integer>();
		for (int i = 0; i < m_testNameList.length; i++)			// Testy tożsame z nowymi atrybutami
			S.add(i);
 
		Integer d = chooseCategory(P);
 
		m_rootNode = buildTree(P, d, S);
 
		ShowedNode showedNode = m_rootNode.showTree();
 
		WriteFile sw = new WriteFile();
		sw.open("Graph_file.txt");
		sw.writeLine("digraph G {"); //zapisuje do pliku tekst
		showedNode.getGraphvizName(sw);
		sw.close();
 
		sw.writeLine("}"); //zapisuje do pliku nastepny tekst
		sw.close(); // zamyka plik
 
		m_attrChangeCollection.WriteAggregateValues();
	}
 
	/**
	 * Funcja zwraca nowo stworzony węzęł lub liść(jeśli nie ma sensu dla węzła)
	 * @param P Zbiór uczący
	 * @param d Domyślna etykieta
	 * @param S Dostępne testy
	 * @return Liść lub węzeł
	 */
	private TermostatLeaf buildTree(ArrayList<Integer> P, Integer d, ArrayList<Integer> S)
	{
		if (P.size() == 0)					// kryterium Stopu
			return new TermostatLeaf(d, P); // Zostaje kategoria domyślna
 
		TermostatLeaf tempLeaf = onlyOneCategory(P);					// Jedyna kategoria
		if (tempLeaf != null)				// kryterium Stopu
			return tempLeaf;
 
		if (S.size() == 0)					// kryterium Stopu
			return new TermostatLeaf(chooseCategory(P), P);		// Najczęściej pojawiająca się kategoria
 
		Integer testIndex = chooseTest(P, S);
		NominalTest t = new NominalTest(m_testNameList[S.get(testIndex)], S.get(testIndex));			// Określamy który atrybut jest dla tego testu
		ArrayList<Integer> STemp = new ArrayList<Integer>(S);
		STemp.remove(testIndex);									// Wyrzucenie już użytego testu
 
		for (Integer index : P)
		{
			t.AddResult((Integer)m_learingTable[index][t.AttributeIndex]);
		}
 
		d = chooseCategory(P);
		TermostatNode node = new TermostatNode(t, d, P);
 
		for (Integer r : t.Results())
		{
			ArrayList<Integer> PTemp = new ArrayList<Integer>();			// Utworzenie nowej listy tylko z tych danych które prowadzą do 'r'
			for (Integer indexP : P)					// Dla każdego przykładu etykietowanego
			{
				if ((Integer)m_learingTable[indexP][t.AttributeIndex] == r)	// Czy atrybut przykładu zgadza się z naszym rezulatatem 'r'
				{
					PTemp.add(indexP);					// Dodanie do P które przekażemy do niższego węzła
				}
			}
			node.addNode(r, buildTree(PTemp, d, STemp));	// Najpierw zostanie zbudowany nowy węzeł dopiero później nie zostanie nigdzie dodany
		}
 
		return node;
	}
 
	/**
	 * Funkcja zakwalifikuje podany ciąg atrybutów do jednej kategorii
	 * @param attributes Zbiór atrybutów
	 * @return Kategoria
	 */
	public Integer findCategory(Object[] attributes)
	{
		m_attrChangeCollection.Change(attributes);
		return m_rootNode.getCategory(attributes);
	}
 
	/**
	 * Wybiera najbardziej liczną kategorię w podanym zbiorze uczącym
	 * @param P Zbiór uczący
	 * @return Kategoria
	 */
	private Integer chooseCategory(ArrayList<Integer> P)
	{
		HashMap<Integer, Integer> categoryCounter = new HashMap<Integer, Integer>();
		int category = 0;
		for (Integer var : P)
		{
			category = (Integer)m_learingTable[var][m_categoryIndex];		// Kategoria tych uczących które jeszcze zostały w indexList
			if (categoryCounter.containsKey(category))
			{
				categoryCounter.put(category, categoryCounter.get(category) + 1); 
			}
			else
			{
				categoryCounter.put(category, 1);
			}
		}
 
		Integer biggerAmount = 0;
		for (Integer key : categoryCounter.keySet())
		{
			if (biggerAmount < categoryCounter.get(key))
			{
				category = key;
			}
		}
		return category;
	}
 
	/**
	 * Wybiera najlepszy test na podstawie najmniejszej entropii
	 * @param P Zbiór uczący
	 * @param S Zbiór testów
	 * @return Wybrany test
	 */
	private Integer chooseTest(ArrayList<Integer> P, ArrayList<Integer> S)
	{
		if (S.size() > 0)
		{
			NominalTest test = null;
			double minEntropy = Double.POSITIVE_INFINITY;
			Integer bestIndex = 0;
 
			for (Integer i = 0; i < S.size(); i++)
			{
				test = new NominalTest(m_testNameList[S.get(i)], S.get(i));			// Określamy który atrybut jest dla tego testu
				double testEntropy = countTestEntropy(P, test);
				if (testEntropy < minEntropy)
				{
					minEntropy = testEntropy;
					bestIndex = i;
				}
			}
 
			return bestIndex;
		}
		else
		{
			throw new RuntimeException("Brak sprawdzania czy testy sie nie skończyły");
		}
	}
 
	/**
	 * Sprawdza czy w zbiorze uczącym jest tylko jedna kategoria
	 * @param P Zbiór uczący
	 * @return Liść lub null w przypadku braku pożadanego węzła
	 */
	private TermostatLeaf onlyOneCategory(ArrayList<Integer> P)
	{
		if (P.size() > 0)
		{
			boolean repeated = false;
			Integer category = (Integer)m_learingTable[P.get(0)][m_categoryIndex];
 
			for (Integer i = 1; i < P.size(); i++)
			{
				if (category != (Integer)m_learingTable[P.get(i)][m_categoryIndex])		// Jeśli któryś jest inny niż pierwszy
					repeated = true;
			}
			if (!repeated)				// Jeśli się nie powtórzył
			{
				return new TermostatLeaf(category, P); // Zwraca jedyną kategorię
			}
			return null;
		}
		else
		{
			throw new RuntimeException("Brak sprawdzania czy zbiór uczący nie jest pusty");
		}
	}
 
	/**
	 * Wylicza entropię danego przykładu trenującego dla danego testu
	 * @param P Zbiór uczący
	 * @param test Test
	 * @return Entropia
	 */
	private double countTestEntropy(ArrayList<Integer> P, NominalTest test)
	{
		double allSum = 0;
		Integer attrIndex = test.AttributeIndex;
		ArrayList<Integer> result = new ArrayList<Integer>();
		ArrayList<Integer> temp = new ArrayList<Integer>();
 
		for (Integer indexP : P)
			if (!result.contains((Integer)m_learingTable[indexP][attrIndex]))
				result.add((Integer)m_learingTable[indexP][attrIndex]);				// Pozbieranie możliwych rezultatów
 
		for (Integer r : result)												// Dla kolejnych rezultatów testu
		{
			temp.clear();
			for (Integer indexP : P)
			{
				if ((Integer)m_learingTable[indexP][attrIndex] == r)				// Czy przykład należy do tego rezultatu
				{
					temp.add(indexP);
				}
			}
			allSum += countEntropy(temp) * temp.size() / P.size();;				// Entropia Et(P) czyli suma po r |Ptr|/|P| * Etr(P)
		}
		return allSum;
	}
 
	/**
	 * Rekurencyjnie wywoływana funkcja dyskretyzacji-zstępującej
	 * @param P Zbiór uczący
	 * @param attrIndex Indeks opisujący atrybut
	 * @return Lista punktów dzielących wartości
	 */
	private ArrayList<Integer> discreteDown(ArrayList<Integer> P, Integer attrIndex)
	{
		if (onlyOneValue(P,attrIndex))								// Kryterium Stopu. Jeśli została tylko jedna wartość atrybutu w zbiorze uczącym
			return new ArrayList<Integer>();
 
		Integer treshold;
		ArrayList<Integer> P0 = new ArrayList<Integer>(P);
		ArrayList<Integer> P1 = new ArrayList<Integer>();
 
		treshold = chooseTreshold(P, attrIndex);					// Wybranie progu
		splitList(P, attrIndex, treshold, P0, P1);
 
		double information = countEntropy(P);
		double entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size();
		if (information - entropy <= m_entropyTreshhold)							// Kryterium Stopu bazujące na wzrastaniu informacji
			return new ArrayList<Integer>();
 
		ArrayList<Integer> fi0 = discreteDown(P0, attrIndex);			// Dzielenie zbioru < treshold
 
		ArrayList<Integer> fi1 = discreteDown(P1, attrIndex);			// Dzielenie zbioru >= treshold
 
		ArrayList<Integer> ret = new ArrayList<Integer>(fi0);
		ret.add(treshold);
		ret.addAll(fi1);
		return ret;
	}
 
	/**
	 * Wylicza entropię
	 * @param top
	 * @param bottom
	 * @return
	 */
	private double countEntropy(Integer top, Integer bottom)
	{
		double temp = (double)top / (double)bottom;
		return - temp * Math.log(temp);
	}
 
	/**
	 * Oblicza entropię Etr(P) czyli sumuje entropie po wszystkiech kategoriach pojęcia docelowego 
	 * @param P Zbiór uczący
	 * @return
	 */
	private double countEntropy(ArrayList<Integer> P)
	{
		Integer key;
		double sum = 0;
		HashMap<Integer, Integer> dict = new HashMap<Integer,Integer>();			// Kategoria -> ilość wystąpień Kategorii
 
		for (Integer indexP : P)
		{
			key = (Integer)m_learingTable[indexP][m_categoryIndex];
				dict.put(key,dict.containsKey(key) ? dict.get(key) + 1 : 1) ;											// Zliczanie jak liczne są poszczególne kategorie
		}
		sum = 0;
		for (Integer count : dict.values())
		{
			sum += countEntropy(count, P.size());						// Wyliczanie entropii Etr(P)
		}
 
		return sum;
	}
 
	/**
	 * Wybiera próg dzielący wartości
	 * @param P Zbiór uczący
	 * @param attrIndex Indeks opisujący atrybut
	 * @return Próg
	 */
	private Integer chooseTreshold(ArrayList<Integer> P, Integer attrIndex)
	{
		if (P.size() > 0)
		{
			double entropy;
			double minEntropy = Double.POSITIVE_INFINITY;
			Integer bestTreshold = 0;
			ArrayList<Integer> P0;
			ArrayList<Integer> P1;
 
			ArrayList<Integer> tresholdList = new ArrayList<Integer>();						
			for (Integer indexP : P)
			{
				Integer tempa = (Integer)m_learingTable[indexP][attrIndex];			
				if (!tresholdList.contains(tempa))
					tresholdList.add(tempa);								// Wybranie możliwości podziału
			}
 
			Collections.sort(tresholdList);
			for (Integer treshold : tresholdList)							// Dla kolejnych wartości rozdzielających (progowych)
			{
				P0  = new ArrayList<Integer>(P);
				P1 = new ArrayList<Integer>();
				splitList(P, attrIndex, treshold, P0, P1);
				entropy = (countEntropy(P0) * P0.size() + countEntropy(P1) * P1.size()) / P.size();
				if (entropy < minEntropy)
				{
					minEntropy = entropy;
					bestTreshold = treshold;
				}
			}
			return bestTreshold;
		}
		else
		{
			throw new RuntimeException("P jest pusty, nie można wyznaczyć progu");
		}
	}
 
	/**
	 * Dzieli zbiór przykładów na <= od progu oraz > od progu
	 * @param P Zbiór przykładów
	 * @param attrIndex Indeks opisujący atrybut
	 * @param treshold Próg podziału
	 * @param P0 Wartości poniżej progu
	 * @param P1 >Wartości powyżej progu
	 */
	private void splitList(ArrayList<Integer> P, Integer attrIndex, Integer treshold, ArrayList<Integer> P0, ArrayList<Integer> P1)
	{
		for (int i = 0; i < P0.size(); i++)							// Jedziemy do końca listy która nam się skraca, może inny warunek
		{
			if ((Integer)m_learingTable[P0.get(i)][attrIndex] > treshold)	// Jeśli większy niż treshold to:
			{
				P1.add(P0.get(i));										// Dodajemy do większego od treshold
				P0.remove(i--);										// Zdejmujemy z mniejszego lub równego od treshold
			}
		}
	}
 
	/**
	 * Sprawdza czy tylko jedna wartość występuje w danym zbiorze 
	 * @param P Zbiór przykładów
	 * @param attrIndex Indeks opisujący atrybut
	 * @return
	 */
	private Boolean onlyOneValue(ArrayList<Integer> P, Integer attrIndex)
	{
		Integer value = (Integer)m_learingTable[P.get(0)][attrIndex];
		for (Integer indexP : P)
		{
			if ((Integer)m_learingTable[indexP][attrIndex] != value)
				return false;
		}
 
		return true;
	}	
}
pl/miw/miw08_rbs_ml/decisiontree.txt · ostatnio zmienione: 2019/06/27 15:50 (edycja zewnętrzna)
www.chimeric.de Valid CSS Driven by DokuWiki do yourself a favour and use a real browser - get firefox!! Recent changes RSS feed Valid XHTML 1.0