golang 创建DB连接池

字数 21阅读 1821

使用通道

import(
    "database/sql"
    _"github.com/go-sql-driver/mysql"
    "log"
    "time"
    "math/rand"
)
// 连接池大小
var MAX_POOL_SIZE = 20
var dbPoll chan *sql.DB

const (
    user="root"
    pass="root"
    db="school"

)
func putDB(db *sql.DB) {
    // 基于函数和接口间互不信任原则,这里再判断一次,养成这个好习惯哦
    if dbPoll == nil {
        dbPoll = make(chan *sql.DB, MAX_POOL_SIZE)
    }
    if len(dbPoll) >= MAX_POOL_SIZE {
        db.Close()
        return
    }
    dbPoll <- db
}
func initDB() {
    // 缓冲机制,相当于消息队列
    if len(dbPoll) == 0 {
        // 如果长度为0,就定义一个redis.Conn类型长度为MAX_POOL_SIZE的channel
        dbPoll = make(chan *sql.DB, MAX_POOL_SIZE)
        go func() {
            for i := 0; i < MAX_POOL_SIZE/2; i++ {
                db,err:=sql.Open("mysql",user+":"+pass+"@tcp(localhost:3306)/"+db+"?charset=utf8")
                if err!=nil {
                    log.Println(err)
                }
                putDB(db)
            }
        } ()
    }
}
func GetDB()  *sql.DB {
    //如果为空就初始化或者长度为零
    if dbPoll == nil||len(dbPoll) == 0{
        initDB()
    }
    return <- dbPoll
}

自己写的连接池

db.go

package db

import (
    "time"
    "github.com/Sirupsen/logrus"
    "github.com/jinzhu/gorm"
    _"github.com/go-sql-driver/mysql"
    "reflect"
    "fmt"
    "errors"
    "strconv"
    "youguoju/conf/config"
)

type dB interface {
    Gorm()*gorm.DB
    Id()uint32//ID的获取方法
}

type myDB struct {
    db *gorm.DB
    id uint32//ID
}

type DBPoll interface {
    Take()(dB,error)//取出实体
    Return(entity dB)(error)//归还实体
    Total()uint32//实体的容量
    Used()uint32//实体中已经被使用的实体数量
}
type myDBPoll struct {
    pool  DBPoll //实体池
    etype reflect.Type    //池内实体的类型
}
//生成DB的函数类型
type genDB func() dB
func newDBPoll(total uint32,gen genDB)(DBPoll,error)  {
    etype:=reflect.TypeOf(gen())
    genEntity:= func() dB{return gen()}
    pool,err:= NewPool(total,etype,genEntity)
    if err!=nil {
        return nil,err
    }
    dbpool:=&myDBPoll{pool,etype}
    return dbpool,nil
}

func (db *myDB)Id()uint32  {
    return db.id
}
func (db *myDB)Gorm()*gorm.DB {
    return db.db
}
//取出实体
func (pool *myDBPoll)Take()(dB,error){
    entity,err:=pool.pool.Take()
    if err!=nil {
        return nil,err
    }
    dl,ok:=entity.(dB)//强制类型转换
    if !ok {
        errMsg:=fmt.Sprintf("The type of entity id NOT %s\n",pool.etype)
        panic(errors.New(errMsg))
    }
    return dl,nil
}
//归还实体
func (pool *myDBPoll)Return(entity dB)(error){
    return pool.pool.Return(entity)
}
//实体的容量
func (pool *myDBPoll)Total()uint32{
    return pool.pool.Total()
}
//实体中已经被使用的实体数量
func (pool *myDBPoll)Used()uint32{
    return pool.pool.Used()
}

var dbPoll DBPoll

func InitDB() {

    total := config.Conf.Total
    to,_:=strconv.Atoi(total)
    dbPoll,_=newDBPoll(uint32(to),initDb)
}
//func GetDBPollInstance() DBPoll {
//  return dbPoll
//}
func GetDBInstance() (dB,error) {
    db,err:=dbPoll.Take()
    if err!=nil {
        return nil, err
    }
    return db,nil
}
func ReturnDB(db dB) error {
    return dbPoll.Return(db)
}
func initDb()  dB{
    var db *gorm.DB
    var err error
    path := config.Conf.DBURL               //从env获取数据库连接地址
    logrus.Info("path:", string(path)) //打印数据库连接地址
    for {
        db, err = gorm.Open("mysql", string(path)) //使用gorm连接数据库
        if err != nil {
            logrus.Error(err, "Retry in 2 seconds!")
            time.Sleep(time.Second * 2)
            continue
        }
        logrus.Info("DB connect successful!")
        break
    }
    return &myDB{db:db,id:idGenertor.GetUint32()}
}
var idGenertor IdGenertor = NewIdGenertor()

pool.go

package db


import (
"reflect"
"fmt"
"errors"
"sync"
)

type Pool interface {
    Take()(dB,error)//取出实体
    Return(entity dB)(error)//归还实体
    Total()uint32//实体的容量
    Used()uint32//实体中已经被使用的实体数量
}
////实体的接口类型
//type Entity1 interface {
//  Id()uint32//ID的获取方法
//}

