32 Commits

Author SHA1 Message Date
  lijianbin f5ad88e70c 优化达梦关键字问题 1 day ago
  lijianbin c5dab75bff BuildSelectSql子查询兼容pgsql 2 weeks ago
  zhenghaorong 3ea2b59dda 调整接口 2 weeks ago
  lijianbin 45bcc98850 达梦--密码特殊字符需转义 3 weeks ago
  lijianbin a9383ba486 修正输出描述 3 weeks ago
  lijianbin fe779bf6ca 兼容高斯和达梦数据库 3 weeks ago
  lijianbin f7b0eb2df6 完善pgsql关键替换问题 1 month ago
  lijianbin 74aa5112a0 完善pgsql关键替换问题 1 month ago
  lijianbin 35782df825 pgsql关键字问题 1 month ago
  lijianbin 3b11f98fdc 修复pgsql中关键字使用问题 1 month ago
  lijianbin 60c84df410 修复pgsql中关键字使用问题 1 month ago
  loshiqi 01f3625f1e 兼容高斯 2 months ago
  loshiqi d16c3ca83c 兼容高斯 2 months ago
  loshiqi 16cff9c0ca 兼容高斯 2 months ago
  loshiqi bd754d1507 兼容高斯 2 months ago
  loshiqi 95a092325e values 2 months ago
  loshiqi 9bd66eefc9 插入和分页查询 2 months ago
  loshiqi 2a9f596ea8 兼容pgsql 2 months ago
  loshiqi ff5d48d51a 占位符 2 months ago
  loshiqi 82954508e0 驱动名称 更改 2 months ago
  loshiqi 2f467a0f92 增加pgsql链接 2 months ago
  loshiqi 14f6e5fc28 增加执行原始方法 3 months ago
  zhenghaorong c13da09c57 增加过滤 11 months ago
  zhenghaorong 7f9e42fed6 修复join无法使用with临时表问题 11 months ago
  zhenghaorong 2d94f24f43 修复join无法使用with临时表问题 11 months ago
  zhenghaorong bc664f29c5 修复join无法使用with临时表问题 11 months ago
  zhenghaorong 6d7835939e 增加with语句 11 months ago
  loshiqi 5db8b87a21 transaction bug修复 1 year ago
  guzeng 6cc83234ed Merge branch 'master' of ssh://git.tetele.net:4001/tgo/dbquery 2 years ago
  guzeng 628562e53e 增加调试 2 years ago
  zhenghaorong 396b0b8f86 增加生成表名的函数 2 years ago
  guzeng f8d6d88e9f 增加左/右连接方法 2 years ago
11 changed files with 1483 additions and 191 deletions
Unified View
  1. +261
    -20
      chain.go
  2. +159
    -7
      chain_test.go
  3. +132
    -0
      common.go
  4. +159
    -27
      conn.go
  5. +250
    -58
      db.go
  6. +40
    -4
      db_test.go
  7. +16
    -2
      go.mod
  8. +0
    -11
      go.sum
  9. +99
    -17
      prepare.go
  10. +102
    -30
      transaction.go
  11. +265
    -15
      transaction_chain.go

+ 261
- 20
chain.go View File

