Custom Accumulator in Spark 2.1

Custom Accumulator in Spark 2.1

Accumulator can sum or count number in spark tasks over all nodes, and then return the final result. For example, LongAccumulator.
But when I want to accumulate value by custom way, it can be implemented by extends AccumulatorV2 in spark 2.1.

Below, a MapLongAccumulator is implemented to count the numbers of several values seperately.

import java.util
import java.util.Collections

import org.apache.spark.util.AccumulatorV2
import scala.collection.JavaConversions._


class MapLongAccumulator[T] extends AccumulatorV2[T, java.util.Map[T, java.lang.Long]] {
    private val _map: java.util.Map[T, java.lang.Long] = Collections.synchronizedMap(new util.HashMap[T, java.lang.Long]())

    override def isZero: Boolean = _map.isEmpty

    override def copyAndReset(): MapLongAccumulator[T] = new MapLongAccumulator

    override def copy(): MapLongAccumulator[T] = {
        val newAcc = new MapLongAccumulator[T]
        _map.synchronized {
            newAcc._map.putAll(_map)
        }
        newAcc
    }

    override def reset(): Unit = _map.clear()

    override def add(v: T): Unit = _map.synchronized {
        val old = _map.put(v, 1l)
        if (old != null) {
            _map.put(v, 1 + old)
        }
    }

    override def merge(other: AccumulatorV2[T, java.util.Map[T, java.lang.Long]]): Unit = other match {
        case o: MapLongAccumulator[T] => {
            for ((k,v) <- o.value) {
                val old = _map.put(k, v)
                if(old != null){
                    _map.put(k, old.longValue() + v)
                }
            }
        }
        case _ => throw new UnsupportedOperationException(
            s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
    }

    override def value: java.util.Map[T, java.lang.Long] = _map.synchronized {
        java.util.Collections.unmodifiableMap(new util.HashMap[T, java.lang.Long](_map))
    }

    def setValue(newValue: java.util.Map[T, java.lang.Long]): Unit = {
        _map.clear()
        _map.putAll(newValue)
    }
}

Use case:

val accumulator = new MapLongAccumulator[String]()
sc.register(accumulator, "CustomAccumulator")
someRdd.map(a => {
            ...
            val convertResult = convert(a)
            convertResultAccumulator.add(convertResult.toString)
            ...
        }).repartition(1).saveAsTextFile(outputPath)
System.out.println(accumulator.value)   // the several convertResults will be counted seperately, and then get the output value.

AccumulatorV2 is new accumulator class since spark 2.0.0.

public abstract class AccumulatorV2<IN,OUT>
extends Object
implements scala.Serializable

The base class for accumulators, that can accumulate inputs of type IN, and produce output of type OUT.
OUT should be a type that can be read atomically (e.g., Int, Long), or thread-safely (e.g., synchronized collections) because it will be read from other threads.

Methods of AccumulatorV2:
merge method doesn't need thread-safe, it is called only in one thread when task completion.(handleTaskCompletion in DAGScheduler)

org.apache.spark.scheduler.DAGScheduler.scala
/**
   * Responds to a task finishing. This is called inside the event loop so it assumes that it can
   * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
   */
  private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
    val task = event.task
    val taskId = event.taskInfo.id
    val stageId = task.stageId
    val taskType = Utils.getFormattedClassName(task)

    outputCommitCoordinator.taskCompleted(
      stageId,
      task.partitionId,
      event.taskInfo.attemptNumber, // this is a task attempt number
      event.reason)

    // Reconstruct task metrics. Note: this may be null if the task has failed.
    val taskMetrics: TaskMetrics =
      if (event.accumUpdates.nonEmpty) {
        try {
          TaskMetrics.fromAccumulators(event.accumUpdates)
        } catch {
          case NonFatal(e) =>
            logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
            null
        }
      } else {
        null
      }

    ......
    val stage = stageIdToStage(task.stageId)
    event.reason match {
      case Success =>
        stage.pendingPartitions -= task.partitionId
        task match {
          case rt: ResultTask[_, _] =>
            // Cast to ResultStage here because it's part of the ResultTask
            // TODO Refactor this out to a function that accepts a ResultStage
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                  ......
           case smt: ShuffleMapTask =>
            val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
            updateAccumulators(event)
            val status = event.result.asInstanceOf[MapStatus]
            val execId = status.location.executorId
            logDebug("ShuffleMapTask finished on " + execId)
            if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
              logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
            } else {
              shuffleStage.addOutputLoc(smt.partitionId, status)
            }
            ......
    case exceptionFailure: ExceptionFailure =>
        // Tasks failed with exceptions might still have accumulator updates.
        updateAccumulators(event)
    ......
}


/**
   * Merge local values from a task into the corresponding accumulators previously registered
   * here on the driver.
   *
   * Although accumulators themselves are not thread-safe, this method is called only from one
   * thread, the one that runs the scheduling loop. This means we only handle one task
   * completion event at a time so we don't need to worry about locking the accumulators.
   * This still doesn't stop the caller from updating the accumulator outside the scheduler,
   * but that's not our problem since there's nothing we can do about that.
   */
  private def updateAccumulators(event: CompletionEvent): Unit = {
    val task = event.task
    val stage = stageIdToStage(task.stageId)
    try {
      event.accumUpdates.foreach { updates =>
        val id = updates.id
        // Find the corresponding accumulator on the driver and update it
        val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
          case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
          case None =>
            throw new SparkException(s"attempted to access non-existent accumulator $id")
        }
        acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])
        // To avoid UI cruft, ignore cases where value wasn't updated
        if (acc.name.isDefined && !updates.isZero) {
          stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
          event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
        }
      }
    } catch {
      case NonFatal(e) =>
        logError(s"Failed to update accumulators for task ${task.partitionId}", e)
    }
  }

org.apache.spark.executor.TaskMetrics.scala
  /**
   * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
   */
  def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
    val tm = new TaskMetrics
    val (internalAccums, externalAccums) =
      accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))

    internalAccums.foreach { acc =>
      val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
      tmAcc.metadata = acc.metadata
      tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
    }

    tm.externalAccums ++= externalAccums
    tm
  }

推荐阅读更多精彩内容