|
|
pl:miw:miw08_rbs_ml:decisiontree [2008/06/10 14:19] miw |
pl:miw:miw08_rbs_ml:decisiontree [2017/07/17 10:08] |
<code="csharp"> | |
using System; | |
using System.Collections.Generic; | |
using System.Text; | |
using System.IO; | |
| |
namespace Morcinek.Machine_Learning | |
{ | |
/// <summary> | |
/// Klasa odpowiedzialna za zbudowanie drzewa decyzyjnego i na jego podstawie | |
/// do klasyfikowania wectora atrybutów do kategorii docelowych | |
/// </summary> | |
public class DecisionTree | |
{ | |
#region Members | |
| |
private ILeaf m_rootNode = null; | |
private object[][] m_learingTable = null; // zbiór uczący | |
/// <summary> | |
/// Zawiera w tablicy nazwy atrybutów konieczne do rozszyfrowania, które Testy jeszcze zostały | |
/// </summary> | |
private string[] m_testNameList = null; | |
| |
/// <summary> | |
/// Indeks pod jakim można w zbiorze uczącym znaleźć przydzieloną kategorię | |
/// </summary> | |
private int m_categoryIndex; | |
| |
private AttrChangeCollection m_attrChangeCollection; | |
| |
private double m_entropyTreshhold = 0.14; | |
| |
#endregion | |
| |
#region Constructor | |
| |
public DecisionTree() | |
{ | |
m_attrChangeCollection = new AttrChangeCollection(); | |
} | |
| |
#endregion | |
| |
#region Public Methods | |
| |
/// <summary> | |
/// Rozpoczyna proces budowania drzewa decyzyjnego | |
/// </summary> | |
/// <param name="learningTable">Zbiór uczący</param> | |
/// <param name="attributesNames">Nazwy kolejnych atrybutów</param> | |
public void Initialize(object[][] learningTable, string[] attributesNames) | |
{ | |
List<int> tresholds; | |
List<int> S; | |
m_learingTable = learningTable; // Przypisanie zbioru uczącego | |
m_categoryIndex = learningTable[0].Length - 1; | |
List<int> P = new List<int>(); | |
for (int i = 0; i < m_learingTable.Length; i++) // Zamiana zbioru na indeksy | |
P.Add(i); | |
| |
/*for (int i = 0; i < m_categoryIndex; i++) // Próba zagregowania każdego atrybutu | |
{ | |
tresholds = DiscreteDown(P, i); | |
m_attrChangeCollection.Add(tresholds, i); | |
} | |
m_attrChangeCollection.Change(m_learingTable); //*/ | |
| |
m_testNameList = attributesNames; | |
S = new List<int>(); | |
for (int i = 0; i < m_testNameList.Length; i++) // Testy tożsame z nowymi atrybutami | |
S.Add(i); | |
| |
int d = ChooseCategory(P); | |
| |
StreamWriter sw = new StreamWriter("nazwa_pliku.txt", false); //tworzy plik "nazwa_pliku.txt" jezeli taki nie istnieje, lub otwiera juz istniejacy plik nie niszczac zawartosci i bedzie dopisywac na koncu | |
sw.WriteLine("digraph G {"); //zapisuje do pliku tekst | |
| |
m_rootNode = BuildTree(P, d, S); | |
| |
string temp; | |
ShowedNode showedNode = m_rootNode.ShowTree(out temp); | |
| |
showedNode.GetGraphvizName(sw); | |
| |
| |
sw.WriteLine("}"); //zapisuje do pliku nastepny tekst | |
sw.Flush(); //czysci bufor, wszystko co bylo w buferze zostaje zapisane do pliku | |
sw.Close(); // zamyka plik | |
| |
m_attrChangeCollection.WriteAggregateValues(); | |
| |
} | |
| |
| |
/// <summary> | |
/// Funcja zwraca nowo stworzony węzęł lub liść(jeśli nie ma sensu dla węzła) | |
/// </summary> | |
/// <param name="P">Indeksy ze zbioru uczącego z których korzystamy</param> | |
/// <param name="d">Domyślna etykieta</param> | |
/// <param name="S">Indeksy dla nazw atrybutów z których można testy stworzyć</param> | |
/// <returns>Liść lub węzeł</returns> | |
private ILeaf BuildTree(List<int> P, int d, List<int> S) | |
{ | |
if (P.Count == 0) // kryterium Stopu | |
return new TermostatLeaf(d, P); // Zostaje kategoria domyślna | |
| |
ILeaf tempLeaf = OnlyOneCategory(P); // Jedyna kategoria | |
if (tempLeaf != null) // kryterium Stopu | |
return tempLeaf; | |
| |
if (S.Count == 0) // kryterium Stopu | |
return new TermostatLeaf(ChooseCategory(P), P); // Najczęściej pojawiająca się kategoria | |
| |
TermostatNode node = new TermostatNode(); | |
| |
int testIndex = ChooseTest(P, S); | |
NominalTest t = new NominalTest(m_testNameList[S[testIndex]], S[testIndex]); // Określamy który atrybut jest dla tego testu | |
List<int> STemp = new List<int>(S); | |
STemp.RemoveAt(testIndex); // Wyrzucenie już użytego testu | |
| |
foreach (int index in P) | |
{ | |
t.AddResult((int)m_learingTable[index][t.AttributeIndex]); | |
} | |
| |
node.Test = t; // Przypisanie testu do node'a | |
d = ChooseCategory(P); | |
node.Category = d; // Przypisanie większościowej kategorii do node'a | |
node.P = P; // Przypisanie pozostałych przykładów do node'a | |
foreach (int r in t.Results) | |
{ | |
List<int> PTemp = new List<int>(); // Utworzenie nowej listy tylko z tych danych które prowadzą do 'r' | |
foreach (int indexP in P) // Dla każdego przykładu etykietowanego | |
{ | |
if ((int)m_learingTable[indexP][t.AttributeIndex] == r) // Czy atrybut przykładu zgadza się z naszym rezulatatem 'r' | |
{ | |
PTemp.Add(indexP); // Dodanie do P które przekażemy do niższego węzła | |
} | |
} | |
| |
node.DictOfNodes[r] = BuildTree(PTemp, d, STemp); // Najpierw zostanie zbudowany nowy węzeł dopiero później nie zostanie nigdzie dodany | |
} | |
| |
return node; | |
} | |
| |
/// <summary> | |
/// Funkcja zakwalifikuje podany ciąg atrybutów do jednej kategorii | |
/// </summary> | |
/// <param name="attributes">Tablica atrybutów</param> | |
/// <returns>Kategoria</returns> | |
public int FindCategory(object[] attributes) | |
{ | |
m_attrChangeCollection.Change(attributes); | |
return m_rootNode.GetCategory(attributes); | |
} | |
| |
#endregion | |
| |
#region Private Methods | |
| |
/// <summary> | |
/// Wybiera najbardziej liczną kategorię w podanym zbiorze uczącym | |
/// </summary> | |
/// <param name="indexList">Lista indeksów mówiących które tablice z zbioru uczącego są w grze</param> | |
/// <param name="d"></param> | |
/// <returns>Kategoria</returns> | |
private int ChooseCategory(List<int> P) | |
{ | |
// Najbardziej liczną kategorię powinno się na razie zwracać. | |
Dictionary<int, int> categoryCounter = new Dictionary<int, int>(); | |
int category = 0; | |
foreach (int var in P) | |
{ | |
category = (int)m_learingTable[var][m_categoryIndex]; // Kategoria tych uczących które jeszcze zostały w indexList | |
if (categoryCounter.ContainsKey(category)) | |
{ | |
categoryCounter[category]++; | |
} | |
else | |
{ | |
categoryCounter.Add(category, 1); | |
} | |
} | |
| |
int biggerAmount = 0; | |
foreach (int key in categoryCounter.Keys) | |
{ | |
if (biggerAmount < categoryCounter[key]) | |
{ | |
category = key; | |
} | |
} | |
return (int)category; | |
} | |
| |
/// <summary> | |
/// Wybiera najlepszy test na podstawie najmniejszej entropii | |
/// </summary> | |
/// <param name="P">Lista indeksów pozostałego zbioru uczącego</param> | |
/// <param name="S">Lista pozostałych testów</param> | |
/// <returns>Indeks wybranego testu</returns> | |
private int ChooseTest(List<int> P, List<int> S) | |
{ | |
if (S.Count > 0) | |
{ | |
NominalTest test = null; | |
double minEntropy = Double.PositiveInfinity; | |
int bestIndex = 0; | |
| |
for (int i = 0; i < S.Count; i++) | |
{ | |
test = new NominalTest(m_testNameList[S[i]], S[i]); // Określamy który atrybut jest dla tego testu | |
double testEntropy = CountTestEntropy(P, test); | |
if (testEntropy < minEntropy) | |
{ | |
minEntropy = testEntropy; | |
bestIndex = i; | |
} | |
} | |
| |
return bestIndex; | |
} | |
else | |
{ | |
throw new Exception("Brak sprawdzania czy testy sie nie skończyły"); | |
} | |
} | |
| |
/// <summary> | |
/// Sprawdza czy w zbiorze uczącym jest tylko jedna kategoria | |
/// </summary> | |
/// <param name="P">Lista indeksów pozostałego zbioru uczącego</param> | |
/// <returns>Liść lub null w przypadku błędu</returns> | |
private ILeaf OnlyOneCategory(List<int> P) | |
{ | |
if (P.Count > 0) | |
{ | |
bool repeated = false; | |
int category = (int)m_learingTable[P[0]][m_categoryIndex]; | |
| |
for (int i = 1; i < P.Count; i++) | |
{ | |
if (category != (int)m_learingTable[P[i]][m_categoryIndex]) // Jeśli któryś jest inny niż pierwszy | |
repeated = true; | |
} | |
if (!repeated) // Jeśli się nie powtórzył | |
{ | |
return new TermostatLeaf(category, P); // Zwraca jedyną kategorię | |
} | |
return null; | |
} | |
else | |
{ | |
throw new Exception("Brak sprawdzania czy zbiór uczący nie jest pusty"); | |
} | |
} | |
| |
/// <summary> | |
/// Wylicza entropię danego przykładu trenującego dla danego testu | |
/// </summary> | |
/// <param name="P">Lista indeksów pozostałego zbioru uczącego</param> | |
/// <param name="test">Test (atrybut)</param> | |
/// <returns>Entropia</returns> | |
private double CountTestEntropy(List<int> P, ITest test) | |
{ | |
// Przez |P| nie dzielimy gdyż nie zmienia to nic | |
double allSum = 0; | |
int attrIndex = test.AttributeIndex; | |
Dictionary<int, int> dict; // Kategoria -> ilość wystąpień Kategorii | |
List<int> result = new List<int>(); | |
List<int> temp = new List<int>(); | |
| |
foreach (int indexP in P) | |
if (!result.Contains((int)m_learingTable[indexP][attrIndex])) | |
result.Add((int)m_learingTable[indexP][attrIndex]); // Pozbieranie możliwych rezultatów | |
| |
foreach (int r in result) // Dla kolejnych rezultatów testu | |
{ | |
dict = new Dictionary<int,int>(); | |
temp.Clear(); | |
foreach (int indexP in P) | |
{ | |
if ((int)m_learingTable[indexP][attrIndex] == r) // Czy przykład należy do tego rezultatu | |
{ | |
temp.Add(indexP); | |
} | |
} | |
allSum += CountEntropy(temp) * temp.Count / P.Count;; // Entropia Et(P) czyli suma po r |Ptr|/|P| * Etr(P) | |
} | |
return allSum; | |
} | |
| |
/// <summary> | |
/// Rekurencyjnie wywoływana funkcja dyskretyzacji-zstępującej | |
/// </summary> | |
/// <param name="P">Zbiór przykładów</param> | |
/// <param name="attrIndex">Indeks opisujący atrybut</param> | |
/// <returns>Lista punktów dzielących wartości</returns> | |
private List<int> DiscreteDown(List<int> P, int attrIndex) | |
{ | |
if (OnlyOneValue(P,attrIndex)) // Kryterium Stopu. Jeśli została tylko jedna wartość atrybutu w zbiorze uczącym | |
return new List<int>(); | |
| |
int treshold; | |
List<int> P0; | |
List<int> P1; | |
| |
treshold = ChooseTreshold(P, attrIndex); // Wybranie progu | |
SplitList(P, attrIndex, treshold, out P0, out P1); | |
| |
double information = CountEntropy(P); | |
double entropy = (CountEntropy(P0) * P0.Count + CountEntropy(P1) * P1.Count) / P.Count; | |
if (information - entropy <= m_entropyTreshhold) // Kryterium Stopu bazujące na wzrastaniu informacji | |
return new List<int>(); | |
| |
List<int> fi0 = DiscreteDown(P0, attrIndex); // Dzielenie zbioru < treshold | |
| |
List<int> fi1 = DiscreteDown(P1, attrIndex); // Dzielenie zbioru >= treshold | |
| |
List<int> ret = new List<int>(fi0); | |
ret.Add(treshold); | |
ret.AddRange(fi1); | |
return ret; | |
} | |
| |
/// <summary> | |
/// Wylicza tą entropię z logarytmem | |
/// </summary> | |
/// <param name="top"></param> | |
/// <param name="bottom"></param> | |
/// <returns></returns> | |
private double CountEntropy(int top, int bottom) | |
{ | |
double temp = (double)top / (double)bottom; | |
return - temp * Math.Log(temp, 2); | |
} | |
| |
/// <summary> | |
/// Oblicza entropię Etr(P) czyli sumuje entropie po wszystkiech kategoriach pojęcia docelowego | |
/// </summary> | |
/// <param name="P">Zbiór uczący</param> | |
/// <returns></returns> | |
private double CountEntropy(List<int> P) | |
{ | |
int value; | |
double sum = 0; | |
Dictionary<int, int> dict = new Dictionary<int,int>(); // Kategoria -> ilość wystąpień Kategorii | |
List<int> result = new List<int>(); | |
| |
foreach (int indexP in P) | |
{ | |
value = (int)m_learingTable[indexP][m_categoryIndex]; | |
if (dict.ContainsKey(value)) | |
dict[value]++; // Zliczanie jak liczne są poszczególne kategorie | |
else | |
dict[value] = 1; | |
} | |
sum = 0; | |
foreach (int count in dict.Values) | |
{ | |
sum += CountEntropy(count, P.Count); // Wyliczanie entropii Etr(P) | |
} | |
//sum *= P.Count; // Mnożenie przez wagę (ilość danego rezultatu) | |
| |
return sum; | |
} | |
| |
/// <summary> | |
/// Wybiera próg dzielący wartości | |
/// </summary> | |
/// <param name="P">Zbiór przykładów</param> | |
/// <param name="attrIndex">Indeks opisujący atrybut</param> | |
/// <returns>Próg</returns> | |
private int ChooseTreshold(List<int> P, int attrIndex) | |
{ | |
if (P.Count > 0) | |
{ | |
double entropy; | |
double minEntropy = double.PositiveInfinity; | |
int bestTreshold = 0; | |
List<int> P0; | |
List<int> P1; | |
| |
List<int> tresholdList = new List<int>(); | |
foreach (int indexP in P) | |
{ | |
int tempa = (int)m_learingTable[indexP][attrIndex]; | |
if (!tresholdList.Contains(tempa)) | |
tresholdList.Add(tempa); // Wybranie możliwości podziału | |
} | |
tresholdList.Sort(); | |
foreach (int treshold in tresholdList) // Dla kolejnych wartości rozdzielających (progowych) | |
{ | |
SplitList(P, attrIndex, treshold, out P0, out P1); | |
entropy = (CountEntropy(P0) * P0.Count + CountEntropy(P1) * P1.Count) / P.Count; | |
if (entropy < minEntropy) | |
{ | |
minEntropy = entropy; | |
bestTreshold = treshold; | |
} | |
} | |
return bestTreshold; | |
} | |
else | |
{ | |
throw new Exception("P jest pusty, nie można wyznaczyć progu"); | |
} | |
} | |
| |
/// <summary> | |
/// Dzieli zbiór przykładów na mniejszy i równy od progu oraz większy od progu | |
/// </summary> | |
/// <param name="P">Zbiór przykładów</param> | |
/// <param name="attrIndex">Indeks opisujący atrybut</param> | |
/// <param name="treshold">Próg podziału</param> | |
/// <param name="P0">Wartości poniżej progu</param> | |
/// <param name="P1">Wartości powyżej progu</param> | |
private void SplitList(List<int> P, int attrIndex, int treshold, out List<int> P0, out List<int> P1) | |
{ | |
P0 = new List<int>(P); | |
P1 = new List<int>(); | |
for (int i = 0; i < P0.Count; i++) // Jedziemy do końca listy która nam się skraca, może inny warunek | |
{ | |
if ((int)m_learingTable[P0[i]][attrIndex] > treshold) // Jeśli większy niż treshold to: | |
{ | |
P1.Add(P0[i]); // Dodajemy do większego od treshold | |
P0.RemoveAt(i); // Zdejmujemy z mniejszego lub równego od treshold | |
i--; | |
} | |
} | |
} | |
| |
/// <summary> | |
/// Sprawdza czy tylko jedna wartość występuje w danym zbiorze | |
/// </summary> | |
/// <param name="P">Zbiór przykładów</param> | |
/// <param name="attrIndex">Indeks opisujący atrybut</param> | |
/// <returns></returns> | |
private bool OnlyOneValue(List<int> P, int attrIndex) | |
{ | |
int value = (int)m_learingTable[P[0]][attrIndex]; | |
foreach (int indexP in P) | |
{ | |
if ((int)m_learingTable[indexP][attrIndex] != value) | |
return false; | |
} | |
| |
return true; | |
} | |
| |
#endregion | |
} | |
| |
} | |
| |
</code> | |