Caffe2核心代码解析系列之五:Blob

介绍

Caffe2中Blob的概念应该来自于Caffe。它是有类型的内存抽象,主要包含两个成员,一为指向存储元素的指针,另一则为此元素的类型(TypeMeta)。这么说来它其实与Tensor好像,本质上它有些赘余,更像是来自Caffe的一种包袱。在笔者已知的框架设计里像Tensorflow/Pytorch/Mxnet等无不是只提供Tensor这么一种有类型的内存抽象。不过在Caffe2中,框架设计者可能是不想它太多余,于是将Serialization(从而将weights存成string)的功能给了它。

Anyway,这样我们在使用Caffe Operator时,会以Blob作为输入、输出(与Caffe一样),只是在Operator内部,一般需要使用Blob的data方法得到指向其元素的指针,然后再将它强制类型转换为合适的类型T(一般为Tensor),再使用它进行各种具体运算。

Caffe2中与Blob相关的代码如下。本节当中我们将重点介绍其中所涉及的Blob/BlobSerializaer/BlobStats等类及相关功能函数。

core]$ ls blob
blob_gpu_test.cc           blob_serialization_gpu.cc  blob_stats.cc
blob.h                     blob_serialization.h       blob_stats.h
blob_serialization.cc      blob_serializer_base.h     blob_test.cc

Blob

以下为Blob的基本描述,可见看出它只有两个成员meta_与pointer_,分别表示指向存储对象的指针以及此指针的类型。

/**
 * @brief Blob is a general container that hosts a typed pointer.
 *
 * A Blob hosts a pointer as well as its type, and takes charge of deleting it
 * properly when the blob is deallocated or re-allocated with a new type. A blob
 * could contain anything, although the most common case is to contain a Tensor.
 */
class CAFFE2_API Blob final {
 public:
  using DestroyCall = void(void*);

  /**
   * Initializes an empty Blob.
   */
  Blob() : meta_(), pointer_(nullptr) {}
  ~Blob() { Reset(); }

  Blob(Blob&& other) noexcept
      : meta_(std::move(other.meta_)),
        pointer_(std::move(other.pointer_)),
        destroy_(std::move(other.destroy_)) {
    other.meta_ = {};
    other.pointer_ = nullptr;
    other.destroy_ = nullptr;
  }
  ........
  ........
};

通过下面两个成员函数,我们可以检查Blob所包含的对象的类型及是否是某种Device tensor类型等。

  /**
   * Checks if the content stored in the blob is of type T.
   */
  template <class T>
  bool IsType() const {
    return meta_.Match<T>();
  }

  bool IsTensorType(DeviceType device_type) const {
    bool is_match = meta_.Match<Tensor>();
    auto* tensor = static_cast<Tensor*>(pointer_);
    if (is_match && tensor && tensor->GetDeviceType() == device_type) {
      return true;
    }
    return false;
  }

以下为两种得到有类型指针与裸对象指针的办法。

  /**
   * @brief Gets the const reference of the stored object. The code checks if
   * the stored object is of the desired type.
   */
  // TODO(jerryzh): add a Get(DeviceType) function?
  template <class T>
  const T& Get() const {
    CAFFE_ENFORCE(
        IsType<T>(),
        "wrong type for the Blob instance. Blob contains ",
        meta_.name(),
        " while caller expects ",
        TypeMeta::TypeName<T>());
    // TODO: after we add Get<Tensor>(DeviceType)
    // and changed all the callsites, we can add
    // a static assert here to enforce T != Tensor
    return *static_cast<const T*>(pointer_);
  }

  const void* GetRaw() const {
    return pointer_;
  }

若想要对其存储对象进行写操作,则需要调用mutable_data方法,如下所示。

  /**
   * @brief Gets a mutable pointer to the stored object.
   *
   * If the current object is not of the right type, a new object is created
   * and the old object is freed. Note that type T should have a default
   * constructor. Otherwise, create the object yourself first, and use
   * Reset().
   */
  template <class T>
  T* GetMutable() {
    static_assert(
        std::is_default_constructible<T>::value,
        "GetMutable can't be called with non-default-constructible types. "
        "Try using specialized methods");
    static_assert(
        !std::is_same<T, Tensor>::value,
        "Use GetMutableTensor(DeviceType) instead");
    if (IsType<T>()) {
      return static_cast<T*>(pointer_);
    } else {
      VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
      return Reset<T>(new T());
    }
  }

  inline Tensor* GetMutableTensor(DeviceType device_type) {
    if (IsTensorType(device_type)) {
      return static_cast<Tensor*>(pointer_);
    } else {
      VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
              << " DeviceType:" << device_type;
      return Reset<Tensor>(new Tensor(device_type));
    }
  }

