关于spark的mllib学习总结(Java版)
本篇博客主要讲述如何利⽤spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所⽰:
加载数据对于数据的加载或保存,mllib提供了MLUtils包,其作⽤是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采⽤spark中提供的数据sample_,其有⼀百个数据样本,658个特征。具体的数据形式如图所⽰:
加载libsvm
JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();
LabeledPoint数据类型是对应与libsvmfile格式⽂件, 具体格式为: Lable(double类型),vector(Vector类型)转化dataFrame数据类型
JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, pty()),
new StructField("features", new VectorUDT(), false, pty()),
});
SQLContext jsql = new SQLContext(sc);
DataFrame df = ateDataFrame(jrow, schema);
DataFrame:DataFrame是⼀个以命名列⽅式组织的分布式数据集。在概念上,它跟关系型数据库中的⼀张表或者1个Python(或者R)中的data frame⼀样,但是⽐他们更优化。DataFrame可以根据结构化的数据⽂件、hive表、外部数据库或者已经存在的RDD构造。 SQLContext:spark sql所有功能的⼊⼝是SQLContext类,或者SQLContext的⼦类。为了创建⼀个基本的SQLContext,需要⼀个SparkContext。特征提取特征归⼀化处理
StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
DataFrame scalerDF = scaler.fit(df).transform(df);
scaler.save(this.scalerModelPath);
利⽤卡⽅统计做特征提取
ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeat
ures").setLabelCol("label").setOutputCol("selectedFeatures"); ChiSqSelectorModel chiModel = selector.fit(scalerDF);
DataFrame selectedDF = ansform(scalerDF).select("label", "selectedFeatures");
chiModel.save(this.featureSelectedModelPath);
训练机器学习模型(以SVM为例)
//转化为LabeledPoint数据类型,训练模型
JavaRDD<Row> selectedrows = selectedDF.javaRDD();
JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());
//训练SVM模型, 并保存
int numIteration = 200;
SVMModel model = ain(trainset.rdd(), numIteration);
model.clearThreshold();
model.save(sc, this.mlModelPath);
// LabeledPoint数据类型转化为Row
static class LabeledPointToRow implements Function<LabeledPoint, Row> {
public Row call(LabeledPoint p) throws Exception {
double label = p.label();
Vector vector = p.features();
ate(label, vector);
}
}
//Rows数据类型转化为LabeledPoint
static class RowToLabel implements Function<Row, LabeledPoint> {
public LabeledPoint call(Row r) throws Exception {
python转java代码Vector features = r.getAs(1);
double label = r.getDouble(0);
return new LabeledPoint(label, features);
}
}
测试新的样本测试新的样本前,需要将样本做数据的转化和特征提取的⼯作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下: //初始化spark
SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
conf.set("", "2147480000");
SparkContext sc = new SparkContext(conf);
//加载测试数据
JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();
//转化DataFrame数据类型
JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, pty()),
new StructField("features", new VectorUDT(), false, pty()),
});
SQLContext jsql = new SQLContext(sc);
DataFrame df = ateDataFrame(jrow, schema);
//数据规范化
StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
DataFrame scalerDF = scaler.fit(df).transform(df);
//特征选取
ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
DataFrame selectedDF = ansform(scalerDF).select("label", "selectedFeatures");
测试数据集
SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
SVMModel model;
public Prediction(SVMModel model){
}
public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
Double score = model.predict(p.features());
return new Tuple2<Double , Double>(score, p.label());
}
}
计算准确率
double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / unt();
System.out.println(accuracy);
static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
public Boolean call(Tuple2<Double, Double> t) throws Exception {
double score = t._1();
double label = t._2();
System.out.print("score:" + score + ", label:"+ label);
if(score >= 0.0 && label >= 0.0) return true;
else if(score < 0.0 && label < 0.0) return true;
else return false;
}
}

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。