/** * DB transaction */ package dbquery import ( "database/sql" "errors" "git.tetele.net/tgo/helper" "github.com/jmoiron/sqlx" "log" "strings" "time" ) /** * 创建数据 */ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64, error) { var insertId int64 var err error if dbname == "" && table == "" { return 0, errors.New("参数错误,没有数据表") } dbName := "" if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbname, table) } if len(data) < 1 { return 0, errors.New("参数错误,没有要写入的数据") } keyList := make([]string, len(data)) keyStr := make([]string, len(data)) valueList := make([]interface{}, len(data)) var i int = 0 for key, value := range data { keyList[i] = key keyStr[i] = "?" valueList[i] = value i++ } if DB_PROVIDER == "PgsqlDb" { var Sql string Sql = "insert into " + dbName + " (" + strings.Join(keyList, ",") + ") values (" + strings.Join(keyStr, ",") + ")" Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) Sql = strings.Replace(Sql, "`", `"`, -1) Sql = helper.StringJoin(Sql, " RETURNING id") stmt, err = tx.Prepare(Sql) if err != nil { return 0, errors.New("创建失败:" + err.Error()) } row := stmt.QueryRow(valueList...) var id int64 err = row.Scan(&id) // 扫描 RETURNING 返回的 ID if err != nil { return 0, errors.New("创建失败:" + err.Error()) } return id, nil } else { result, err := tx.Exec("insert into "+dbName+" ("+strings.Join(keyList, ",")+") value("+strings.Join(keyStr, ",")+")", valueList...) if err != nil { log.Println("ERROR", "insert into ", dbName, "error:", err) return insertId, err } else { insertId, _ = result.LastInsertId() time.Sleep(time.Second * 2) return insertId, nil } } } /** * 准备写入 * return Stmt error */ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{}) (int64, error) { if dbname == "" && table == "" { return 0, errors.New("params error,no db or table") } dbName := "" if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbname, table) } if len(data) < 1 { return 0, errors.New("params error,no data to insert") } var err error var stmt *sql.Stmt var field []string = make([]string, len(data)) var valuelist []interface{} = make([]interface{}, len(data)) insert_data := []string{} value_data := []string{} var i int = 0 for key, item := range data { field[i] = key + "=?" valuelist[i] = item i++ insert_data = append(insert_data, key) value_data = append(value_data, "?") } if DB_PROVIDER == "PgsqlDb" { Sql := helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) Sql = strings.Replace(Sql, "`", `"`, -1) stmt, err = tx.Prepare(Sql) if err != nil { return 0, errors.New("创建失败:" + err.Error()) } row := stmt.QueryRow(valuelist...) var id int64 err = row.Scan(&id) // 扫描 RETURNING 返回的 ID if err != nil { return 0, errors.New("创建失败:" + err.Error()) } return id, nil } else { sql := "insert into " + dbName + " set " + strings.Join(field, " , ") stmt, err = tx.Prepare(sql) if err != nil { log.Println("insert prepare error:", sql, err) return 0, errors.New("insert prepare error:" + err.Error()) } result, err := stmt.Exec(valuelist...) if err != nil { log.Println("insert exec error:", sql, valuelist, err) return 0, errors.New("insert exec error:" + err.Error()) } insertId, _ := result.LastInsertId() return insertId, nil } } /** * 修改数据 */ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where map[string]string) (int64, error) { var rowsAffected int64 var err error if dbname == "" && table == "" { return rowsAffected, errors.New("参数错误,没有数据表") } dbName := "" if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbname, table) } if len(data) < 1 { return rowsAffected, errors.New("参数错误,没有要写入的数据") } if len(where) < 1 { return rowsAffected, errors.New("参数错误,没有修改条件") } keyList := make([]string, len(data)) valueList := make([]interface{}, len(data), len(data)+len(where)) whereStr := make([]string, len(where)) var i int = 0 empty := false for key, value := range data { keyList[i] = key + "=?" valueList[i] = value i++ } i = 0 for key, value := range where { if value == "" { empty = true break } whereStr[i] = key + "=?" valueList = append(valueList, value) i++ } if empty { log.Println("ERROR", "update", dbName, "error, params empty") return rowsAffected, errors.New("params empty") } var Sql string Sql = "update " + dbName + " set " + strings.Join(keyList, " , ") + " where " + strings.Join(whereStr, " and ") if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) Sql = strings.Replace(Sql, "`", `"`, -1) } result, err := tx.Exec(Sql, valueList...) if err != nil { log.Println("ERROR", "update", dbName, "error:", err) return rowsAffected, err } else { rowsAffected, _ = result.RowsAffected() return rowsAffected, nil } } /** * 准备更新 * return Stmt error */ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string, valuelist []interface{}) (int64, error) { if dbname == "" && table == "" { return 0, errors.New("params error,no db or table") } dbName := "" if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbname, table) } if len(where) < 1 { return 0, errors.New("params error, no data for update") } var err error var stmt *sql.Stmt sql := "update " + dbName + " set " + strings.Join(data, " , ") + " where " + strings.Join(where, " and ") if DB_PROVIDER == "PgsqlDb" { sql = sqlx.Rebind(sqlx.DOLLAR, sql) sql = strings.Replace(sql, "`", `"`, -1) } stmt, err = tx.Prepare(sql) if err != nil { log.Println("update prepare error:", sql, err) return 0, errors.New("update prepare error:" + err.Error()) } res, err := stmt.Exec(valuelist...) if err != nil { log.Println("update exec error:", sql, valuelist, err) return 0, errors.New("update exec error:" + err.Error()) } return res.RowsAffected() } /** * 删除数据 * @param count 删除数量 */ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_count ...string) (int64, error) { var count int64 var err error if dbname == "" && table == "" { return count, errors.New("参数错误,没有数据表") } dbName := "" if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbname, table) } if len(where) < 1 { return count, errors.New("参数错误,没有删除条件") } keyList := make([]string, len(where)) valueList := make([]interface{}, len(where)) var i int = 0 empty := false for key, value := range where { if value == "" { empty = true break } keyList[i] = key + "=?" valueList[i] = value i++ } if empty { log.Println("ERROR", "delete from", dbName, "error, where:", where) return count, errors.New("params empty") } var limitStr string = "" if len(del_count) > 0 { limitStr = " limit " + del_count[0] } var Sql string Sql = "delete from " + dbName + " where " + strings.Join(keyList, " and ") + limitStr if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) Sql = strings.Replace(Sql, "`", `"`, -1) } result, err := tx.Exec(Sql, valueList...) if err != nil { log.Println("ERROR", "delete from", dbName, "error:", err) return count, err } else { count, _ = result.RowsAffected() return count, nil } } /** * 准备查询 * return Stmt error */ func TxForRead(tx *sql.Tx, dbName, table string, title string, where []string) (*sql.Stmt, error) { if dbName == "" && table == "" { return nil, errors.New("参数错误,没有数据表") } if strings.Contains(table, "select ") { dbName = table } else { dbName = getTableName(dbName, table) } if len(title) < 1 { return nil, errors.New("没有要查询内容") } var stmt *sql.Stmt var err error if len(where) > 0 { // log.Println("SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE") var Sql string Sql = "SELECT " + title + " FROM " + dbName + " where " + strings.Join(where, " and ") + " FOR UPDATE" if DB_PROVIDER == "PgsqlDb" { Sql = sqlx.Rebind(sqlx.DOLLAR, Sql) Sql = strings.Replace(Sql, "`", `"`, -1) } stmt, err = tx.Prepare(Sql) } else { // log.Println("SELECT " + title + " FROM " + dbName + " FOR UPDATE") stmt, err = tx.Prepare("SELECT " + title + " FROM " + dbName + " FOR UPDATE") } return stmt, err } /** * 使用db prepare方式查询单条数据 * @param dbName * @param title 查询的字段名 * @param where 查询条件 * @param valuelist 查询的条件值 * @param limit 查询排序 * GZ * 2020/05/19 */ func TxGetData(tx *sql.Tx, dbName string, table string, title string, where []string, valuelist []interface{}) (map[string]string, error) { stmt, err := TxForRead(tx, dbName, table, title, where) if err != nil { return nil, err } defer stmt.Close() return StmtForQueryRow(stmt, valuelist) }