26 Commits

Author SHA1 Message Date
  lijianbin c5dab75bff BuildSelectSql子查询兼容pgsql 1 week ago
  zhenghaorong 3ea2b59dda 调整接口 2 weeks ago
  lijianbin 45bcc98850 达梦--密码特殊字符需转义 2 weeks ago
  lijianbin a9383ba486 修正输出描述 2 weeks ago
  lijianbin fe779bf6ca 兼容高斯和达梦数据库 2 weeks ago
  lijianbin f7b0eb2df6 完善pgsql关键替换问题 1 month ago
  lijianbin 74aa5112a0 完善pgsql关键替换问题 1 month ago
  lijianbin 35782df825 pgsql关键字问题 1 month ago
  lijianbin 3b11f98fdc 修复pgsql中关键字使用问题 1 month ago
  lijianbin 60c84df410 修复pgsql中关键字使用问题 1 month ago
  loshiqi 01f3625f1e 兼容高斯 1 month ago
  loshiqi d16c3ca83c 兼容高斯 1 month ago
  loshiqi 16cff9c0ca 兼容高斯 2 months ago
  loshiqi bd754d1507 兼容高斯 2 months ago
  loshiqi 95a092325e values 2 months ago
  loshiqi 9bd66eefc9 插入和分页查询 2 months ago
  loshiqi 2a9f596ea8 兼容pgsql 2 months ago
  loshiqi ff5d48d51a 占位符 2 months ago
  loshiqi 82954508e0 驱动名称 更改 2 months ago
  loshiqi 2f467a0f92 增加pgsql链接 2 months ago
  loshiqi 14f6e5fc28 增加执行原始方法 3 months ago
  zhenghaorong c13da09c57 增加过滤 11 months ago
  zhenghaorong 7f9e42fed6 修复join无法使用with临时表问题 11 months ago
  zhenghaorong 2d94f24f43 修复join无法使用with临时表问题 11 months ago
  zhenghaorong bc664f29c5 修复join无法使用with临时表问题 11 months ago
  zhenghaorong 6d7835939e 增加with语句 11 months ago
11 changed files with 1509 additions and 180 deletions
Split View
  1. +227
    -22
      chain.go
  2. +144
    -7
      chain_test.go
  3. +130
    -0
      common.go
  4. +159
    -27
      conn.go
  5. +248
    -58
      db.go
  6. +40
    -4
      db_test.go
  7. +16
    -2
      go.mod
  8. +110
    -3
      go.sum
  9. +99
    -17
      prepare.go
  10. +99
    -27
      transaction.go
  11. +237
    -13
      transaction_chain.go

+ 227
- 22
chain.go View File

