三种特征选择方法及Spark MLlib调用实例(Scala Java python)

整理文档很辛苦,赏杯茶钱您下走!

免费阅读已结束,点击下载阅读编辑剩下 ...

阅读已结束,您可以下载文档离线阅读编辑

资源描述

三种特征选择方法及SparkMLlib调用实例(Scala/Java/python)VectorSlicer算法介绍:VectorSlicer是一个转换器输入特征向量,输出原始特征向量子集。VectorSlicer接收带有特定索引的向量列,通过对这些索引的值进行筛选得到新的向量集。可接受如下两种索引1.整数索引,setIndices()。2.字符串索引代表向量中特征的名字,此类要求向量列有AttributeGroup,因为该工具根据Attribute来匹配名字字段。指定整数或者字符串类型都是可以的。另外,同时使用整数索引和字符串名字也是可以的。不允许使用重复的特征,所以所选的索引或者名字必须是没有独一的。注意如果使用名字特征,当遇到空值的时候将会报错。输出将会首先按照所选的数字索引排序(按输入顺序),其次按名字排序(按输入顺序)。示例:假设我们有一个DataFrame含有userFeatures列:userFeatures------------------[0.0,10.0,0.5]userFeatures是一个向量列包含3个用户特征。假设userFeatures的第一列全为0,我们希望删除它并且只选择后两项。我们可以通过索引setIndices(1,2)来选择后两项并产生一个新的features列:userFeatures|features------------------|-----------------------------[0.0,10.0,0.5]|[10.0,0.5]假设我们还有如同[f1,f2,f3]的属性,那可以通过名字setNames(f2,f3)的形式来选择:userFeatures|features------------------|-----------------------------[0.0,10.0,0.5]|[10.0,0.5][f1,f2,f3]|[f2,f3]调用示例:Scala:[plain]viewplaincopyimportjava.util.Arraysimportorg.apache.spark.ml.attribute.{Attribute,AttributeGroup,NumericAttribute}importorg.apache.spark.ml.feature.VectorSlicerimportorg.apache.spark.ml.linalg.Vectorsimportorg.apache.spark.sql.Rowimportorg.apache.spark.sql.types.StructTypevaldata=Arrays.asList(Row(Vectors.dense(-2.0,2.3,0.0)))valdefaultAttr=NumericAttribute.defaultAttrvalattrs=Array(f1,f2,f3).map(defaultAttr.withName)valattrGroup=newAttributeGroup(userFeatures,attrs.asInstanceOf[Array[Attribute]])valdataset=spark.createDataFrame(data,StructType(Array(attrGroup.toStructField())))valslicer=newVectorSlicer().setInputCol(userFeatures).setOutputCol(features)slicer.setIndices(Array(1)).setNames(Array(f3))//orslicer.setIndices(Array(1,2)),orslicer.setNames(Array(f2,f3))valoutput=slicer.transform(dataset)println(output.select(userFeatures,features).first())Java:[java]viewplaincopyimportjava.util.List;importcom.google.common.collect.Lists;importorg.apache.spark.ml.attribute.Attribute;importorg.apache.spark.ml.attribute.AttributeGroup;importorg.apache.spark.ml.attribute.NumericAttribute;importorg.apache.spark.ml.feature.VectorSlicer;importorg.apache.spark.ml.linalg.Vectors;importorg.apache.spark.sql.Dataset;importorg.apache.spark.sql.Row;importorg.apache.spark.sql.RowFactory;importorg.apache.spark.sql.types.*;Attribute[]attrs=newAttribute[]{NumericAttribute.defaultAttr().withName(f1),NumericAttribute.defaultAttr().withName(f2),NumericAttribute.defaultAttr().withName(f3)};AttributeGroupgroup=newAttributeGroup(userFeatures,attrs);ListRowdata=Lists.newArrayList(RowFactory.create(Vectors.sparse(3,newint[]{0,1},newdouble[]{-2.0,2.3})),RowFactory.create(Vectors.dense(-2.0,2.3,0.0)));DatasetRowdataset=spark.createDataFrame(data,(newStructType()).add(group.toStructField()));VectorSlicervectorSlicer=newVectorSlicer().setInputCol(userFeatures).setOutputCol(features);vectorSlicer.setIndices(newint[]{1}).setNames(newString[]{f3});//orslicer.setIndices(newint[]{1,2}),orslicer.setNames(newString[]{f2,f3})DatasetRowoutput=vectorSlicer.transform(dataset);System.out.println(output.select(userFeatures,features).first());Python:[python]viewplaincopyfrompyspark.ml.featureimportVectorSlicerfrompyspark.ml.linalgimportVectorsfrompyspark.sql.typesimportRowdf=spark.createDataFrame([Row(userFeatures=Vectors.sparse(3,{0:-2.0,1:2.3}),),Row(userFeatures=Vectors.dense([-2.0,2.3,0.0]),)])slicer=VectorSlicer(inputCol=userFeatures,outputCol=features,indices=[1])output=slicer.transform(df)output.select(userFeatures,features).show()RFormula算法介绍:RFormula通过R模型公式来选择列。支持R操作中的部分操作,包括‘~’,‘.’,‘:’,‘+’以及‘-‘,基本操作如下:1.~分隔目标和对象2.+合并对象,“+0”意味着删除空格3.:交互(数值相乘,类别二值化)4..除了目标外的全部列假设a和b为两列:1.y~a+b表示模型y~w0+w1*a+w2*b其中w0为截距,w1和w2为相关系数。2.y~a+b+a:b–1表示模型y~w1*a+w2*b+w3*a*b,其中w1,w2,w3是相关系数。RFormula产生一个向量特征列以及一个double或者字符串标签列。如果类别列是字符串类型,它将通过StringIndexer转换为double类型。如果标签列不存在,则输出中将通过规定的响应变量创造一个标签列。示例:假设我们有一个DataFrame含有id,country,hour和clicked四列:id|country|hour|clicked---|---------|------|---------7|US|18|1.08|CA|12|0.09|NZ|15|0.0如果我们使用RFormula公式clicked~country+hour,则表明我们希望基于country和hour预测clicked,通过转换我们可以得到如下DataFrame:id|country|hour|clicked|features|label---|---------|------|---------|------------------|-------7|US|18|1.0|[0.0,0.0,18.0]|1.08|CA|12|0.0|[0.0,1.0,12.0]|0.09|NZ|15|0.0|[1.0,0.0,15.0]|0.0调用示例:Scala:[plain]viewplaincopyimportorg.apache.spark.ml.feature.RFormulavaldataset=spark.createDataFrame(Seq((7,US,18,1.0),(8,CA,12,0.0),(9,NZ,15,0.0))).toDF(id,country,hour,clicked)valformula=newRFormula().setFormula(clicked~country+hour).setFeaturesCol(features).setLabelCol(label)valoutput=formula.fit(dataset).transform(dataset)output.select(features,label).show()Java:[java]viewplaincopyimportjava.util.Arrays;importjava.util.List;importorg.apache.spark.ml.feature.RFormula;importorg.apache.spark.sql.Dataset;importorg.apache.spark.sql.Row;importorg.apache.spark.sql.RowFactory;importorg.apache.spark.sql.types.StructField;importorg.apache.spark.sql.types.StructType;importstaticorg.apache.spark.sql.types.DataTypes.*;StructTypeschema=createStructType(newStructField[]{createStructField(id,IntegerType,false),createStructField(country,StringType,false),createStructF

1 / 7
下载文档,编辑使用

©2015-2020 m.777doc.com 三七文档.

备案号:鲁ICP备2024069028号-1 客服联系 QQ:2149211541

×
保存成功