//实体池的实现类型
type myPool struct {
    total uint32//池的总容量
    etype reflect.Type//池中实体的类型
    genEntity func()dB//池中实体的生成函数
    container chan dB//实体容器
    //实体Id的容器
    idContainer map[uint32]bool
    mutex sync.Mutex
}

func NewPool(total uint32,entityType reflect.Type,genEntity func()dB)(Pool,error)  {
    if total==0 {
        errMsg:=fmt.Sprintf("The pool can not be initialized! (total=%d)\n",total)
        return nil,errors.New(errMsg)
    }
    size:=int(total)
    container:=make(chan dB,size)
    idContainer:=make(map[uint32]bool)
    for i:=0;i<size ; i++ {
        newEntity:=genEntity()
        if entityType!=reflect.TypeOf(newEntity) {
            errMsg:=fmt.Sprintf("The type of result of function gen Entity()is Not %s\n",entityType)
            return nil,errors.New(errMsg)
        }
        container<-newEntity
        idContainer[newEntity.Id()]=true
    }
    pool:=&myPool{total,entityType,genEntity,container,idContainer,*new(sync.Mutex)}
    return pool,nil
}
//取出实体
func (pool *myPool)Take()(dB,error){
    entity,ok:=<-pool.container
    if !ok {
        return nil,errors.New("The innercontainer is invalid")
    }
    pool.mutex.Lock()
    defer pool.mutex.Unlock()
    pool.idContainer[entity.Id()]=false
    return  entity,nil
}
//归还实体
func (pool *myPool)Return(entity dB)(error){
    if entity==nil {
        return errors.New("The returning entity is invalid")
    }
    if pool.etype!=reflect.TypeOf(entity) {
        errMsg:=fmt.Sprintf("The type of result of function gen Entity()is Not %s\n",pool.etype)
        return errors.New(errMsg)
    }
    entityId:=entity.Id()
    caseResult:=pool.compareAndSetForIdContainer(entityId,false,true)
    if caseResult==-1 {
        errMsg:=fmt.Sprintf("The entity(id=%d) is illegal!\n",entity.Id())
        return errors.New(errMsg)
    }
    if caseResult==0 {
        errMsg:=fmt.Sprintf("The entity(id=%d) is already in the pool!\n",entity.Id())
        return errors.New(errMsg)
    }else {
        pool.idContainer[entityId]=true
        pool.container<-entity
        return nil
    }
}
//比较并设置实体ID容器中与给定实体ID对应的键值对的元素值
//结果值;1操作成功
//      0.对应的id在容器中已经存在
//      -1.对应的id在容器中不存在
//
func (pool *myPool) compareAndSetForIdContainer(entityId uint32,oldValue,newValue bool)int8  {
    pool.mutex.Lock()
    defer pool.mutex.Unlock()
    v,ok:=pool.idContainer[entityId]
    if !ok {
        return -1
    }
    if v!=oldValue {
        return 0
    }
    pool.idContainer[entityId]=newValue
    return 1
}
//实体的容量
func (pool *myPool)Total()uint32{
    return uint32(cap(pool.container))
}
//实体中已经被使用的实体数量
func (pool *myPool)Used()uint32{

    return uint32(cap(pool.container)-len(pool.container))
}

id生成器

package db

import (
    "sync"
    "math"
)

type IdGenertor interface {
    GetUint32() uint32//获取一个unit32类型的Id
}
type cyclicIdGenertor struct {
    id uint32//当前id
    ended bool//签一个id是否为其类型所能表示的做大值
    mutex sync.Mutex
}

func NewIdGenertor() IdGenertor {
    return &cyclicIdGenertor{}
}
//获取一个unit32类型的Id
func (gen *cyclicIdGenertor)GetUint32() uint32 {
    gen.mutex.Lock()
    defer gen.mutex.Unlock()
    if gen.ended {
        defer func() {gen.ended=false}()
        gen.id=uint32(1)
        return uint32(0)
    }
    id:=gen.id
    if  id<math.MaxInt32{
        gen.id++
    }else {
        gen.ended=true
    }
    return id
}
type cyclicIdGenertor2 struct {
    base cyclicIdGenertor//基本id生成器
    cycleCount uint64//基于unit32类型的取值范围的周期计数
}
//获取一个unit64类型的Id
func (gen *cyclicIdGenertor2)GetUint64() uint64{
    var id64 uint64
    if gen.cycleCount%2==1 {
        id64+=math.MaxUint32
    }
    id32:=gen.base.GetUint32()
    if id32==math.MaxInt32 {
        gen.cycleCount++
    }
    id64+=uint64(id32)
    return id64
}

使用

//利用token从数据库库中获取用户信息
func (user *User)GetUserByTokenFromDB(token string)(bool)  {
    instance,err:=db.GetDBInstance()
    defer db.ReturnDB(instance)
    if err!=nil {
        return false
    }
    gorm:=instance.Gorm()
    gorm.Where("Token = ?", token).Find(&user)
    return true
}

推荐阅读更多精彩内容