Golang 学习笔记(12)—— ORM实现

本文为转载,原文:Golang 学习笔记(12)—— ORM实现

Golang

介绍

本文将利用之前所学习到的内容实现一个简单的orm,实现比较简单,没有考虑过多的设计原则,以及性能安全之类的,只是单纯的以学习为导向,做的一个练手的小工具。其中有不合理的地方还请看到的同学见谅,并指出,本人也好加以改正。

实现

首先看下完整的代码目录吧:


代码目录

数据库

数据库,这里我选择mysql数据库,可以用以下sql语句创建一个测试的表:

CREATE TABLE `userinfo` (
    `uid` INT(10) NOT NULL AUTO_INCREMENT,
    `username` VARCHAR(64) NULL DEFAULT NULL,
    `departname` VARCHAR(64) NULL DEFAULT NULL,
    `created` DATE NULL DEFAULT NULL,
    PRIMARY KEY (`uid`)
)

go

实体

建完表之后,我们在go中需要穿件一个与之对应的struct

type UserInfo struct{
    TableName orm.TableName "userinfo"
    UserName string `name:"username"`
    Uid int `name:"uid"PK:"true"auto:"true"`
    DepartName string `name:"departname"`
    Created string `name:"created"`
}

从struct中可以看到,我们的字段分为两部分,第一个就是TableName,该字段没有具体内容,只是使用后面的tag标记其对应的数据库表名,剩余的则是与数据库表字段一一对应的字段了。
每个字段的tag则说明了其在数据库的属性。name表示在数据库中的字段名,不写则与字段名一致,PK表示是否为主键,若为主键则标记为"ture",默认为false,auto表示是否为自增长,若为自增长则标记为"true",默认为false,由于我的设计比较简单,所以,先就标记这几种属性。

反射对象

既然有了实体对象,自然我们要通过反射去解析该对象。该代码写在orm/orm.go代码文件中。

package orm

import (
    "fmt"
    "database/sql"
    "errors"
    "strings"
    "reflect"
)

/*
  表信息
*/
type TableInfo struct{
    Name string  //表名
    Fields []FieldInfo //表字段信息
    TMMap map[string]string //表字段与实体字段名映射关系,key为表字段名,val为实体字段名
}

/*
  表字段详细信息
*/
type FieldInfo struct{
    Name string
    IsPrimaryKey bool
    IsAutoGenerate bool
    Valve reflect.Value
}

/*
  实体对象信息
*/
type ModelInfo struct{
    TableInfo  // 实体对应的表信息
    TbName string // 表名称
    Model interface{} //实体实例
}

//表名
type TableName string
//表名类型
var typeTableName TableName
var tableNameType reflect.Type = reflect.TypeOf(typeTableName)

//实体映射,key为表名,val为实体信息
var ModelMapping map[string]ModelInfo

/*
  注册实体,每当有一个实体时,需要调用该方法注册。
  注册到 ModelMapping
*/
func Register(model interface{}){
    if ModelMapping == nil{
        ModelMapping = make(map[string]ModelInfo)
    }
    tbInfo, _ := getTableInfo(model)
    ModelMapping[tbInfo.Name] = ModelInfo{TbName:tbInfo.Name, Model:model}
}

