# 第一步：特征工程

``````SibSp,Parch,Cabin_No,Cabin_Yes,Embarked_C,Embarked_Q,Embarked_S,Sex_female,Sex_male,Pclass_1,Pclass_2,Pclass_3,Age_scaled,Fare_scaled,Survived
1,0,1,0,0,0,1,0,1,0,0,1,-0.56136323207,-0.502445171436,0
1,0,0,1,1,0,0,1,0,1,0,0,0.613181832266,0.786845293588,1
0,0,1,0,0,0,1,1,0,0,0,1,-0.267726965986,-0.488854257585,1
1,0,0,1,0,0,1,1,0,1,0,0,0.392954632703,0.420730236069,1
0,0,1,0,0,0,1,0,1,0,0,1,0.392954632703,-0.486337421687,0
0,0,1,0,0,1,0,0,1,0,0,1,-0.427101530014,-0.478116428909,0
``````

``````SibSp,Parch,Cabin_No,Cabin_Yes,Embarked_C,Embarked_Q,Embarked_S,Sex_female,Sex_male,Pclass_1,Pclass_2,Pclass_3,Age_scaled,Fare_scaled
0,0,1,0,0,1,0,0,1,0,0,1,0.307534608854,-0.496637106488
1,0,1,0,0,0,1,1,0,0,0,1,1.25623006816,-0.511497104137
0,0,1,0,0,1,0,0,1,0,1,0,2.39466461933,-0.463334726327
0,0,1,0,0,0,1,0,1,0,0,1,-0.261682666729,-0.481703633213
1,1,1,0,0,0,1,1,0,0,0,1,-0.641160850452,-0.416740425935
``````

# 第二步：用Spark MLlib自带LR建模并预测

``````public class TitanicLogisticRegressionWithElasticNet {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("JavaLogisticRegressionWithElasticNetExample")
.getOrCreate();

// \$example on\$
// Load training data
//    System.out.println("\n------- Read csv data:");
//    training.printSchema();
//    training.show(5, false);

String origStr = "SibSp,Parch,Cabin_No,Cabin_Yes,Embarked_C,Embarked_Q,Embarked_S,Sex_female,Sex_male,Pclass_1,Pclass_2,Pclass_3,Age_scaled,Fare_scaled";
String[] arrOrig = origStr.split(",");
VectorAssembler vectorAssem = new VectorAssembler()
.setInputCols(arrOrig).setOutputCol("features");
Dataset<Row> feaTrain = vectorAssem.transform(training);
//    System.out.println("\n------- assembled out:");
//    feaTrain.printSchema();
//    feaTrain.show(5, false);
feaTrain = feaTrain.select("features", "Survived");
System.out.println("\n------- after selected:");
feaTrain.printSchema();
feaTrain.show(5, false);

LogisticRegression lr = new LogisticRegression()
.setLabelCol("Survived")
.setMaxIter(10000)
.setRegParam(0.0)
.setElasticNetParam(0.8);

// Fit the model
LogisticRegressionModel lrModel = lr.fit(feaTrain);

// Print the coefficients and intercept for logistic regression
System.out.println("\n+++++++++ Binomial logistic regression's Coefficients: "
+ lrModel.coefficients() + "\nBinomial Intercept: " + lrModel.intercept());

Dataset<Row> feaTest = vectorAssem.transform(testData);
feaTest = feaTest.select("features");
Dataset<Row> result = lrModel.transform(feaTest);
//    System.out.println("\n------- after predict:");
//    result.printSchema();
//    result.show(5, false);
//result = result.withColumn("PassengerId", result.col("prediction"));
result = result.withColumnRenamed("prediction", "Survived");
System.out.println("\n====== after add and rename:");
result.printSchema();
result.show(5, false);

spark.stop();
}
}
``````

### 列举写代码时遇到的问题

• 问题1：读取csv格式文件
虽然之前阅读DataFrame的API文档的时候官网上有讲DataFrame可以从csv格式的文件中生成DataFrame，但是我在实际写代码中遇到了以下几个问题：

``````Dataset<Row> training = spark.read().format("csv")
``````

``````Dataset<Row> training = spark.read().format("csv").option("header", true)
``````

``````Dataset<Row> training = spark.read().format("csv").option("header", true).option("inferSchema", true)
``````

• 问题2： 怎么产生LogisticRegression所需要的特征列向量
``````    String origStr = "SibSp,Parch,Cabin_No,Cabin_Yes,Embarked_C,Embarked_Q,Embarked_S,Sex_female,Sex_male,Pclass_1,Pclass_2,Pclass_3,Age_scaled,Fare_scaled";
String[] arrOrig = origStr.split(",");
VectorAssembler vectorAssem = new VectorAssembler()
.setInputCols(arrOrig).setOutputCol("features");
Dataset<Row> feaTrain = vectorAssem.transform(training);
//    System.out.println("\n------- assembled out:");
//    feaTrain.printSchema();
//    feaTrain.show(5, false);
feaTrain = feaTrain.select("features", "Survived");
``````

# 第三步：对比Spark的MLlib LogisticRegression结果

``````LogisticRegression lr = new LogisticRegression()
.setLabelCol("Survived")
.setMaxIter(10000)
.setRegParam(0.0)
.setElasticNetParam(0.8);
``````

Spark的和我自己写的预测结果竟然完全一致

# 总结：

• 感谢google、感谢Stack Overflow、感谢Spark官网的Programming Guide以及example
• 强烈鄙视http://spark.apache.org/docs/latest/api/java/index.html， 如果我打开的方式没有错误的话，感觉它完全没啥用（如果是我打开的方式不对也麻烦在评论区帮我指正，先感谢了！）。
• 源码也很有帮助，就是scala的语法还要再学学，不然看起来不顺畅。
• 对于RDD或者DataFrame的Transformation玩的不够溜，这个确实是基础，也接下来需要找机会多学习和实践。