82 Commits

Author SHA1 Message Date
  lijianbin 4263f44027 优化 5 days ago
  lijianbin 55c6c7a6f9 优化 5 days ago
  lijianbin 322600d5ee 修复达梦8和pgsql中merge into的bug 5 days ago
  lijianbin f5ad88e70c 优化达梦关键字问题 4 weeks ago
  lijianbin c5dab75bff BuildSelectSql子查询兼容pgsql 1 month ago
  zhenghaorong 3ea2b59dda 调整接口 1 month ago
  lijianbin 45bcc98850 达梦--密码特殊字符需转义 1 month ago
  lijianbin a9383ba486 修正输出描述 1 month ago
  lijianbin fe779bf6ca 兼容高斯和达梦数据库 1 month ago
  lijianbin f7b0eb2df6 完善pgsql关键替换问题 2 months ago
  lijianbin 74aa5112a0 完善pgsql关键替换问题 2 months ago
  lijianbin 35782df825 pgsql关键字问题 2 months ago
  lijianbin 3b11f98fdc 修复pgsql中关键字使用问题 2 months ago
  lijianbin 60c84df410 修复pgsql中关键字使用问题 2 months ago
  loshiqi 01f3625f1e 兼容高斯 3 months ago
  loshiqi d16c3ca83c 兼容高斯 3 months ago
  loshiqi 16cff9c0ca 兼容高斯 3 months ago
  loshiqi bd754d1507 兼容高斯 3 months ago
  loshiqi 95a092325e values 3 months ago
  loshiqi 9bd66eefc9 插入和分页查询 3 months ago
  loshiqi 2a9f596ea8 兼容pgsql 3 months ago
  loshiqi ff5d48d51a 占位符 3 months ago
  loshiqi 82954508e0 驱动名称 更改 3 months ago
  loshiqi 2f467a0f92 增加pgsql链接 3 months ago
  loshiqi 14f6e5fc28 增加执行原始方法 4 months ago
  zhenghaorong c13da09c57 增加过滤 1 year ago
  zhenghaorong 7f9e42fed6 修复join无法使用with临时表问题 1 year ago
  zhenghaorong 2d94f24f43 修复join无法使用with临时表问题 1 year ago
  zhenghaorong bc664f29c5 修复join无法使用with临时表问题 1 year ago
  zhenghaorong 6d7835939e 增加with语句 1 year ago
  loshiqi 5db8b87a21 transaction bug修复 1 year ago
  guzeng 6cc83234ed Merge branch 'master' of ssh://git.tetele.net:4001/tgo/dbquery 2 years ago
  guzeng 628562e53e 增加调试 2 years ago
  zhenghaorong 396b0b8f86 增加生成表名的函数 2 years ago
  guzeng f8d6d88e9f 增加左/右连接方法 2 years ago
  guzeng 2a1e8837ad Merge branch 'master' of ssh://git.tetele.net:4001/tgo/dbquery 2 years ago
  guzeng 0765adfb79 修改构造查询,兼容mssql的page size 2 years ago
  loshiqi db904d18e2 join兼容子查询 2 years ago
  zhenghaorong 373438a283 兼容多表名的情况 2 years ago
  guzeng 16ae9b2dbe 修改查询条件,兼容不同服务商数据库 2 years ago
  guzeng 0fe53a0d4f 修改事务写入 3 years ago
  guzeng 89c478b953 修改事务写入 3 years ago
  guzeng ff798de0a6 查询增加主从判断逻辑 3 years ago
  zhenghaorong 155e124c63 增加查询表格信息 3 years ago
  zhenghaorong 3d486e6fb8 增加查询表格信息 3 years ago
  zhenghaorong 2c7fb178b3 解决append顺序错乱问题 3 years ago
  zhenghaorong 00943e75ed 事务方法增加批量更新 3 years ago
  zhenghaorong 33971bf841 修改提示语 3 years ago
  zhenghaorong 4ac2e8c30d 增加批量更新 3 years ago
  zhenghaorong 0ccf9774b7 解决append乱序 3 years ago
  zhenghaorong cd13573588 增加批量添加 3 years ago
  zhenghaorong dcf3856862 增加批量添加 3 years ago
  zhenghaorong 6a197add57 增加批量添加 3 years ago
  zhenghaorong bb5a253e43 修改子查询带条件查询 3 years ago
  zhenghaorong f90aec27a3 添加子查询 3 years ago
  zhenghaorong 4f6940cc58 添加子查询 3 years ago
  zhenghaorong 8a4277f0eb 添加分组功能 3 years ago
  zhenghaorong 44c03e0282 增加分组功能 3 years ago
  guzeng 07e9c8a562 修改GetRow查询方法 3 years ago
  guzeng dc051970d3 增加从库查询 3 years ago
  guzeng c565655527 交换Find、Get方法,交换Select、List方法 3 years ago
  guzeng 302482fa54 chain.go增加直接查询数据方法Get(),List() 3 years ago
  listen 36fdaf99a8 尝试从读 3 years ago
  guzeng 75e7036854 修改连接charset 4 years ago
  guzeng da0617d167 事务单条查询增加默认条件 4 years ago
  guzeng 13587615a1 修改find方法 4 years ago
  guzeng abb2bbf3b5 关闭sqlserver连接修改 4 years ago
  guzeng c8293a8936 修改库连接操作 4 years ago
  guzeng 87001a68d4 增加事务链式操作 4 years ago
  guzeng 74d90296a6 修改连接 4 years ago
  guzeng c1a82cf1b0 修改查询 4 years ago
  guzeng f0254c8655 修改连接方法 4 years ago
  guzeng 6df1432cc1 增加sqlserver连接方法 4 years ago
  guzeng cb4b8fe26d 修改分页计算 4 years ago
  guzeng 44f10248cc 查询方式增加wheres 4 years ago
  guzeng 53351218b4 修改查询方法 4 years ago
  guzeng 19c1efb26d 更新条件判断 4 years ago
  guzeng d46608bec8 链式操作增加批量 4 years ago
  guzeng 1c6146621a 补充readme 4 years ago
  guzeng be05d170be 更新说明 4 years ago
  guzeng e9b1b20c62 增加说明 4 years ago
  guzeng da5f7621da 增加链式操作 4 years ago