/*
  根据实体通过反射获取表信息
  返回表信息
*/
func getTableInfo(model interface{})(tabInfo *TableInfo, err error){
    defer func(){
        if e := recover(); err != nil{
            tabInfo = nil
            err = e.(error)
        }
    }()

    err = nil
    tabInfo = &TableInfo{}
    tabInfo.TMMap = make(map[string]string)
    rt := reflect.TypeOf(model)
    rv := reflect.ValueOf(model)

    tabInfo.Name = rt.Name()
    if rt.Kind() == reflect.Ptr{
        rt = rt.Elem()
        rv = rv.Elem()
    }
    //字段解析
    for i, j := 0, rt.NumField(); i < j; i++{
        rtf := rt.Field(i)
        rvf := rv.Field(i)
        if rtf.Type == tableNameType{
            tabInfo.Name = string(rtf.Tag)
            continue
        }
        if rtf.Tag == "-"{
            continue
        }
        //解析字段的tag
        var f FieldInfo
        //没有tag,表字段名与实体字段ing一致
        if rtf.Tag == ""{
            f = FieldInfo{Name:rtf.Name, IsAutoGenerate:false, IsPrimaryKey:false, Valve:rvf}
            tabInfo.TMMap[rtf.Name] = rtf.Name
        }else{
            strTag := string(rtf.Tag)
            if strings.Index(strTag, ":") == -1{
                //tag中没有":"时,表字段名与实体字段ing一致
                f = FieldInfo{Name:rtf.Name, IsAutoGenerate:false, IsPrimaryKey:false, Valve:rvf}
                tabInfo.TMMap[rtf.Name] = rtf.Name
            }else{
                //解析tag中的name值为表字段名
                strName := rtf.Tag.Get("name")
                if strName == ""{
                    strName = rtf.Name
                }
                //解析tag中的PK
                isPk := false
                strIspk := rtf.Tag.Get("PK")
                if strIspk == "true"{
                    isPk = true
                }
                //解析tag中的auto
                isAuto := false
                strIsauto := rtf.Tag.Get("auto")
                if strIsauto == "true"{
                    isAuto = true
                }
                f = FieldInfo{Name:strName, IsPrimaryKey:isPk, IsAutoGenerate:isAuto, Valve:rvf}
                tabInfo.TMMap[strName] = rtf.Name
            }
        }
        tabInfo.Fields = append(tabInfo.Fields, f)
    }
    return
}

/*
  根据实体生成插入语句
*/
func generateInsertSql(model interface{})(string, []interface{}, *TableInfo, error){
    //获取表信息
    tbInfo, err := getTableInfo(model)
    if err != nil{
        return "", nil, nil, err
    }
    if len(tbInfo.Fields) == 0 {
        return "", nil, nil, errors.New(tbInfo.Name + "结构体中没有字段")
    }

    //根据字段信息拼Sql语句,以及参数值
    strSql := "insert into " + tbInfo.Name
    strFileds := ""
    strValues := ""
    var params []interface{}
    for _, v := range tbInfo.Fields{
        if v.IsAutoGenerate {
            continue
        }
        strFileds += v.Name + ","
        strValues += "?,"
        params = append(params, v.Valve.Interface())
    }
    if strFileds == ""{
        return "", nil, nil, errors.New(tbInfo.Name + "结构体中没有字段,或只有自增字段")
    }
    strFileds = strings.TrimRight(strFileds, ",")
    strValues = strings.TrimRight(strValues, ",")
    strSql += " (" + strFileds + ") values(" + strValues + ")"
    fmt.Println("sql: ",strSql)
    fmt.Println("params: ",params)
    return strSql, params, tbInfo, nil
}

/*
  根据实体生成修改的sql语句
*/
func generateUpdateSql(model interface{})(string, []interface{}, error){
    //获取表信息
    tbInfo, err := getTableInfo(model)
    if err != nil{
        return "", nil, err
    }
    if len(tbInfo.Fields) == 0 {
        return "", nil, errors.New(tbInfo.Name + "结构体中没有字段")
    }
    //根据字段信息拼Sql语句,以及参数值
    strSql := "update " + tbInfo.Name + " set "
    strFileds := ""
    strWhere := ""
    var p interface{}
    var params []interface{}
    for _, v := range tbInfo.Fields{
        if v.IsAutoGenerate && !v.IsPrimaryKey{
            continue
        }
        if v.IsPrimaryKey{
            strWhere += v.Name + "=?"
            p = v.Valve.Interface()
            continue
        }
        strFileds += v.Name + "=?,"
        params = append(params, v.Valve.Interface())
    }
    params = append(params, p)
    strFileds = strings.TrimRight(strFileds, ",")
    strSql += strFileds + " where " + strWhere
    fmt.Println("update sql: ", strSql)
    fmt.Println("update params: ", params)
    return strSql, params, nil
}