@ -3,6 +3,7 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/jmoiron/sqlx"
"log" "log"
"strconv" "strconv"
"strings" "strings"
@ -34,6 +35,7 @@ type Query struct {
conn *sql.DB conn *sql.DB
debug bool debug bool
dbtype string dbtype string
with [][]string //[[临时表的sql语句,临时表的名称]]
} }
func NewQuery(t ...string) *Query { func NewQuery(t ...string) *Query {
@ -67,6 +69,11 @@ func (this *Query) Conn(conn *sql.DB) *Query {
} }
func (this *Query) Db(dbname string) *Query { func (this *Query) Db(dbname string) *Query {
this.dbname = dbname this.dbname = dbname
if DB_PROVIDER == "PgsqlDb" {
this.dbname = ""
} else if DB_PROVIDER == "DmSql" {
this.dbname = ""
}
return this return this
} }
@ -103,6 +110,14 @@ func (this *Query) Groupby(groupby string) *Query {
this.groupby = groupby this.groupby = groupby
return this return this
} }
func (this *Query) With(with []string) *Query {
this.with = append(this.with, with)
return this
}
func (this *Query) Withs(withs [][]string) *Query {
this.with = append(this.with, withs...)
return this
}
func (this *Query) Where(where string) *Query { func (this *Query) Where(where string) *Query {
this.where = append(this.where, where) this.where = append(this.where, where)
return this return this
@ -145,6 +160,27 @@ func (this *Query) Join(join []string) *Query {
this.join = append(this.join, join) this.join = append(this.join, join)
return this return this
} }
/**
* 左连接
* 2023/08/10
* gz
*/
func (this *Query) LeftJoin(table_name string, condition string) *Query {
this.join = append(this.join, []string{table_name, condition, "left"})
return this
}
/**
* 右连接
* 2023/08/10
* gz
*/
func (this *Query) RightJoin(table_name string, condition string) *Query {
this.join = append(this.join, []string{table_name, condition, "right"})
return this
}
func (this *Query) Data(data string) *Query { func (this *Query) Data(data string) *Query {
this.data = append(this.data, data) this.data = append(this.data, data)
return this return this
@ -176,6 +212,8 @@ func (this *Query) Clean() *Query {
this.save_data = this.save_data[0:0] this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0] this.upd_field = this.upd_field[0:0]
this.having = "" this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this return this
} }
@ -189,15 +227,55 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
"COLUMN_COMMENT", //备注 "COLUMN_COMMENT", //备注
"IS_NULLABLE", //是否为空 "IS_NULLABLE", //是否为空
} }
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?"
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?"
if DB_PROVIDER == "PgsqlDb" {
//pgsql中,未加引号的标识符会被自动转换为小写
sql = `SELECT
c.column_name AS 'COLUMN_NAME',
c.column_default AS 'COLUMN_DEFAULT',
c.data_type AS 'DATA_TYPE',
c.udt_name AS 'COLUMN_TYPE',
pgdesc.description AS 'COLUMN_COMMENT',
c.is_nullable AS 'IS_NULLABLE'
FROM
information_schema.columns c
LEFT JOIN
pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name
LEFT JOIN
pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position
WHERE
c.table_name = ?`
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
} else if DB_PROVIDER == "DmSql" {
sql = `SELECT
COLUMN_NAME,
DATA_DEFAULT AS COLUMN_DEFAULT,
DATA_TYPE,
CASE
WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')'
WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')'
ELSE DATA_TYPE
END AS COLUMN_TYPE,
'' AS COLUMN_COMMENT,
CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE
FROM
ALL_TAB_COLUMNS
WHERE
TABLE_NAME = ?`
}
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
stmtSql, err := this.conn.Prepare(sql) stmtSql, err := this.conn.Prepare(sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname})
list, err := StmtForQueryList(stmtSql, []interface{}{table})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -214,6 +292,10 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
} }
for _, k := range field { for _, k := range field {
index := helper.StrFirstToUpper(k) index := helper.StrFirstToUpper(k)
if DB_PROVIDER == "DmSql" {
index = helper.StrFirstToUpper(strings.ToLower(k))
}
if v, ok := item[index]; ok { if v, ok := item[index]; ok {
switch k { switch k {
case "COLUMN_NAME": case "COLUMN_NAME":
@ -245,13 +327,41 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
}, nil }, nil
} }
// 返回表名
func (this *Query) GetTableName(table string) string {
return getTableName(this.dbname, table)
}
// 构造子查询 // 构造子查询
func (this *Query) BuildSelectSql() (map[string]interface{}, error) { func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if this.dbname == "" && this.table == "" { if this.dbname == "" && this.table == "" {
return nil, errors.New("参数错误,没有数据表") return nil, errors.New("参数错误,没有数据表")
} }
var table = "" var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table table = this.table
} else { } else {
table = getTableName(this.dbname, this.table, this.dbtype) table = getTableName(this.dbname, this.table, this.dbtype)
@ -269,15 +379,15 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if this.dbtype == "mssql" { if this.dbtype == "mssql" {
if this.page_size > 0 { if this.page_size > 0 {
sql = helper.StringJoin("select top ", helper.ToStr(this.page_size), " ")
sql = helper.StringJoin(withSql, "select top ", helper.ToStr(this.page_size), " ")
} else { } else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
} }
} else { } else {
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql = "/*slave*/ select "
sql = helper.StringJoin("/*slave*/ ", withSql, " select ")
} else { } else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
} }
} }
@ -290,15 +400,31 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table) sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 { if len(this.join) > 0 {
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join { for _, joinitem := range this.join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 3 {
sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
} else { //默认左连接
sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
} }
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
} }
} }
if len(this.where) > 0 || len(this.where_or) > 0 { if len(this.where) > 0 || len(this.where_or) > 0 {
@ -337,12 +463,18 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
from := strconv.Itoa((this.page - 1) * this.page_size) from := strconv.Itoa((this.page - 1) * this.page_size)
offset := strconv.Itoa(this.page_size) offset := strconv.Itoa(this.page_size)
if from != "" && offset != "" { if from != "" && offset != "" {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from)
} else {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
}
} }
} }
if this.debug { if this.debug {
log.Println("query sql:", sql, this.value) log.Println("query sql:", sql, this.value)
} }
condition_len := 0 //所有条件数 condition_len := 0 //所有条件数
for _, ch2 := range sql { for _, ch2 := range sql {
if string(ch2) == "?" { if string(ch2) == "?" {
@ -352,6 +484,7 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if condition_len != len(this.value) { if condition_len != len(this.value) {
return nil, errors.New("参数错误,条件值错误") return nil, errors.New("参数错误,条件值错误")
} }
return map[string]interface{}{ return map[string]interface{}{
"sql": sql, "sql": sql,
"value": this.value, "value": this.value,
@ -377,6 +510,13 @@ func (this *Query) QueryStmt() error {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -423,7 +563,12 @@ func (this *Query) UpdateStmt() error {
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -448,8 +593,11 @@ func (this *Query) UpdateAllStmt() error {
var dataSql []string //一组用到的占位字符 var dataSql []string //一组用到的占位字符
var valSql []string //占位字符组 var valSql []string //占位字符组
var updSql []string //更新字段的sql var updSql []string //更新字段的sql
var updSql_dm []string //更新字段的sql--达梦和高斯用
var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值 var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值
dataLen := len(this.save_data) dataLen := len(this.save_data)
if dataLen > 0 { if dataLen > 0 {
//批量操作 //批量操作
this.data = this.data[0:0] this.data = this.data[0:0]
@ -467,11 +615,14 @@ func (this *Query) UpdateAllStmt() error {
case 0: case 0:
//预览创建数据的长度 //预览创建数据的长度
updSql = make([]string, 0, fieldLen) updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default: default:
//按照需要更新字段数长度 //按照需要更新字段数长度
updSql = make([]string, 0, updFieldLen) updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
} }
} }
for k := range this.save_data[i] { for k := range this.save_data[i] {
@ -479,6 +630,7 @@ func (this *Query) UpdateAllStmt() error {
dataSql = append(dataSql, "?") //存储需要的占位符 dataSql = append(dataSql, "?") //存储需要的占位符
if updFieldLen == 0 && k != "id" { if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
} }
} }
dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式 dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式
@ -496,21 +648,26 @@ func (this *Query) UpdateAllStmt() error {
switch updFieldLen { switch updFieldLen {
case 0: case 0:
updSql = make([]string, 0, fieldLen) updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default: default:
updSql = make([]string, 0, updFieldLen) updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
} }
} }
for i := 0; i < fieldLen; i++ { for i := 0; i < fieldLen; i++ {
dataSql = append(dataSql, "?") dataSql = append(dataSql, "?")
if updFieldLen == 0 && this.data[i] != "id" { if updFieldLen == 0 && this.data[i] != "id" {
updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")") updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")")
updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"")
} }
} }
if updFieldLen > 0 { if updFieldLen > 0 {
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
} }
} }
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
@ -527,8 +684,44 @@ func (this *Query) UpdateAllStmt() error {
if len(valSql) > 1 { if len(valSql) > 1 {
setText = " value " setText = " value "
} }
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , "))
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
sql = `merge into ` + dbName + ` as t
using (
` + setText + strings.Join(valSql, ",") + `
) s (` + strings.Join(this.data, " , ") + `)
on (t.id = s.id)
when matched then
update set
` + strings.Join(updSql_dm, " , ") + `
when NOT matched then
insert (` + strings.Join(this.data, " , ") + `)
values (` + strings.Join(val_field, " , ") + `)`
} else if DB_PROVIDER == "DmSql" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
title_field := addPrefixInField(this.data, "? AS ")
sql = `MERGE INTO ` + dbName + ` AS t
USING (
SELECT
` + strings.Join(title_field, " , ") + `
FROM DUAL
) s
ON (t.id = s.id)
WHEN MATCHED THEN
UPDATE SET
` + strings.Join(updSql_dm, " , ") + `
WHEN NOT MATCHED THEN
INSERT (` + strings.Join(this.data, " , ") + `)
VALUES (` + strings.Join(val_field, " , ") + `)`
}
if this.debug { if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value) log.Println("insert on duplicate key update sql:", sql, this.value)
} }
@ -545,7 +738,12 @@ func (this *Query) UpdateAllStmt() error {
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -616,6 +814,11 @@ func (this *Query) CreateAllStmt() error {
if len(valSql) > 1 { if len(valSql) > 1 {
setText = " value " setText = " value "
} }
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
} else if DB_PROVIDER == "DmSql" {
setText = " values "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","))
if this.debug { if this.debug {
@ -634,7 +837,12 @@ func (this *Query) CreateAllStmt() error {
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "add")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -656,8 +864,31 @@ func (this *Query) CreateStmt() error {
dbName := getTableName(this.dbname, this.table, this.dbtype) dbName := getTableName(this.dbname, this.table, this.dbtype)
var sql string var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range this.data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
if this.debug { if this.debug {
log.Println("insert sql:", sql, this.value) log.Println("insert sql:", sql, this.value)
@ -676,7 +907,12 @@ func (this *Query) CreateStmt() error {
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -725,7 +961,12 @@ func (this *Query) DeleteStmt() error {
if this.conn == nil { if this.conn == nil {
this.conn = DB this.conn = DB
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.conn.Prepare(sql) stmt, err = this.conn.Prepare(sql)
if err != nil { if err != nil {
@ -743,7 +984,7 @@ func (this *Query) DeleteStmt() error {
*/ */
func (this *Query) Select() ([]map[string]string, error) { func (this *Query) Select() ([]map[string]string, error) {
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join,
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug) this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug)
return rows, err return rows, err
@ -774,7 +1015,7 @@ func (this *Query) List() ([]map[string]string, error) {
*/ */
func (this *Query) Find() (map[string]string, error) { func (this *Query) Find() (map[string]string, error) {
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join,
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug) this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug)
return row, err return row, err


+ 159
- 7
chain_test.go View File

@ -1,17 +1,169 @@
package dbquery package dbquery
import ( import (
"fmt"
"testing" "testing"
) )
// 测试各数据库下各种情况
func Test_Chain(t *testing.T) { func Test_Chain(t *testing.T) {
Connect("127.0.0.1", "root", "123456", "shop", "3306")
//测试数据库连接
//err := Connect("127.0.0.1", "root", "root", "canyin", "3306")
//err := PgConnect("192.168.233.157", "bin", "Bin123456", "canyin", "5432")
//err := DmConnect("192.168.233.148", "SHOPV2", "Bin123456", "", "5236")
err := DmConnect("10.33.0.91", "ZYSG", "Zysg!#2025", "", "5236")
ret, err := new(Query).Db("shop").Table("ttl_order_product").Alias("op").
Join([]string{"ttl_product as p", "op.product_id=p.id"}).
Title("op.id,op.order_price,p.thumb").WhereOr("op.id =?").WhereOr("op.id = ?").Value(63).Value(64).Debug(true).List()
if err != nil {
t.Log(err)
}
db_name := ""
table_name := "ttl_user_log"
//time := time.Now().Unix()
t.Log(len(ret))
t.Log(ret)
t.Log(err)
//================查询表结构===========
ret, err := new(Query).Db(db_name).GetTableInfo(table_name)
if err != nil {
t.Log(err)
}
fmt.Println("===GetTableInfo:", ret)
//==========获取信息=================
/*query := new(Query).Db(db_name).Clean().Table("ttl_dorm_goods_reserve").Alias("a").
Join([]string{"ttl_dorm_goods_reserve_detail b", "a.id=b.reserve_id", "left"}).
Join([]string{"ttl_dorm_goods c", "c.id=b.goods_id", "left"}).
Join([]string{"ttl_dorm_room d", "d.id=a.room_id", "left"}).
Join([]string{"dorm_room_item e", "e.id=a.room_item_id", "left"}).
Where("a.user_id =?").Value(6006)
info, err := query.Groupby("a.id").Title("a.id").BuildSelectSql()*/
//info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("user").Join([]string{"ttl_user u", "u.id = user.user_id", "inner"}).Where("user.id=?").Value("3").Title("user.id,user.user_id,u.nickname").Find()
//info, err := GetDataByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, nil)
/*if err != nil {
t.Log(err)
}
fmt.Println("===Find:", info)*/
//============获取列表==================
list, err := new(Query).Db("").Table("ttl_area").
Title("`first`,id,level,mergename,name,pid,shortname").
Select()
if err != nil {
t.Log(err)
}
fmt.Println("===List:", list)
//===========添加数据============
//insert_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("1").Data("createtime=?").Value(time).Create()
//insert_res, err := InsertByStmt(db_name, table_name, []string{"user_id=?", "createtime=?"}, []interface{}{"1", time})
//insert_res, err := Insert(db_name, table_name, map[string]string{"user_id": "1", "createtime": helper.ToStr(time)})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Insert:", insert_res)
//================更新数据=====================
//update_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("2").Data("createtime=?").Value(time).Where("id=?").Value("6").Update()
//update_res, err := UpdateByStmt(db_name, table_name, []string{"createtime=?", "user_id=?"}, []string{"id=?"}, []interface{}{time, 3, 6})
//if err != nil {
// t.Log(err)
//}
//fmt.Println("===Update:", update_res)
//=============事务================
/*fmt.Println("================开启事务============")
tx, err := DB.Begin()
if err != nil {
t.Log(err)
}
update_log, err := TxPreUpdate(tx, db_name, table_name, []string{"createtime= ?"}, []string{"id=?"}, []interface{}{time, 2})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===========事务执行:==================")
fmt.Println("===事务update:", update_log)
insert_log, err := TxPreInsert(tx, db_name, table_name, map[string]interface{}{"user_id": "1", "createtime": helper.ToStr(time)})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("===事务insert:", insert_log)
del_log, err := TxDelete(tx, db_name, table_name, map[string]string{"id": "2"})
if err != nil {
tx.Rollback()
t.Log(err)
}
fmt.Println("====事务delete:", del_log)
err = tx.Commit()
if err != nil {
t.Log(err)
tx.Rollback()
}
fmt.Println("=======事务执行完成==========")*/
/*trans := NewTxQuery().Db(db_name)
update_trans_res, err := trans.Clean().Table("ttl_dorm_check_in_apply").SaveData(map[string]interface{}{
"check_out_id": "48",
"id": "21",
"updatetime": time.Now().Unix(),
}).UpdateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=======事务update_trans", update_trans_res)
err = trans.Commit()
if err != nil {
trans.Rollback()
t.Log(err)
}*/
/*fmt.Println("====================执行事务trans============")
trans := NewTxQuery().Db(db_name)
info_trans, err := trans.Clean().Table(table_name).Where("id = ?").Value(5).Find()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("===事务Find_trans:", info_trans)
list_trans, err := trans.Clean().Table(table_name).Title("*").Select()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=========事务List_trans:", list_trans)
data := map[string]interface{}{
"user_id": 5,
"memo": "test",
"createtime": time,
}
add_trans, err := trans.Clean().Table(table_name).SaveData(data).CreateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("======事务Add_trans:", add_trans)
data["id"] = 15
update_trans_res, err := trans.Clean().Table(table_name).SaveData(data).UpdateAll()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("=======事务update_trans", update_trans_res)
err = trans.Commit()
if err != nil {
trans.Rollback()
t.Log(err)
}
fmt.Println("====================执行事务结束==================")*/
} }

+ 132
- 0
common.go View File

@ -0,0 +1,132 @@
package dbquery
import (
"fmt"
"git.tetele.net/tgo/helper"
"log"
"reflect"
"regexp"
"strconv"
"strings"
"time"
)
// ===================达梦兼容===============
// 非关键字可以不添加标识符,关键字须添加
// 日期函数的使用TO_CHAR(TO_DATE('1970-01-01','yyyy-mm-dd') + (createtime / 86400), 'yyyy-mm-dd')
// group_concat替换成LISTAGG
// ========================================
// 关键字替换-支持达梦和高斯
func ReplaeByOtherSql(sql, sql_type, action string) string {
sql_type_arr := []string{"DmSql", "PgsqlDb"}
if !helper.IsInStringArray(sql_type_arr, sql_type) {
log.Println("sql_type error", sql_type)
return ""
}
// PgsqlDb用
if action == "add" {
sql = helper.StringJoin(sql, " RETURNING id")
}
// 定义需要处理的关键字列表
keywords := []string{"user", "order", "group", "table", "view", "admin", "new"}
//设置保护词组
excludePhrases := []string{
"order by", "group by", "GROUP BY", "ORDER BY", "WITHIN GROUP", "within group",
}
if sql_type == "PgsqlDb" {
keywords = []string{"user", "order", "group"}
excludePhrases = []string{
"order by", "group by", "GROUP BY", "ORDER BY",
}
// 移除所有反引号
sql = strings.ReplaceAll(sql, "`", "")
}
// 使用单词边界 \b 确保只匹配完整单词
pattern := `\b(` + strings.Join(keywords, "|") + `)\b`
re := regexp.MustCompile(pattern)
//保护排除短语
phraseMap := make(map[string]string)
for i, phrase := range excludePhrases {
placeholder := fmt.Sprintf("__EXCLUDE_%d__", i)
phraseMap[placeholder] = phrase
sql = strings.Replace(sql, phrase, placeholder, -1)
}
// 执行替换
sql = re.ReplaceAllStringFunc(sql, func(match string) string {
// 检查匹配是否在字符串常量中
if isInStringLiteral(sql, match) {
return match
}
if sql_type == "DmSql" {
return "`" + match + "`"
} else {
return `"` + match + `"`
}
})
// 恢复排除短语
for placeholder, phrase := range phraseMap {
sql = strings.Replace(sql, placeholder, phrase, -1)
}
return sql
}
// 检查匹配是否在字符串常量中
func isInStringLiteral(sql, match string) bool {
index := strings.Index(sql, match)
if index == -1 {
return false
}
// 检查匹配前的单引号数量
prefix := sql[:index]
singleQuotes := strings.Count(prefix, "'") - strings.Count(prefix, "\\'")
// 奇数表示在字符串常量中
return singleQuotes%2 != 0
}
// 字段值类型转换--针对达梦用
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
case int, int8, int16, int32, int64:
return strconv.FormatInt(reflect.ValueOf(value).Int(), 10)
case uint, uint8, uint16, uint32, uint64:
return strconv.FormatUint(reflect.ValueOf(value).Uint(), 10)
case float32, float64:
return strconv.FormatFloat(reflect.ValueOf(value).Float(), 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
case time.Time:
return v.Format("2006-01-02 15:04:05")
default:
return fmt.Sprintf("%v", v)
}
}
// 字符串切片元素追加前缀
func addPrefixInField(slice []string, prefix string) []string {
new_slice := make([]string, len(slice))
for i, v := range slice {
new_slice[i] = prefix + v
}
return new_slice
}
func DmFieldDeal(fields string) string {
//移除所有反引号
title := strings.Replace(fields, "`", "", -1)
return title
}

+ 159
- 27
conn.go View File

@ -2,15 +2,20 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"fmt"
"log" "log"
"net/url"
"errors" "errors"
"strings" "strings"
"time" "time"
_ "dm"
"git.tetele.net/tgo/helper" "git.tetele.net/tgo/helper"
_ "gitee.com/opengauss/openGauss-connector-go-pq" // 高斯驱动(推荐)或 "github.com/lib/pq"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
//_ "github.com/lib/pq" // 关键驱动导入
) )
var DB *sql.DB var DB *sql.DB
@ -21,39 +26,53 @@ var SLAVER_DB *sql.DB
var DB_PROVIDER string var DB_PROVIDER string
func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error { func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("mysql database connectting...")
log.Println("database connectting...")
var dbConnErr error var dbConnErr error
if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != ""
const maxRetries = 10
for i := 0; i < 10; i++ {
DB, dbConnErr = sql.Open("mysql", DBUSER+":"+DBPWD+"@tcp("+DBHOST+":"+DBPORT+")/"+DBNAME+"?charset=utf8mb4")
if dbConnErr != nil {
log.Println("ERROR", "can not connect to Database, ", dbConnErr)
time.Sleep(time.Second * 5)
} else {
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("mysql DBconnection params errors")
}
DB.Ping()
dsn := DBUSER + ":" + DBPWD + "@tcp(" + DBHOST + ":" + DBPORT + ")/" + DBNAME + "?charset=utf8mb4"
log.Println("database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
break
}
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("mysql", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
} }
} else {
return errors.New("db connection params errors")
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("mysql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
} }
return dbConnErr
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
} }
func CloseConn() error { func CloseConn() error {
@ -106,7 +125,11 @@ func CloseSlaverConn() error {
func getTableName(dbName, table string, dbtype ...string) string { func getTableName(dbName, table string, dbtype ...string) string {
var db_type string = "mysql" var db_type string = "mysql"
if DB_PROVIDER == "PgsqlDb" {
dbName = ""
} else if DB_PROVIDER == "DmSql" {
dbName = ""
}
if len(dbtype) > 0 { if len(dbtype) > 0 {
if dbtype[0] != "" { if dbtype[0] != "" {
db_type = dbtype[0] db_type = dbtype[0]
@ -151,3 +174,112 @@ func GetDbTableName(dbName, table string) string {
func judg() []string { func judg() []string {
return []string{"=", ">", "<", "!=", "<=", ">="} return []string{"=", ">", "<", "!=", "<=", ">="}
} }
// pgsql连接
func PgConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("pg database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "PgsqlDb"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("pgsql DBconnection params errors")
}
dsn := "host=" + DBHOST + " port=" + DBPORT + " user=" + DBUSER + " password=" + DBPWD + " dbname=" + DBNAME + " sslmode=disable search_path=public"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("opengauss", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("pgsql database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}
// 达梦8连接
func DmConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("DM database connectting...")
var dbConnErr error
const maxRetries = 10
DB_PROVIDER = "DmSql"
if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" {
return errors.New("dm DBconnection params errors")
}
//达梦8如果密码存在特殊字符,需使用 url.PathEscape 进行转义后再放入连接串
DBPWD = url.PathEscape(DBPWD)
dsn := "dm://" + DBUSER + ":" + DBPWD + "@" + DBHOST + ":" + DBPORT + "?charSet=utf8&compatibleMode=mysql"
log.Println("database dsn", dsn)
for i := 0; i < maxRetries; i++ {
//Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题)
// 每次尝试创建新连接对象
DB, dbConnErr = sql.Open("dm", dsn)
if dbConnErr != nil {
log.Println("sql open failed:", i+1, dbConnErr)
time.Sleep(time.Second * 5)
continue
}
// 验证连接有效性
if pingErr := DB.Ping(); pingErr != nil {
DB.Close()
// 记录真实错误信息
dbConnErr = pingErr
log.Println("ping failed:", i+1, pingErr)
time.Sleep(time.Second * 5)
continue
}
if len(conns) > 0 {
DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制
} else {
DB.SetMaxOpenConns(200) //默认值为0表示不限制
}
if len(conns) > 1 {
DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数
} else {
DB.SetMaxIdleConns(50)
}
log.Println("dm database connected")
DB.SetConnMaxLifetime(time.Minute * 2)
return nil
}
return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr)
}

+ 250
- 58
db.go View File

@ -2,6 +2,7 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"github.com/jmoiron/sqlx"
"log" "log"
"strconv" "strconv"
@ -44,16 +45,33 @@ func Insert(dbName, table string, data map[string]string) (int64, error) {
valueList[i] = value valueList[i] = value
i++ i++
} }
result, err := DB.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = DB.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else { } else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := DB.Exec(Sql, valueList...)
if err != nil {
log.Println("ERROR|插入", dbName, "数据失败,", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
} }
} }
@ -67,7 +85,6 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
if dbName == "" && table == "" { if dbName == "" && table == "" {
return rowsAffected, errors.New("没有数据表") return rowsAffected, errors.New("没有数据表")
} }
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
@ -109,7 +126,15 @@ func Update(dbName, table string, data map[string]string, where map[string]strin
log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data) log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data)
return rowsAffected, errors.New("条件中有空数据") return rowsAffected, errors.New("条件中有空数据")
} }
result, err := DB.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR|修改", dbName, "数据失败,", err) log.Println("ERROR|修改", dbName, "数据失败,", err)
@ -131,7 +156,6 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
if dbName == "" && table == "" { if dbName == "" && table == "" {
return count, errors.New("没有数据表") return count, errors.New("没有数据表")
} }
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
@ -167,7 +191,15 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) (
limitStr = " limit " + del_count[0] limitStr = " limit " + del_count[0]
} }
result, err := DB.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := DB.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR|删除", dbName, "数据失败,", err) log.Println("ERROR|删除", dbName, "数据失败,", err)
@ -191,7 +223,6 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" { if dbName == "" && table == "" {
return count, info, errors.New("没有数据表") return count, info, errors.New("没有数据表")
} }
dbName = getTableName(dbName, table) dbName = getTableName(dbName, table)
if len(title) < 1 { if len(title) < 1 {
@ -210,7 +241,11 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if _, ok := limit["from"]; ok { if _, ok := limit["from"]; ok {
from = limit["from"] from = limit["from"]
} }
limitStr += " limit " + from + ",1"
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit 1 OFFSET " + from
} else {
limitStr += " limit " + from + ",1"
}
} else { } else {
limitStr = " limit 1" limitStr = " limit 1"
@ -241,8 +276,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
var err error var err error
var queryNum int = 0 var queryNum int = 0
for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题 for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题
rows, err = DB.Query("SELECT "+title+" FROM "+dbName+" where "+strings.Join(keyList, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(keyList, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -274,8 +316,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
count++ count++
@ -295,16 +344,38 @@ func GetData(dbName, table string, title string, where map[string]string, limit
* @param dbName 数据表名 * @param dbName 数据表名
* @param title 查询字段名 * @param title 查询字段名
*/ */
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0 var count int = 0
info := make(map[string]string) info := make(map[string]string)
if dbName == "" && table_name == "" { if dbName == "" && table_name == "" {
return count, info, errors.New("没有数据表") return count, info, errors.New("没有数据表")
} }
table := "" table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name table = table_name
} else { } else {
table = getTableName(dbName, table_name) table = getTableName(dbName, table_name)
@ -316,10 +387,11 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} else { } else {
title = "*" title = "*"
} }
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else { } else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
} }
if alias != "" { if alias != "" {
table = helper.StringJoin(table, " as ", alias) table = helper.StringJoin(table, " as ", alias)
@ -328,17 +400,31 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
sql_str = helper.StringJoin(sql_str, " from ", table) sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 { if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join { for _, joinitem := range join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
} }
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
} }
} }
if len(where) > 0 || len(where_or) > 0 { if len(where) > 0 || len(where_or) > 0 {
@ -363,10 +449,10 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
if orderby != "" { if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby) sql_str = helper.StringJoin(sql_str, " order by ", orderby)
} }
if debug { if debug {
log.Println("query sql:", sql_str, valueList) log.Println("query sql:", sql_str, valueList)
} }
condition_len := 0 //所有条件数 condition_len := 0 //所有条件数
for _, ch2 := range sql_str { for _, ch2 := range sql_str {
if string(ch2) == "?" { if string(ch2) == "?" {
@ -392,6 +478,12 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} }
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...) rows, err = db.Query(sql_str, valueList...)
@ -406,6 +498,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} }
if err != nil { if err != nil {
log.Println("DB error:", err)
rows.Close() rows.Close()
return count, info, err return count, info, err
} }
@ -425,8 +518,15 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
count++ count++
@ -436,6 +536,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} }
rows.Close() rows.Close()
if rowerr != nil { if rowerr != nil {
log.Println("DB row error:", rowerr)
return count, info, rowerr return count, info, rowerr
} }
return count, info, nil return count, info, nil
@ -447,7 +548,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
* @param dbName 数据表名 * @param dbName 数据表名
* @param title 查询字段名 * @param title 查询字段名
*/ */
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0 var count int = 0
list := make([]map[string]string, 0) list := make([]map[string]string, 0)
@ -455,7 +556,30 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
return count, list, errors.New("没有数据表") return count, list, errors.New("没有数据表")
} }
table := "" table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name table = table_name
} else { } else {
table = getTableName(dbName, table_name) table = getTableName(dbName, table_name)
@ -468,10 +592,11 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
} else { } else {
title = "*" title = "*"
} }
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else { } else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
} }
if alias != "" { if alias != "" {
table = helper.StringJoin(table, " as ", alias) table = helper.StringJoin(table, " as ", alias)
@ -480,17 +605,31 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
sql_str = helper.StringJoin(sql_str, " from ", table) sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 { if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join { for _, joinitem := range join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) >= 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
} }
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
} }
} }
@ -528,7 +667,11 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
from := strconv.Itoa((page - 1) * page_size) from := strconv.Itoa((page - 1) * page_size)
offset := strconv.Itoa(page_size) offset := strconv.Itoa(page_size)
if from != "" && offset != "" { if from != "" && offset != "" {
sql_str = helper.StringJoin(sql_str, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql_str = helper.StringJoin(sql_str, " limit ", offset, " OFFSET ", from)
} else {
sql_str = helper.StringJoin(sql_str, " limit ", from, " , ", offset)
}
} }
} }
if debug { if debug {
@ -554,7 +697,12 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
var err error var err error
var queryNum int = 0 var queryNum int = 0
for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题
if DB_PROVIDER == "PgsqlDb" {
sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str)
sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql_str = ReplaeByOtherSql(sql_str, "DmSql", "")
}
rows, err = db.Query(sql_str, valueList...) rows, err = db.Query(sql_str, valueList...)
if err == nil { if err == nil {
@ -590,8 +738,15 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
count++ count++
@ -632,7 +787,6 @@ func GetList(dbName, table string, title string, where map[string]string, limit
if dbName == "" && table == "" { if dbName == "" && table == "" {
return list, errors.New("没有数据表") return list, errors.New("没有数据表")
} }
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
@ -660,7 +814,12 @@ func GetList(dbName, table string, title string, where map[string]string, limit
from = limit["from"] from = limit["from"]
} }
if offset != "0" && from != "" { if offset != "0" && from != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
} }
} }
@ -694,8 +853,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
} }
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select "+title+" from "+dbName+" where "+strings.Join(whereStr, " and ")+" "+limitStr, valueList...)
var Sql string
Sql = "select " + title + " from " + dbName + " where " + strings.Join(whereStr, " and ") + " " + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -740,8 +906,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
} }
} }
list = append(list, record) list = append(list, record)
@ -759,7 +932,6 @@ func GetTotal(dbName, table string, args ...string) (total int) {
if dbName == "" && table == "" { if dbName == "" && table == "" {
return return
} }
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
@ -777,7 +949,6 @@ func GetTotal(dbName, table string, args ...string) (total int) {
var queryNum int = 0 var queryNum int = 0
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1") rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1")
if err == nil { if err == nil {
@ -861,7 +1032,15 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to
for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题
rows, err = DB.Query("select count("+title+") as count from "+dbName+" where "+strings.Join(whereStr, " and ")+" limit 1", valueList...)
var Sql string
Sql = "select count(" + title + ") as count from " + dbName + " where " + strings.Join(whereStr, " and ") + " limit 1"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
rows, err = DB.Query(Sql, valueList...)
if err == nil { if err == nil {
break break
@ -920,6 +1099,12 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题 for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题
if len(args) > 1 { if len(args) > 1 {
if DB_PROVIDER == "PgsqlDb" {
queryStr = sqlx.Rebind(sqlx.DOLLAR, queryStr)
queryStr = ReplaeByOtherSql(queryStr, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
queryStr = ReplaeByOtherSql(queryStr, "DmSql", "")
}
rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",") rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",")
if err != nil { if err != nil {
log.Println("ERROR|DoQuery error:", err) log.Println("ERROR|DoQuery error:", err)
@ -960,8 +1145,15 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
record[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
record[index] = helper.ToString(col)
}
} }
} }
list = append(list, record) list = append(list, record)


+ 40
- 4
db_test.go View File

@ -6,7 +6,6 @@ import (
) )
func Test_Connet(t *testing.T) { func Test_Connet(t *testing.T) {
//go func() {
for i := 0; i < 1; i++ { for i := 0; i < 1; i++ {
dbhost := "localhost" dbhost := "localhost"
dbname := "shop" dbname := "shop"
@ -36,8 +35,8 @@ func Test_Connet(t *testing.T) {
orderby := "id desc" orderby := "id desc"
debug := true debug := true
// count, row, err := GetRow(dbname, table, alias, title, join, where, where_or, valueList, orderby, debug)
count, row, err := FetchRows(dbname, table, alias, title, join, where, where_or, valueList, orderby, "", 1, 10, debug)
//count, row, err := GetRow(dbname, table, alias, title, join, where, where_or, valueList, orderby, debug)
count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count) log.Println(count)
log.Println(row) log.Println(row)
@ -46,6 +45,43 @@ func Test_Connet(t *testing.T) {
log.Println(err.Error()) log.Println(err.Error())
} }
} }
//}()
}
func Test_Query(t *testing.T) {
token := "67c121aa-6e1c-011f-ebb6-976d855fd777"
dbhost := "192.168.233.134"
dbname := "canyin"
dbusername := "bin"
dbpassword := "Bin123456"
dbport := "5432"
table := "ttl_user_token"
err := PgConnect(dbhost, dbusername, dbpassword, dbname, dbport)
if err != nil {
log.Println(err.Error())
}
title := "user.*,ut.expiretime"
alias := "ut"
join := [][]string{}
join = append(join, []string{"ttl_user as user", "ut.user_id= user.id"})
where := []string{
"ut.token=?",
}
where_or := []string{}
valueList := []interface{}{
token,
}
orderby := ""
debug := true
count, row, err := GetRow(dbname, table, alias, title, [][]string{}, join, where, where_or, valueList, orderby, "", "", debug)
//count, row, err := FetchRows(dbname, table, alias, title, join, join, where, where_or, valueList, orderby, "", "", 1, 10, debug)
log.Println(count)
log.Println(row)
log.Println(err)
if err != nil {
log.Println(err.Error())
}
} }

+ 16
- 2
go.mod View File

@ -1,9 +1,23 @@
module git.tetele.net/tgo/dbquery module git.tetele.net/tgo/dbquery
go 1.14
go 1.23.0
toolchain go1.24.0
require ( require (
git.tetele.net/tgo/helper v0.1.0 git.tetele.net/tgo/helper v0.1.0
gitee.com/opengauss/openGauss-connector-go-pq v1.0.7
github.com/denisenkom/go-mssqldb v0.11.0 github.com/denisenkom/go-mssqldb v0.11.0
github.com/go-sql-driver/mysql v1.5.0
github.com/go-sql-driver/mysql v1.8.1
github.com/jmoiron/sqlx v1.4.0
)
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
golang.org/x/crypto v0.40.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
) )

+ 0
- 11
go.sum View File

@ -1,11 +0,0 @@
git.tetele.net/tgo/helper v0.1.0 h1:ZdsBXUWX3+22ZzHTZRldBfBsQwu+CwUH8qScUvpgimE=
git.tetele.net/tgo/helper v0.1.0/go.mod h1:shYQE/hvMy3fOE8JXKGxvywOXiz3M5Nw4e+u7HR8+NY=
github.com/denisenkom/go-mssqldb v0.11.0 h1:9rHa233rhdOyrz2GcP9NM+gi2psgJZ4GWDpL/7ND8HI=
github.com/denisenkom/go-mssqldb v0.11.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

+ 99
- 17
prepare.go View File

@ -3,6 +3,7 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/jmoiron/sqlx"
"log" "log"
"strings" "strings"
@ -41,8 +42,11 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
offset = limit["offset"] offset = limit["offset"]
} }
if from != "" && offset != "" { if from != "" && offset != "" {
limitStr += " limit " + from + "," + offset
if DB_PROVIDER == "PgsqlDb" {
limitStr += " limit " + offset + " OFFSET " + from
} else {
limitStr += " limit " + from + "," + offset
}
} }
} }
@ -52,7 +56,15 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s
if len(where) > 0 { if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr) // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr)
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
} else { } else {
// log.Println("SELECT " + title + " FROM " + dbName + limitStr) // log.Println("SELECT " + title + " FROM " + dbName + limitStr)
stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr) stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr)
@ -75,8 +87,6 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
return nil, errors.New("缺少必要参数") return nil, errors.New("缺少必要参数")
} }
// log.Println(valuelist...)
rows, err := stmt.Query(valuelist...) rows, err := stmt.Query(valuelist...)
defer stmt.Close() defer stmt.Close()
if err != nil { if err != nil {
@ -106,8 +116,15 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
} else { } else {
@ -154,8 +171,15 @@ func StmtForQueryRow(stmt *sql.Stmt, valuelist []interface{}) (map[string]string
if rowerr == nil { if rowerr == nil {
for i, col := range values { for i, col := range values {
if col != nil { if col != nil {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
if DB_PROVIDER == "DmSql" {
//达梦返回全大写字段,需先转小写
index = helper.StrFirstToUpper(strings.ToLower(columns[i]))
//达梦返回的字段类型比较细,比如:int16、int32
info[index] = ToString(col)
} else {
index = helper.StrFirstToUpper(columns[i])
info[index] = helper.ToString(col)
}
} }
} }
} else { } else {
@ -186,8 +210,15 @@ func StmtForUpdate(dbName, table string, data []string, where []string) (*sql.St
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
stmt, err = DB.Prepare("update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and "))
var Sql string
Sql = "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = DB.Prepare(Sql)
return stmt, err return stmt, err
} }
@ -224,7 +255,41 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return nil, errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(data, " , "))
}
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
//stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , "))
stmt, err = DB.Prepare(sql)
return stmt, err return stmt, err
} }
@ -234,11 +299,23 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) {
* @return lastId error * @return lastId error
*/ */
func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) { func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
if DB_PROVIDER == "PgsqlDb" {
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
res, err := stmt.Exec(valuelist...)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return res.LastInsertId()
} }
return res.LastInsertId()
} }
/** /**
@ -350,7 +427,12 @@ func StmtForQuery(querysql string) (*sql.Stmt, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
var err error var err error
if DB_PROVIDER == "PgsqlDb" {
querysql = sqlx.Rebind(sqlx.DOLLAR, querysql)
querysql = ReplaeByOtherSql(querysql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
querysql = ReplaeByOtherSql(querysql, "DmSql", "")
}
stmt, err = DB.Prepare(querysql) stmt, err = DB.Prepare(querysql)
return stmt, err return stmt, err


+ 102
- 30
transaction.go View File

@ -6,6 +6,8 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"git.tetele.net/tgo/helper"
"github.com/jmoiron/sqlx"
"log" "log"
"strings" "strings"
"time" "time"
@ -25,7 +27,7 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
} }
if len(data) < 1 { if len(data) < 1 {
return 0, errors.New("参数错误,没有要写入的数据") return 0, errors.New("参数错误,没有要写入的数据")
@ -43,16 +45,33 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
valueList[i] = value valueList[i] = value
i++ i++
} }
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
if DB_PROVIDER == "PgsqlDb" {
var Sql string
Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")"
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valueList...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else { } else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...)
if err != nil {
log.Println("ERROR", "insert into ", dbName, "error:", err)
return insertId, err
} else {
insertId, _ = result.LastInsertId()
time.Sleep(time.Second * 2)
return insertId, nil
}
} }
} }
@ -84,28 +103,52 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{})
var field []string = make([]string, len(data)) var field []string = make([]string, len(data))
var valuelist []interface{} = make([]interface{}, len(data)) var valuelist []interface{} = make([]interface{}, len(data))
insert_data := []string{}
value_data := []string{}
var i int = 0 var i int = 0
for key, item := range data { for key, item := range data {
field[i] = key + "=?" field[i] = key + "=?"
valuelist[i] = item valuelist[i] = item
i++ i++
}
insert_data = append(insert_data, key)
value_data = append(value_data, "?")
}
if DB_PROVIDER == "PgsqlDb" {
Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add")
stmt, err = tx.Prepare(Sql)
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
row := stmt.QueryRow(valuelist...)
var id int64
err = row.Scan(&id) // 扫描 RETURNING 返回的 ID
if err != nil {
return 0, errors.New("创建失败:" + err.Error())
}
return id, nil
} else {
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
stmt, err = tx.Prepare(sql)
sql := "insert into " + dbName + " set " + strings.Join(field, " , ")
if DB_PROVIDER == "DmSql" {
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql)
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
if err != nil {
log.Println("insert prepare error:", sql, err)
return 0, errors.New("insert prepare error:" + err.Error())
}
result, err := stmt.Exec(valuelist...)
if err != nil {
log.Println("insert exec error:", sql, valuelist, err)
return 0, errors.New("insert exec error:" + err.Error())
}
insertId, _ := result.LastInsertId()
return insertId, nil
} }
insertId, _ := result.LastInsertId()
return insertId, nil
} }
@ -160,7 +203,15 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma
log.Println("ERROR", "update", dbName, "error, params empty") log.Println("ERROR", "update", dbName, "error, params empty")
return rowsAffected, errors.New("params empty") return rowsAffected, errors.New("params empty")
} }
result, err := tx.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...)
var Sql string
Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ")
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR", "update", dbName, "error:", err) log.Println("ERROR", "update", dbName, "error:", err)
@ -186,7 +237,7 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
} }
if len(where) < 1 { if len(where) < 1 {
@ -198,7 +249,12 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
var stmt *sql.Stmt var stmt *sql.Stmt
sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = tx.Prepare(sql) stmt, err = tx.Prepare(sql)
if err != nil { if err != nil {
@ -228,7 +284,7 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
if strings.Contains(table, "select ") { if strings.Contains(table, "select ") {
dbName = table dbName = table
} else { } else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
} }
if len(where) < 1 { if len(where) < 1 {
return count, errors.New("参数错误,没有删除条件") return count, errors.New("参数错误,没有删除条件")
@ -260,7 +316,15 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
limitStr = " limit " + del_count[0] limitStr = " limit " + del_count[0]
} }
result, err := tx.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...)
var Sql string
Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
result, err := tx.Exec(Sql, valueList...)
if err != nil { if err != nil {
log.Println("ERROR", "delete from", dbName, "error:", err) log.Println("ERROR", "delete from", dbName, "error:", err)
@ -296,7 +360,15 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) (
if len(where) > 0 { if len(where) > 0 {
// log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE") // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE")
var Sql string
Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE"
if DB_PROVIDER == "PgsqlDb" {
Sql = sqlx.Rebind(sqlx.DOLLAR, Sql)
Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
Sql = ReplaeByOtherSql(Sql, "DmSql", "")
}
stmt, err = tx.Prepare(Sql)
} else { } else {
// log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE") // log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE")
stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE") stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE")


+ 265
- 15
transaction_chain.go View File

@ -6,6 +6,7 @@ package dbquery
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/jmoiron/sqlx"
"log" "log"
"strconv" "strconv"
"strings" "strings"
@ -34,6 +35,7 @@ type TxQuery struct {
conn *sql.DB conn *sql.DB
tx *sql.Tx tx *sql.Tx
debug bool debug bool
with [][]string //[[临时表的sql语句,临时表的名称]]
} }
func NewTxQuery(t ...string) *TxQuery { func NewTxQuery(t ...string) *TxQuery {
@ -109,6 +111,14 @@ func (this *TxQuery) Where(where string) *TxQuery {
this.where = append(this.where, where) this.where = append(this.where, where)
return this return this
} }
func (this *TxQuery) With(with []string) *TxQuery {
this.with = append(this.with, with)
return this
}
func (this *TxQuery) Withs(withs [][]string) *TxQuery {
this.with = append(this.with, withs...)
return this
}
func (this *TxQuery) Wheres(wheres []string) *TxQuery { func (this *TxQuery) Wheres(wheres []string) *TxQuery {
if len(wheres) > 0 { if len(wheres) > 0 {
this.where = append(this.where, wheres...) this.where = append(this.where, wheres...)
@ -147,6 +157,26 @@ func (this *TxQuery) Join(join []string) *TxQuery {
this.join = append(this.join, join) this.join = append(this.join, join)
return this return this
} }
/**
* 左连接
* 2023/08/10
* gz
*/
func (this *TxQuery) LeftJoin(table_name string, condition string) *TxQuery {
this.join = append(this.join, []string{table_name, condition, "left"})
return this
}
/**
* 右连接
* 2023/08/10
* gz
*/
func (this *TxQuery) RightJoin(table_name string, condition string) *TxQuery {
this.join = append(this.join, []string{table_name, condition, "right"})
return this
}
func (this *TxQuery) Data(data string) *TxQuery { func (this *TxQuery) Data(data string) *TxQuery {
this.data = append(this.data, data) this.data = append(this.data, data)
return this return this
@ -177,16 +207,46 @@ func (this *TxQuery) Clean() *TxQuery {
this.save_data = this.save_data[0:0] this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0] this.upd_field = this.upd_field[0:0]
this.having = "" this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this return this
} }
//构造子查询
// 返回表名
func (this *TxQuery) GetTableName(table string) string {
return getTableName(this.dbname, table)
}
// 构造子查询
func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
if this.dbname == "" && this.table == "" { if this.dbname == "" && this.table == "" {
return nil, errors.New("参数错误,没有数据表") return nil, errors.New("参数错误,没有数据表")
} }
var table = "" var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table table = this.table
} else { } else {
table = getTableName(this.dbname, this.table) table = getTableName(this.dbname, this.table)
@ -199,7 +259,8 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
} else { } else {
title = "*" title = "*"
} }
sql = helper.StringJoin("select ", title)
sql = helper.StringJoin(withSql, "select ", title)
if this.alias != "" { if this.alias != "" {
table = helper.StringJoin(table, " as ", this.alias) table = helper.StringJoin(table, " as ", this.alias)
@ -208,15 +269,31 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table) sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 { if len(this.join) > 0 {
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join { for _, joinitem := range this.join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 3 {
sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
} }
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
} }
} }
if len(this.where) > 0 || len(this.where_or) > 0 { if len(this.where) > 0 || len(this.where_or) > 0 {
@ -255,7 +332,12 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
from := strconv.Itoa((this.page - 1) * this.page_size) from := strconv.Itoa((this.page - 1) * this.page_size)
offset := strconv.Itoa(this.page_size) offset := strconv.Itoa(this.page_size)
if from != "" && offset != "" { if from != "" && offset != "" {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from)
} else {
sql = helper.StringJoin(sql, " limit ", from, " , ", offset)
}
} }
} }
@ -272,13 +354,19 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
if condition_len != len(this.value) { if condition_len != len(this.value) {
return nil, errors.New("参数错误,条件值错误") return nil, errors.New("参数错误,条件值错误")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
return map[string]interface{}{ return map[string]interface{}{
"sql": sql, "sql": sql,
"value": this.value, "value": this.value,
}, nil }, nil
} }
//获取表格信息
// 获取表格信息
func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error) { func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error) {
field := []string{ field := []string{
"COLUMN_NAME", //字段名 "COLUMN_NAME", //字段名
@ -288,13 +376,50 @@ func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error)
"COLUMN_COMMENT", //备注 "COLUMN_COMMENT", //备注
"IS_NULLABLE", //是否为空 "IS_NULLABLE", //是否为空
} }
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?"
sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?"
if DB_PROVIDER == "PgsqlDb" {
//pgsql中,未加引号的标识符会被自动转换为小写
sql = `SELECT
c.column_name AS 'COLUMN_NAME',
c.column_default AS 'COLUMN_DEFAULT',
c.data_type AS 'DATA_TYPE',
c.udt_name AS 'COLUMN_TYPE',
pgdesc.description AS 'COLUMN_COMMENT',
c.is_nullable AS 'IS_NULLABLE'
FROM
information_schema.columns c
LEFT JOIN
pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name
LEFT JOIN
pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position
WHERE
c.table_name = ?`
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
} else if DB_PROVIDER == "DmSql" {
sql = `SELECT
COLUMN_NAME,
DATA_DEFAULT AS COLUMN_DEFAULT,
DATA_TYPE,
CASE
WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')'
WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')'
ELSE DATA_TYPE
END AS COLUMN_TYPE,
'' AS COLUMN_COMMENT,
CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE
FROM
ALL_TAB_COLUMNS
WHERE
TABLE_NAME = ?`
}
stmtSql, err := this.tx.Prepare(sql) stmtSql, err := this.tx.Prepare(sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname})
list, err := StmtForQueryList(stmtSql, []interface{}{table})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -395,7 +520,12 @@ func (this *TxQuery) UpdateStmt() error {
if condition_len != len(this.value) { if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误") return errors.New("参数错误,条件值错误")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "MysqlDb" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql) stmt, err = this.tx.Prepare(sql)
if err != nil { if err != nil {
@ -420,8 +550,11 @@ func (this *TxQuery) UpdateAllStmt() error {
var dataSql []string //一组用到的占位字符 var dataSql []string //一组用到的占位字符
var valSql []string //占位字符组 var valSql []string //占位字符组
var updSql []string //更新字段的sql var updSql []string //更新字段的sql
var updSql_dm []string //更新字段的sql--达梦和高斯用
var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值 var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值
dataLen := len(this.save_data) dataLen := len(this.save_data)
if dataLen > 0 { if dataLen > 0 {
//批量操作 //批量操作
this.data = this.data[0:0] this.data = this.data[0:0]
@ -439,11 +572,14 @@ func (this *TxQuery) UpdateAllStmt() error {
case 0: case 0:
//预览创建数据的长度 //预览创建数据的长度
updSql = make([]string, 0, fieldLen) updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default: default:
//按照需要更新字段数长度 //按照需要更新字段数长度
updSql = make([]string, 0, updFieldLen) updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
} }
} }
for k := range this.save_data[i] { for k := range this.save_data[i] {
@ -451,6 +587,7 @@ func (this *TxQuery) UpdateAllStmt() error {
dataSql = append(dataSql, "?") //存储需要的占位符 dataSql = append(dataSql, "?") //存储需要的占位符
if updFieldLen == 0 && k != "id" { if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段
updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段
} }
} }
dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式 dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式
@ -468,21 +605,26 @@ func (this *TxQuery) UpdateAllStmt() error {
switch updFieldLen { switch updFieldLen {
case 0: case 0:
updSql = make([]string, 0, fieldLen) updSql = make([]string, 0, fieldLen)
updSql_dm = make([]string, 0, fieldLen)
default: default:
updSql = make([]string, 0, updFieldLen) updSql = make([]string, 0, updFieldLen)
updSql_dm = make([]string, 0, updFieldLen)
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
} }
} }
for i := 0; i < fieldLen; i++ { for i := 0; i < fieldLen; i++ {
dataSql = append(dataSql, "?") dataSql = append(dataSql, "?")
if updFieldLen == 0 && this.data[i] != "id" { if updFieldLen == 0 && this.data[i] != "id" {
updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")") updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")")
updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"")
} }
} }
if updFieldLen > 0 { if updFieldLen > 0 {
for _, k := range this.upd_field { for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")") updSql = append(updSql, k+"=values("+k+")")
updSql_dm = append(updSql_dm, k+"=s."+k+"")
} }
} }
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
@ -498,8 +640,44 @@ func (this *TxQuery) UpdateAllStmt() error {
if len(valSql) > 1 { if len(valSql) > 1 {
setText = " value " setText = " value "
} }
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , "))
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
sql = `merge into ` + dbName + ` as t
using (
` + setText + strings.Join(valSql, ",") + `
) s (` + strings.Join(this.data, " , ") + `)
on (t.id = s.id)
when matched then
update set
` + strings.Join(updSql_dm, " , ") + `
when NOT matched then
insert (` + strings.Join(this.data, " , ") + `)
values (` + strings.Join(val_field, " , ") + `)`
} else if DB_PROVIDER == "DmSql" {
setText = " values "
val_field := addPrefixInField(this.data, "s.")
title_field := addPrefixInField(this.data, "? AS ")
sql = `MERGE INTO ` + dbName + ` AS t
USING (
SELECT
` + strings.Join(title_field, " , ") + `
FROM DUAL
) s
ON (t.id = s.id)
WHEN MATCHED THEN
UPDATE SET
` + strings.Join(updSql_dm, " , ") + `
WHEN NOT MATCHED THEN
INSERT (` + strings.Join(this.data, " , ") + `)
VALUES (` + strings.Join(val_field, " , ") + `)`
}
if this.debug { if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value) log.Println("insert on duplicate key update sql:", sql, this.value)
} }
@ -513,6 +691,13 @@ func (this *TxQuery) UpdateAllStmt() error {
if conditionLen != len(this.value) { if conditionLen != len(this.value) {
return errors.New("参数错误,条件值数量不匹配") return errors.New("参数错误,条件值数量不匹配")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql) stmt, err = this.tx.Prepare(sql)
if err != nil { if err != nil {
@ -534,8 +719,32 @@ func (this *TxQuery) CreateStmt() error {
dbName := getTableName(this.dbname, this.table) dbName := getTableName(this.dbname, this.table)
var sql string var sql string
if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" {
insert_data := []string{}
value_data := []string{}
for _, rv := range this.data {
dv := strings.Split(rv, "=")
if len(dv) < 2 {
return errors.New("参数错误,条件值错误,=号不存在")
}
if strings.Contains(rv, "?") {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, "?")
} else {
insert_data = append(insert_data, dv[0])
value_data = append(value_data, dv[1])
}
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")")
if DB_PROVIDER == "PgsqlDb" {
sql = helper.StringJoin(sql, " RETURNING id")
}
} else {
sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
}
//sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , "))
if this.debug { if this.debug {
log.Println("insert sql:", sql, this.value) log.Println("insert sql:", sql, this.value)
@ -551,7 +760,12 @@ func (this *TxQuery) CreateStmt() error {
if condition_len != len(this.value) { if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误") return errors.New("参数错误,条件值错误")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql) stmt, err = this.tx.Prepare(sql)
if err != nil { if err != nil {
@ -622,6 +836,11 @@ func (this *TxQuery) CreateAllStmt() error {
if len(valSql) > 1 { if len(valSql) > 1 {
setText = " value " setText = " value "
} }
if DB_PROVIDER == "PgsqlDb" {
setText = " values "
} else if DB_PROVIDER == "DmSql" {
setText = " values "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","))
if len(this.value) == 0 { if len(this.value) == 0 {
return errors.New("参数错误,条件值错误") return errors.New("参数错误,条件值错误")
@ -640,6 +859,12 @@ func (this *TxQuery) CreateAllStmt() error {
if conditionLen != len(this.value) { if conditionLen != len(this.value) {
return errors.New("参数错误,条件值数量不匹配") return errors.New("参数错误,条件值数量不匹配")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "add")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql) stmt, err = this.tx.Prepare(sql)
@ -686,7 +911,12 @@ func (this *TxQuery) DeleteStmt() error {
if condition_len != len(this.value) { if condition_len != len(this.value) {
return errors.New("参数错误,条件值错误") return errors.New("参数错误,条件值错误")
} }
if DB_PROVIDER == "PgsqlDb" {
sql = sqlx.Rebind(sqlx.DOLLAR, sql)
sql = ReplaeByOtherSql(sql, "PgsqlDb", "")
} else if DB_PROVIDER == "DmSql" {
sql = ReplaeByOtherSql(sql, "DmSql", "")
}
stmt, err = this.tx.Prepare(sql) stmt, err = this.tx.Prepare(sql)
if err != nil { if err != nil {
@ -801,6 +1031,26 @@ func (this *TxQuery) CreateAll() (int64, error) {
return StmtForInsertExec(this.stmt, this.value) return StmtForInsertExec(this.stmt, this.value)
} }
/**
* 执行原生sql
* return error
*/
func (this *TxQuery) ExecSql(sql string) (int64, error) {
if this.debug {
log.Println("ExecSql sql:", sql)
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
return 0, err
}
res, err := stmt.Exec()
if err != nil {
return 0, errors.New("执行失败:" + err.Error())
}
return res.RowsAffected()
}
/** /**
* 提交 * 提交
*/ */


Loading…
Cancel
Save