diff --git a/chain.go b/chain.go index 8ce986d..2f0eb93 100644 --- a/chain.go +++ b/chain.go @@ -71,6 +71,8 @@ func (this *Query) Db(dbname string) *Query { this.dbname = dbname if DB_PROVIDER == "PgsqlDb" { this.dbname = "" + } else if DB_PROVIDER == "DmSql" { + this.dbname = "" } return this } @@ -225,19 +227,55 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) { "COLUMN_COMMENT", //备注 "IS_NULLABLE", //是否为空 } - sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?" - if this.conn == nil { - this.conn = DB - } + sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?" + if DB_PROVIDER == "PgsqlDb" { + //pgsql中,未加引号的标识符会被自动转换为小写 + sql = `SELECT + c.column_name AS 'COLUMN_NAME', + c.column_default AS 'COLUMN_DEFAULT', + c.data_type AS 'DATA_TYPE', + c.udt_name AS 'COLUMN_TYPE', + pgdesc.description AS 'COLUMN_COMMENT', + c.is_nullable AS 'IS_NULLABLE' + FROM + information_schema.columns c + LEFT JOIN + pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name + LEFT JOIN + pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position + WHERE + c.table_name = ?` + sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + } else if DB_PROVIDER == "DmSql" { + sql = `SELECT + COLUMN_NAME, + DATA_DEFAULT AS COLUMN_DEFAULT, + DATA_TYPE, + CASE + WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')' + WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')' + ELSE DATA_TYPE + END AS COLUMN_TYPE, + '' AS COLUMN_COMMENT, + CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE + FROM + ALL_TAB_COLUMNS + WHERE + TABLE_NAME = ?` } + + if this.conn == nil { + this.conn = DB + } + stmtSql, err := this.conn.Prepare(sql) if err != nil { return nil, err } - list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname}) + + list, err := StmtForQueryList(stmtSql, []interface{}{table}) if err != nil { return nil, err } @@ -254,6 +292,10 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) { } for _, k := range field { index := helper.StrFirstToUpper(k) + if DB_PROVIDER == "DmSql" { + index = helper.StrFirstToUpper(strings.ToLower(k)) + } + if v, ok := item[index]; ok { switch k { case "COLUMN_NAME": @@ -442,7 +484,9 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } return map[string]interface{}{ "sql": sql, @@ -517,7 +561,9 @@ func (this *Query) UpdateStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.conn.Prepare(sql) @@ -543,8 +589,11 @@ func (this *Query) UpdateAllStmt() error { var dataSql []string //一组用到的占位字符 var valSql []string //占位字符组 var updSql []string //更新字段的sql + var updSql_dm []string //更新字段的sql--达梦和高斯用 var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值 + dataLen := len(this.save_data) + if dataLen > 0 { //批量操作 this.data = this.data[0:0] @@ -562,11 +611,14 @@ func (this *Query) UpdateAllStmt() error { case 0: //预览创建数据的长度 updSql = make([]string, 0, fieldLen) + updSql_dm = make([]string, 0, fieldLen) default: //按照需要更新字段数长度 updSql = make([]string, 0, updFieldLen) + updSql_dm = make([]string, 0, updFieldLen) for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 + updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段 } } for k := range this.save_data[i] { @@ -574,6 +626,7 @@ func (this *Query) UpdateAllStmt() error { dataSql = append(dataSql, "?") //存储需要的占位符 if updFieldLen == 0 && k != "id" { updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 + updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段 } } dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式 @@ -591,21 +644,26 @@ func (this *Query) UpdateAllStmt() error { switch updFieldLen { case 0: updSql = make([]string, 0, fieldLen) + updSql_dm = make([]string, 0, fieldLen) default: updSql = make([]string, 0, updFieldLen) + updSql_dm = make([]string, 0, updFieldLen) for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") + updSql_dm = append(updSql_dm, k+"=s."+k+"") } } for i := 0; i < fieldLen; i++ { dataSql = append(dataSql, "?") if updFieldLen == 0 && this.data[i] != "id" { updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")") + updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"") } } if updFieldLen > 0 { for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") + updSql_dm = append(updSql_dm, k+"=s."+k+"") } } valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") @@ -622,10 +680,43 @@ func (this *Query) UpdateAllStmt() error { if len(valSql) > 1 { setText = " value " } + + sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) + if DB_PROVIDER == "PgsqlDb" { setText = " values " + val_field := addPrefixInField(this.data, "s.") + + sql = `merge into ` + dbName + ` as t + using ( + ` + setText + strings.Join(valSql, ",") + ` + ) s (` + strings.Join(this.data, " , ") + `) + on (t.id = s.id) + when matched then + update set + ` + strings.Join(updSql_dm, " , ") + ` + when NOT matched then + insert (` + strings.Join(this.data, " , ") + `) + values (` + strings.Join(val_field, " , ") + `)` + } else if DB_PROVIDER == "DmSql" { + setText = " values " + val_field := addPrefixInField(this.data, "s.") + title_field := addPrefixInField(this.data, "? AS ") + + sql = `MERGE INTO ` + dbName + ` AS t + USING ( + SELECT + ` + strings.Join(title_field, " , ") + ` + FROM DUAL + ) s + ON (t.id = s.id) + WHEN MATCHED THEN + UPDATE SET + ` + strings.Join(updSql_dm, " , ") + ` + WHEN NOT MATCHED THEN + INSERT (` + strings.Join(this.data, " , ") + `) + VALUES (` + strings.Join(val_field, " , ") + `)` } - sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) if this.debug { log.Println("insert on duplicate key update sql:", sql, this.value) @@ -645,7 +736,9 @@ func (this *Query) UpdateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.conn.Prepare(sql) @@ -719,6 +812,8 @@ func (this *Query) CreateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { setText = " values " + } else if DB_PROVIDER == "DmSql" { + setText = " values " } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) @@ -740,7 +835,9 @@ func (this *Query) CreateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "add") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "add") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.conn.Prepare(sql) @@ -763,7 +860,7 @@ func (this *Query) CreateStmt() error { dbName := getTableName(this.dbname, this.table, this.dbtype) var sql string - if DB_PROVIDER == "PgsqlDb" { + if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" { insert_data := []string{} value_data := []string{} for _, rv := range this.data { @@ -780,7 +877,11 @@ func (this *Query) CreateStmt() error { } } - sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + + if DB_PROVIDER == "PgsqlDb" { + sql = helper.StringJoin(sql, " RETURNING id") + } } else { sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) } @@ -804,7 +905,9 @@ func (this *Query) CreateStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.conn.Prepare(sql) @@ -856,7 +959,9 @@ func (this *Query) DeleteStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.conn.Prepare(sql) diff --git a/chain_test.go b/chain_test.go index 4d7d2d7..212af8c 100644 --- a/chain_test.go +++ b/chain_test.go @@ -1,17 +1,144 @@ package dbquery import ( + "fmt" "testing" ) +// 测试各数据库下各种情况 func Test_Chain(t *testing.T) { - Connect("127.0.0.1", "root", "123456", "shop", "3306") + //测试数据库连接 + //err := Connect("127.0.0.1", "root", "root", "canyin", "3306") + err := PgConnect("192.168.233.145", "bin", "Bin123456", "canyin", "5432") + //err := DmConnect("192.168.233.141", "SHOPV2", "Bin123456", "", "5236") + if err != nil { + t.Log(err) + } + db_name := "" + table_name := "ttl_user_log" + //time := time.Now().Unix() - ret, err := new(Query).Db("shop").Table("ttl_order_product").Alias("op"). - Join([]string{"ttl_product as p", "op.product_id=p.id"}). - Title("op.id,op.order_price,p.thumb").WhereOr("op.id =?").WhereOr("op.id = ?").Value(63).Value(64).Debug(true).List() + //================查询表结构=========== + ret, err := new(Query).Db(db_name).GetTableInfo(table_name) + if err != nil { + t.Log(err) + } + fmt.Println("===GetTableInfo:", ret) - t.Log(len(ret)) - t.Log(ret) - t.Log(err) + //==========获取信息================= + //info, err := new(Query).Db(db_name).Clean().Table(table_name).Where("id=?").Value("3").Title("*").Find() + //info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("user").Join([]string{"ttl_user u", "u.id = user.user_id", "inner"}).Where("user.id=?").Value("3").Title("user.id,user.user_id,u.nickname").Find() + info, err := GetDataByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, nil) + if err != nil { + t.Log(err) + } + fmt.Println("===Find:", info) + + //============获取列表================== + list, err := new(Query).Db(db_name).Clean().Table(table_name).Alias("u").Join([]string{"ttl_user user", "user.id = u.user_id", "inner"}).Where("user.id=?").Value(1).Title("user.id,user.nickname,user.status").Orderby("user.id desc").Page(1).PageSize(10).Select() + //list, err := DoQuery("select user.nickname,log.* from ttl_user user left join ttl_user_log log on user.id = log.user_id where user.id = ?", "1") + //list, err := GetListByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, map[string]string{"id": "id asc"}) + //list, err := QueryByStmt("select * from "+table_name+" where id < ?", []interface{}{10}) + if err != nil { + t.Log(err) + } + fmt.Println("===List:", list) + + //===========添加数据============ + //insert_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("1").Data("createtime=?").Value(time).Create() + //insert_res, err := InsertByStmt(db_name, table_name, []string{"user_id=?", "createtime=?"}, []interface{}{"1", time}) + //insert_res, err := Insert(db_name, table_name, map[string]string{"user_id": "1", "createtime": helper.ToStr(time)}) + //if err != nil { + // t.Log(err) + //} + //fmt.Println("===Insert:", insert_res) + + //================更新数据===================== + //update_res, err := new(Query).Db(db_name).Clean().Table(table_name).Data("user_id=?").Value("2").Data("createtime=?").Value(time).Where("id=?").Value("6").Update() + //update_res, err := UpdateByStmt(db_name, table_name, []string{"createtime=?", "user_id=?"}, []string{"id=?"}, []interface{}{time, 3, 6}) + //if err != nil { + // t.Log(err) + //} + //fmt.Println("===Update:", update_res) + + //=============事务================ + /*fmt.Println("================开启事务============") + tx, err := DB.Begin() + if err != nil { + t.Log(err) + } + update_log, err := TxPreUpdate(tx, db_name, table_name, []string{"createtime= ?"}, []string{"id=?"}, []interface{}{time, 2}) + if err != nil { + tx.Rollback() + t.Log(err) + } + fmt.Println("===========事务执行:==================") + fmt.Println("===事务update:", update_log) + + insert_log, err := TxPreInsert(tx, db_name, table_name, map[string]interface{}{"user_id": "1", "createtime": helper.ToStr(time)}) + if err != nil { + tx.Rollback() + t.Log(err) + } + fmt.Println("===事务insert:", insert_log) + + del_log, err := TxDelete(tx, db_name, table_name, map[string]string{"id": "2"}) + if err != nil { + tx.Rollback() + t.Log(err) + } + fmt.Println("====事务delete:", del_log) + + err = tx.Commit() + if err != nil { + t.Log(err) + tx.Rollback() + } + fmt.Println("=======事务执行完成==========")*/ + + /*fmt.Println("====================执行事务trans============") + trans := NewTxQuery().Db(db_name) + + info_trans, err := trans.Clean().Table(table_name).Where("id = ?").Value(5).Find() + if err != nil { + trans.Rollback() + t.Log(err) + } + fmt.Println("===事务Find_trans:", info_trans) + + list_trans, err := trans.Clean().Table(table_name).Title("*").Select() + if err != nil { + trans.Rollback() + t.Log(err) + } + fmt.Println("=========事务List_trans:", list_trans) + + data := map[string]interface{}{ + "user_id": 5, + "memo": "test", + "createtime": time, + } + + add_trans, err := trans.Clean().Table(table_name).SaveData(data).CreateAll() + if err != nil { + trans.Rollback() + t.Log(err) + } + + fmt.Println("======事务Add_trans:", add_trans) + + data["id"] = 15 + + update_trans_res, err := trans.Clean().Table(table_name).SaveData(data).UpdateAll() + if err != nil { + trans.Rollback() + t.Log(err) + } + fmt.Println("=======事务update_trans", update_trans_res) + err = trans.Commit() + if err != nil { + trans.Rollback() + t.Log(err) + } + fmt.Println("====================执行事务结束==================")*/ } diff --git a/common.go b/common.go index dd3c320..da8f997 100644 --- a/common.go +++ b/common.go @@ -1,21 +1,130 @@ package dbquery import ( + "fmt" "git.tetele.net/tgo/helper" + "log" + "reflect" + "regexp" + "strconv" "strings" + "time" ) -// 对执行前的sql语句进行处理--针对pgsql用 -func SqlReplace(sql_s, sql_type string) string { - sql := strings.Replace(sql_s, "`", `"`, -1) +// ===================达梦兼容=============== +// 非关键字可以不添加标识符,关键字须添加 +// 日期函数的使用TO_CHAR(TO_DATE('1970-01-01','yyyy-mm-dd') + (createtime / 86400), 'yyyy-mm-dd') +// group_concat替换成LISTAGG +// ======================================== - if sql_type == "add" { +// 关键字替换-支持达梦和高斯 +func ReplaeByOtherSql(sql, sql_type, action string) string { + sql_type_arr := []string{"DmSql", "PgsqlDb"} + if !helper.IsInStringArray(sql_type_arr, sql_type) { + log.Println("sql_type error", sql_type) + return "" + } + // PgsqlDb用 + if action == "add" { sql = helper.StringJoin(sql, " RETURNING id") } + // 定义需要处理的关键字列表 + keywords := []string{"user", "order", "group", "table", "view", "admin", "new"} + if sql_type == "PgsqlDb" { + keywords = []string{"user", "order", "group"} + } + + // 移除所有反引号 + sql = strings.ReplaceAll(sql, "`", "") + + // 使用单词边界 \b 确保只匹配完整单词 + pattern := `\b(` + strings.Join(keywords, "|") + `)\b` + re := regexp.MustCompile(pattern) + + //设置保护词组 + excludePhrases := []string{ + "order by", "group by", + } + //保护排除短语 + phraseMap := make(map[string]string) + for i, phrase := range excludePhrases { + placeholder := fmt.Sprintf("__EXCLUDE_%d__", i) + phraseMap[placeholder] = phrase + sql = strings.Replace(sql, phrase, placeholder, -1) + } - sql = strings.Replace(sql, " user ", ` "user" `, -1) - sql = strings.Replace(sql, " user.", ` "user".`, -1) - sql = strings.Replace(sql, "=user.", `="user".`, -1) + // 执行替换 + sql = re.ReplaceAllStringFunc(sql, func(match string) string { + // 检查匹配是否在字符串常量中 + if isInStringLiteral(sql, match) { + return match + } + + if sql_type == "DmSql" { + return "`" + match + "`" + } else { + return `"` + match + `"` + } + }) + + // 恢复排除短语 + for placeholder, phrase := range phraseMap { + sql = strings.Replace(sql, placeholder, phrase, -1) + } return sql } + +// 检查匹配是否在字符串常量中 +func isInStringLiteral(sql, match string) bool { + index := strings.Index(sql, match) + if index == -1 { + return false + } + + // 检查匹配前的单引号数量 + prefix := sql[:index] + singleQuotes := strings.Count(prefix, "'") - strings.Count(prefix, "\\'") + + // 奇数表示在字符串常量中 + return singleQuotes%2 != 0 +} + +// 字段值类型转换--针对达梦用 +func ToString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + case int, int8, int16, int32, int64: + return strconv.FormatInt(reflect.ValueOf(value).Int(), 10) + case uint, uint8, uint16, uint32, uint64: + return strconv.FormatUint(reflect.ValueOf(value).Uint(), 10) + case float32, float64: + return strconv.FormatFloat(reflect.ValueOf(value).Float(), 'f', -1, 64) + case bool: + return strconv.FormatBool(v) + case time.Time: + return v.Format("2006-01-02 15:04:05") + default: + return fmt.Sprintf("%v", v) + } +} + +// 字符串切片元素追加前缀 +func addPrefixInField(slice []string, prefix string) []string { + new_slice := make([]string, len(slice)) + for i, v := range slice { + new_slice[i] = prefix + v + } + + return new_slice +} + +func DmFieldDeal(fields string) string { + //移除所有反引号 + title := strings.Replace(fields, "`", "", -1) + + return title +} diff --git a/conn.go b/conn.go index 003ca15..b70bdd2 100644 --- a/conn.go +++ b/conn.go @@ -2,12 +2,14 @@ package dbquery import ( "database/sql" + "fmt" "log" "errors" "strings" "time" + _ "dm" "git.tetele.net/tgo/helper" _ "gitee.com/opengauss/openGauss-connector-go-pq" // 高斯驱动(推荐)或 "github.com/lib/pq" @@ -23,39 +25,53 @@ var SLAVER_DB *sql.DB var DB_PROVIDER string func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error { + log.Println("mysql database connectting...") - log.Println("database connectting...") var dbConnErr error - if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != "" + const maxRetries = 10 - for i := 0; i < 10; i++ { - DB, dbConnErr = sql.Open("mysql", DBUSER+":"+DBPWD+"@tcp("+DBHOST+":"+DBPORT+")/"+DBNAME+"?charset=utf8mb4") - if dbConnErr != nil { - log.Println("ERROR", "can not connect to Database, ", dbConnErr) - time.Sleep(time.Second * 5) - } else { - if len(conns) > 0 { - DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制 - } else { - DB.SetMaxOpenConns(200) //默认值为0表示不限制 - } - if len(conns) > 1 { - DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数 - } else { - DB.SetMaxIdleConns(50) - } + if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" { + return errors.New("dm DBconnection params errors") + } - DB.Ping() + dsn := DBUSER + ":" + DBPWD + "@tcp(" + DBHOST + ":" + DBPORT + ")/" + DBNAME + "?charset=utf8mb4" - log.Println("database connected") - DB.SetConnMaxLifetime(time.Minute * 2) - break - } + for i := 0; i < maxRetries; i++ { + //Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题) + // 每次尝试创建新连接对象 + DB, dbConnErr = sql.Open("mysql", dsn) + if dbConnErr != nil { + log.Println("sql open failed:", i+1, dbConnErr) + time.Sleep(time.Second * 5) + continue } - } else { - return errors.New("db connection params errors") + + // 验证连接有效性 + if pingErr := DB.Ping(); pingErr != nil { + DB.Close() + // 记录真实错误信息 + dbConnErr = pingErr + log.Println("ping failed:", i+1, pingErr) + time.Sleep(time.Second * 5) + continue + } + + if len(conns) > 0 { + DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制 + } else { + DB.SetMaxOpenConns(200) //默认值为0表示不限制 + } + if len(conns) > 1 { + DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数 + } else { + DB.SetMaxIdleConns(50) + } + + log.Println("dm database connected") + DB.SetConnMaxLifetime(time.Minute * 2) + return nil } - return dbConnErr + return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr) } func CloseConn() error { @@ -110,6 +126,8 @@ func getTableName(dbName, table string, dbtype ...string) string { var db_type string = "mysql" if DB_PROVIDER == "PgsqlDb" { dbName = "" + } else if DB_PROVIDER == "DmSql" { + dbName = "" } if len(dbtype) > 0 { if dbtype[0] != "" { @@ -156,39 +174,110 @@ func judg() []string { return []string{"=", ">", "<", "!=", "<=", ">="} } +// pgsql连接 func PgConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error { - DB_PROVIDER = "PgsqlDb" log.Println("pg database connectting...") + var dbConnErr error - if DBHOST != "" && DBUSER != "" && DBPWD != "" && DBPORT != "" { //&& DBNAME != "" - dsn := "host=" + DBHOST + " port=" + DBPORT + " user=" + DBUSER + " password=" + DBPWD + " dbname=" + DBNAME + " sslmode=disable search_path=public" - log.Println("database dsn", dsn) - for i := 0; i < 10; i++ { - DB, dbConnErr = sql.Open("opengauss", dsn) - if dbConnErr != nil { - log.Println("ERROR", "can not connect to pg Database, ", dbConnErr) - time.Sleep(time.Second * 5) - } else { - if len(conns) > 0 { - DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制 - } else { - DB.SetMaxOpenConns(200) //默认值为0表示不限制 - } - if len(conns) > 1 { - DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数 - } else { - DB.SetMaxIdleConns(50) - } + const maxRetries = 10 + DB_PROVIDER = "PgsqlDb" - DB.Ping() + if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" { + return errors.New("dm DBconnection params errors") + } - log.Println("pg database connected") - DB.SetConnMaxLifetime(time.Minute * 2) - break - } + dsn := "host=" + DBHOST + " port=" + DBPORT + " user=" + DBUSER + " password=" + DBPWD + " dbname=" + DBNAME + " sslmode=disable search_path=public" + log.Println("database dsn", dsn) + + for i := 0; i < maxRetries; i++ { + //Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题) + // 每次尝试创建新连接对象 + DB, dbConnErr = sql.Open("opengauss", dsn) + if dbConnErr != nil { + log.Println("sql open failed:", i+1, dbConnErr) + time.Sleep(time.Second * 5) + continue } - } else { - return errors.New("db connection params errors") + + // 验证连接有效性 + if pingErr := DB.Ping(); pingErr != nil { + DB.Close() + // 记录真实错误信息 + dbConnErr = pingErr + log.Println("ping failed:", i+1, pingErr) + time.Sleep(time.Second * 5) + continue + } + + if len(conns) > 0 { + DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制 + } else { + DB.SetMaxOpenConns(200) //默认值为0表示不限制 + } + if len(conns) > 1 { + DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数 + } else { + DB.SetMaxIdleConns(50) + } + + log.Println("dm database connected") + DB.SetConnMaxLifetime(time.Minute * 2) + return nil } - return dbConnErr + + return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr) +} + +// 达梦8连接 +func DmConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error { + log.Println("DM database connectting...") + + var dbConnErr error + const maxRetries = 10 + DB_PROVIDER = "DmSql" + + if DBHOST == "" || DBUSER == "" || DBPWD == "" || DBPORT == "" { + return errors.New("dm DBconnection params errors") + } + + dsn := "dm://" + DBUSER + ":" + DBPWD + "@" + DBHOST + ":" + DBPORT + "?charSet=utf8&compatibleMode=mysql" + log.Println("database dsn", dsn) + + for i := 0; i < maxRetries; i++ { + //Open并不真正建立连接,它只是创建一个连接对象,实际的连接是延迟建立的。通过ping来检测(账号密码和网络问题) + // 每次尝试创建新连接对象 + DB, dbConnErr = sql.Open("dm", dsn) + if dbConnErr != nil { + log.Println("sql open failed:", i+1, dbConnErr) + time.Sleep(time.Second * 5) + continue + } + + // 验证连接有效性 + if pingErr := DB.Ping(); pingErr != nil { + DB.Close() + // 记录真实错误信息 + dbConnErr = pingErr + log.Println("ping failed:", i+1, pingErr) + time.Sleep(time.Second * 5) + continue + } + + if len(conns) > 0 { + DB.SetMaxOpenConns(conns[0]) //用于设置最大打开的连接数,默认值为0表示不限制 + } else { + DB.SetMaxOpenConns(200) //默认值为0表示不限制 + } + if len(conns) > 1 { + DB.SetMaxIdleConns(conns[1]) //用于设置闲置的连接数 + } else { + DB.SetMaxIdleConns(50) + } + + log.Println("dm database connected") + DB.SetConnMaxLifetime(time.Minute * 2) + return nil + } + + return fmt.Errorf("after %d attempts: %w", maxRetries, dbConnErr) } diff --git a/db.go b/db.go index ef2c30a..36fb16a 100644 --- a/db.go +++ b/db.go @@ -49,7 +49,7 @@ func Insert(dbName, table string, data map[string]string) (int64, error) { Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")" if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "add") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add") stmt, err = DB.Prepare(Sql) if err != nil { return 0, errors.New("创建失败:" + err.Error()) @@ -130,7 +130,9 @@ func Update(dbName, table string, data map[string]string, where map[string]strin Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ") if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } result, err := DB.Exec(Sql, valueList...) @@ -193,7 +195,9 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) ( Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } result, err := DB.Exec(Sql, valueList...) @@ -276,7 +280,9 @@ func GetData(dbName, table string, title string, where map[string]string, limit Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(keyList, " and ") + " " + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } rows, err = DB.Query(Sql, valueList...) @@ -310,8 +316,15 @@ func GetData(dbName, table string, title string, where map[string]string, limit if rowerr == nil { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - info[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + info[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + info[index] = helper.ToString(col) + } } } count++ @@ -467,7 +480,9 @@ func GetRow(dbName, table_name, alias string, titles string, with, join [][]stri for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 if DB_PROVIDER == "PgsqlDb" { sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str) - sql_str = SqlReplace(sql_str, "") + sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql_str = ReplaeByOtherSql(sql_str, "DmSql", "") } rows, err = db.Query(sql_str, valueList...) @@ -503,8 +518,15 @@ func GetRow(dbName, table_name, alias string, titles string, with, join [][]stri if rowerr == nil { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - info[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + info[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + info[index] = helper.ToString(col) + } } } count++ @@ -677,7 +699,9 @@ func FetchRows(dbName, table_name, alias string, titles string, with, join [][]s for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 if DB_PROVIDER == "PgsqlDb" { sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str) - sql_str = SqlReplace(sql_str, "") + sql_str = ReplaeByOtherSql(sql_str, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql_str = ReplaeByOtherSql(sql_str, "DmSql", "") } rows, err = db.Query(sql_str, valueList...) @@ -714,8 +738,15 @@ func FetchRows(dbName, table_name, alias string, titles string, with, join [][]s if rowerr == nil { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - info[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + info[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + info[index] = helper.ToString(col) + } } } count++ @@ -826,7 +857,9 @@ func GetList(dbName, table string, title string, where map[string]string, limit Sql = "select " + title + " from " + dbName + " where " + strings.Join(whereStr, " and ") + " " + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } rows, err = DB.Query(Sql, valueList...) @@ -873,8 +906,15 @@ func GetList(dbName, table string, title string, where map[string]string, limit for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - record[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + record[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + record[index] = helper.ToString(col) + } } } list = append(list, record) @@ -996,7 +1036,9 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to Sql = "select count(" + title + ") as count from " + dbName + " where " + strings.Join(whereStr, " and ") + " limit 1" if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } rows, err = DB.Query(Sql, valueList...) @@ -1059,7 +1101,9 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) { if len(args) > 1 { if DB_PROVIDER == "PgsqlDb" { queryStr = sqlx.Rebind(sqlx.DOLLAR, queryStr) - queryStr = SqlReplace(queryStr, "") + queryStr = ReplaeByOtherSql(queryStr, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + queryStr = ReplaeByOtherSql(queryStr, "DmSql", "") } rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",") if err != nil { @@ -1101,8 +1145,15 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - record[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + record[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + record[index] = helper.ToString(col) + } } } list = append(list, record) diff --git a/go.mod b/go.mod index b0231aa..8f336fe 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module git.tetele.net/tgo/dbquery -go 1.14 +go 1.23.0 + +toolchain go1.24.0 require ( git.tetele.net/tgo/helper v0.1.0 @@ -9,3 +11,52 @@ require ( github.com/go-sql-driver/mysql v1.8.1 github.com/jmoiron/sqlx v1.4.0 ) + +require ( + cloud.google.com/go v0.26.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/BurntSushi/toml v0.3.1 // indirect + github.com/census-instrumentation/opencensus-proto v0.2.1 // indirect + github.com/client9/misspell v0.3.4 // indirect + github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/envoyproxy/go-control-plane v0.9.4 // indirect + github.com/envoyproxy/protoc-gen-validate v0.1.0 // indirect + github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe // indirect + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b // indirect + github.com/golang/mock v1.1.1 // indirect + github.com/golang/protobuf v1.4.2 // indirect + github.com/golang/snappy v1.0.0 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/kr/pty v1.1.1 // indirect + github.com/kr/text v0.1.0 // indirect + github.com/lib/pq v1.10.9 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4 // indirect + github.com/stretchr/objx v0.1.0 // indirect + github.com/stretchr/testify v1.7.0 // indirect + github.com/tjfoc/gmsm v1.4.1 // indirect + github.com/yuin/goldmark v1.4.13 // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/exp v0.0.0-20190121172915-509febef88a4 // indirect + golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/net v0.42.0 // indirect + golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/telemetry v0.0.0-20250710130107-8d8967aff50b // indirect + golang.org/x/term v0.33.0 // indirect + golang.org/x/text v0.28.0 // indirect + golang.org/x/tools v0.35.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + google.golang.org/appengine v1.4.0 // indirect + google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55 // indirect + google.golang.org/grpc v1.31.0 // indirect + google.golang.org/protobuf v1.23.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc // indirect +) diff --git a/go.sum b/go.sum index 6c187de..aa22d14 100644 --- a/go.sum +++ b/go.sum @@ -32,10 +32,13 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -55,16 +58,20 @@ github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5Cc github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tjfoc/gmsm v1.4.1 h1:aMe1GlZb+0bLjn+cKTPEvvn9oUEBlJitaZiiBwsbgho= github.com/tjfoc/gmsm v1.4.1/go.mod h1:j4INPkHWMrhJb38G+J6W4Tw0AbuN8Thu3PbdVYhVcTE= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201012173705-84dcc777aaee/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -72,25 +79,33 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/telemetry v0.0.0-20250710130107-8d8967aff50b/go.mod h1:4ZwOYna0/zsOKwuR5X/m0QFOJpSZvAxFfkQT+Erd9D4= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/prepare.go b/prepare.go index ffd46ac..6582c9a 100644 --- a/prepare.go +++ b/prepare.go @@ -60,7 +60,9 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } stmt, err = DB.Prepare(Sql) } else { @@ -85,8 +87,6 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str return nil, errors.New("缺少必要参数") } - // log.Println(valuelist...) - rows, err := stmt.Query(valuelist...) defer stmt.Close() if err != nil { @@ -116,8 +116,15 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str if rowerr == nil { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - info[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + info[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + info[index] = helper.ToString(col) + } } } } else { @@ -164,8 +171,15 @@ func StmtForQueryRow(stmt *sql.Stmt, valuelist []interface{}) (map[string]string if rowerr == nil { for i, col := range values { if col != nil { - index = helper.StrFirstToUpper(columns[i]) - info[index] = helper.ToString(col) + if DB_PROVIDER == "DmSql" { + //达梦返回全大写字段,需先转小写 + index = helper.StrFirstToUpper(strings.ToLower(columns[i])) + //达梦返回的字段类型比较细,比如:int16、int32 + info[index] = ToString(col) + } else { + index = helper.StrFirstToUpper(columns[i]) + info[index] = helper.ToString(col) + } } } } else { @@ -200,7 +214,9 @@ func StmtForUpdate(dbName, table string, data []string, where []string) (*sql.St Sql = "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } stmt, err = DB.Prepare(Sql) @@ -240,7 +256,7 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) { var err error var sql string - if DB_PROVIDER == "PgsqlDb" { + if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" { insert_data := []string{} value_data := []string{} for _, rv := range data { @@ -257,13 +273,19 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) { } } - sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + + if DB_PROVIDER == "PgsqlDb" { + sql = helper.StringJoin(sql, " RETURNING id") + } } else { sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(data, " , ")) } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } //stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , ")) @@ -407,7 +429,9 @@ func StmtForQuery(querysql string) (*sql.Stmt, error) { var err error if DB_PROVIDER == "PgsqlDb" { querysql = sqlx.Rebind(sqlx.DOLLAR, querysql) - querysql = SqlReplace(querysql, "") + querysql = ReplaeByOtherSql(querysql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + querysql = ReplaeByOtherSql(querysql, "DmSql", "") } stmt, err = DB.Prepare(querysql) diff --git a/transaction.go b/transaction.go index 60794c8..4ed6899 100644 --- a/transaction.go +++ b/transaction.go @@ -49,7 +49,7 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64, var Sql string Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")" Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "add") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add") stmt, err = tx.Prepare(Sql) if err != nil { return 0, errors.New("创建失败:" + err.Error()) @@ -114,9 +114,9 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{}) value_data = append(value_data, "?") } if DB_PROVIDER == "PgsqlDb" { - Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "add") stmt, err = tx.Prepare(Sql) if err != nil { return 0, errors.New("创建失败:" + err.Error()) @@ -131,6 +131,10 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{}) } else { sql := "insert into " + dbName + " set " + strings.Join(field, " , ") + if DB_PROVIDER == "DmSql" { + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + sql = ReplaeByOtherSql(sql, "DmSql", "") + } stmt, err = tx.Prepare(sql) if err != nil { @@ -203,7 +207,9 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ") if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } result, err := tx.Exec(Sql, valueList...) @@ -245,7 +251,9 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = tx.Prepare(sql) @@ -312,7 +320,9 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } result, err := tx.Exec(Sql, valueList...) @@ -354,7 +364,9 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) ( Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE" if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) - Sql = SqlReplace(Sql, "") + Sql = ReplaeByOtherSql(Sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + Sql = ReplaeByOtherSql(Sql, "DmSql", "") } stmt, err = tx.Prepare(Sql) } else { diff --git a/transaction_chain.go b/transaction_chain.go index 9c68cc8..1924ef8 100644 --- a/transaction_chain.go +++ b/transaction_chain.go @@ -356,7 +356,9 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } return map[string]interface{}{ "sql": sql, @@ -374,16 +376,50 @@ func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error) "COLUMN_COMMENT", //备注 "IS_NULLABLE", //是否为空 } - sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?" + sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ?" + if DB_PROVIDER == "PgsqlDb" { + //pgsql中,未加引号的标识符会被自动转换为小写 + sql = `SELECT + c.column_name AS 'COLUMN_NAME', + c.column_default AS 'COLUMN_DEFAULT', + c.data_type AS 'DATA_TYPE', + c.udt_name AS 'COLUMN_TYPE', + pgdesc.description AS 'COLUMN_COMMENT', + c.is_nullable AS 'IS_NULLABLE' + FROM + information_schema.columns c + LEFT JOIN + pg_catalog.pg_statio_all_tables st ON st.schemaname = c.table_schema AND st.relname = c.table_name + LEFT JOIN + pg_catalog.pg_description pgdesc ON pgdesc.objoid = st.relid AND pgdesc.objsubid = c.ordinal_position + WHERE + c.table_name = ?` + sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + } else if DB_PROVIDER == "DmSql" { + sql = `SELECT + COLUMN_NAME, + DATA_DEFAULT AS COLUMN_DEFAULT, + DATA_TYPE, + CASE + WHEN DATA_TYPE LIKE '%CHAR%' OR DATA_TYPE LIKE '%TEXT%' THEN DATA_TYPE || '(' || CHAR_LENGTH || ')' + WHEN DATA_TYPE IN ('NUMERIC', 'DECIMAL') THEN DATA_TYPE || '(' || DATA_PRECISION || ',' || DATA_SCALE || ')' + ELSE DATA_TYPE + END AS COLUMN_TYPE, + '' AS COLUMN_COMMENT, + CASE NULLABLE WHEN 'Y' THEN 'YES' WHEN 'N' THEN 'NO' END AS IS_NULLABLE + FROM + ALL_TAB_COLUMNS + WHERE + TABLE_NAME = ?` } + stmtSql, err := this.tx.Prepare(sql) if err != nil { return nil, err } - list, err := StmtForQueryList(stmtSql, []interface{}{table, this.dbname}) + list, err := StmtForQueryList(stmtSql, []interface{}{table}) if err != nil { return nil, err } @@ -486,7 +522,9 @@ func (this *TxQuery) UpdateStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "MysqlDb" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.tx.Prepare(sql) @@ -512,8 +550,11 @@ func (this *TxQuery) UpdateAllStmt() error { var dataSql []string //一组用到的占位字符 var valSql []string //占位字符组 var updSql []string //更新字段的sql + var updSql_dm []string //更新字段的sql--达梦和高斯用 var updFieldLen = len(this.upd_field) //需要更新的字段数量,为0时更新除id外添加值 + dataLen := len(this.save_data) + if dataLen > 0 { //批量操作 this.data = this.data[0:0] @@ -531,11 +572,14 @@ func (this *TxQuery) UpdateAllStmt() error { case 0: //预览创建数据的长度 updSql = make([]string, 0, fieldLen) + updSql_dm = make([]string, 0, fieldLen) default: //按照需要更新字段数长度 updSql = make([]string, 0, updFieldLen) + updSql_dm = make([]string, 0, updFieldLen) for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 + updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段 } } for k := range this.save_data[i] { @@ -543,6 +587,7 @@ func (this *TxQuery) UpdateAllStmt() error { dataSql = append(dataSql, "?") //存储需要的占位符 if updFieldLen == 0 && k != "id" { updSql = append(updSql, k+"=values("+k+")") //存储需要更新的字段 + updSql_dm = append(updSql_dm, k+"=s."+k+"") //存储需要更新的字段 } } dataSqlText = strings.Join(dataSql, ",") //组成每组占位字符格式 @@ -560,21 +605,26 @@ func (this *TxQuery) UpdateAllStmt() error { switch updFieldLen { case 0: updSql = make([]string, 0, fieldLen) + updSql_dm = make([]string, 0, fieldLen) default: updSql = make([]string, 0, updFieldLen) + updSql_dm = make([]string, 0, updFieldLen) for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") + updSql_dm = append(updSql_dm, k+"=s."+k+"") } } for i := 0; i < fieldLen; i++ { dataSql = append(dataSql, "?") if updFieldLen == 0 && this.data[i] != "id" { updSql = append(updSql, this.data[i]+"=values("+this.data[i]+")") + updSql_dm = append(updSql_dm, this.data[i]+"=s."+this.data[i]+"") } } if updFieldLen > 0 { for _, k := range this.upd_field { updSql = append(updSql, k+"=values("+k+")") + updSql_dm = append(updSql_dm, k+"=s."+k+"") } } valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") @@ -590,10 +640,43 @@ func (this *TxQuery) UpdateAllStmt() error { if len(valSql) > 1 { setText = " value " } + + sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) + if DB_PROVIDER == "PgsqlDb" { setText = " values " + val_field := addPrefixInField(this.data, "s.") + + sql = `merge into ` + dbName + ` as t + using ( + ` + setText + strings.Join(valSql, ",") + ` + ) s (` + strings.Join(this.data, " , ") + `) + on (t.id = s.id) + when matched then + update set + ` + strings.Join(updSql_dm, " , ") + ` + when NOT matched then + insert (` + strings.Join(this.data, " , ") + `) + values (` + strings.Join(val_field, " , ") + `)` + } else if DB_PROVIDER == "DmSql" { + setText = " values " + val_field := addPrefixInField(this.data, "s.") + title_field := addPrefixInField(this.data, "? AS ") + + sql = `MERGE INTO ` + dbName + ` AS t + USING ( + SELECT + ` + strings.Join(title_field, " , ") + ` + FROM DUAL + ) s + ON (t.id = s.id) + WHEN MATCHED THEN + UPDATE SET + ` + strings.Join(updSql_dm, " , ") + ` + WHEN NOT MATCHED THEN + INSERT (` + strings.Join(this.data, " , ") + `) + VALUES (` + strings.Join(val_field, " , ") + `)` } - sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) if this.debug { log.Println("insert on duplicate key update sql:", sql, this.value) @@ -610,8 +693,11 @@ func (this *TxQuery) UpdateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } + stmt, err = this.tx.Prepare(sql) if err != nil { @@ -633,7 +719,7 @@ func (this *TxQuery) CreateStmt() error { dbName := getTableName(this.dbname, this.table) var sql string - if DB_PROVIDER == "PgsqlDb" { + if DB_PROVIDER == "PgsqlDb" || DB_PROVIDER == "DmSql" { insert_data := []string{} value_data := []string{} for _, rv := range this.data { @@ -650,7 +736,11 @@ func (this *TxQuery) CreateStmt() error { } } - sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + + if DB_PROVIDER == "PgsqlDb" { + sql = helper.StringJoin(sql, " RETURNING id") + } } else { sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) } @@ -672,7 +762,9 @@ func (this *TxQuery) CreateStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.tx.Prepare(sql) @@ -746,6 +838,8 @@ func (this *TxQuery) CreateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { setText = " values " + } else if DB_PROVIDER == "DmSql" { + setText = " values " } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) if len(this.value) == 0 { @@ -767,8 +861,11 @@ func (this *TxQuery) CreateAllStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "add") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "add") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } + stmt, err = this.tx.Prepare(sql) if err != nil { @@ -816,7 +913,9 @@ func (this *TxQuery) DeleteStmt() error { } if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) - sql = SqlReplace(sql, "") + sql = ReplaeByOtherSql(sql, "PgsqlDb", "") + } else if DB_PROVIDER == "DmSql" { + sql = ReplaeByOtherSql(sql, "DmSql", "") } stmt, err = this.tx.Prepare(sql)