12 changed files with 3871 additions and 113 deletions
Unified View
  1. +30
    -1
      README.md
  2. +1142
    -0
      chain.go
  3. +255
    -0
      chain_test.go
  4. +132
    -0
      common.go
  5. +225
    -19
      conn.go
  6. +572
    -38
      db.go
  7. +80
    -4
      db_test.go
  8. +28
    -0
      go.mod
  9. +99
    -17
      prepare.go
  10. +70
    -0
      sqlserver.go
  11. +133
    -34
      transaction.go
  12. +1105
    -0
      transaction_chain.go

+ 30
- 1
README.md View File

@ -1,3 +1,32 @@
# dbquery # dbquery
数据库操作
数据库操作
## 链式查询使用
```
查询单条记录
map,err := new(Query).Db(dbname).Table(tablename).Where("id=?").Where("name=?").Value(1).Value("test").Find()
查询列表
list,err := new(Query).Db(dbname).Table(tablename).Where("id=?").Where("name=?").Value(1).Value("test").Select()
条件"或"
list,err := new(Query).Db(dbname).Table(tablename).Where("id=?").Where("name=?").WhereOr("mobile=?").Value(1).Value("test").Value("22").Select()
联表查
使用Join
list,err := new(Query).Db(dbname).Table(tablename).Join([]string{jointable,tablename.id=jointable.cid,"LEFT"}).Where("id=?").Where("name=?").Value(1).Value("test").Select()
更新
ret,err := new(Query).Db(dbname).Table(tablename).Data("name=?").Data("depart=?").Value("xxx").Value("test").Update()
插入
ret,err := new(Query).Db(dbname).Table(tablename).Data("name=?").Data("depart=?").Value("xxx").Value("test").Create()
删除
ret,err := new(Query).Db(dbname).Table(tablename).Where("name=?").Where("depart=?").Value("xxx").Value("test").Delete()
```

+ 1142
- 0
chain.go
File diff suppressed because it is too large
View File


+ 255
- 0
chain_test.go View File