/*
  自动生成删除的sql语句,以主键为删除条件
*/
func generateDeleteSql(model interface{})(string, []interface{}, error){
    //获取表信息
    tbInfo, err := getTableInfo(model)
    if err != nil{
        return "", nil, err
    }
    //根据字段信息拼Sql语句,以及参数值
    strSql := "delete from " + tbInfo.Name + " where "
    var idVal interface{}
    for _, v := range tbInfo.Fields{
        if v.IsPrimaryKey{
            strSql += v.Name + "=?"
            idVal = v.Valve.Interface()
        }
    }
    params := []interface{}{idVal}
    fmt.Println("update sql: ", strSql)
    fmt.Println("update params: ", params)
    return strSql, params, nil
}

/*
  设置自增长字段的值
*/
func setAuto(result sql.Result, tbInfo *TableInfo)(err error){
    defer func(){
        if e := recover(); e != nil{
            err = e.(error)
        }
    }()
    id, err := result.LastInsertId()
    if id == 0{
        return
    }
    if err != nil{
        return
    }
    for _, v := range tbInfo.Fields{
        if v.IsAutoGenerate && v.Valve.CanSet(){
            v.Valve.SetInt(id)
            break
        }
    }
    return
}

这里面,我们实现了通过model实体,来生成新增,修改,删除的sql语句,以及参数。但是查询的怎么办呢?

MyRows

在orm/MyRows.go代码文件中,实现了一个自己的Rows,来处理查询:

package orm

import (
    "strconv"
    "reflect"
    "database/sql"
)

type MyRows struct{
    * sql.Rows
    Values map[string]interface{} //表字段和值的映射
    ColumnNames []string //表字段名集合
}

/*
  获取数据
*/
func (this *MyRows)Next()bool{
    bResult := this.Rows.Next()
    if bResult{
        //获取表字段名称集合
        if this.ColumnNames == nil || len(this.ColumnNames) == 0{
            this.ColumnNames, _ = this.Rows.Columns()
        }
        //初始化表字段和值的映射
        if this.Values == nil{
            this.Values = make(map[string]interface{})
        }
        //调用scan函数的参数
        scanArgs := make([]interface{}, len(this.ColumnNames))
        //scan函数的值
        values := make([][]byte, len(this.ColumnNames))
        for i := range values{
            scanArgs[i] = &values[i]
        }
        this.Rows.Scan(scanArgs...)
        //将结果存放到Values中
        for i := 0; i < len(this.ColumnNames); i++{
            this.Values[this.ColumnNames[i]] = values[i]
        }
    }
    return bResult
}

/*
  将数据映射到实体切片
  tbname:U对应的数据表名
*/
func (this *MyRows)To(tbname string) ([]interface{},error){
    mi := ModelMapping[tbname]
    ti, _ := getTableInfo(mi.Model)
    var models []interface{}
    for this.Next(){
            v := reflect.New(reflect.TypeOf(mi.Model).Elem()).Elem()
            for k, val := range this.Values{
                f := v.FieldByName(ti.TMMap[k])
                var strVal string
                if bt, ok := val.([]byte); ok{
                    strVal = string(bt)
                    switch f.Type().Name(){
                    case "int":
                        i, _ := strconv.ParseInt(strVal, 10, 64)
                        f.SetInt(i)
                        break
                    case "string":
                        f.SetString(strVal)
                        break
                    }
                }
            }
            models = append(models, v.Interface())
    }
    return models, nil
}

该代码实现了查询结果到实体的映射

MysqlDB

orm/MysqlDB.go是实现数据库方法的代码文件。