@ -3,6 +3,7 @@ package dbquery
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
"log"
"strconv"
"strings"
@ -34,6 +35,7 @@ type Query struct {
conn *sql.DB
debug bool
dbtype string
with [][]string //[[临时表的sql语句,临时表的名称]]
}
func NewQuery(t ...string) *Query {
@ -67,6 +69,11 @@ func (this *Query) Conn(conn *sql.DB) *Query {
}
func (this *Query) Db(dbname string) *Query {
this.dbname = dbname
if DB_PROVIDER == "PgsqlDb" {
this.dbname = ""
} else if DB_PROVIDER == "DmSql" {
this.dbname = ""
}
return this
}
@ -103,6 +110,14 @@ func (this *Query) Groupby(groupby string) *Query {
this.groupby = groupby
return this
}
func (this *Query) With(with []string) *Query {
this.with = append(this.with, with)
return this
}
func (this *Query) Withs(withs [][]string) *Query {
this.with = append(this.with, withs...)
return this
}
func (this *Query) Where(where string) *Query {
this.where = append(this.where, where)
return this
@ -198,6 +213,7 @@ func (this *Query) Clean() *Query {
this.upd_field = this.upd_field[0:0]
this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this
}
@ -211,15 +227,55 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
"COLUMN_COMMENT", //备注
"IS_NULLABLE", //是否为空
}
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?"
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?"
if DB_PROVIDER == "PgsqlDb" {
//pgsql中,未加引号的标识符会被自动转换为小写
sql = `SELECT
c.column_name AS 'COLUMN_NAME',
c.column_default AS 'COLUMN_DEFAULT',
c.data_type AS 'DATA_TYPE',
c.udt_name AS 'COLUMN_TYPE',
pgdesc.description AS 'COLUMN_COMMENT',
c.is_nullable AS 'IS_NULLABLE'
FROM
information_schema.columns c
LEFT JOIN
pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name
LEFT JOIN
pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position
WHERE
c.table_name = ?`
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
} else if DB_PROVIDER == "DmSql" {
sql = `SELECT
COLUMN_NAME,
DATA_DEFAULT AS COLUMN_DEFAULT,
DATA_TYPE,
CASE
WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')'
WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')'
ELSE DATA_TYPE
END AS COLUMN_TYPE,
'' AS COLUMN_COMMENT,
CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE
FROM
ALL_TAB_COLUMNS
WHERE
TABLE_NAME = ?`
}
if this.conn == nil {
this.conn = DB
}
stmtSql, err := this.conn.Prepare(sql)
if err != nil {
return nil, err
}
list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname})
list, err := StmtForQueryList(stmtSql, []interface{}{table})
if err != nil {
return nil, err
}
@ -236,6 +292,10 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
}
for _, k := range field {
index := helper.StrFirstToUpper(k)
if DB_PROVIDER == "DmSql" {
index = helper.StrFirstToUpper(strings.ToLower(k))
}
if v, ok := item[index]; ok {
switch k {
case "COLUMN_NAME":
@ -278,7 +338,30 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
return nil, errors.New("参数错误,没有数据表")
}
var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table
} else {
table = getTableName(this.dbname, this.table, this.dbtype)
@ -296,15 +379,15 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if this.dbtype == "mssql" {
if this.page_size > 0 {
sql = helper.StringJoin("select top ", helper.ToStr(this.page_size), " ")
sql = helper.StringJoin(withSql, "select top ", helper.ToStr(this.page_size), " ")
} else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
}
} else {
if DB_PROVIDER == "TencentDB" {
sql = "/*slave*/ select "
sql = helper.StringJoin("/*slave*/ ", withSql, " select ")
} else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
}
}
@ -317,17 +400,31 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 {
join_type := "left"
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
join_type = joinitem[2]
} else { //默认左连接
join_type = "left"
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
}
sql = helper.StringJoin(sql, " ", join_type, " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
}
}
if len(this.where) > 0 || len(this.where_or) > 0 {
@ -366,12 +463,18 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
from := strconv.Itoa((this.page - 1) * this.page_size)
offset := strconv.Itoa(this.page_size)
if from != "" && offset != "" {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from)
} else {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
}
}
}
if this.debug {
log.Println("query sql:", sql, this.value)
}
condition_len := 0 //所有条件数
for _, ch2 := range sql {
if string(ch2) == "?" {
@ -381,6 +484,7 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if condition_len != len(this.value) {
return nil, errors.New("参数错误,条件值错误")
}
return map[string]interface{}{
"sql": sql,
"value": this.value,
@ -452,7 +556,12 @@ func (this *Query) UpdateStmt() error {
if this.conn == nil {
this.conn = DB
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
@ -477,8 +586,11 @@ func (this *Query) UpdateAllStmt() error {
var dataSql []string //一组用到的占位字符
var valSql []string //占位字符组
var updSql []string //更新字段的sql
var updSql_dm []string //更新字段的sql--达梦和高斯用
var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值
dataLen := len(this.save_data)
if dataLen > 0 {
//批量操作
this.data = this.data[0:0]
@ -496,11 +608,14 @@ func (this *Query) UpdateAllStmt() error {
case 0:
//预览创建数据的长度
updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default:
//按照需要更新字段数长度
updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
}
}
for k := range this.save_data[i] {
@ -508,6 +623,7 @@ func (this *Query) UpdateAllStmt() error {
dataSql = append(dataSql, "?") //存储需要的占位符
if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
}
}
dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式
@ -525,21 +641,26 @@ func (this *Query) UpdateAllStmt() error {
switch updFieldLen {
case 0:
updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default:
updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
}
}
for i := 0; i < fieldLen; i++ {
dataSql = append(dataSql, "?")
if updFieldLen == 0 && this.data[i] != "id" {
updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")")
updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
}
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
@ -556,8 +677,44 @@ func (this *Query) UpdateAllStmt() error {
if len(valSql) > 1 {
setText = " value "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , "))
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
sql = `merge into ` + dbName + ` as t
using (
` + setText + strings.Join(valSql, ",") + `
) s (` + strings.Join(this.data, " , ") + `)
on (t.id = s.id)
when matched then
update set
` + strings.Join(updSql_dm, " , ") + `
when NOT matched then
insert (` + strings.Join(this.data, " , ") + `)
values (` + strings.Join(val_field, " , ") + `)`
} else if DB_PROVIDER == "DmSql" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
title_field := addPrefixInField(this.data, "? AS ")
sql = `MERGE INTO ` + dbName + ` AS t
USING (
SELECT
` + strings.Join(title_field, " , ") + `
FROM DUAL
) s
ON (t.id = s.id)
WHEN MATCHED THEN
UPDATE SET
` + strings.Join(updSql_dm, " , ") + `
WHEN NOT MATCHED THEN
INSERT (` + strings.Join(this.data, " , ") + `)
VALUES (` + strings.Join(val_field, " , ") + `)`
}
if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value)
}
@ -574,7 +731,12 @@ func (this *Query) UpdateAllStmt() error {
if this.conn == nil {
this.conn = DB
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
@ -645,6 +807,11 @@ func (this *Query) CreateAllStmt() error {
if len(valSql) > 1 {
setText = " value "
}
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
} else if DB_PROVIDER == "DmSql" {
setText = " values "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","))
if this.debug {
@ -663,7 +830,12 @@ func (this *Query) CreateAllStmt() error {
if this.conn == nil {
this.conn = DB
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "add")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
@ -685,8 +857,31 @@ func (this *Query) CreateStmt() error {
dbName := getTableName(this.dbname, this.table, this.dbtype)
var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range this.data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
if this.debug {
log.Println("insert sql:", sql, this.value)
@ -705,7 +900,12 @@ func (this *Query) CreateStmt() error {
if this.conn == nil {
this.conn = DB
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
@ -754,7 +954,12 @@ func (this *Query) DeleteStmt() error {
if this.conn == nil {
this.conn = DB
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
@ -772,7 +977,7 @@ func (this *Query) DeleteStmt() error {
*/
func (this *Query) Select() ([]map[string]string, error) {
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join,
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug)
return rows, err
@ -803,7 +1008,7 @@ func (this *Query) List() ([]map[string]string, error) {
*/
func (this *Query) Find() (map[string]string, error) {
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join,
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug)
return row, err


+ 144
- 7
chain_test.go View File

@ -1,17 +1,154 @@
package dbquery
import (
"fmt"
"testing"
)
// 测试各数据库下各种情况
func Test_Chain(t *testing.T) {
Connect("127.0.0.1", "root", "123456", "shop", "3306")
//测试数据库连接
//err := Connect("127.0.0.1", "root", "root", "canyin", "3306")
err := PgConnect("192.168.233.151", "bin", "Bin123456", "canyin", "5432")
//err := DmConnect("192.168.233.148", "SHOPV2", "Bin123456", "", "5236")
//err := DmConnect("10.33.0.91", "ZYSG", "Zysg!#2025", "", "5236")
//err := DmConnect("10.33.0.91", "dbck", "dskk!#555", "", "5236")
ret, err := new(Query).Db("shop").Table("ttl_order_product").Alias("op").
Join([]string{"ttl_product as p", "op.product_id=p.id"}).
Title("op.id,op.order_price,p.thumb").WhereOr("op.id =?").WhereOr("op.id = ?").Value(63).Value(64).Debug(true).List()
if err != nil {
t.Log(err)
}
db_name := ""
table_name := "ttl_user_log"
//time := time.Now().Unix()
t.Log(len(ret))
t.Log(ret)
t.Log(err)
//================查询表结构===========
ret, err := new(Query).Db(db_name).GetTableInfo(table_name)
if err != nil {
t.Log(err)
}
fmt.Println("===GetTableInfo:", ret)
//==========获取信息=================
query := new(Query).Db(db_name).Clean().Table("ttl_dorm_goods_reserve").Alias("a").
Join([]string{"ttl_dorm_goods_reserve_detail b", "a.id=b.reserve_id", "left"}).
Join([]string{"ttl_dorm_goods c", "c.id=b.goods_id", "left"}).
Join([]string{"ttl_dorm_room d", "d.id=a.room_id", "left"}).
Join([]string{"dorm_room_item e", "e.id=a.room_item_id", "left"}).
Where("a.user_id =?").Value(6006)
info, err := query.Groupby("a.id").Title("a.id").BuildSelectSql()
//info, err := new(Query).Db(db_name).Clean().Table(table_name).Where("id=?").Value("3").Title("*").Find()
//info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("user").Join([]string{"ttl_user u", "u.id = user.user_id", "inner"}).Where("user.id=?").Value("3").Title("user.id,user.user_id,u.nickname").Find()
//info, err := GetDataByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, nil)
if err != nil {
t.Log(err)
}
fmt.Println("===Find:", info)
//============获取列表==================
list, err := new(Query).Db(db_name).Clean().Table(table_name).Alias("u").Join([]string{"ttl_user user", "user.id = u.user_id", "inner"}).Where("user.id=?").Value(1).Title("user.id,user.nickname,user.status").Orderby("user.id desc").Page(1).PageSize(10).Select()
//list, err := DoQuery("select user.nickname,log.* from ttl_user user left join ttl_user_log log on user.id = log.user_id where user.id = ?", "1")
//list, err := GetListByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, map[string]string{"id": "id asc"})
//list, err := QueryByStmt("select * from "+table_name+" where id < ?", []interface{}{10})
if err != nil {
t.Log(err)
}
fmt.Println("===List:", list)
//===========添加数据============
//insert_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("1").Data("createtime=?").Value(time).Create()
//insert_res, err := InsertByStmt(db_name, table_name, []string{"user_id=?", "createtime=?"}, []interface{}{"1", time})
//insert_res, err := Insert(db_name, table_name, map[string]string{"user_id": "1", "createtime": helper.ToStr(time)})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Insert:", insert_res)
//================更新数据=====================
//update_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("2").Data("createtime=?").Value(time).Where("id=?").Value("6").Update()
//update_res, err := UpdateByStmt(db_name, table_name, []string{"createtime=?", "user_id=?"}, []string{"id=?"}, []interface{}{time, 3, 6})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Update:", update_res)
//=============事务================
/*fmt.Println("================开启事务============")
tx, err := DB.Begin()
if err != nil {
t.Log(err)
}
update_log, err := TxPreUpdate(tx, db_name, table_name, []string{"createtime= ?"}, []string{"id=?"}, []interface{}{time, 2})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===========事务执行:==================")
fmt.Println("===事务update:", update_log)
insert_log, err := TxPreInsert(tx, db_name, table_name, map[string]interface{}{"user_id": "1", "createtime": helper.ToStr(time)})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===事务insert:", insert_log)
del_log, err := TxDelete(tx, db_name, table_name, map[string]string{"id": "2"})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("====事务delete:", del_log)
err = tx.Commit()
if err != nil {
t.Log(err)
tx.Rollback()
}
fmt.Println("=======事务执行完成==========")*/
/*fmt.Println("====================执行事务trans============")
trans := NewTxQuery().Db(db_name)
info_trans, err := trans.Clean().Table(table_name).Where("id = ?").Value(5).Find()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("===事务Find_trans:", info_trans)
list_trans, err := trans.Clean().Table(table_name).Title("*").Select()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=========事务List_trans:", list_trans)
data := map[string]interface{}{
"user_id": 5,
"memo": "test",
"createtime": time,
}
add_trans, err := trans.Clean().Table(table_name).SaveData(data).CreateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("======事务Add_trans:", add_trans)
data["id"] = 15
update_trans_res, err := trans.Clean().Table(table_name).SaveData(data).UpdateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=======事务update_trans", update_trans_res)
err = trans.Commit()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("====================执行事务结束==================")*/
}

+ 130
- 0
common.go View File

@ -0,0 +1,130 @@
package dbquery
import (
"fmt"
"git.tetele.net/tgo/helper"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// ===================达梦兼容===============
// 非关键字可以不添加标识符,关键字须添加
// 日期函数的使用TO_CHAR(TO_DATE('1970-01-01','yyyy-mm-dd') + (createtime / 86400), 'yyyy-mm-dd')
// group_concat替换成LISTAGG
// ========================================
// 关键字替换-支持达梦和高斯
func ReplaeByOtherSql(sql, sql_type, action string) string {
sql_type_arr := []string{"DmSql", "PgsqlDb"}
if !helper.IsInStringArray(sql_type_arr, sql_type) {
log.Println("sql_type error", sql_type)
return ""
}
// PgsqlDb用
if action == "add" {
sql = helper.StringJoin(sql, " RETURNING id")
}
// 定义需要处理的关键字列表
keywords := []string{"user", "order", "group", "table", "view", "admin", "new"}
if sql_type == "PgsqlDb" {
keywords = []string{"user", "order", "group"}
}
// 移除所有反引号
sql = strings.ReplaceAll(sql, "`", "")
// 使用单词边界 \b 确保只匹配完整单词
pattern := `\b(` + strings.Join(keywords, "|") + `)\b`
re := regexp.MustCompile(pattern)
//设置保护词组
excludePhrases := []string{
"order by", "group by",
}
//保护排除短语
phraseMap := make(map[string]string)
for i, phrase := range excludePhrases {
placeholder := fmt.Sprintf("__EXCLUDE_%d__", i)
phraseMap[placeholder] = phrase
sql = strings.Replace(sql, phrase, placeholder, -1)
}
// 执行替换
sql = re.ReplaceAllStringFunc(sql, func(match string) string {
// 检查匹配是否在字符串常量中
if isInStringLiteral(sql, match) {
return match
}
if sql_type == "DmSql" {
return "`" + match + "`"
} else {
return `"` + match + `"`
}
})
// 恢复排除短语
for placeholder, phrase := range phraseMap {
sql = strings.Replace(sql, placeholder, phrase, -1)
}
return sql
}
// 检查匹配是否在字符串常量中
func isInStringLiteral(sql, match string) bool {
index := strings.Index(sql, match)
if index == -1 {
return false
}
// 检查匹配前的单引号数量
prefix := sql[:index]
singleQuotes := strings.Count(prefix, "'") - strings.Count(prefix, "\\'")
// 奇数表示在字符串常量中
return singleQuotes%2 != 0
}
// 字段值类型转换--针对达梦用
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case int, int8, int16, int32, int64:
return strconv.FormatInt(reflect.ValueOf(value).Int(), 10)
case uint, uint8, uint16, uint32, uint64:
return strconv.FormatUint(reflect.ValueOf(value).Uint(), 10)
case float32, float64:
return strconv.FormatFloat(reflect.ValueOf(value).Float(), 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
case time.Time:
return v.Format("2006-01-02 15:04:05")
default:
return fmt.Sprintf("%v", v)
}
}
// 字符串切片元素追加前缀
func addPrefixInField(slice []string, prefix string) []string {
new_slice := make([]string, len(slice))
for i, v := range slice {
new_slice[i] = prefix + v
}
return new_slice
}
func DmFieldDeal(fields string) string {
//移除所有反引号
title := strings.Replace(fields, "`", "", -1)
return title
}

+ 159
- 27
conn.go View File

@ -2,15 +2,20 @@ package dbquery
import (
"database/sql"
"fmt"
"log"
"net/url"
"errors"
"strings"
"time"
_ "dm"
"git.tetele.net/tgo/helper"
_ "gitee.com/opengauss/openGauss-connector-go-pq" // 高斯驱动(推荐)或 "github.com/lib/pq"
_ "github.com/go-sql-driver/mysql"
//_ "github.com/lib/pq" // 关键驱动导入
)
var DB *sql.DB
@ -21,39 +26,53 @@ var SLAVER_DB *sql.DB
var DB_PROVIDER string
func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("mysql database connectting...")
log.Println("database connectting...")
var dbConnErr error
if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != ""
const maxRetries = 10
for i := 0; i < 10; i++ {
DB, dbConnErr = sql.Open("mysql", DBUSER+":"+DBPWD+"@tcp("+DBHOST+":"+DBPORT+")/"+DBNAME+"?charset=utf8mb4")
if dbConnErr != nil {
log.Println("ERROR", "can not connect to Database, ", dbConnErr)
time.Sleep(time.Second * 5)
} else {
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("mysql DBconnection params errors")
}
DB.Ping()
dsn := DBUSER + ":" + DBPWD + "@tcp(" + DBHOST + ":" + DBPORT + ")/" + DBNAME + "?charset=utf8mb4"
log.Println("database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
break
}
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("mysql", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
} else {
return errors.New("db connection params errors")
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("mysql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return dbConnErr
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}
func CloseConn() error {
@ -106,7 +125,11 @@ func CloseSlaverConn() error {
func getTableName(dbName, table string, dbtype ...string) string {
var db_type string = "mysql"
if DB_PROVIDER == "PgsqlDb" {
dbName = ""
} else if DB_PROVIDER == "DmSql" {
dbName = ""
}
if len(dbtype) > 0 {
if dbtype[0] != "" {
db_type = dbtype[0]
@ -151,3 +174,112 @@ func GetDbTableName(dbName, table string) string {
func judg() []string {
return []string{"=", ">", "<", "!=", "<=", ">="}
}
// pgsql连接
func PgConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("pg database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "PgsqlDb"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("pgsql DBconnection params errors")
}
dsn := "host=" + DBHOST + " port=" + DBPORT + " user=" + DBUSER + " password=" + DBPWD + " dbname=" + DBNAME + " sslmode=disable search_path=public"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("opengauss", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("pgsql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}
// 达梦8连接
func DmConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("DM database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "DmSql"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("dm DBconnection params errors")
}
//达梦8如果密码存在特殊字符,需使用 url.PathEscape 进行转义后再放入连接串
DBPWD = url.PathEscape(DBPWD)
dsn := "dm://" + DBUSER + ":" + DBPWD + "@" + DBHOST + ":" + DBPORT + "?charSet=utf8&compatibleMode=mysql"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("dm", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("dm database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}

+ 248
- 58
db.go View File

@ -2,6 +2,7 @@ package dbquery
import (
"database/sql"
"github.com/jmoiron/sqlx"
"log"
"strconv"
@ -44,16 +45,33 @@ func Insert(dbName, table string, data map[string]string) (int64, error) {
valueList[i] = value
i++
}
result, err := DB.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = DB.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := DB.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
}
}
@ -67,7 +85,6 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
if dbName == "" && table == "" {
return rowsAffected, errors.New("没有数据表")
}
if strings.Contains(table, "select ") {
dbName = table
} else {
@ -109,7 +126,15 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data)
return rowsAffected, errors.New("条件中有空数据")
}
result, err := DB.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR|修改", dbName, "数据失败,", err)
@ -131,7 +156,6 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
if dbName == "" && table == "" {
return count, errors.New("没有数据表")
}
if strings.Contains(table, "select ") {
dbName = table
} else {
@ -167,7 +191,15 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
limitStr = " limit " + del_count[0]
}
result, err := DB.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR|删除", dbName, "数据失败,", err)
@ -191,7 +223,6 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" {
return count, info, errors.New("没有数据表")
}
dbName = getTableName(dbName, table)
if len(title) < 1 {
@ -210,7 +241,11 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if _, ok := limit["from"]; ok {
from = limit["from"]
}
limitStr += " limit " + from + ",1"
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit 1 OFFSET " + from
} else {
limitStr += " limit " + from + ",1"
}
} else {
limitStr = " limit 1"
@ -241,8 +276,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
var err error
var queryNum int = 0
for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题
rows, err = DB.Query("SELECT "+title+" FROM "+dbName+" where "+strings.Join(keyList, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(keyList, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil {
break
@ -274,8 +316,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if rowerr == nil {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
count++
@ -295,16 +344,38 @@ func GetData(dbName, table string, title string, where map[string]string, limit
* @param dbName 数据表名
* @param title 查询字段名
*/
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0
info := make(map[string]string)
if dbName == "" && table_name == "" {
return count, info, errors.New("没有数据表")
}
table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
@ -316,10 +387,11 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} else {
title = "*"
}
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
@ -328,17 +400,31 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}
if len(where) > 0 || len(where_or) > 0 {
@ -363,10 +449,10 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby)
}
if debug {
log.Println("query sql:", sql_str, valueList)
}
condition_len := 0 //所有条件数
for _, ch2 := range sql_str {
if string(ch2) == "?" {
@ -392,6 +478,12 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
}
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...)
@ -426,8 +518,15 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
if rowerr == nil {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
count++
@ -449,7 +548,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
* @param dbName 数据表名
* @param title 查询字段名
*/
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0
list := make([]map[string]string, 0)
@ -457,7 +556,30 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
return count, list, errors.New("没有数据表")
}
table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
@ -470,10 +592,11 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
} else {
title = "*"
}
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
@ -482,17 +605,31 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) >= 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}
@ -530,7 +667,11 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
from := strconv.Itoa((page - 1) * page_size)
offset := strconv.Itoa(page_size)
if from != "" && offset != "" {
sql_str = helper.StringJoin(sql_str, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql_str = helper.StringJoin(sql_str, " limit ", offset, " OFFSET ", from)
} else {
sql_str = helper.StringJoin(sql_str, " limit ", from, " , ", offset)
}
}
}
if debug {
@ -556,7 +697,12 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
var err error
var queryNum int = 0
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...)
if err == nil {
@ -592,8 +738,15 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
if rowerr == nil {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
count++
@ -634,7 +787,6 @@ func GetList(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" {
return list, errors.New("没有数据表")
}
if strings.Contains(table, "select ") {
dbName = table
} else {
@ -662,7 +814,12 @@ func GetList(dbName, table string, title string, where map[string]string, limit
from = limit["from"]
}
if offset != "0" && from != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
}
}
@ -696,8 +853,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
}
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select "+title+" from "+dbName+" where "+strings.Join(whereStr, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "select " + title + " from " + dbName + " where " + strings.Join(whereStr, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil {
break
@ -742,8 +906,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
}
}
list = append(list, record)
@ -761,7 +932,6 @@ func GetTotal(dbName, table string, args ...string) (total int) {
if dbName == "" && table == "" {
return
}
if strings.Contains(table, "select ") {
dbName = table
} else {
@ -779,7 +949,6 @@ func GetTotal(dbName, table string, args ...string) (total int) {
var queryNum int = 0
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1")
if err == nil {
@ -863,7 +1032,15 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count("+title+") as count from "+dbName+" where "+strings.Join(whereStr, " and ")+" limit 1", valueList...)
var Sql string
Sql = "select count(" + title + ") as count from " + dbName + " where " + strings.Join(whereStr, " and ") + " limit 1"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil {
break
@ -922,6 +1099,12 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题
if len(args) > 1 {
if DB_PROVIDER == "PgsqlDb" {
queryStr = sqlx.Rebind(sqlx.DOLLAR, queryStr)
queryStr = ReplaeByOtherSql(queryStr, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
queryStr = ReplaeByOtherSql(queryStr, "DmSql", "")
}
rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",")
if err != nil {
log.Println("ERROR|DoQuery error:", err)
@ -962,8 +1145,15 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
}
}
list = append(list, record)


+ 40
- 4
db_test.go View File

@ -6,7 +6,6 @@ import (
)
func Test_Connet(t *testing.T) {
//go func() {
for i := 0; i < 1; i++ {
dbhost := "localhost"
dbname := "shop"
@ -36,8 +35,8 @@ func Test_Connet(t *testing.T) {
orderby := "id desc"
debug := true
// count, row, err := GetRow(dbname, table, alias, title, join, where, where_or, valueList, orderby, debug)
count, row, err := FetchRows(dbname, table, alias, title, join, where, where_or, valueList, orderby, "", 1, 10, debug)
//count, row, err := GetRow(dbname, table, alias, title, join, where, where_or, valueList, orderby, debug)
count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count)
log.Println(row)
@ -46,6 +45,43 @@ func Test_Connet(t *testing.T) {
log.Println(err.Error())
}
}
//}()
}
func Test_Query(t *testing.T) {
token := "67c121aa-6e1c-011f-ebb6-976d855fd777"
dbhost := "192.168.233.134"
dbname := "canyin"
dbusername := "bin"
dbpassword := "Bin123456"
dbport := "5432"
table := "ttl_user_token"
err := PgConnect(dbhost, dbusername, dbpassword, dbname, dbport)
if err != nil {
log.Println(err.Error())
}
title := "user.*,ut.expiretime"
alias := "ut"
join := [][]string{}
join = append(join, []string{"ttl_user as user", "ut.user_id= user.id"})
where := []string{
"ut.token=?",
}
where_or := []string{}
valueList := []interface{}{
token,
}
orderby := ""
debug := true
count, row, err := GetRow(dbname, table, alias, title, [][]string{}, join, where, where_or, valueList, orderby, "", "", debug)
//count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count)
log.Println(row)
log.Println(err)
if err != nil {
log.Println(err.Error())
}
}

+ 16
- 2
go.mod View File

@ -1,9 +1,23 @@
module git.tetele.net/tgo/dbquery
go 1.14
go 1.23.0
toolchain go1.24.0
require (
git.tetele.net/tgo/helper v0.1.0
gitee.com/opengauss/openGauss-connector-go-pq v1.0.7
github.com/denisenkom/go-mssqldb v0.11.0
github.com/go-sql-driver/mysql v1.5.0
github.com/go-sql-driver/mysql v1.8.1
github.com/jmoiron/sqlx v1.4.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
)

+ 110
- 3
go.sum View File

@ -1,11 +1,118 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
git.tetele.net/tgo/helper v0.1.0 h1:ZdsBXUWX3+22ZzHTZRldBfBsQwu+CwUH8qScUvpgimE=
git.tetele.net/tgo/helper v0.1.0/go.mod h1:shYQE/hvMy3fOE8JXKGxvywOXiz3M5Nw4e+u7HR8+NY=
gitee.com/opengauss/openGauss-connector-go-pq v1.0.7 h1:plLidoldV5RfMU6i/I+tvRKtP3sfDyUzQ//HGXLLsZo=
gitee.com/opengauss/openGauss-connector-go-pq v1.0.7/go.mod h1:2UEp+ug6ls6C0pLfZgBn7VBzBntFUzxJuy+6FlQ7qyI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.11.0 h1:9rHa233rhdOyrz2GcP9NM+gi2psgJZ4GWDpL/7ND8HI=
github.com/denisenkom/go-mssqldb v0.11.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8=
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA=
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs=
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w=
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o=
github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho=
github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak=
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM=
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE=
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

+ 99
- 17
prepare.go View File

@ -3,6 +3,7 @@ package dbquery
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
"log"
"strings"
@ -41,8 +42,11 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
offset = limit["offset"]
}
if from != "" && offset != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
}
}
@ -52,7 +56,15 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
} else {
// log.Println("SELECT " + title + " FROM " + dbName + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr)
@ -75,8 +87,6 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
return nil, errors.New("缺少必要参数")
}
// log.Println(valuelist...)
rows, err := stmt.Query(valuelist...)
defer stmt.Close()
if err != nil {
@ -106,8 +116,15 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
if rowerr == nil {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
} else {
@ -154,8 +171,15 @@ func StmtForQueryRow(stmt *sql.Stmt, valuelist []interface{}) (map[string]string
if rowerr == nil {
for i, col := range values {
if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
}
}
} else {
@ -186,8 +210,15 @@ func StmtForUpdate(dbName, table string, data []string, where []string) (*sql.St
var stmt *sql.Stmt
var err error
stmt, err = DB.Prepare("update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and "))
var Sql string
Sql = "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
return stmt, err
}
@ -224,7 +255,41 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
var stmt *sql.Stmt
var err error
stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return nil, errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(data, " , "))
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
//stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
stmt, err = DB.Prepare(sql)
return stmt, err
}
@ -234,11 +299,23 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
* @return lastId error
*/
func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
if DB_PROVIDER == "PgsqlDb" {
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return res.LastInsertId()
}
return res.LastInsertId()
}
/**
@ -350,7 +427,12 @@ func StmtForQuery(querysql string) (*sql.Stmt, error) {
var stmt *sql.Stmt
var err error
if DB_PROVIDER == "PgsqlDb" {
querysql = sqlx.Rebind(sqlx.DOLLAR, querysql)
querysql = ReplaeByOtherSql(querysql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
querysql = ReplaeByOtherSql(querysql, "DmSql", "")
}
stmt, err = DB.Prepare(querysql)
return stmt, err


+ 99
- 27
transaction.go View File

@ -6,6 +6,8 @@ package dbquery
import (
"database/sql"
"errors"
"git.tetele.net/tgo/helper"
"github.com/jmoiron/sqlx"
"log"
"strings"
"time"
@ -43,16 +45,33 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
valueList[i] = value
i++
}
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
if DB_PROVIDER == "PgsqlDb" {
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
}
}
@ -84,28 +103,52 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{})
var field []string = make([]string, len(data))
var valuelist []interface{} = make([]interface{}, len(data))
insert_data := []string{}
value_data := []string{}
var i int = 0
for key, item := range data {
field[i] = key + "=?"
valuelist[i] = item
i++
}
insert_data = append(insert_data, key)
value_data = append(value_data, "?")
}
if DB_PROVIDER == "PgsqlDb" {
Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
stmt, err = tx.Prepare(sql)
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
if DB_PROVIDER == "DmSql" {
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql)
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
}
insertId, _ := result.LastInsertId()
return insertId, nil
}
insertId, _ := result.LastInsertId()
return insertId, nil
}
@ -160,7 +203,15 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma
log.Println("ERROR", "update", dbName, "error, params empty")
return rowsAffected, errors.New("params empty")
}
result, err := tx.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR", "update", dbName, "error:", err)
@ -198,7 +249,12 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
var stmt *sql.Stmt
sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql)
if err != nil {
@ -260,7 +316,15 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
limitStr = " limit " + del_count[0]
}
result, err := tx.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR", "delete from", dbName, "error:", err)
@ -296,7 +360,15 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) (
if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = tx.Prepare(Sql)
} else {
// log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE")


+ 237
- 13
transaction_chain.go View File

@ -6,6 +6,7 @@ package dbquery
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
"log"
"strconv"
"strings"
@ -34,6 +35,7 @@ type TxQuery struct {
conn *sql.DB
tx *sql.Tx
debug bool
with [][]string //[[临时表的sql语句,临时表的名称]]
}
func NewTxQuery(t ...string) *TxQuery {
@ -109,6 +111,14 @@ func (this *TxQuery) Where(where string) *TxQuery {
this.where = append(this.where, where)
return this
}
func (this *TxQuery) With(with []string) *TxQuery {
this.with = append(this.with, with)
return this
}
func (this *TxQuery) Withs(withs [][]string) *TxQuery {
this.with = append(this.with, withs...)
return this
}
func (this *TxQuery) Wheres(wheres []string) *TxQuery {
if len(wheres) > 0 {
this.where = append(this.where, wheres...)
@ -198,6 +208,7 @@ func (this *TxQuery) Clean() *TxQuery {
this.upd_field = this.upd_field[0:0]
this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this
}
@ -212,7 +223,30 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
return nil, errors.New("参数错误,没有数据表")
}
var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table
} else {
table = getTableName(this.dbname, this.table)
@ -225,7 +259,8 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
} else {
title = "*"
}
sql = helper.StringJoin("select ", title)
sql = helper.StringJoin(withSql, "select ", title)
if this.alias != "" {
table = helper.StringJoin(table, " as ", this.alias)
@ -234,15 +269,31 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 {
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
}
}
if len(this.where) > 0 || len(this.where_or) > 0 {
@ -281,7 +332,12 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
from := strconv.Itoa((this.page - 1) * this.page_size)
offset := strconv.Itoa(this.page_size)
if from != "" && offset != "" {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from)
} else {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
}
}
}
@ -298,6 +354,12 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
if condition_len != len(this.value) {
return nil, errors.New("参数错误,条件值错误")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
return map[string]interface{}{
"sql": sql,
"value": this.value,
@ -314,13 +376,50 @@ func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error)
"COLUMN_COMMENT", //备注
"IS_NULLABLE", //是否为空
}
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?"
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?"
if DB_PROVIDER == "PgsqlDb" {
//pgsql中,未加引号的标识符会被自动转换为小写
sql = `SELECT
c.column_name AS 'COLUMN_NAME',
c.column_default AS 'COLUMN_DEFAULT',
c.data_type AS 'DATA_TYPE',
c.udt_name AS 'COLUMN_TYPE',
pgdesc.description AS 'COLUMN_COMMENT',
c.is_nullable AS 'IS_NULLABLE'
FROM
information_schema.columns c
LEFT JOIN
pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name
LEFT JOIN
pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position
WHERE
c.table_name = ?`
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
} else if DB_PROVIDER == "DmSql" {
sql = `SELECT
COLUMN_NAME,
DATA_DEFAULT AS COLUMN_DEFAULT,
DATA_TYPE,
CASE
WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')'
WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')'
ELSE DATA_TYPE
END AS COLUMN_TYPE,
'' AS COLUMN_COMMENT,
CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE
FROM
ALL_TAB_COLUMNS
WHERE
TABLE_NAME = ?`
}
stmtSql, err := this.tx.Prepare(sql)
if err != nil {
return nil, err
}
list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname})
list, err := StmtForQueryList(stmtSql, []interface{}{table})
if err != nil {
return nil, err
}
@ -421,7 +520,12 @@ func (this *TxQuery) UpdateStmt() error {
if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "MysqlDb" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
@ -446,8 +550,11 @@ func (this *TxQuery) UpdateAllStmt() error {
var dataSql []string //一组用到的占位字符
var valSql []string //占位字符组
var updSql []string //更新字段的sql
var updSql_dm []string //更新字段的sql--达梦和高斯用
var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值
dataLen := len(this.save_data)
if dataLen > 0 {
//批量操作
this.data = this.data[0:0]
@ -465,11 +572,14 @@ func (this *TxQuery) UpdateAllStmt() error {
case 0:
//预览创建数据的长度
updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default:
//按照需要更新字段数长度
updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
}
}
for k := range this.save_data[i] {
@ -477,6 +587,7 @@ func (this *TxQuery) UpdateAllStmt() error {
dataSql = append(dataSql, "?") //存储需要的占位符
if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
}
}
dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式
@ -494,21 +605,26 @@ func (this *TxQuery) UpdateAllStmt() error {
switch updFieldLen {
case 0:
updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default:
updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
}
}
for i := 0; i < fieldLen; i++ {
dataSql = append(dataSql, "?")
if updFieldLen == 0 && this.data[i] != "id" {
updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")")
updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
}
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
@ -524,8 +640,44 @@ func (this *TxQuery) UpdateAllStmt() error {
if len(valSql) > 1 {
setText = " value "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , "))
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
sql = `merge into ` + dbName + ` as t
using (
` + setText + strings.Join(valSql, ",") + `
) s (` + strings.Join(this.data, " , ") + `)
on (t.id = s.id)
when matched then
update set
` + strings.Join(updSql_dm, " , ") + `
when NOT matched then
insert (` + strings.Join(this.data, " , ") + `)
values (` + strings.Join(val_field, " , ") + `)`
} else if DB_PROVIDER == "DmSql" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
title_field := addPrefixInField(this.data, "? AS ")
sql = `MERGE INTO ` + dbName + ` AS t
USING (
SELECT
` + strings.Join(title_field, " , ") + `
FROM DUAL
) s
ON (t.id = s.id)
WHEN MATCHED THEN
UPDATE SET
` + strings.Join(updSql_dm, " , ") + `
WHEN NOT MATCHED THEN
INSERT (` + strings.Join(this.data, " , ") + `)
VALUES (` + strings.Join(val_field, " , ") + `)`
}
if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value)
}
@ -539,6 +691,13 @@ func (this *TxQuery) UpdateAllStmt() error {
if conditionLen != len(this.value) {
return errors.New("参数错误,条件值数量不匹配")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
@ -560,8 +719,32 @@ func (this *TxQuery) CreateStmt() error {
dbName := getTableName(this.dbname, this.table)
var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range this.data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
//sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
if this.debug {
log.Println("insert sql:", sql, this.value)
@ -577,7 +760,12 @@ func (this *TxQuery) CreateStmt() error {
if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
@ -648,6 +836,11 @@ func (this *TxQuery) CreateAllStmt() error {
if len(valSql) > 1 {
setText = " value "
}
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
} else if DB_PROVIDER == "DmSql" {
setText = " values "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","))
if len(this.value) == 0 {
return errors.New("参数错误,条件值错误")
@ -666,6 +859,12 @@ func (this *TxQuery) CreateAllStmt() error {
if conditionLen != len(this.value) {
return errors.New("参数错误,条件值数量不匹配")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "add")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql)
@ -712,7 +911,12 @@ func (this *TxQuery) DeleteStmt() error {
if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误")
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
@ -827,6 +1031,26 @@ func (this *TxQuery) CreateAll() (int64, error) {
return StmtForInsertExec(this.stmt, this.value)
}
/**
* 执行原生sql
* return error
*/
func (this *TxQuery) ExecSql(sql string) (int64, error) {
if this.debug {
log.Println("ExecSql sql:", sql)
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
return 0, err
}
res, err := stmt.Exec()
if err != nil {
return 0, errors.New("执行失败:" + err.Error())
}
return res.RowsAffected()
}
/**
* 提交
*/


Loading…
Cancel
Save