@ -0,0 +1,255 @@
package dbquery
import (
"encoding/json"
"fmt"
"git.tetele.net/tgo/helper"
"log"
"strings"
"testing"
"time"
)
func CreateData(dbname string, data []map[string]interface{}) (int64, error) {
timestamp := time.Now().Unix()
rows := make([]map[string]interface{}, 0)
region_id_arr := []string{}
for _, v := range data {
region_id_arr = append(region_id_arr, helper.ToStr(v["regionId"]))
rows = append(rows, map[string]interface{}{
"region_id": v["regionId"],
"name": v["regionName"],
"createtime": timestamp,
"updatetime": timestamp,
"is_delete": 0,
"deletetime": 0,
})
}
where := "1=1"
if len(region_id_arr) > 0 {
where = "region_id not in('" + strings.Join(region_id_arr, "','") + "')"
}
_, err := new(Query).Db(dbname).Clean().Table("ttl_project").
Data("is_delete=?").Value(1).
Data("deletetime=?").Value(timestamp).
Where(where).
Update()
if err != nil {
log.Println("update project err", err)
}
if len(rows) > 0 {
_, err := new(Query).Db(dbname).Clean().Table("ttl_project").
SaveDatas(rows).
UpdFields([]string{"updatetime", "is_delete", "deletetime", "name"}).
MergeIntoWhereField([]string{"region_id"}).
UpdateAll()
if err != nil {
log.Println("insert to project err", err)
}
} else {
log.Println("rows is null", rows)
}
return 0, nil
}
func Test_ChainM(t *testing.T) {
err := DmConnect("192.168.233.155", "WUYE", "Bin123456", "", "5236")
if err != nil {
t.Log(err)
}
db_name := ""
var apiResp map[string]interface{}
test_json := `{"code":0,"message":"success","data":[{"regionId":"1","regionName":"项目 1"},{"regionId":"2","regionName":"项目 2"}]}`
err = json.Unmarshal([]byte(test_json), &apiResp)
if err != nil {
log.Println("------Get region queryType json Unmarshal err", err.Error())
return
}
if _, exist := apiResp["code"]; exist {
if helper.ToInt(apiResp["code"]) == 0 {
if _, exist = apiResp["data"]; exist {
data, err := helper.InterfaceToMapInterfaceArr(apiResp["data"])
if err != nil {
log.Println("数据转换失败,", err.Error())
} else {
if len(data) > 0 {
log.Println("data", data)
CreateData(db_name, data)
}
}
} else {
log.Println("请求无数据返回")
}
} else {
log.Println("请求异常message:", apiResp["message"])
}
} else {
log.Println("code is not exist", apiResp)
}
}
// 测试各数据库下各种情况
func Test_Chain(t *testing.T) {
//测试数据库连接
//err := Connect("127.0.0.1", "root", "root", "canyin", "3306")
//err := PgConnect("192.168.233.157", "bin", "Bin123456", "canyin", "5432")
//err := DmConnect("192.168.233.148", "SHOPV2", "Bin123456", "", "5236")
err := DmConnect("10.33.0.91", "ZYSG", "Zysg!#2025", "", "5236")
if err != nil {
t.Log(err)
}
db_name := ""
table_name := "ttl_user_log"
//time := time.Now().Unix()
//================查询表结构===========
ret, err := new(Query).Db(db_name).GetTableInfo(table_name)
if err != nil {
t.Log(err)
}
fmt.Println("===GetTableInfo:", ret)
//==========获取信息=================
/*query := new(Query).Db(db_name).Clean().Table("ttl_dorm_goods_reserve").Alias("a").
Join([]string{"ttl_dorm_goods_reserve_detail b", "a.id=b.reserve_id", "left"}).
Join([]string{"ttl_dorm_goods c", "c.id=b.goods_id", "left"}).
Join([]string{"ttl_dorm_room d", "d.id=a.room_id", "left"}).
Join([]string{"dorm_room_item e", "e.id=a.room_item_id", "left"}).
Where("a.user_id =?").Value(6006)
info, err := query.Groupby("a.id").Title("a.id").BuildSelectSql()*/
//info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("user").Join([]string{"ttl_user u", "u.id = user.user_id", "inner"}).Where("user.id=?").Value("3").Title("user.id,user.user_id,u.nickname").Find()
//info, err := GetDataByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, nil)
/*if err != nil {
t.Log(err)
}
fmt.Println("===Find:", info)*/
//============获取列表==================
list, err := new(Query).Db("").Table("ttl_area").
Title("`first`,id,level,mergename,name,pid,shortname").
Select()
if err != nil {
t.Log(err)
}
fmt.Println("===List:", list)
//===========添加数据============
//insert_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("1").Data("createtime=?").Value(time).Create()
//insert_res, err := InsertByStmt(db_name, table_name, []string{"user_id=?", "createtime=?"}, []interface{}{"1", time})
//insert_res, err := Insert(db_name, table_name, map[string]string{"user_id": "1", "createtime": helper.ToStr(time)})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Insert:", insert_res)
//================更新数据=====================
//update_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("2").Data("createtime=?").Value(time).Where("id=?").Value("6").Update()
//update_res, err := UpdateByStmt(db_name, table_name, []string{"createtime=?", "user_id=?"}, []string{"id=?"}, []interface{}{time, 3, 6})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Update:", update_res)
//=============事务================
/*fmt.Println("================开启事务============")
tx, err := DB.Begin()
if err != nil {
t.Log(err)
}
update_log, err := TxPreUpdate(tx, db_name, table_name, []string{"createtime= ?"}, []string{"id=?"}, []interface{}{time, 2})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===========事务执行:==================")
fmt.Println("===事务update:", update_log)
insert_log, err := TxPreInsert(tx, db_name, table_name, map[string]interface{}{"user_id": "1", "createtime": helper.ToStr(time)})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===事务insert:", insert_log)
del_log, err := TxDelete(tx, db_name, table_name, map[string]string{"id": "2"})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("====事务delete:", del_log)
err = tx.Commit()
if err != nil {
t.Log(err)
tx.Rollback()
}
fmt.Println("=======事务执行完成==========")*/
/*trans := NewTxQuery().Db(db_name)
update_trans_res, err := trans.Clean().Table("ttl_dorm_check_in_apply").SaveData(map[string]interface{}{
"check_out_id": "48",
"id": "21",
"updatetime": time.Now().Unix(),
}).UpdateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=======事务update_trans", update_trans_res)
err = trans.Commit()
if err != nil {
trans.Rollback()
t.Log(err)
}*/
/*fmt.Println("====================执行事务trans============")
trans := NewTxQuery().Db(db_name)
info_trans, err := trans.Clean().Table(table_name).Where("id = ?").Value(5).Find()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("===事务Find_trans:", info_trans)
list_trans, err := trans.Clean().Table(table_name).Title("*").Select()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=========事务List_trans:", list_trans)
data := map[string]interface{}{
"user_id": 5,
"memo": "test",
"createtime": time,
}
add_trans, err := trans.Clean().Table(table_name).SaveData(data).CreateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("======事务Add_trans:", add_trans)
data["id"] = 15
update_trans_res, err := trans.Clean().Table(table_name).SaveData(data).UpdateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=======事务update_trans", update_trans_res)
err = trans.Commit()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("====================执行事务结束==================")*/
}

+ 132
- 0
common.go View File

@ -0,0 +1,132 @@
package dbquery
import (
"fmt"
"git.tetele.net/tgo/helper"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// ===================达梦兼容===============
// 非关键字可以不添加标识符,关键字须添加
// 日期函数的使用TO_CHAR(TO_DATE('1970-01-01','yyyy-mm-dd') + (createtime / 86400), 'yyyy-mm-dd')
// group_concat替换成LISTAGG
// ========================================
// 关键字替换-支持达梦和高斯
func ReplaeByOtherSql(sql, sql_type, action string) string {
sql_type_arr := []string{"DmSql", "PgsqlDb"}
if !helper.IsInStringArray(sql_type_arr, sql_type) {
log.Println("sql_type error", sql_type)
return ""
}
// PgsqlDb用
if action == "add" {
sql = helper.StringJoin(sql, " RETURNING id")
}
// 定义需要处理的关键字列表
keywords := []string{"user", "order", "group", "table", "view", "admin", "new"}
//设置保护词组
excludePhrases := []string{
"order by", "group by", "GROUP BY", "ORDER BY", "WITHIN GROUP", "within group",
}
if sql_type == "PgsqlDb" {
keywords = []string{"user", "order", "group"}
excludePhrases = []string{
"order by", "group by", "GROUP BY", "ORDER BY",
}
// 移除所有反引号
sql = strings.ReplaceAll(sql, "`", "")
}
// 使用单词边界 \b 确保只匹配完整单词
pattern := `\b(` + strings.Join(keywords, "|") + `)\b`
re := regexp.MustCompile(pattern)
//保护排除短语
phraseMap := make(map[string]string)
for i, phrase := range excludePhrases {
placeholder := fmt.Sprintf("__EXCLUDE_%d__", i)
phraseMap[placeholder] = phrase
sql = strings.Replace(sql, phrase, placeholder, -1)
}
// 执行替换
sql = re.ReplaceAllStringFunc(sql, func(match string) string {
// 检查匹配是否在字符串常量中
if isInStringLiteral(sql, match) {
return match
}
if sql_type == "DmSql" {
return "`" + match + "`"
} else {
return `"` + match + `"`
}
})
// 恢复排除短语
for placeholder, phrase := range phraseMap {
sql = strings.Replace(sql, placeholder, phrase, -1)
}
return sql
}
// 检查匹配是否在字符串常量中
func isInStringLiteral(sql, match string) bool {
index := strings.Index(sql, match)
if index == -1 {
return false
}
// 检查匹配前的单引号数量
prefix := sql[:index]
singleQuotes := strings.Count(prefix, "'") - strings.Count(prefix, "\\'")
// 奇数表示在字符串常量中
return singleQuotes%2 != 0
}
// 字段值类型转换--针对达梦用
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case int, int8, int16, int32, int64:
return strconv.FormatInt(reflect.ValueOf(value).Int(), 10)
case uint, uint8, uint16, uint32, uint64:
return strconv.FormatUint(reflect.ValueOf(value).Uint(), 10)
case float32, float64:
return strconv.FormatFloat(reflect.ValueOf(value).Float(), 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
case time.Time:
return v.Format("2006-01-02 15:04:05")
default:
return fmt.Sprintf("%v", v)
}
}
// 字符串切片元素追加前缀
func addPrefixInField(slice []string, prefix string) []string {
new_slice := make([]string, len(slice))
for i, v := range slice {
new_slice[i] = prefix + v
}
return new_slice
}
func DmFieldDeal(fields string) string {
//移除所有反引号
title := strings.Replace(fields, "`", "", -1)
return title
}

