`
cloudtech
  • 浏览: 4611050 次
  • 性别: Icon_minigender_1
  • 来自: 武汉
文章分类
社区版块
存档分类
最新评论

mahout决策树之Partial Implementation源码分析 part3

 
阅读更多

part3主要分析下名为“decisionforestbuilder”的Job的操作,上篇说到这个Job只有Mapper,所以也就是针对这个Mapper,即Step1Mapper类的分析:

Step1Mapper.java在org.apache.mahout.classifier.df.mapreduce.partial中,打开这个源文件,可以看到其操作主要有以下三点:

1. setup():主要是一些参数的设置,包括seed(随机化输入数据时用到的随机数种子),numTrees(要生成的决策树个数);

2. map(): 把原始的Text转为所需的格式,用到的类有DataConverter,Instance比如下面的数据转换:

因为转换的时候需要用到dataset,所以先贴出生成dataset的原始数据:

vhigh,vhigh,3,4,big,low,unacc
vhigh,vhigh,3,4,big,high,unacc
vhigh,vhigh,5more,2,big,med,unacc
vhigh,vhigh,5more,4,med,high,unacc
vhigh,high,2,4,med,high,unacc
vhigh,high,3,2,small,low,unacc
vhigh,high,3,4,small,med,unacc
vhigh,high,4,2,big,med,unacc
vhigh,high,4,2,big,high,unacc
vhigh,high,4,4,med,med,unacc
vhigh,high,5more,4,small,med,unacc
vhigh,high,5more,4,med,med,unacc
vhigh,med,2,2,med,low,unacc
vhigh,med,2,4,small,high,acc
vhigh,med,2,4,big,low,unacc
vhigh,med,4,more,small,high,acc
vhigh,med,5more,more,med,low,unacc
vhigh,low,2,2,small,high,unacc
vhigh,low,2,4,med,med,unacc
vhigh,low,5more,more,small,low,unacc
high,vhigh,2,2,big,med,unacc
high,vhigh,2,more,small,low,unacc
high,vhigh,3,2,small,med,unacc
high,vhigh,3,2,med,high,unacc
high,vhigh,5more,4,med,low,unacc
high,vhigh,5more,more,small,low,unacc
high,high,3,more,med,low,unacc
high,high,4,more,big,high,acc
high,high,5more,more,small,med,unacc
high,med,2,2,small,high,unacc
high,med,3,4,med,high,acc
high,med,4,2,med,low,unacc
high,med,4,2,med,med,unacc
med,vhigh,4,2,small,high,unacc
med,vhigh,4,more,big,high,acc
med,vhigh,5more,2,small,high,unacc
med,vhigh,5more,more,med,high,acc
med,high,2,2,small,low,unacc
med,high,2,2,big,high,unacc
med,high,2,4,big,high,acc
med,high,2,more,med,low,unacc
med,high,3,4,small,high,acc
med,high,4,2,small,med,unacc
med,med,2,4,med,med,acc
med,med,4,2,small,high,unacc
low,vhigh,2,4,med,med,unacc
low,vhigh,2,4,big,high,acc
low,vhigh,3,more,med,low,unacc
low,high,2,2,med,med,unacc
low,high,4,2,med,high,unacc
low,high,5more,more,small,high,acc
low,med,4,2,med,low,unacc
low,med,5more,2,small,high,unacc
low,low,3,more,med,low,unacc
生成的dataset如下:(查看dataset定义可以链接:http://blog.csdn.net/fansy1990/article/details/8443342

原始数据为:"vhigh,vhigh,2,2,small,med,unacc";"vhigh,high,3,4,med,low,unacc";"vhigh,low,3,2,med,high,unacc",则生成的List<Instance> instances 为:

{0:1.0,1:1.0,2:1.0,4:2.0,6:1.0};{0:1.0,1:2.0,3:2.0,5:2.0,6:1.0};{0:1.0,1:3.0,5:1.0,6:1.0}

注意:当属性的取值范围编码为0.0时,则不会显示出来,比如第一个数据的属性5为med,其编码为0.0,在instances[0]中就没有5:0.0的数据;可以用下面的代码进行测试:

package org.fansy.forest.test;

import java.io.IOException;
import java.util.List;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.classifier.df.data.DataConverter;
import org.apache.mahout.classifier.df.data.Dataset;
import org.apache.mahout.classifier.df.data.Instance;

import com.google.common.collect.Lists;

public class TestDataConverter {

	/**
	 * 测试 DataConverter类的convert() method
	 * 对应的car_test.info 为
	 * { med,vhigh,high,low
	 * 	 med,vhigh,high,low
	 * 	 3,	2,	5more, 4
	 * 	 2,more,4
	 * 	 med,big,small
	 *   med,high,low
	 *   good,unacc,acc,vgood
	 * }
	 * 若为第一个则不显示,且对应编码为0.0,1.0,2.0,3.0,4.0 
	 * @param args
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
		Path dsPath=new Path("/home/fansy/workspace/MahTestDemo/car_test.info");
		Dataset ds=Dataset.load(new Configuration(), dsPath);
		System.out.println(ds.getLabelString(0.0));
		DataConverter converter=new DataConverter(ds);
		String[] values={"vhigh,vhigh,2,2,small,med,unacc",   
				"vhigh,high,3,4,med,low,unacc",		
				"vhigh,low,3,2,med,high,unacc"};                          
		List<Instance> list=Lists.newArrayList();
		for(int i=0;i<values.length;i++){
			list.add(converter.convert(values[i]));
		}
		String str=toStr(list);
			System.out.println(str);
	 /* output :
		
		good
		{0:1.0,1:1.0,2:1.0,4:2.0,6:1.0}
		{0:1.0,1:2.0,3:2.0,5:2.0,6:1.0}
		{0:1.0,1:3.0,5:1.0,6:1.0}
		*/
	}
	
	public static String toStr(List<Instance> list){
		StringBuffer sb=new StringBuffer();

		for(int i=0;i<list.size();i++){
			sb.append(list.get(i).toString()+"\n");
		}
		return sb.toString().substring(0,sb.length()-1);
	}

}
不过需要在eclipse建立的MR工程导入这个jar包http://download.csdn.net/detail/fansy1990/5030740

3. cleanup():

(1)创建data:Data data = new Data(getDataset(), instances);

(2)创建Bagging(此类后面会有细说):Bagging bagging = new Bagging(getTreeBuilder(), data);

(3)根据设定的numTrees个数依次输出每棵树(非准确代码,只代表大概意思):

for(int index=0;index<numTrees;index++){
      Node tree=bagging.build();
      write(index,tree);
}
查看Bagging的build方法如下:

public Node build(Random rng) {
    log.debug("Bagging...");
    Arrays.fill(sampled, false);
    Data bag = data.bagging(rng, sampled);   
    log.debug("Building...");
    return treeBuilder.build(rng, bag);
  }
这里有两个操作:

<1> data 重新赋值了,使用了data.bagging()方法,打开此方法可以看到:

public Data bagging(Random rng, boolean[] sampled) {
    int datasize = size();
    List<Instance> bag = Lists.newArrayListWithCapacity(datasize);
    
    for (int i = 0; i < datasize; i++) {
      int index = rng.nextInt(datasize);
      bag.add(instances.get(index));
      sampled[index] = true;
    }
    return new Data(dataset, bag);
  }
上面的操作就是随机的选出原始data中的若干个进行复制到其大小和原始数据一样的数据进行返回;
<2>treeBuilder.build()方法建树(不得不说这个方法最长了):

(首先说明一点,这里使用的数据全部是离散的,所以在代码中也只考虑离散的部分)

方法进入后首先判断两个参数,可以忽略,下面也只是主要的操作:

<2.1>使用isIdentical(data)进行判断,如果符合则直接输出叶子节点。这个方法判断的是否所有data的属性(不包括已被选择的属性)值是否和data[0]一样;

<2.2>使用data.identicalLabel()进行判断,如果符合条件,则同样直接输出叶子节点。这个方法判断的是否全部的data都是属于同一个类;

<2.3>int[] attributes=randomAttributes(rng,selected,m)从全部data中的属性中随机选择m个属性;

igSplit =new OptIgSplit() % 初始化data数据统计类

 for (int attr : attributes) {
      Split split = igSplit.computeSplit(data, attr);
      if (best == null || best.getIg() < split.getIg()) {
        best = split;
      }
    }
计算“误差”最小的属性值(即m个属性中的一个),何为“误差”,这个就要看igSplit.computeSplit()方法了,这个方法调用了categoricalSplit()方法,

所以看categoricalSplit()方法。这个方法其实也就是计算了一个ig(double)值而已

ig=hy-hyx,

hy=entropy(count[],dataSize)

hy代表所有data的熵,每个label的概率使用每个label(最终的类标号)的记录数除以总记录数;

hyxi=(size(xi)/size(Y))*entropy(counxi[],size(xi));

总的来说可以理解为这m个属性中一个对最终的label分类的影响最大的属性即可,比如为mm;

<2.4>double[] values =data.values(best.getAttr),即上面的mm属性的取值范围,比如对于属性5,则values[] 对应为{0.0,2.0,1.0}(我自己做测试的时候是这个)为什么这里不是按顺序呢?因为在data里面存储这个的时候使用了HashSet,导致顺序不一定,但是一旦顺序被确定了,后面都是按照这个顺序了。

<2.5>Data[] subsets = new Data[values.length];subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index])); 这两句的意思就是把data按mm属性分为n部分,n值为mm属性的取值范围的个数,比如第5个属性的范围数目为3,则n就是3;

<2.6>继续判断subsets[i]

for(int i=0;i<subsets.size();i++){

childen[i]=build(rng,subsets[i])

}


继续参考下篇:mahout决策树之Partial Implementation源码分析 part3_


分享,快乐,成长


转载请注明出处:http://blog.csdn.net/fansy1990


分享到:
评论

相关推荐

    mahout Algorithms源码分析

    mahoutAlgorithms源码分析 mahout代码解析

    mahout-core-0.7-job.jar

    用于测试mahout中的决策树 ,即Partial Implementation用到的测试jar包。所谓的测试其实也只是把相应的数据可以打印出来,方便单机调试,理解算法实现原理而已。

    Mahout源码

    Mahout是一个Java的机器学习库。Mahout的完整源代码,基于maven,可以轻易导入工程中

    mahout源码

    mahout,朴素贝叶斯分类,中文分词,mahout,朴素贝叶斯分类,中文分词,

    Mahout_In_Action(源码)

    Mahout in Action 源码,结合Mahout in Action 学习数据挖掘,比较容易理解

    Mahout算法调用展示平台2.1-part3

    第三部分 功能主要包括四个方面:集群配置、集群算法监控、Hadoop模块、Mahout模块。 详情参考《Mahout算法调用展示平台2.1》

    mahout in action中的源码

    该资源是mahout in action 中的源码,适用于自学,可在github下载:https://github.com/tdunning/MiA

    Mahout算法调用展示平台2.1-part2

    第二部分 功能主要包括四个方面:集群配置、集群算法监控、Hadoop模块、Mahout模块。 详情参考《Mahout算法调用展示平台2.1》

    mahout-distribution-0.5-src.zip mahout 源码包

    mahout-distribution-0.5-src.zip mahout 源码包

    Apache.Mahout.Beyond.MapReduce.1523775785

    Apache Mahout: Beyond MapReduce. Distributed algorithm design This book is about designing mathematical and Machine Learning algorithms using the Apache Mahout "Samsara" platform. The material takes...

    svd mahout算法

    svd算法的工具类,直接调用出结果,调用及设置方式参考http://blog.csdn.net/fansy1990 &lt;mahout源码分析之DistributedLanczosSolver(七)&gt;

    mahout 0.7 src

    mahout 0.7 src, mahout 源码包, hadoop 机器学习子项目 mahout 源码包

    mahout-distribution-0.5.tar.gz + 源码

    mahout实战 源码 mahout实战 配套 mahout-distribution-0.5.tar.gz 版本

    MAHOUT源码包

    Mahout 是 Apache Software Foundation(ASF) 旗下的一个开源项目,提供一些可扩展的机器学习领域经典算法的实现,旨在帮助开发人员更加方便快捷地创建智能应用程序。Mahout包含许多实现,包括聚类、分类、推荐过滤...

    mahout api 学习资料

    mahout_help,mahout的java api帮助文档,可以帮你更轻松掌握mahout

    mahout0.9源码(支持hadoop2)

    mahout0.9的源码,支持hadoop2,需要自行使用mvn编译。mvn编译使用命令: mvn clean install -Dhadoop2 -Dhadoop.2.version=2.2.0 -DskipTests

    hadoop 2.4.1+mahout0.9环境搭建

    最新的HADOOP2.4.1版本不支持MAHOUT 0.9,本MAHOUT 0.9是经过修改官方MAHOUT 0.9源代码后的源码包,可直接导入ECLIPS中编译、安装,也可通过命令行进行。

    Mahout in Action

    PART 3 CLASSIFICATION ........................................................225 13 ■ Introduction to classification 227 14 ■ Training a classifier 255 15 ■ Evaluating and tuning a classifier 281 ...

    maven_mahout_template-mahout-0.8

    maven_mahout_template-mahout-0.8

Global site tag (gtag.js) - Google Analytics