为Spark Deep Learning 集成TFoS

字数 517阅读 907

前言

昨晚睡了12小时,早上起来神清气爽,索性把之前提的一个Issue:Is there any plan to port TensorframeOnSpark(From yahoo) 给尝试着集成进来。 前两天已经添加了一个 TFTextEstimator:为Spark Deep Learning 添加NLP处理实现,不过只能做hyper parameter tuning,做不了真正的分布式训练,所以正好把这个特性加到了这个Estimator里。

使用方法

建议看这篇文章之前,先看为Spark Deep Learning 添加NLP处理实现。 我给TFTextFileEstimator 添加了一个新的参数叫做 runningMode。目前只有两个值: Normal 和 TFoS。

# create a estimator to training where map_fun contains tensorflow's code
estimator = TFTextFileEstimator(inputCol="sentence_matrix", outputCol="sentence_matrix", labelCol="preds",
                                fitParam=[{"epochs": 1, "cluster_size": 2, "batch_size": 1, "model": "/tmp/model"}],
                                runningMode="TFoS",
                                mapFnParam=map_fun)

如果使用TFoS model参数是必须的。并且 map_fun方法也需要做些改造。主要是tensorflow 分布式training 和 单机多device 还是有区别的。

原理

在TFTextEstimator里,通过参数runningMode控制:

        if self.getRunningMode() == "TFoS":
            return self._fitInCluster(dataset, paramMaps)
        else:
            return self._fitInParallel(dataset, paramMaps)

如果是,则走集群模式,否则走并行训练。 我们来看看_fitInCluster:

def _fitInCluster(self, dataset, paramMaps):
        sc = JVMAPI._curr_sc()

        temp_item = dataset.take(1)[0]
        vocab_s = temp_item["vocab_size"]
        embedding_size = temp_item["embedding_size"]

        baseParamMap = self.extractParamMap()
        baseParamDict = dict([(param.name, val) for param, val in baseParamMap.items()])

        args = self._clusterModelDefaultValue(sc, paramMaps[0])
        args["feature"] = self.getInputCol()
        args["label"] = self.getLabelCol()
        args["vacab_size"] = vocab_s
        args["embedding_size"] = embedding_size
        args["params"] = baseParamDict

        cluster = TFCluster.run(sc, self.getMapFnParam(), args, args['cluster_size'], args['num_ps'],
                                args['tensorboard'],
                                TFCluster.InputMode.SPARK)
        cluster.train(dataset.rdd, args["epochs"])
        cluster.shutdown()

很简单,创建 TFCluster对象,并且调用其train方法。 最核心的还是 map_fun函数,这里实现了所有的tf逻辑(除了数据以外)。我后面会单独一个篇幅来讲。在做实现的过程,发现两个问题:

  1. TFoS 最好一个批次的数据会丢失 ,我对应提了一个IssueWhen training, the data of last batch will not be trained
  2. TFoS 没有办法跑在Local模式,所以调试麻烦些,需要跑在spark standalone模式下。

可运行的实例代码在: TFoSTest.py

mapfun函数解析

TFoSTest.py 里的代码兼容单机和集群模式运行。

def map_fun(args={}, ctx=None, _read_data=None):

如果ctx为None,则是单机模式,否则为集群模式。如果是集群模式则直接使用

TFNode.DataFeed(ctx.mgr, True)

获取数据,否则使用 _read_data 获取数据。具体详细参看示例中代码。

结束语

这个只是Demo性质的,单机和集群模式的融合度还不够好,map_fun编写难度还有些。

推荐阅读更多精彩内容