Reset 成员函数将使Blob得到此传入对象的ownership。在此前总要先释放指之前拥有的对象的ownership。

  /**
   * Sets the underlying object to the allocated one. The Blob then takes over
   * the ownership of the passed in pointer. If there is already an object in
   * the Blob, the old object is freed.
   *
   * This is used when the underlying class T does not have a default ctor, or
   * complex initializations needs to be done outside the blob.
   */
  template <class T>
  T* Reset(T* allocated) {
    if (pointer_ && destroy_) {
      destroy_(pointer_);
    }
    meta_ = TypeMeta::Make<T>();
    pointer_ = static_cast<void*>(allocated);
    destroy_ = &Destroy<T>;
    return allocated;
  }

  inline void*
  Reset(void* allocated, const TypeMeta& meta, DestroyCall* destroy) {
    if (pointer_ && destroy_) {
      destroy_(pointer_);
    }
    meta_ = meta;
    pointer_ = static_cast<void*>(allocated);
    destroy_ = destroy;
    return allocated;
  }

  /**
   * Resets the Blob to an empty one.
   */
  inline void Reset() {
    if (pointer_ && destroy_) {
      destroy_(pointer_);
    }
    pointer_ = nullptr;
    meta_ = TypeMeta();
    destroy_ = nullptr;
  }

ShareExternal与Reset相反,它只享用此传入对象,但并不负责释放它即并不需对它付责任。

  /**
   * Sets the underlying object to the allocated one, but does not take over
   * the ownership of the passed in pointer. If there is already an object in
   * the Blob, the old object is freed.
   *
   * Unlike Reset, this does not take over the ownership of the pointer and the
   * caller is responsible for making sure that the lifetime of the allocated
   * blob outlasts the lifetime of any access to this blob, until another Reset
   * call is made or the blob is destructed.
   */
  template <class T>
  typename std::remove_const<T>::type* ShareExternal(
      typename std::remove_const<T>::type* allocated) {
    return static_cast<T*>(ShareExternal(
        static_cast<void*>(allocated),
        TypeMeta::Make<typename std::remove_const<T>::type>()));
  }

  void* ShareExternal(void* allocated, const TypeMeta& meta) {
    if (pointer_ && destroy_) {
      destroy_(pointer_);
    }
    meta_ = meta;
    pointer_ = static_cast<void*>(allocated);
    destroy_ = nullptr;
    return allocated;
  }

Blob承担了部分Serialize的功能,可见所有的Weights需要放入到Blob里面正是需要仰仗它的这一功能来进行checkpoints存取。

  /**
   * Serializes the current blob, if possible. Note that this serialization uses
   * the registration mechanism and one has to implement specific serialization
   * approaches for specific classes. Acceptor should take care of writing data
   * to the actual storage.
   */
  void Serialize(
      const string& name,
      BlobSerializerBase::SerializationAcceptor acceptor,
      int chunk_size = kDefaultChunkSize) const;

  /**
   * @brief Convenience function to serialize a blob to a string.
   *
   * This is a conveinence function to serialize small Blobs that produce
   * manageable serialized strings. To serialize big blobs such as
   * large sparse tensors, use the fully-functional interface in
   * blob_serializer_base.h.
   *
   * NOTE: this function doesn't do chunking and might break with big tensors.
   */
  string Serialize(const string& name) const;

  /**
   * Deserializes from a string containing either BlobProto or TensorProto. If
   * the deserialization fails, the content in the blob should no longer be
   * trusted.
   */
  void Deserialize(const string& content);
  void Deserialize(const BlobProto& proto);

