package MachineLearning; import java.util.*; /** * Klasa z której program rozpoczyna swoje wykonywanie * @author Angello * */ public class Program { /** * Początek programu * @param args */ public static void main(String[] args) { test(); } /** * Tworzy tablicę zawierająca wszyskie możliwe przypadki * @return */ public static Object[][] create() { Object[][] learning = new Object[2016][]; Integer index = -1; Integer day; Integer oper; Integer month; for (Integer i = 0; i < 7; i++) // Dni { day = i; for (Integer j = 0; j < 24; j++) // Godziny { oper = j; for (Integer k = 0; k < 12; k++) // Miesiące { month = k; learning[++index] = new Object[] { day, oper, month, 0 }; learning[index][3] = CountTemperature(learning[index]); } } } return learning; } /** * Zbudowanie drzewa i sprawdzenie działanie drzewa dla pozostałych przykładów * @param learningLength * @return */ public static int makeOneTest(Integer learningLength, boolean c45) { Integer randValue; Object[][] everything; Object[][] testing; Object[][] learning; DecisionTree decisionTree; Integer goodClassified = 0; everything = create();//= Common.BinaryDeSerialize("Thermostat_2016.bin"); String[] attributeNames = new String[] { "Day", "Hour", "Month" }; Random rand = new Random(); ArrayList learningIndexes = new ArrayList(); ArrayList testingIndexes = new ArrayList(); for (Integer i = 0; i < learningLength; i++) { randValue = rand.nextInt(2015); if (!learningIndexes.contains(randValue)) learningIndexes.add(randValue); else --i; } Collections.sort(learningIndexes); for (Integer i = 0; i < 2016; i++) { if (!learningIndexes.contains(i)) testingIndexes.add(i); } testing = indexesToArray(everything, testingIndexes); learning = indexesToArray(everything, learningIndexes); decisionTree = new DecisionTree(); decisionTree.initialize(learning, attributeNames, c45); for (Object[] row : testing) { if (decisionTree.findCategory(row) == (Integer)row[3]) goodClassified++; } System.out.println(goodClassified); return goodClassified; } /** * Tworzy tablicę Przykładów z podanej listy indeksów oraz tablicy wszystkich przykładów * @param everything * @param indexesList * @return */ private static Object[][] indexesToArray(Object[][] everything, ArrayList indexesList) { ArrayList testingList = new ArrayList(); for (Integer index : indexesList) { testingList.add(everything[index]); } Object[][] returnTable = new Object[indexesList.size()][everything[0].length]; testingList.toArray(returnTable); return returnTable; } /** * Oblicza jaka powinna być nastawa termostatu przy przedstawionych atrybutach (example) * @param example Zbiór atrybutów * @return */ public static int CountTemperature(Object[] example) { boolean workday = true; if (Days.saturday.ordinal() == (Integer)example[0] || Days.sunday.ordinal() == (Integer)example[0]) workday = false; boolean oper = false; if ((Integer)example[1] >= 9 && (Integer)example[1] < 17) oper = true; boolean business = false; if (workday && oper) business = true; Thermostat temperature = Thermostat.t14; switch ((Integer)example[2]) { case 11: case 0: case 1: if (business) temperature = Thermostat.t18; else temperature = Thermostat.t14; break; case 2: case 3: case 4: if (business) temperature = Thermostat.t20; else temperature = Thermostat.t15; break; case 5: case 6: case 7: if (business) temperature = Thermostat.t24; else temperature = Thermostat.t27; break; case 8: case 9: case 10: if (business) temperature = Thermostat.t20; else temperature = Thermostat.t16; break; } return temperature.ordinal(); } /** * Wykonuje 20 testów i wypisuje statystyki */ private static void test() { int learningLength = 150; ArrayList list = new ArrayList(); Integer sum = 0; for (int i = 0; i < 20; i++) { list.add(makeOneTest(learningLength, false)); } for (Integer var : list) { sum += var; } double percent = (double) sum / (double)(2016 - 150) / (double)list.size(); Collections.sort(list); ArrayList listc45 = new ArrayList(); Integer sumc45 = 0; for (int i = 0; i < 20; i++) { listc45.add(makeOneTest(learningLength, true)); } for (Integer var : listc45) { sumc45 += var; } double percentc45 = (double) sumc45 / (double)(2016 - 150) / (double)listc45.size(); Collections.sort(listc45); System.out.println("Średnia skuteczność dla ID3: "+percent); System.out.println("Najgorsza: "+list.get(0)); System.out.println("Najlepsza: "+list.get(listc45.size()-1)); System.out.println("Średnia skuteczność dla C4.5: "+percentc45); System.out.println("Najgorsza: "+listc45.get(0)); System.out.println("Najlepsza: "+listc45.get(listc45.size()-1)); } }