+ 225
- 19
conn.go View File

@ -2,46 +2,110 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"fmt"
"log" "log"
"net/url"
"errors" "errors"
"strings" "strings"
"time" "time"
_ "dm"
"git.tetele.net/tgo/helper" "git.tetele.net/tgo/helper"
_ "gitee.com/opengauss/openGauss-connector-go-pq" // 高斯驱动(推荐)或 "github.com/lib/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
//_ "github.com/lib/pq" // 关键驱动导入
) )
var DB *sql.DB var DB *sql.DB
var SLAVER_DB *sql.DB
// db类型,默认空,如TencentDB(腾讯),
var DB_PROVIDER string
func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error { func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("mysql database connectting...")
log.Println("database connectting...")
var dbConnErr error var dbConnErr error
if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != ""
const maxRetries = 10
for i := 0; i < 10; i++ {
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("mysql DBconnection params errors")
}
dsn := DBUSER + ":" + DBPWD + "@tcp(" + DBHOST + ":" + DBPORT + ")/" + DBNAME + "?charset=utf8mb4"
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("mysql", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5) time.Sleep(time.Second * 5)
DB, dbConnErr = sql.Open("mysql", DBUSER+":"+DBPWD+"@tcp("+DBHOST+":"+DBPORT+")/"+DBNAME+"?charset=utf8")
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("mysql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}
func CloseConn() error {
return DB.Close()
}
func ConnectSlaver(DBHOST, DBUSER_SLAVER, DBPWD_SLAVER, DBNAME, DBPORT string, conns ...int) error {
log.Println("database connectting with slaver...")
var dbConnErr error
if DBHOST != "" && DBUSER_SLAVER != "" && DBPWD_SLAVER != "" && DBPORT != "" { //&& DBNAME != ""
for i := 0; i < 10; i++ {
SLAVER_DB, dbConnErr = sql.Open("mysql", DBUSER_SLAVER+":"+DBPWD_SLAVER+"@tcp("+DBHOST+":"+DBPORT+")/"+DBNAME+"?charset=utf8mb4")
if dbConnErr != nil { if dbConnErr != nil {
log.Println("ERROR", "can not connect to Database, ", dbConnErr) log.Println("ERROR", "can not connect to Database, ", dbConnErr)
time.Sleep(time.Second * 5)
} else { } else {
if len(conns) > 0 { if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
SLAVER_DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else { } else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
SLAVER_DB.SetMaxOpenConns(200) //默认值为0表示不限制
} }
if len(conns) > 1 { if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
SLAVER_DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else { } else {
DB.SetMaxIdleConns(50)
SLAVER_DB.SetMaxIdleConns(50)
} }
DB.Ping()
SLAVER_DB.Ping()
log.Println("database connected") log.Println("database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
SLAVER_DB.SetConnMaxLifetime(time.Minute * 2)
break break
} }
} }
@ -51,23 +115,56 @@ func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
return dbConnErr return dbConnErr
} }
func CloseConn() error {
return DB.Close()
func CloseSlaverConn() error {
return SLAVER_DB.Close()
} }
/** /**
* 检测表名 * 检测表名
*/ */
func getTableName(dbName, table string) string {
func getTableName(dbName, table string, dbtype ...string) string {
if strings.Contains(table, ".") {
return table
var db_type string = "mysql"
if DB_PROVIDER == "PgsqlDb" {
dbName = ""
} else if DB_PROVIDER == "DmSql" {
dbName = ""
} }
if dbName != "" {
return helper.StringJoin(dbName, ".", table)
} else {
return table
if len(dbtype) > 0 {
if dbtype[0] != "" {
db_type = dbtype[0]
}
}
var ret string
switch db_type {
case "mysql":
if strings.Contains(table, ".") {
ret = table
}
if dbName != "" {
if strings.Contains(table, ",") {
arr := strings.Split(table, ",")
arrStrs := make([]string, 0, len(arr))
for _, v := range arr {
arrStrs = append(arrStrs, helper.StringJoin(dbName, ".", v))
}
ret = strings.Join(arrStrs, ",")
} else {
ret = helper.StringJoin(dbName, ".", table)
}
} else {
ret = table
}
case "mssql":
ret = helper.StringJoin(dbName, ".", table)
} }
return ret
} }
func GetDbTableName(dbName, table string) string { func GetDbTableName(dbName, table string) string {
@ -77,3 +174,112 @@ func GetDbTableName(dbName, table string) string {
func judg() []string { func judg() []string {
return []string{"=", ">", "<", "!=", "<=", ">="} return []string{"=", ">", "<", "!=", "<=", ">="}
} }
// pgsql连接
func PgConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("pg database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "PgsqlDb"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("pgsql DBconnection params errors")
}
dsn := "host=" + DBHOST + " port=" + DBPORT + " user=" + DBUSER + " password=" + DBPWD + " dbname=" + DBNAME + " sslmode=disable search_path=public"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("opengauss", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("pgsql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}
// 达梦8连接
func DmConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("DM database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "DmSql"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("dm DBconnection params errors")
}
//达梦8如果密码存在特殊字符,需使用 url.PathEscape 进行转义后再放入连接串
DBPWD = url.PathEscape(DBPWD)
dsn := "dm://" + DBUSER + ":" + DBPWD + "@" + DBHOST + ":" + DBPORT + "?charSet=utf8&compatibleMode=mysql"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("dm", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("dm database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}

+ 572
- 38
db.go View File

@ -2,7 +2,9 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"github.com/jmoiron/sqlx"
"log" "log"
"strconv"
"errors" "errors"
"strings" "strings"
@ -22,7 +24,11 @@ func Insert(dbName, table string, data map[string]string) (int64, error) {
if dbName == "" && table == "" { if dbName == "" && table == "" {
return insertId, errors.New("没有数据表") return insertId, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
if len(data) < 1 { if len(data) < 1 {
return insertId, errors.New("没有要写入的数据") return insertId, errors.New("没有要写入的数据")
@ -39,16 +45,33 @@ func Insert(dbName, table string, data map[string]string) (int64, error) {
valueList[i] = value valueList[i] = value
i++ i++
} }
result, err := DB.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = DB.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else { } else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := DB.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
} }
} }
@ -62,8 +85,11 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
if dbName == "" && table == "" { if dbName == "" && table == "" {
return rowsAffected, errors.New("没有数据表") return rowsAffected, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
if len(data) < 1 { if len(data) < 1 {
return rowsAffected, errors.New("同有更新的数据") return rowsAffected, errors.New("同有更新的数据")
} }
@ -100,7 +126,15 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data) log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data)
return rowsAffected, errors.New("条件中有空数据") return rowsAffected, errors.New("条件中有空数据")
} }
result, err := DB.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR|修改", dbName, "数据失败,", err) log.Println("ERROR|修改", dbName, "数据失败,", err)
@ -122,9 +156,11 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
if dbName == "" && table == "" { if dbName == "" && table == "" {
return count, errors.New("没有数据表") return count, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
if len(data) < 1 { if len(data) < 1 {
return count, errors.New("没有要删除的数据") return count, errors.New("没有要删除的数据")
} }
@ -155,7 +191,15 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
limitStr = " limit " + del_count[0] limitStr = " limit " + del_count[0]
} }
result, err := DB.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR|删除", dbName, "数据失败,", err) log.Println("ERROR|删除", dbName, "数据失败,", err)
@ -179,7 +223,6 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" { if dbName == "" && table == "" {
return count, info, errors.New("没有数据表") return count, info, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table) dbName = getTableName(dbName, table)
if len(title) < 1 { if len(title) < 1 {
@ -198,7 +241,11 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if _, ok := limit["from"]; ok { if _, ok := limit["from"]; ok {
from = limit["from"] from = limit["from"]
} }
limitStr += " limit " + from + ",1"
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit 1 OFFSET " + from
} else {
limitStr += " limit " + from + ",1"
}
} else { } else {
limitStr = " limit 1" limitStr = " limit 1"
@ -229,8 +276,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
var err error var err error
var queryNum int = 0 var queryNum int = 0
for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题 for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题
rows, err = DB.Query("SELECT "+title+" FROM "+dbName+" where "+strings.Join(keyList, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(keyList, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -262,8 +316,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
count++ count++
@ -277,6 +338,431 @@ func GetData(dbName, table string, title string, where map[string]string, limit
return count, info, nil return count, info, nil
} }
/**
* 查找一条记录
* @param dbName 数据表名
* @param title 查询字段名
*/
func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0
info := make(map[string]string)
if dbName == "" && table_name == "" {
return count, info, errors.New("没有数据表")
}
table := ""
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
}
var sql_str, title string
if titles != "" {
title = titles
} else {
title = "*"
}
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
}
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}
if len(where) > 0 || len(where_or) > 0 {
sql_str = helper.StringJoin(sql_str, " where ")
}
if len(where) > 0 {
sql_str = helper.StringJoin(sql_str, " (", strings.Join(where, " and "), " ) ")
}
if len(where_or) > 0 {
if len(where) > 0 {
sql_str = helper.StringJoin(sql_str, " or ", strings.Join(where_or, " or "))
} else {
sql_str = helper.StringJoin(sql_str, strings.Join(where_or, " or "))
}
}
if groupby != "" {
sql_str = helper.StringJoin(sql_str, " group by ", groupby)
}
if having != "" {
sql_str = helper.StringJoin(sql_str, " having ", having)
}
if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby)
}
if debug {
log.Println("query sql:", sql_str, valueList)
}
condition_len := 0 //所有条件数
for _, ch2 := range sql_str {
if string(ch2) == "?" {
condition_len++
}
}
if condition_len != len(valueList) {
return 0, nil, errors.New("参数错误,条件值错误")
}
var rows *sql.Rows
var err error
var queryNum int = 0
sql_str = helper.StringJoin(sql_str, " limit 1")
var db *sql.DB
if SLAVER_DB != nil {
db = SLAVER_DB
} else {
db = DB
}
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...)
if err == nil {
break
} else {
log.Println(err)
time.Sleep(time.Millisecond * 500)
}
queryNum++
}
if err != nil {
log.Println("DB error:", err)
rows.Close()
return count, info, err
}
columns, _ := rows.Columns()
scanArgs := make([]interface{}, len(columns))
values := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
var index string
var rowerr error
for rows.Next() {
rowerr = rows.Scan(scanArgs...)
if rowerr == nil {
for i, col := range values {
if col != nil {
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
count++
} else {
log.Println("ERROR", rowerr)
}
}
rows.Close()
if rowerr != nil {
log.Println("DB row error:", rowerr)
return count, info, rowerr
}
return count, info, nil
}
/**
* 查找多条记录
* @param dbName 数据表名
* @param title 查询字段名
*/
func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0
list := make([]map[string]string, 0)
if dbName == "" && table_name == "" {
return count, list, errors.New("没有数据表")
}
table := ""
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
}
var sql_str, title string
if titles != "" {
title = titles
} else {
title = "*"
}
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
}
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) >= 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}
if len(where) > 0 || len(where_or) > 0 {
sql_str = helper.StringJoin(sql_str, " where ")
}
if len(where) > 0 {
sql_str = helper.StringJoin(sql_str, " (", strings.Join(where, " and "), " ) ")
}
if len(where_or) > 0 {
if len(where) > 0 {
sql_str = helper.StringJoin(sql_str, " or ", strings.Join(where_or, " or "))
} else {
sql_str = helper.StringJoin(sql_str, strings.Join(where_or, " or "))
}
}
if groupby != "" {
sql_str = helper.StringJoin(sql_str, " group by ", groupby)
}
if having != "" {
sql_str = helper.StringJoin(sql_str, " HAVING ", having)
}
if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby)
}
if page > 0 || page_size > 0 {
if page < 1 {
page = 1
}
if page_size < 1 {
page_size = 10
}
from := strconv.Itoa((page - 1) * page_size)
offset := strconv.Itoa(page_size)
if from != "" && offset != "" {
if DB_PROVIDER == "PgsqlDb" {
sql_str = helper.StringJoin(sql_str, " limit ", offset, " OFFSET ", from)
} else {
sql_str = helper.StringJoin(sql_str, " limit ", from, " , ", offset)
}
}
}
if debug {
log.Println("query sql:", sql_str, valueList)
}
condition_len := 0 //所有条件数
for _, ch2 := range sql_str {
if string(ch2) == "?" {
condition_len++
}
}
if condition_len != len(valueList) {
return 0, list, errors.New("参数错误,条件值错误")
}
var db *sql.DB
if SLAVER_DB != nil {
db = SLAVER_DB
} else {
db = DB
}
var rows *sql.Rows
var err error
var queryNum int = 0
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...)
if err == nil {
break
} else {
log.Println(err)
time.Sleep(time.Millisecond * 500)
}
queryNum++
}
if err != nil {
rows.Close()
return 0, list, err
}
columns, _ := rows.Columns()
scanArgs := make([]interface{}, len(columns))
values := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
var index string
var rowerr error
var info map[string]string
for rows.Next() {
rowerr = rows.Scan(scanArgs...)
info = make(map[string]string)
if rowerr == nil {
for i, col := range values {
if col != nil {
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
count++
} else {
log.Println("ERROR", rowerr)
}
if len(info) > 0 {
list = append(list, info)
}
}
rows.Close()
return count, list, nil
}
func GetInfo(dbName, table string, title string, where map[string]string) (map[string]string, error) { func GetInfo(dbName, table string, title string, where map[string]string) (map[string]string, error) {
count, info, gzErr := GetData(dbName, table, title, where, nil) count, info, gzErr := GetData(dbName, table, title, where, nil)
@ -301,9 +787,11 @@ func GetList(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" { if dbName == "" && table == "" {
return list, errors.New("没有数据表") return list, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
var rows *sql.Rows var rows *sql.Rows
var err error var err error
var queryNum int = 0 var queryNum int = 0
@ -326,7 +814,12 @@ func GetList(dbName, table string, title string, where map[string]string, limit
from = limit["from"] from = limit["from"]
} }
if offset != "0" && from != "" { if offset != "0" && from != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
} }
} }
@ -360,8 +853,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
} }
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select "+title+" from "+dbName+" where "+strings.Join(whereStr, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "select " + title + " from " + dbName + " where " + strings.Join(whereStr, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -406,8 +906,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
} }
} }
list = append(list, record) list = append(list, record)
@ -425,8 +932,11 @@ func GetTotal(dbName, table string, args ...string) (total int) {
if dbName == "" && table == "" { if dbName == "" && table == "" {
return return
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
var title string = "*" var title string = "*"
@ -439,7 +949,6 @@ func GetTotal(dbName, table string, args ...string) (total int) {
var queryNum int = 0 var queryNum int = 0
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1") rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1")
if err == nil { if err == nil {
@ -480,7 +989,11 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to
if dbName == "" && table == "" { if dbName == "" && table == "" {
return return
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
var title string = "*" var title string = "*"
@ -519,7 +1032,15 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count("+title+") as count from "+dbName+" where "+strings.Join(whereStr, " and ")+" limit 1", valueList...)
var Sql string
Sql = "select count(" + title + ") as count from " + dbName + " where " + strings.Join(whereStr, " and ") + " limit 1"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -578,6 +1099,12 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题
if len(args) > 1 { if len(args) > 1 {
if DB_PROVIDER == "PgsqlDb" {
queryStr = sqlx.Rebind(sqlx.DOLLAR, queryStr)
queryStr = ReplaeByOtherSql(queryStr, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
queryStr = ReplaeByOtherSql(queryStr, "DmSql", "")
}
rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",") rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",")
if err != nil { if err != nil {
log.Println("ERROR|DoQuery error:", err) log.Println("ERROR|DoQuery error:", err)
@ -618,8 +1145,15 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
} }
} }
list = append(list, record) list = append(list, record)


+ 80
- 4
db_test.go View File

@ -1,11 +1,87 @@
package dbquery package dbquery
import ( import (
"log"
"testing" "testing"
) )
func Test_GetTotal(t *testing.T) {
Connect()
total := GetTotal("dev_tetel_net", TABLE_CONFIG, "id")
t.Log(total)
func Test_Connet(t *testing.T) {
for i := 0; i < 1; i++ {
dbhost := "localhost"
dbname := "shop"
dbusername := "tetele"
dbpassword := "fly123456"
dbport := "3306"
table := "ttl_order_product"
err := Connect(dbhost, dbusername, dbpassword, dbname, dbport)
if err != nil {
log.Println(err.Error())
}
//_,err = new(Query).Db(dbname).Table("ttl_user").Where("id > 0").Select()
// _, err = new(Query).Db(dbname).Table("ttl_news").
// Datas([]string{"title=?", "content=?"}).
// Values([]interface{}{"aaaaaaaaa", "bbbbbb"}).Create()
title := "op.id,op.sn,op.order_price"
alias := "op"
join := [][]string{}
join = append(join, []string{"ttl_product as p", "op.product_id=p.id"})
where := []string{}
where_or := []string{}
valueList := []interface{}{}
orderby := "id desc"
debug := true
//count, row, err := GetRow(dbname, table, alias, title, join, where, where_or, valueList, orderby, debug)
count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count)
log.Println(row)
log.Println(err)
if err != nil {
log.Println(err.Error())
}
}
}
func Test_Query(t *testing.T) {
token := "67c121aa-6e1c-011f-ebb6-976d855fd777"
dbhost := "192.168.233.134"
dbname := "canyin"
dbusername := "bin"
dbpassword := "Bin123456"
dbport := "5432"
table := "ttl_user_token"
err := PgConnect(dbhost, dbusername, dbpassword, dbname, dbport)
if err != nil {
log.Println(err.Error())
}
title := "user.*,ut.expiretime"
alias := "ut"
join := [][]string{}
join = append(join, []string{"ttl_user as user", "ut.user_id= user.id"})
where := []string{
"ut.token=?",
}
where_or := []string{}
valueList := []interface{}{
token,
}
orderby := ""
debug := true
count, row, err := GetRow(dbname, table, alias, title, [][]string{}, join, where, where_or, valueList, orderby, "", "", debug)
//count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count)
log.Println(row)
log.Println(err)
if err != nil {
log.Println(err.Error())
}
} }

+ 28
- 0
go.mod View File

@ -0,0 +1,28 @@
module git.tetele.net/tgo/dbquery
go 1.23.0
toolchain go1.24.0
require (
git.tetele.net/tgo/helper v0.8.0
gitee.com/opengauss/openGauss-connector-go-pq v1.0.7
github.com/denisenkom/go-mssqldb v0.11.0
github.com/go-sql-driver/mysql v1.8.1
github.com/jmoiron/sqlx v1.4.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
git.tetele.net/tgo/crypter v0.2.2 // indirect
github.com/ZZMarquis/gm v1.3.2 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
)

+ 99
- 17
prepare.go View File

@ -3,6 +3,7 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/jmoiron/sqlx"
"log" "log"
"strings" "strings"
@ -41,8 +42,11 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
offset = limit["offset"] offset = limit["offset"]
} }
if from != "" && offset != "" { if from != "" && offset != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
} }
} }
@ -52,7 +56,15 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
if len(where) > 0 { if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr) // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
} else { } else {
// log.Println("SELECT " + title + " FROM " + dbName + limitStr) // log.Println("SELECT " + title + " FROM " + dbName + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr) stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr)
@ -75,8 +87,6 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
return nil, errors.New("缺少必要参数") return nil, errors.New("缺少必要参数")
} }
// log.Println(valuelist...)
rows, err := stmt.Query(valuelist...) rows, err := stmt.Query(valuelist...)
defer stmt.Close() defer stmt.Close()
if err != nil { if err != nil {
@ -106,8 +116,15 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
} else { } else {
@ -154,8 +171,15 @@ func StmtForQueryRow(stmt *sql.Stmt, valuelist []interface{}) (map[string]string
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
} else { } else {
@ -186,8 +210,15 @@ func StmtForUpdate(dbName, table string, data []string, where []string) (*sql.St
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
stmt, err = DB.Prepare("update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and "))
var Sql string
Sql = "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
return stmt, err return stmt, err
} }
@ -224,7 +255,41 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return nil, errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(data, " , "))
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
//stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
stmt, err = DB.Prepare(sql)
return stmt, err return stmt, err
} }
@ -234,11 +299,23 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
* @return lastId error * @return lastId error
*/ */
func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) { func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
if DB_PROVIDER == "PgsqlDb" {
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return res.LastInsertId()
} }
return res.LastInsertId()
} }
/** /**
@ -350,7 +427,12 @@ func StmtForQuery(querysql string) (*sql.Stmt, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
if DB_PROVIDER == "PgsqlDb" {
querysql = sqlx.Rebind(sqlx.DOLLAR, querysql)
querysql = ReplaeByOtherSql(querysql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
querysql = ReplaeByOtherSql(querysql, "DmSql", "")
}
stmt, err = DB.Prepare(querysql) stmt, err = DB.Prepare(querysql)
return stmt, err return stmt, err


+ 70
- 0
sqlserver.go View File

@ -0,0 +1,70 @@
package dbquery
import (
"database/sql"
"errors"
"fmt"
"strconv"
"log"
"time"
_ "github.com/denisenkom/go-mssqldb"
)
var MSDB_CONN *sql.DB
func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT, encrypt string, conns ...int) error {
log.Println("msdb connectting...")
var dbConnErr error
if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != ""
for i := 0; i < 10; i++ {
//连接字符串
db_port, _ := strconv.Atoi(DBPORT)
params := "server=%s;port=%d;database=%s;user id=%s;password=%s"
if encrypt != "" {
params = params + ";encrypt=" + encrypt
}
connString := fmt.Sprintf(params, DBHOST, db_port, DBNAME, DBUSER, DBPWD)
log.Println(connString)
//建立连接
MSDB_CONN, dbConnErr = sql.Open("mssql", connString)
if dbConnErr != nil {
log.Println("ERROR", "can not connect to Database, ", dbConnErr)
time.Sleep(time.Second * 5)
} else {
err = MSDB_CONN.Ping()
log.Println("msdb connected", err)
break
}
}
} else {
return errors.New("msdb connection params errors")
}
return dbConnErr
}
func CloseMSConn() error {
if MSDB_CONN != nil {
return MSDB_CONN.Close()
}
return nil
}

+ 133
- 34
transaction.go View File

@ -6,6 +6,8 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"git.tetele.net/tgo/helper"
"github.com/jmoiron/sqlx"
"log" "log"
"strings" "strings"
"time" "time"
@ -21,7 +23,12 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
if dbname == "" && table == "" { if dbname == "" && table == "" {
return 0, errors.New("参数错误,没有数据表") return 0, errors.New("参数错误,没有数据表")
} }
dbName := getTableName(dbname, table)
dbName := ""
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbname, table)
}
if len(data) < 1 { if len(data) < 1 {
return 0, errors.New("参数错误,没有要写入的数据") return 0, errors.New("参数错误,没有要写入的数据")
} }
@ -38,16 +45,33 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
valueList[i] = value valueList[i] = value
i++ i++
} }
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
if DB_PROVIDER == "PgsqlDb" {
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else { } else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
} }
} }
@ -62,7 +86,12 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{})
return 0, errors.New("params error,no db or table") return 0, errors.New("params error,no db or table")
} }
dbName := getTableName(dbname, table)
dbName := ""
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbname, table)
}
if len(data) < 1 { if len(data) < 1 {
return 0, errors.New("params error,no data to insert") return 0, errors.New("params error,no data to insert")
@ -74,28 +103,52 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{})
var field []string = make([]string, len(data)) var field []string = make([]string, len(data))
var valuelist []interface{} = make([]interface{}, len(data)) var valuelist []interface{} = make([]interface{}, len(data))
insert_data := []string{}
value_data := []string{}
var i int = 0 var i int = 0
for key, item := range data { for key, item := range data {
field[i] = key + "=?" field[i] = key + "=?"
valuelist[i] = item valuelist[i] = item
i++ i++
insert_data = append(insert_data, key)
value_data = append(value_data, "?")
} }
if DB_PROVIDER == "PgsqlDb" {
Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
stmt, err = tx.Prepare(sql)
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
if DB_PROVIDER == "DmSql" {
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql)
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
}
insertId, _ := result.LastInsertId()
return insertId, nil
} }
insertId, _ := result.LastInsertId()
return insertId, nil
} }
@ -108,7 +161,12 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma
if dbname == "" && table == "" { if dbname == "" && table == "" {
return rowsAffected, errors.New("参数错误,没有数据表") return rowsAffected, errors.New("参数错误,没有数据表")
} }
dbName := getTableName(dbname, table)
dbName := ""
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbname, table)
}
if len(data) < 1 { if len(data) < 1 {
return rowsAffected, errors.New("参数错误,没有要写入的数据") return rowsAffected, errors.New("参数错误,没有要写入的数据")
} }
@ -145,7 +203,15 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma
log.Println("ERROR", "update", dbName, "error, params empty") log.Println("ERROR", "update", dbName, "error, params empty")
return rowsAffected, errors.New("params empty") return rowsAffected, errors.New("params empty")
} }
result, err := tx.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR", "update", dbName, "error:", err) log.Println("ERROR", "update", dbName, "error:", err)
@ -167,7 +233,12 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
return 0, errors.New("params error,no db or table") return 0, errors.New("params error,no db or table")
} }
dbName := getTableName(dbname, table)
dbName := ""
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbname, table)
}
if len(where) < 1 { if len(where) < 1 {
return 0, errors.New("params error, no data for update") return 0, errors.New("params error, no data for update")
@ -178,7 +249,12 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
var stmt *sql.Stmt var stmt *sql.Stmt
sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql) stmt, err = tx.Prepare(sql)
if err != nil { if err != nil {
@ -204,8 +280,12 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
if dbname == "" && table == "" { if dbname == "" && table == "" {
return count, errors.New("参数错误,没有数据表") return count, errors.New("参数错误,没有数据表")
} }
dbName := getTableName(dbname, table)
dbName := ""
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbname, table)
}
if len(where) < 1 { if len(where) < 1 {
return count, errors.New("参数错误,没有删除条件") return count, errors.New("参数错误,没有删除条件")
} }
@ -236,7 +316,15 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
limitStr = " limit " + del_count[0] limitStr = " limit " + del_count[0]
} }
result, err := tx.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR", "delete from", dbName, "error:", err) log.Println("ERROR", "delete from", dbName, "error:", err)
@ -257,8 +345,11 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) (
if dbName == "" && table == "" { if dbName == "" && table == "" {
return nil, errors.New("参数错误,没有数据表") return nil, errors.New("参数错误,没有数据表")
} }
dbName = getTableName(dbName, table)
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
}
if len(title) < 1 { if len(title) < 1 {
return nil, errors.New("没有要查询内容") return nil, errors.New("没有要查询内容")
@ -269,7 +360,15 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) (
if len(where) > 0 { if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE") // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = tx.Prepare(Sql)
} else { } else {
// log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE") // log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE") stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE")


+ 1105
- 0
transaction_chain.go
File diff suppressed because it is too large
View File


Loading…
Cancel
Save