package com.algorithm;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ejml.data.DenseMatrix64F;
public class tree {
private node root;
public node getRoot() {
return root;
}
public void setRoot(node root) {
this.root = root;
}
public static double log(double value, double base) {
return Math.log(value) / Math.log(base);
}
public static double calcShannonEnt(String [] labels) {
double shannonEnt = 0;
Map<String,Integer> labelCounts = new HashMap<String,Integer>();
for(int i=0;i < labels.length;i++) {
if(labelCounts.containsKey(labels[i])) {
int tmp = labelCounts.get(labels[i])+1;
labelCounts.remove(labels[i]);
labelCounts.put(labels[i],tmp);
}else {
labelCounts.put(labels[i], 1);
}
}
for (Map.Entry<String, Integer> entry : labelCounts.entrySet()){
double prob = ((double)entry.getValue())/((double)labels.length);
shannonEnt -= prob*log(prob,2);
}
return shannonEnt;
}
public static DataInfo splitDataSet(DenseMatrix64F datas,String [] labels,int axis,double value) {
DenseMatrix64F rs = new DenseMatrix64F(0,datas.numCols-1);
List<String> rs_labels = new ArrayList<String>();
for(int i=0;i < datas.numRows;i++) {
if(datas.get(i, axis) == value) {
for(int j=0;j<datas.numCols;j++) {
int k=0;
if(j != axis) {
rs.reshape(rs.numRows+1, datas.numCols-1,true);
rs.set(rs.numRows-1, k, datas.get(i, j));
k++;
}
}
rs_labels.add(labels[i]);
}
}
DataInfo di = new DataInfo();
di.setDatas(rs);
String[] strs = new String[rs_labels.size()];
rs_labels.toArray(strs);
di.setLabels(strs);
return di;
}
public static int chooseBestFeatureToSplit(DenseMatrix64F datas,String [] labels) {
double baseEntropy = calcShannonEnt(labels);
double bestInfoGain=0.0;
int bestFeature=-1;
for(int i=0;i<datas.numCols;i++) {
double newEntropy=0.0;
List<Double> dislist = new ArrayList<Double>();
for(int j=0;j<datas.numRows;j++) {
if(!dislist.contains(datas.get(j,i))) {
dislist.add(datas.get(j,i));
}
}
for(int j=0;j<dislist.size();j++) {
DataInfo di = splitDataSet(datas,labels,i,dislist.get(j));
double prob=((double)di.getDatas().numRows)/((double)datas.numRows);
newEntropy+=prob*calcShannonEnt(di.getLabels());
}
double infoGain= baseEntropy-newEntropy;
if(infoGain > bestInfoGain) {
bestInfoGain=infoGain;
bestFeature=i;
}
}
return bestFeature;
}
public static String [] removeItem(String [] src,String item) {
List<String> list = new ArrayList<String>();
for (int i=0; i<src.length; i++) {
if(!list.contains(src[i]) && !src[i].equals(item)) {
list.add(src[i]);
}
}
String[] strs = new String[list.size()];
list.toArray(strs);
return strs;
}
public static node createTree(DenseMatrix64F datas,String [] labels,String [] attrs) {
node nd = new node();
List<String> dislist = new ArrayList<String>();
Map<String,Integer> labelMap = new HashMap<String,Integer>();
for(int j=0;j<labels.length;j++) {
if(!dislist.contains(labels[j])) {
dislist.add(labels[j]);
}
if(labelMap.containsKey(labels[j])) {
int tmp = labelMap.get(labels[j])+1;
labelMap.remove(labels[j]);
labelMap.put(labels[j],tmp);
}else {
labelMap.put(labels[j],1);
}
}
if(dislist.size() == 1) {
nd.setName(labels[0]);
return nd;
}
if(attrs.length == 0) {
int labNum = 0;
String lab = "";
for (Map.Entry<String, Integer> entry : labelMap.entrySet()){
if(entry.getValue() > labNum) {
labNum = entry.getValue();
lab = entry.getKey();
}
}
nd.setName(lab);
return nd;
}
int bestFeat = chooseBestFeatureToSplit(datas,labels);
nd.setName(attrs[bestFeat]);
List<Double> disBestlist = new ArrayList<Double>();
for(int j=0;j<datas.numRows;j++) {
if(!disBestlist.contains(datas.get(j,bestFeat))) {
disBestlist.add(datas.get(j,bestFeat));
}
}
String [] subAttrs = removeItem(attrs,attrs[bestFeat]);
for(int j=0;j<disBestlist.size();j++) {
DataInfo di = splitDataSet(datas,labels,bestFeat,disBestlist.get(j));
node item = createTree(di.getDatas(),di.getLabels(),subAttrs);
if(item != null)
nd.getChilds().put(disBestlist.get(j), item);
}
return nd;
}
public static String classify(node tree,String [] attrs,DenseMatrix64F textDatas) {
int attrInx = java.util.Arrays.asList(attrs).indexOf(tree.getName());
String label = "";
if(tree.getChilds().size() == 0)
return tree.getName();
for (Map.Entry<Double, node> entry : tree.getChilds().entrySet()){
if(textDatas.get(0, attrInx) == entry.getKey()) {
label = classify(entry.getValue(),attrs,textDatas);
}
}
return label;
}
public static void main(String[] args) {
DenseMatrix64F datas = new DenseMatrix64F(5,2);
datas.set(0,0,1);
datas.set(0,1,1);
datas.set(1,0,1);
datas.set(1,1,1);
datas.set(2,0,1);
datas.set(2,1,0);
datas.set(3,0,0);
datas.set(3,1,1);
datas.set(4,0,0);
datas.set(4,1,1);
DenseMatrix64F textDatas = new DenseMatrix64F(1,2);
textDatas.set(0,0,1);
textDatas.set(0,1,1);
String labels[] = {"yes","yes","no","no","no"};
String attrs[] = {"no surfacing","flippers"};
DataInfo di = splitDataSet(datas,labels,0,1);
node root = createTree(datas,labels,attrs);
System.out.println(classify(root,attrs,textDatas));
}
}