最后则为Blob私有空间的一些成员与公共函数。Destroy是一个static 模板成员函数,用在这里是再合适不过了。

 private:
  /**
   * @brief A destroy call that is used to properly deconstruct objects.
   */
  template <class T>
  static void Destroy(void* pointer) {
    delete static_cast<T*>(pointer);
  }
  TypeMeta meta_;
  void* pointer_ = nullptr;
  DestroyCall* destroy_ = nullptr;

  AT_DISABLE_COPY_AND_ASSIGN(Blob);
};

BlobSerializerBase和BlobDeserializerBase

下面为BlobSerializerBase的概况,它是一个实现Blob serialization功能的虚基类。不同类型的Blob需要分别继承它来实现自己的Serialization操作。

/**
 * @brief BlobSerializerBase is an abstract class that serializes a blob to a
 * string.
 *
 * This class exists purely for the purpose of registering type-specific
 * serialization code. If you need to serialize a specific type, you should
 * write your own Serializer class, and then register it using
 * REGISTER_BLOB_SERIALIZER. For a detailed example, see TensorSerializer for
 * details.
 */
class BlobSerializerBase {
 public:
  virtual ~BlobSerializerBase() {}
  using SerializationAcceptor =
     std::function<void(const std::string& blobName, const std::string& data)>;

下面为它的两个主要的Serialization功能函数。

   * @brief The virtual function that returns a serialized string for the input
   * blob.
   * @param blob
   *     the input blob to be serialized.
   * @param name
   *     the blob name to be used in the serialization implementation. It is up
   *     to the implementation whether this name field is going to be used or
   *     not.
   * @param acceptor
   *     a lambda which accepts key value pairs to save them to storage.
   *     serailizer can use it to save blob in several chunks
   *     acceptor should be thread-safe
   */
  virtual void Serialize(const Blob& blob, const std::string& name,
                        SerializationAcceptor acceptor) = 0;

  virtual void SerializeWithChunkSize(
      const Blob& blob,
      const std::string& name,
      SerializationAcceptor acceptor,
      int /*chunk_size*/) {
    // Base implementation.
    Serialize(blob, name, acceptor);
  }
};

我们需要对每个类型Blob生成其特定的BlobSerializer子类。

// The Blob serialization registry and serializer creator functions.
CAFFE_DECLARE_TYPED_REGISTRY(
    BlobSerializerRegistry,
    TypeIdentifier,
    BlobSerializerBase,
    std::unique_ptr);
#define REGISTER_BLOB_SERIALIZER(id, ...) \
  CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__)
// Creates an operator with the given operator definition.
inline unique_ptr<BlobSerializerBase> CreateSerializer(TypeIdentifier id) {
  return BlobSerializerRegistry()->Create(id);
}

相对应的有个Deserializer虚基类提供了Deserialization需要的函数接口。

/**
 * @brief BlobDeserializerBase is an abstract class that deserializes a blob
 * from a BlobProto or a TensorProto.
 */
class CAFFE2_API BlobDeserializerBase {
 public:
  virtual ~BlobDeserializerBase() {}

  // Deserializes from a BlobProto object.
  virtual void Deserialize(const BlobProto& proto, Blob* blob) = 0;
};

CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase);
#define REGISTER_BLOB_DESERIALIZER(name, ...) \
  CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__)
// Creates an operator with the given operator definition.
inline unique_ptr<BlobDeserializerBase> CreateDeserializer(const string& type) {
  return BlobDeserializerRegistry()->Create(type);
}

TensorSerializer和TensorDeserializer

TensorSerializer为BlobSerializerBase的一个子类,顾名思义,它主要用来实现Tensor类型的Serialization操作。同样还有一个为TensorDeserializer,它是BlobDeserializerBase的子类。

下面为在进行Serialization时的细节实现。可见主要是将需要的类型数据存到Protocol buffer里面,然后再使用它的功能来进行serialization/deserialization。

