diff --git a/chain.go b/chain.go index 65720ad..af8106a 100644 --- a/chain.go +++ b/chain.go @@ -69,6 +69,9 @@ func (this *Query) Conn(conn *sql.DB) *Query { } func (this *Query) Db(dbname string) *Query { this.dbname = dbname + if DB_PROVIDER == "PgsqlDb" { + this.dbname = "" + } return this } @@ -226,6 +229,9 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) { if this.conn == nil { this.conn = DB } + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmtSql, err := this.conn.Prepare(sql) if err != nil { return nil, err @@ -414,7 +420,11 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { from := strconv.Itoa((this.page - 1) * this.page_size) offset := strconv.Itoa(this.page_size) if from != "" && offset != "" { - sql = helper.StringJoin(sql, " limit ", from, " , ", offset) + if DB_PROVIDER == "PgsqlDb" { + sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from) + } else { + sql = helper.StringJoin(sql, " limit ", from, " , ", offset) + } } } if this.debug { @@ -429,6 +439,9 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { if condition_len != len(this.value) { return nil, errors.New("参数错误,条件值错误") } + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } return map[string]interface{}{ "sql": sql, "value": this.value, @@ -501,9 +514,7 @@ func (this *Query) UpdateStmt() error { this.conn = DB } if DB_PROVIDER == "PgsqlDb" { - log.Println("PgsqlDb sql", sql) sql = sqlx.Rebind(sqlx.DOLLAR, sql) - log.Println("PgsqlDb sql", sql) } stmt, err = this.conn.Prepare(sql) @@ -608,6 +619,9 @@ func (this *Query) UpdateAllStmt() error { if len(valSql) > 1 { setText = " value " } + if DB_PROVIDER == "PgsqlDb" { + setText = " values " + } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) if this.debug { @@ -626,7 +640,9 @@ func (this *Query) UpdateAllStmt() error { if this.conn == nil { this.conn = DB } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.conn.Prepare(sql) if err != nil { @@ -697,6 +713,9 @@ func (this *Query) CreateAllStmt() error { if len(valSql) > 1 { setText = " value " } + if DB_PROVIDER == "PgsqlDb" { + setText = " values " + } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) if this.debug { @@ -745,10 +764,10 @@ func (this *Query) CreateStmt() error { value_data := []string{} for _, rv := range this.data { dv := strings.Split(rv, "=") - if len(dv) != 2 { + if len(dv) < 2 { return errors.New("参数错误,条件值错误,=号不存在") } - if strings.Contains(rv, "=?") { + if strings.Contains(rv, "?") { insert_data = append(insert_data, dv[0]) value_data = append(value_data, "?") } else { @@ -780,9 +799,7 @@ func (this *Query) CreateStmt() error { this.conn = DB } if DB_PROVIDER == "PgsqlDb" { - log.Println("PgsqlDb sql", sql) sql = sqlx.Rebind(sqlx.DOLLAR, sql) - log.Println("PgsqlDb sql", sql) } stmt, err = this.conn.Prepare(sql) @@ -832,7 +849,9 @@ func (this *Query) DeleteStmt() error { if this.conn == nil { this.conn = DB } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.conn.Prepare(sql) if err != nil { diff --git a/conn.go b/conn.go index ca4bd88..003ca15 100644 --- a/conn.go +++ b/conn.go @@ -108,7 +108,9 @@ func CloseSlaverConn() error { func getTableName(dbName, table string, dbtype ...string) string { var db_type string = "mysql" - + if DB_PROVIDER == "PgsqlDb" { + dbName = "" + } if len(dbtype) > 0 { if dbtype[0] != "" { db_type = dbtype[0] diff --git a/db.go b/db.go index 8f3361c..3d596e4 100644 --- a/db.go +++ b/db.go @@ -45,16 +45,33 @@ func Insert(dbName, table string, data map[string]string) (int64, error) { valueList[i] = value i++ } - - result, err := DB.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...) - - if err != nil { - log.Println("ERROR|插入", dbName, "数据失败,", err) - return insertId, err + var Sql string + Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")" + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + Sql = helper.StringJoin(Sql, " RETURNING id") + stmt, err = DB.Prepare(Sql) + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + row := stmt.QueryRow(valueList...) + var id int64 + err = row.Scan(&id) // 扫描 RETURNING 返回的 ID + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return id, nil } else { - insertId, _ = result.LastInsertId() - time.Sleep(time.Second * 2) - return insertId, nil + result, err := DB.Exec(Sql, valueList...) + + if err != nil { + log.Println("ERROR|插入", dbName, "数据失败,", err) + return insertId, err + } else { + insertId, _ = result.LastInsertId() + time.Sleep(time.Second * 2) + return insertId, nil + } } } @@ -68,7 +85,6 @@ func Update(dbName, table string, data map[string]string, where map[string]strin if dbName == "" && table == "" { return rowsAffected, errors.New("没有数据表") } - if strings.Contains(table, "select ") { dbName = table } else { @@ -110,7 +126,12 @@ func Update(dbName, table string, data map[string]string, where map[string]strin log.Println("ERROR|修改数据表", dbName, "时条件中有空数据,条件:", where, "数据:", data) return rowsAffected, errors.New("条件中有空数据") } - result, err := DB.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...) + var Sql string + Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ") + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + result, err := DB.Exec(Sql, valueList...) if err != nil { log.Println("ERROR|修改", dbName, "数据失败,", err) @@ -132,7 +153,6 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) ( if dbName == "" && table == "" { return count, errors.New("没有数据表") } - if strings.Contains(table, "select ") { dbName = table } else { @@ -168,7 +188,12 @@ func Delete(dbName, table string, data map[string]string, del_count ...string) ( limitStr = " limit " + del_count[0] } - result, err := DB.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...) + var Sql string + Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + result, err := DB.Exec(Sql, valueList...) if err != nil { log.Println("ERROR|删除", dbName, "数据失败,", err) @@ -192,7 +217,6 @@ func GetData(dbName, table string, title string, where map[string]string, limit if dbName == "" && table == "" { return count, info, errors.New("没有数据表") } - dbName = getTableName(dbName, table) if len(title) < 1 { @@ -211,7 +235,11 @@ func GetData(dbName, table string, title string, where map[string]string, limit if _, ok := limit["from"]; ok { from = limit["from"] } - limitStr += " limit " + from + ",1" + if DB_PROVIDER == "PgsqlDb" { + limitStr += " limit 1 OFFSET " + from + } else { + limitStr += " limit " + from + ",1" + } } else { limitStr = " limit 1" @@ -242,8 +270,12 @@ func GetData(dbName, table string, title string, where map[string]string, limit var err error var queryNum int = 0 for queryNum < 3 { //如发生错误,继续查询3次,防止数据库连接断开问题 - - rows, err = DB.Query("SELECT "+title+" FROM "+dbName+" where "+strings.Join(keyList, " and ")+" "+limitStr, valueList...) + var Sql string + Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(keyList, " and ") + " " + limitStr + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + rows, err = DB.Query(Sql, valueList...) if err == nil { break @@ -303,7 +335,6 @@ func GetRow(dbName, table_name, alias string, titles string, with, join [][]stri if dbName == "" && table_name == "" { return count, info, errors.New("没有数据表") } - table := "" withSql := "" if len(with) > 0 { @@ -432,9 +463,7 @@ func GetRow(dbName, table_name, alias string, titles string, with, join [][]stri for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 if DB_PROVIDER == "PgsqlDb" { - log.Println("PgsqlDb sql_str", sql_str) sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str) - log.Println("PgsqlDb sql_str", sql_str) } rows, err = db.Query(sql_str, valueList...) @@ -642,9 +671,7 @@ func FetchRows(dbName, table_name, alias string, titles string, with, join [][]s var queryNum int = 0 for queryNum < 2 { //如发生错误,继续查询2次,防止数据库连接断开问题 if DB_PROVIDER == "PgsqlDb" { - log.Println("PgsqlDb sql_str", sql_str) sql_str = sqlx.Rebind(sqlx.DOLLAR, sql_str) - log.Println("PgsqlDb sql_str", sql_str) } rows, err = db.Query(sql_str, valueList...) @@ -723,7 +750,6 @@ func GetList(dbName, table string, title string, where map[string]string, limit if dbName == "" && table == "" { return list, errors.New("没有数据表") } - if strings.Contains(table, "select ") { dbName = table } else { @@ -751,7 +777,12 @@ func GetList(dbName, table string, title string, where map[string]string, limit from = limit["from"] } if offset != "0" && from != "" { - limitStr += " limit " + from + "," + offset + + if DB_PROVIDER == "PgsqlDb" { + limitStr += " limit " + offset + " OFFSET " + from + } else { + limitStr += " limit " + from + "," + offset + } } } @@ -785,8 +816,12 @@ func GetList(dbName, table string, title string, where map[string]string, limit } for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 - - rows, err = DB.Query("select "+title+" from "+dbName+" where "+strings.Join(whereStr, " and ")+" "+limitStr, valueList...) + var Sql string + Sql = "select " + title + " from " + dbName + " where " + strings.Join(whereStr, " and ") + " " + limitStr + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + rows, err = DB.Query(Sql, valueList...) if err == nil { break @@ -850,7 +885,6 @@ func GetTotal(dbName, table string, args ...string) (total int) { if dbName == "" && table == "" { return } - if strings.Contains(table, "select ") { dbName = table } else { @@ -868,7 +902,6 @@ func GetTotal(dbName, table string, args ...string) (total int) { var queryNum int = 0 for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 - rows, err = DB.Query("select count(" + title + ") as count from " + dbName + " limit 1") if err == nil { @@ -952,7 +985,12 @@ func GetCount(dbName, table string, where map[string]string, args ...string) (to for queryNum < 5 { //如发生错误,继续查询5次,防止数据库连接断开问题 - rows, err = DB.Query("select count("+title+") as count from "+dbName+" where "+strings.Join(whereStr, " and ")+" limit 1", valueList...) + var Sql string + Sql = "select count(" + title + ") as count from " + dbName + " where " + strings.Join(whereStr, " and ") + " limit 1" + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + rows, err = DB.Query(Sql, valueList...) if err == nil { break @@ -1011,6 +1049,9 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) { for queryNum < 3 { //如发生错误,继续查询5次,防止数据库连接断开问题 if len(args) > 1 { + if DB_PROVIDER == "PgsqlDb" { + queryStr = sqlx.Rebind(sqlx.DOLLAR, queryStr) + } rows, err = DB.Query(queryStr, args[1:]...) //strings.Join(args[1:], ",") if err != nil { log.Println("ERROR|DoQuery error:", err) diff --git a/prepare.go b/prepare.go index 26955ed..6194066 100644 --- a/prepare.go +++ b/prepare.go @@ -3,6 +3,7 @@ package dbquery import ( "database/sql" "errors" + "github.com/jmoiron/sqlx" "log" "strings" @@ -41,8 +42,11 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s offset = limit["offset"] } if from != "" && offset != "" { - - limitStr += " limit " + from + "," + offset + if DB_PROVIDER == "PgsqlDb" { + limitStr += " limit " + offset + " OFFSET " + from + } else { + limitStr += " limit " + from + "," + offset + } } } @@ -52,7 +56,12 @@ func StmtForRead(dbName, table string, title string, where []string, limit map[s if len(where) > 0 { // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr) - stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr) + var Sql string + Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + limitStr + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + stmt, err = DB.Prepare(Sql) } else { // log.Println("SELECT " + title + " FROM " + dbName + limitStr) stmt, err = DB.Prepare("SELECT " + title + " FROM " + dbName + limitStr) @@ -186,8 +195,12 @@ func StmtForUpdate(dbName, table string, data []string, where []string) (*sql.St var stmt *sql.Stmt var err error - - stmt, err = DB.Prepare("update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ")) + var Sql string + Sql = "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + stmt, err = DB.Prepare(Sql) return stmt, err } @@ -224,7 +237,34 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) { var stmt *sql.Stmt var err error - stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , ")) + var sql string + if DB_PROVIDER == "PgsqlDb" { + insert_data := []string{} + value_data := []string{} + for _, rv := range data { + dv := strings.Split(rv, "=") + if len(dv) < 2 { + return nil, errors.New("参数错误,条件值错误,=号不存在") + } + if strings.Contains(rv, "?") { + insert_data = append(insert_data, dv[0]) + value_data = append(value_data, "?") + } else { + insert_data = append(insert_data, dv[0]) + value_data = append(value_data, dv[1]) + } + + } + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + } else { + sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(data, " , ")) + } + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } + + //stmt, err = DB.Prepare("insert into " + dbName + " set " + strings.Join(data, " , ")) + stmt, err = DB.Prepare(sql) return stmt, err } @@ -362,7 +402,9 @@ func StmtForQuery(querysql string) (*sql.Stmt, error) { var stmt *sql.Stmt var err error - + if DB_PROVIDER == "PgsqlDb" { + querysql = sqlx.Rebind(sqlx.DOLLAR, querysql) + } stmt, err = DB.Prepare(querysql) return stmt, err diff --git a/transaction.go b/transaction.go index b755243..bd6cd01 100644 --- a/transaction.go +++ b/transaction.go @@ -6,6 +6,8 @@ package dbquery import ( "database/sql" "errors" + "git.tetele.net/tgo/helper" + "github.com/jmoiron/sqlx" "log" "strings" "time" @@ -43,16 +45,33 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64, valueList[i] = value i++ } - - result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...) - - if err != nil { - log.Println("ERROR", "insert into ", dbName, "error:", err) - return insertId, err + if DB_PROVIDER == "PgsqlDb" { + var Sql string + Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") value(" + strings.Join(keyStr, ",") + ")" + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + Sql = helper.StringJoin(Sql, " RETURNING id") + stmt, err = tx.Prepare(Sql) + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + row := stmt.QueryRow(valueList...) + var id int64 + err = row.Scan(&id) // 扫描 RETURNING 返回的 ID + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return id, nil } else { - insertId, _ = result.LastInsertId() - time.Sleep(time.Second * 2) - return insertId, nil + result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...) + + if err != nil { + log.Println("ERROR", "insert into ", dbName, "error:", err) + return insertId, err + } else { + insertId, _ = result.LastInsertId() + time.Sleep(time.Second * 2) + return insertId, nil + } } } @@ -84,28 +103,47 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{}) var field []string = make([]string, len(data)) var valuelist []interface{} = make([]interface{}, len(data)) - + insert_data := []string{} + value_data := []string{} var i int = 0 for key, item := range data { field[i] = key + "=?" valuelist[i] = item i++ - } + insert_data = append(insert_data, key) + value_data = append(value_data, "?") + } + if DB_PROVIDER == "PgsqlDb" { + Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + stmt, err = tx.Prepare(Sql) + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + row := stmt.QueryRow(valuelist...) + var id int64 + err = row.Scan(&id) // 扫描 RETURNING 返回的 ID + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return id, nil + } else { - sql := "insert into " + dbName + " set " + strings.Join(field, " , ") - stmt, err = tx.Prepare(sql) + sql := "insert into " + dbName + " set " + strings.Join(field, " , ") + stmt, err = tx.Prepare(sql) - if err != nil { - log.Println("insert prepare error:", sql, err) - return 0, errors.New("insert prepare error:" + err.Error()) - } - result, err := stmt.Exec(valuelist...) - if err != nil { - log.Println("insert exec error:", sql, valuelist, err) - return 0, errors.New("insert exec error:" + err.Error()) + if err != nil { + log.Println("insert prepare error:", sql, err) + return 0, errors.New("insert prepare error:" + err.Error()) + } + result, err := stmt.Exec(valuelist...) + if err != nil { + log.Println("insert exec error:", sql, valuelist, err) + return 0, errors.New("insert exec error:" + err.Error()) + } + insertId, _ := result.LastInsertId() + return insertId, nil } - insertId, _ := result.LastInsertId() - return insertId, nil } @@ -160,7 +198,12 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma log.Println("ERROR", "update", dbName, "error, params empty") return rowsAffected, errors.New("params empty") } - result, err := tx.Exec("update "+dbName+" set "+strings.Join(keyList, " , ")+" where "+strings.Join(whereStr, " and "), valueList...) + var Sql string + Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ") + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + result, err := tx.Exec(Sql, valueList...) if err != nil { log.Println("ERROR", "update", dbName, "error:", err) @@ -198,7 +241,9 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string var stmt *sql.Stmt sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = tx.Prepare(sql) if err != nil { @@ -260,7 +305,12 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou limitStr = " limit " + del_count[0] } - result, err := tx.Exec("delete from "+dbName+" where "+strings.Join(keyList, " and ")+limitStr, valueList...) + var Sql string + Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + result, err := tx.Exec(Sql, valueList...) if err != nil { log.Println("ERROR", "delete from", dbName, "error:", err) @@ -296,7 +346,12 @@ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) ( if len(where) > 0 { // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE") - stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE") + var Sql string + Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE" + if DB_PROVIDER == "PgsqlDb" { + Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) + } + stmt, err = tx.Prepare(Sql) } else { // log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE") stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE") diff --git a/transaction_chain.go b/transaction_chain.go index cce1127..03b51ae 100644 --- a/transaction_chain.go +++ b/transaction_chain.go @@ -6,6 +6,7 @@ package dbquery import ( "database/sql" "errors" + "github.com/jmoiron/sqlx" "log" "strconv" "strings" @@ -331,7 +332,12 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { from := strconv.Itoa((this.page - 1) * this.page_size) offset := strconv.Itoa(this.page_size) if from != "" && offset != "" { - sql = helper.StringJoin(sql, " limit ", from, " , ", offset) + if DB_PROVIDER == "PgsqlDb" { + sql = helper.StringJoin(sql, " limit ", offset, " OFFSET ", from) + } else { + sql = helper.StringJoin(sql, " limit ", from, " , ", offset) + } + } } @@ -348,6 +354,9 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { if condition_len != len(this.value) { return nil, errors.New("参数错误,条件值错误") } + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } return map[string]interface{}{ "sql": sql, "value": this.value, @@ -365,7 +374,9 @@ func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error) "IS_NULLABLE", //是否为空 } sql := "select `" + strings.Join(field, "`,`") + "` from information_schema.COLUMNS where table_name = ? and table_schema = ?" - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmtSql, err := this.tx.Prepare(sql) if err != nil { return nil, err @@ -471,7 +482,9 @@ func (this *TxQuery) UpdateStmt() error { if condition_len != len(this.value) { return errors.New("参数错误,条件值错误") } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.tx.Prepare(sql) if err != nil { @@ -574,6 +587,9 @@ func (this *TxQuery) UpdateAllStmt() error { if len(valSql) > 1 { setText = " value " } + if DB_PROVIDER == "PgsqlDb" { + setText = " values " + } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) if this.debug { @@ -589,6 +605,9 @@ func (this *TxQuery) UpdateAllStmt() error { if conditionLen != len(this.value) { return errors.New("参数错误,条件值数量不匹配") } + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.tx.Prepare(sql) if err != nil { @@ -610,8 +629,28 @@ func (this *TxQuery) CreateStmt() error { dbName := getTableName(this.dbname, this.table) var sql string + if DB_PROVIDER == "PgsqlDb" { + insert_data := []string{} + value_data := []string{} + for _, rv := range this.data { + dv := strings.Split(rv, "=") + if len(dv) < 2 { + return errors.New("参数错误,条件值错误,=号不存在") + } + if strings.Contains(rv, "?") { + insert_data = append(insert_data, dv[0]) + value_data = append(value_data, "?") + } else { + insert_data = append(insert_data, dv[0]) + value_data = append(value_data, dv[1]) + } - sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) + } + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") + } else { + sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) + } + //sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) if this.debug { log.Println("insert sql:", sql, this.value) @@ -627,7 +666,9 @@ func (this *TxQuery) CreateStmt() error { if condition_len != len(this.value) { return errors.New("参数错误,条件值错误") } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.tx.Prepare(sql) if err != nil { @@ -698,6 +739,9 @@ func (this *TxQuery) CreateAllStmt() error { if len(valSql) > 1 { setText = " value " } + if DB_PROVIDER == "PgsqlDb" { + setText = " values " + } sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) if len(this.value) == 0 { return errors.New("参数错误,条件值错误") @@ -716,7 +760,10 @@ func (this *TxQuery) CreateAllStmt() error { if conditionLen != len(this.value) { return errors.New("参数错误,条件值数量不匹配") } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + sql = helper.StringJoin(sql, " RETURNING id") + } stmt, err = this.tx.Prepare(sql) if err != nil { @@ -762,7 +809,9 @@ func (this *TxQuery) DeleteStmt() error { if condition_len != len(this.value) { return errors.New("参数错误,条件值错误") } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + } stmt, err = this.tx.Prepare(sql) if err != nil {