Spark性能优化:基于分区进行操作

我的原创地址:https://dongkelun.com/2018/09/02/sparkMapPartitions/

前言(摘自Spark快速大数据分析)

基于分区对数据进行操作可以让我们避免为每个数据元素进行重复的配置工作。诸如打开数据库连接或创建随机数生成器等操作,都是我们应当尽量避免为每个元素都配置一次的工作。Spark 提供基于分区的map 和foreach,让你的部分代码只对RDD 的每个分区运行一次,这样可以帮助降低这些操作的代价。
当基于分区操作RDD 时,Spark 会为函数提供该分区中的元素的迭代器。返回值方面,也返回一个迭代器。除mapPartitions() 外,Spark 还有一些别的基于分区的操作符,见下表:

函数名 调用所提供的 返回的 对于RDD[T]的函数签名
mapPartitions() 该分区中元素的迭代器 返回的元素的迭代器 f: (Iterator[T]) → Iterator[U]
mapPartitionsWithIndex() 分区序号,以及每个分区中的元素的迭代器 返回的元素的迭代器 f: (Int, Iterator[T]) → Iterator[U]
foreachPartitions() 元素迭代器 f: (Iterator[T]) → Unit

首先给出上面三个算子的具体代码示例。

1、mapPartitions

与map类似,不同点是map是对RDD的里的每一个元素进行操作,而mapPartitions是对每一个分区的数据(迭代器)进行操作,具体可以看上面的表格。
下面同时用map和mapPartitions实现WordCount,看一下mapPartitions的用法以及与map的区别

package com.dkl.leanring.spark.test

import org.apache.spark.sql.SparkSession

object WordCount {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().master("local").appName("WordCount").getOrCreate()
    val sc = spark.sparkContext

    val input = sc.parallelize(Seq("Spark Hive Kafka", "Hadoop Kafka Hive Hbase", "Java Scala Spark"))
    val words = input.flatMap(line => line.split(" "))
    val counts = words.map(word => (word, 1)).reduceByKey { (x, y) => x + y }
    println(counts.collect().mkString(","))
    val counts1 = words.mapPartitions(it => it.map(word => (word, 1))).reduceByKey { (x, y) => x + y }
    println(counts1.collect().mkString(","))

    spark.stop()

  }
}
image

2、mapPartitionsWithIndex

和mapPartitions一样,只是多了一个分区的序号,下面的代码实现了将Rdd的元素数字n变为(分区序号,n*n)

val rdd = sc.parallelize(1 to 10, 5)
val res = rdd.mapPartitionsWithIndex((index, it) => {
  it.map(n => (index, n * n))
})
println(res.collect().mkString(" "))
image

3、foreachPartitions

foreachPartitions和foreach类似,不同点也是foreachPartitions基于分区进行操作的

rdd.foreachPartition(it => it.foreach(println))

4、关于如何避免重复配置

下面以打开数据库连接举例,需求是这样的:
读取mysql表里的数据,做了一系列数据处理得到结果之后,需要修改我们mysql表里的每一条数据的状态,代表程序已经处理过了,下次不需要处理了。

4.1 表

以最简单表结构示例

字段名 注释
ID 主键、唯一标识
ISDEAL 程序是否处理过

建表语句

CREATE TABLE test (
    id INTEGER NOT NULL AUTO_INCREMENT,
    isdeal INTEGER DEFAULT 0 NOT NULL,
    primary key(id) 
)
ENGINE=InnoDB
DEFAULT CHARSET=utf8
COLLATE=utf8_general_ci;
image

4.2 不基于分区操作

一共用两种方法

4.2.1 第一种

package com.dkl.leanring.spark.sql.mysql

import org.apache.spark.sql.SparkSession

object UpdateMysqlDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("UpdateMysqlDemo").master("local").getOrCreate()

    val database_url = "jdbc:mysql://192.168.44.128:3306/test?useUnicode=true&characterEncoding=utf-8&useSSL=false"
    val user = "root"
    val password = "Root-123456"
    val df = spark.read
      .format("jdbc")
      .option("url", database_url)
      .option("dbtable", "(select * from test where isDeal=0 limit 5)a")
      .option("user", user)
      .option("password", password)
      .option("driver", "com.mysql.jdbc.Driver")
      .option("numPartitions", "5")
      .option("partitionColumn", "ID")
      .option("lowerBound", "1")
      .option("upperBound", "10")
      .load()

    import java.sql.{ Connection, DriverManager, ResultSet };
    df.rdd.foreach(row => {
      val conn = DriverManager.getConnection(database_url, user, password)
      try {
        // Configure to be Read Only
        val statement = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
        val prep = conn.prepareStatement(s"update test set isDeal=1 where id=?")

        val id = row.getAs[Int]("id")
        prep.setInt(1, id)
        prep.executeUpdate

      } catch {
        case e: Exception => e.printStackTrace
      } finally {
        conn.close()
      }

    })

    spark.stop()
  }
}


  • 上面的代码,取isDeal=0的前五条,因为造的数据量少,所以只取了前五条,然后指定了五个分区,这里只是一个代码示例,实际工作中应该数据量很大,每个分区肯定不止一条数据

根据上面的代码,看到用这种方式的缺点是每一个元素都要创建一个数据库连接,这样频繁创建连接、关闭连接,在数据量很大的情况下,势必会对性能产生影响,但是优点是不用担心内存不够。

4.2.2 第二种

val conn = DriverManager.getConnection(database_url, user, password)
try {
  val statement = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
  val prep = conn.prepareStatement(s"update test set isDeal=1 where id=?")

  df.select("id").collect().foreach(row => {
    val id = row.getAs[Int]("id")
    prep.setInt(1, id)
    prep.executeUpdate

 })

} catch {
  case e: Exception => e.printStackTrace
}

这种方式的缺点是把要操作的数据全部转成scala数组,仅在Driver端执行,但是如果数据量很大的话,可能因为Driver内存不够大而抛出异常,优点是只建立一次数据库连接,在数据量不是特别大,且确定Driver的内存足够的时候,可以采取这种方式。

4.3 基于分区的方式

df.rdd.foreachPartition(it => {
  val conn = DriverManager.getConnection(database_url, user, password)
  try {
    val statement = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    val prep = conn.prepareStatement(s"update test set isDeal=1 where id=?")
    it.foreach(row => {
      val id = row.getAs[Int]("id")
      prep.setInt(1, id)
      prep.executeUpdate
    })

  } catch {
    case e: Exception => e.printStackTrace
  } finally {
    conn.close()
  }

})

这种方式就结合了上面两种方式的优点,基于分区的方式使得创建连接的次数不会那么多,然后每个分区的数据也可以平均分到每个节点的executor上,避免了内存不足产生的异常,当然前提是要合理的分配分区数,既不能让分区数太多,也不能让每个分区的数据太多,还有要注意数据倾斜的问题,因为当数据倾斜造成某个分区数据量太大同样造成OOM(内存溢出)。

4.4 其他

上面只是列举了一个例子,且只是在foreach这样的action算子里体现的,当然肯定也有需求需是在transformation里进行如数据库的连接这样的操作,大家可类比的使用mapPartitions即可

5、其他优点(未证实)

网上有很多博客提到mapPartitions还有其他优点,就是mapPartitions比map快,性能高,原因是因为map的function会执行rdd.count次,而mapPartitions的function则执行rdd.numPartitions次。
但我并这么认为,因mapPartitions的function和map的function是不一样的,mapPartitions里的迭代器的每个元素还是都要执行一遍的,实际上也是执行rdd.count次。
下面以其中一篇博客举例(只列出优点,大部分博客上的写的都一样的,应该出自同一篇博客吧~)

image

博客地址:Spark---算子调优之MapPartitions提升Map类操作性能

  • 至于mapPartitions是否真的比map处理速度快,如果我有时间验证得到结果的话,我再更新一下这个地方~

相关阅读

推荐阅读更多精彩内容