namespace detail {
template <typename SrcType, typename DstType>
inline void CopyToProtoAsIs(
    const size_t size,
    const SrcType* src,
    google::protobuf::RepeatedField<DstType>* field,
    BaseContext* context) {
  static_assert(
      sizeof(SrcType) == sizeof(DstType),
      "The source type and dest type cannot be copied as-is. Did "
      "you mean CopyToProtoWithCast?");
  field->Reserve(size);
  for (int i = 0; i < size; ++i) {
    field->Add(0);
  }
  context->template CopyToCPU<SrcType>(
      size, src, reinterpret_cast<SrcType*>(field->mutable_data()));
  // Make sure that we finish the copy into the protobuf.
  context->FinishDeviceComputation();
}

template <typename SrcType, typename DstType>
inline void CopyToProtoWithCast(
    const size_t size,
    const SrcType* src,
    google::protobuf::RepeatedField<DstType>* field,
    BaseContext* context) {
  // TODO: we are having one unnecessary copy here if the context is already
  // CPUContext. Remove it if it is performance critical.
  unique_ptr<SrcType[]> buffer(new SrcType[size]);
  context->template CopyToCPU<SrcType>(size, src, buffer.get());
  context->FinishDeviceComputation();
  field->Reserve(size);
  for (int i = 0; i < size; ++i) {
    field->Add(static_cast<DstType>(buffer[i]));
  }
}

以下为Blob里面的两个Serialize函数实现,可以看出它主要是借助不同类型的BlobSerializer来完成此功能。Deserialize函数的实现与此类似,在此不再赘述。

// The blob serialization member function implementation.
void Blob::Serialize(
    const string& name,
    BlobSerializerBase::SerializationAcceptor acceptor,
    int chunk_size) const {
  std::unique_ptr<BlobSerializerBase> serializer(CreateSerializer(meta_.id()));
  CAFFE_ENFORCE(serializer, "No known serializer for ", meta_.name());
  serializer->SerializeWithChunkSize(*this, name, acceptor, chunk_size);
}

// The blob serialization member function implementation.
std::string Blob::Serialize(const string& name) const {
  std::string data;
  BlobSerializerBase::SerializationAcceptor acceptor = [&data](
      const std::string&, const std::string& blob) {
    DCHECK(data.empty()); // should be called once with kNoChunking
    data = blob;
  };
  this->Serialize(name, acceptor, kNoChunking);
  return data;
}

BlobStatGetter

Blob里面提供了些辅助类来提供些统计等功能如BlobStatGetter类。由下面代码可以看出,它亦是通过Type Id来选择使用不同的子类StatGetter的。

struct BlobStatGetter {
  virtual size_t sizeBytes(const Blob& blob) const = 0;
  virtual ~BlobStatGetter() {}
};

struct BlobStatRegistry {
 private:
  std::unordered_map<TypeIdentifier, std::unique_ptr<BlobStatGetter>> map_;
  void doRegister(TypeIdentifier id, std::unique_ptr<BlobStatGetter>&& v);

 public:
  template <typename T, typename Getter>
  struct Registrar {
    Registrar() {
      BlobStatRegistry::instance().doRegister(
          TypeMeta::Id<T>(), std::unique_ptr<Getter>(new Getter));
    }
  };

  const BlobStatGetter* get(TypeIdentifier id);
  static BlobStatRegistry& instance();
};

以下为具体的对其使用。

const BlobStatGetter* BlobStatRegistry::get(TypeIdentifier id) {
  auto it = map_.find(id);
  if (it == map_.end()) {
    return nullptr;
  }
  return it->second.get();
}

BlobStatRegistry& BlobStatRegistry::instance() {
  static BlobStatRegistry registry;
  return registry;
}

void BlobStatRegistry::doRegister(
    TypeIdentifier id,
    std::unique_ptr<BlobStatGetter>&& v) {
  // don't use CAFFE_ENFORCE_EQ to avoid static initialization order fiasco.
  if (map_.count(id) > 0) {
    throw std::runtime_error("BlobStatRegistry: Type already registered.");
  }
  map_[id] = std::move(v);
}

namespace BlobStat {

size_t sizeBytes(const Blob& blob) {
  auto* p = BlobStatRegistry::instance().get(blob.meta().id());
  return p ? p->sizeBytes(blob) : 0;
}

} // namespace BlobStats

参考文献

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 156,757评论 4 359
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 66,478评论 1 289
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 106,540评论 0 237
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,593评论 0 203
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 51,903评论 3 285
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,329评论 1 210
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,659评论 2 309
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,383评论 0 195
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,055评论 1 238
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,337评论 2 241
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,864评论 1 256
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,227评论 2 251
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,820评论 3 231
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 25,999评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,750评论 0 192
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,365评论 2 269
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,260评论 2 258

推荐阅读更多精彩内容