package orm

import (
    "strconv"
    "reflect"
    "database/sql"
)

type MyRows struct{
    * sql.Rows
    Values map[string]interface{} //表字段和值的映射
    ColumnNames []string //表字段名集合
}

/*
  获取数据
*/
func (this *MyRows)Next()bool{
    bResult := this.Rows.Next()
    if bResult{
        //获取表字段名称集合
        if this.ColumnNames == nil || len(this.ColumnNames) == 0{
            this.ColumnNames, _ = this.Rows.Columns()
        }
        //初始化表字段和值的映射
        if this.Values == nil{
            this.Values = make(map[string]interface{})
        }
        //调用scan函数的参数
        scanArgs := make([]interface{}, len(this.ColumnNames))
        //scan函数的值
        values := make([][]byte, len(this.ColumnNames))
        for i := range values{
            scanArgs[i] = &values[i]
        }
        this.Rows.Scan(scanArgs...)
        //将结果存放到Values中
        for i := 0; i < len(this.ColumnNames); i++{
            this.Values[this.ColumnNames[i]] = values[i]
        }
    }
    return bResult
}

/*
  将数据映射到实体切片
  tbname:U对应的数据表名
*/
func (this *MyRows)To(tbname string) ([]interface{},error){
    mi := ModelMapping[tbname]
    ti, _ := getTableInfo(mi.Model)
    var models []interface{}
    for this.Next(){
            v := reflect.New(reflect.TypeOf(mi.Model).Elem()).Elem()
            for k, val := range this.Values{
                f := v.FieldByName(ti.TMMap[k])
                var strVal string
                if bt, ok := val.([]byte); ok{
                    strVal = string(bt)
                    switch f.Type().Name(){
                    case "int":
                        i, _ := strconv.ParseInt(strVal, 10, 64)
                        f.SetInt(i)
                        break
                    case "string":
                        f.SetString(strVal)
                        break
                    }
                }
            }
            models = append(models, v.Interface())
    }
    return models, nil
}

详情请见注释

调用

package main

import (
    "time"
    "fmt"
    "stu_demo/orm"
    _ "github.com/go-sql-driver/mysql"
)

type UserInfo struct{
    TableName orm.TableName "userinfo"
    UserName string `name:"username"`
    Uid int `name:"uid"PK:"true"auto:"true"`
    DepartName string `name:"departname"`
    Created string `name:"created"`
}

func main(){
    ui := UserInfo{UserName:"CHAIN", DepartName:"TEST", Created:time.Now().String()}
    orm.Register(new(UserInfo))
    db, err := orm.NewDb("mysql", "root:pwd@tcp(xxx.xxx.xxx.xxx:x3306/demo?charset=utf8")
    if err != nil {
        fmt.Println("打开SQL时出错:", err.Error())
        return
    }
    defer db.Close()
    
    //插入测试
    err = db.Insert(&ui)
    if err != nil {
        fmt.Println("插入时错误:", err.Error())
    }
    fmt.Println("插入成功")
    //修改测试
    ui.UserName = "BBBB"
    err = db.Update(ui)
    if err != nil {
        fmt.Println("修改时错误:", err.Error())
    }
    fmt.Println("修改成功")
    //删除测试
    err = db.Delete(ui)
    if err != nil {
        fmt.Println("删除时错误:", err.Error())
    }
    fmt.Println("删除成功")
    //查询测试
    res, err := db.From("userinfo").
    Select("username", "departname", "uid").
    Where("uid__gt", 20).
    Where("username", "chain").Get()
    if err != nil{
        fmt.Println("err: ", err.Error())
    }
    fmt.Println(res)
}
运行结果

源码

github 源码地址

转载请注明出处:
Golang 学习笔记(12)—— ORM实现

目录
上一节:Golang 学习笔记(11)—— 反射
下一节:Golang Web学习(13)—— 搭建简单的Web服务器