diff --git a/chain_test.go b/chain_test.go index c761561..d62706b 100644 --- a/chain_test.go +++ b/chain_test.go @@ -122,9 +122,10 @@ func Test_Chain(t *testing.T) { Where("a.user_id =?").Value(6006) info, err := query.Groupby("a.id").Title("a.id").BuildSelectSql()*/ //info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("user").Join([]string{"ttl_user u", "u.id = user.user_id", "inner"}).Where("user.id=?").Value("3").Title("user.id,user.user_id,u.nickname").Find() - //info, err := GetDataByStmt(db_name, table_name, "*", []string{"id = ?"}, []interface{}{3}, nil) + //info, err := new(Query).Db(db_name).Clean().Table(table_name).Clean().Alias("p").Where("p.id=?").Value("1").Title("p.id,p.view_count,p.title,p.thumb,p.price,p.exchange_score_spend,p.exchange_score_value,p.tax_rate").Find() + /*info, err := GetDataByStmt(db_name, table_name, "id,exchange_score_value,exchange_score_spend,tax_rate", []string{"id = ?"}, []interface{}{1}, nil) - /*if err != nil { + if err != nil { t.Log(err) } fmt.Println("===Find:", info)*/ @@ -144,7 +145,8 @@ func Test_Chain(t *testing.T) { Where("p.id=?").Value(1). Orderby("p.view_count desc"). Page(1).PageSize(10). - Title("p.id,p.view_count,p.title,p.thumb,p.price,p.exchange_score_spend,p.exchange_score_value"). + Title("p.id,p.view_count,p.title,p.thumb,p.price,p.exchange_score_spend,p.exchange_score_value,p.tax_rate"). + //Title("p.*"). Select() if err != nil { t.Log(err) @@ -267,15 +269,3 @@ func Test_Chain(t *testing.T) { } fmt.Println("====================执行事务结束==================")*/ } - -/*func Test_Data(t *testing.T) { - aa := 0.50 - bb := 15.5 - - res := ToString(aa) - resb := ToString(bb) - - fmt.Println(res) - fmt.Println(resb) - fmt.Println(fmt.Sprintf("%v", aa)) -}*/ diff --git a/common.go b/common.go index 045f470..771a05f 100644 --- a/common.go +++ b/common.go @@ -1,6 +1,7 @@ package dbquery import ( + "database/sql" "fmt" "git.tetele.net/tgo/helper" "log" @@ -29,7 +30,7 @@ func ReplaeByOtherSql(sql, sql_type, action string) string { sql = helper.StringJoin(sql, " RETURNING id") } // 定义需要处理的关键字列表 - keywords := []string{"user", "order", "group", "table", "view", "admin", "new", "top"} + keywords := []string{"user", "order", "group", "table", "view", "admin", "new"} //设置保护词组 excludePhrases := []string{ "order by", "group by", "GROUP BY", "ORDER BY", "WITHIN GROUP", "within group", @@ -130,3 +131,24 @@ func DmFieldDeal(fields string) string { return title } + +func DmFormatDecimal(col interface{}) string { + res := "" + + switch v := col.(type) { + case *sql.RawBytes: + if v != nil { + strVal := string(*v) + // 处理前导零缺失的问题 + if strings.HasPrefix(strVal, ".") { + strVal = "0" + strVal + } + res = strVal + } + case *interface{}: + // 其他类型的字段 + res = ToString(*v) + } + + return res +} diff --git a/db.go b/db.go index 36fb16a..d05144d 100644 --- a/db.go +++ b/db.go @@ -301,37 +301,64 @@ func GetData(dbName, table string, title string, where map[string]string, limit return count, info, err } + var index string + var rowerr error columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) - var index string - var rowerr error - for rows.Next() { - rowerr = rows.Scan(scanArgs...) - if rowerr == nil { - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - info[index] = ToString(col) - } else { + info[index] = DmFormatDecimal(col) + } + } + count++ + } else { + log.Println("ERROR", "rows scan error", rowerr, dbName, keyList, valueList) + } + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) info[index] = helper.ToString(col) } } + count++ + } else { + log.Println("ERROR", "rows scan error", rowerr, dbName, keyList, valueList) } - count++ - } else { - log.Println("ERROR", "rows scan error", rowerr, dbName, keyList, valueList) } } + if rowerr != nil { return count, info, rowerr } @@ -503,37 +530,63 @@ func GetRow(dbName, table_name, alias string, titles string, with, join [][]stri return count, info, err } + var index string + var rowerr error columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) - var index string - var rowerr error - for rows.Next() { - rowerr = rows.Scan(scanArgs...) - if rowerr == nil { - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - info[index] = ToString(col) - } else { + info[index] = DmFormatDecimal(col) + } + } + count++ + } else { + log.Println("ERROR", rowerr) + } + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) info[index] = helper.ToString(col) } } + count++ + } else { + log.Println("ERROR", rowerr) } - count++ - } else { - log.Println("ERROR", rowerr) } } + rows.Close() if rowerr != nil { log.Println("DB row error:", rowerr) @@ -657,7 +710,6 @@ func FetchRows(dbName, table_name, alias string, titles string, with, join [][]s } if page > 0 || page_size > 0 { - if page < 1 { page = 1 } @@ -720,41 +772,71 @@ func FetchRows(dbName, table_name, alias string, titles string, with, join [][]s return 0, list, err } - columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - - for i := range values { - scanArgs[i] = &values[i] - } - var index string var rowerr error var info map[string]string - for rows.Next() { - rowerr = rows.Scan(scanArgs...) - info = make(map[string]string) - if rowerr == nil { - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + columns, _ := rows.Columns() + + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) + + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + info = make(map[string]string) + if rowerr == nil { + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - info[index] = ToString(col) - } else { + info[index] = DmFormatDecimal(col) + } + } + count++ + } else { + log.Println("ERROR", rowerr) + } + if len(info) > 0 { + list = append(list, info) + } + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + info = make(map[string]string) + if rowerr == nil { + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) info[index] = helper.ToString(col) } } + count++ + } else { + log.Println("ERROR", rowerr) + } + if len(info) > 0 { + list = append(list, info) } - count++ - } else { - log.Println("ERROR", rowerr) - } - if len(info) > 0 { - list = append(list, info) } } @@ -890,34 +972,59 @@ func GetList(dbName, table string, title string, where map[string]string, limit } defer rows.Close() - columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - var record map[string]string var index string - for rows.Next() { - //将行数据保存到record字典 - err = rows.Scan(scanArgs...) - record = make(map[string]string) + columns, _ := rows.Columns() + + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + //将行数据保存到record字典 + err = rows.Scan(scanArgs...) + record = make(map[string]string) + + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - record[index] = ToString(col) - } else { + record[index] = DmFormatDecimal(col) + } + } + list = append(list, record) + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + //将行数据保存到record字典 + err = rows.Scan(scanArgs...) + record = make(map[string]string) + + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) record[index] = helper.ToString(col) } } + list = append(list, record) } - list = append(list, record) } return list, nil @@ -1129,34 +1236,59 @@ func DoQuery(args ...interface{}) ([]map[string]string, error) { } defer rows.Close() - columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - var record map[string]string var index string - for rows.Next() { - //将行数据保存到record字典 - err = rows.Scan(scanArgs...) - record = make(map[string]string) + columns, _ := rows.Columns() + + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + //将行数据保存到record字典 + err = rows.Scan(scanArgs...) + record = make(map[string]string) + + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - record[index] = ToString(col) - } else { + record[index] = DmFormatDecimal(col) + } + } + list = append(list, record) + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + //将行数据保存到record字典 + err = rows.Scan(scanArgs...) + record = make(map[string]string) + + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) record[index] = helper.ToString(col) } } + list = append(list, record) } - list = append(list, record) } return list, nil diff --git a/prepare.go b/prepare.go index 6582c9a..dd2f509 100644 --- a/prepare.go +++ b/prepare.go @@ -95,43 +95,73 @@ func StmtForQueryList(stmt *sql.Stmt, valuelist []interface{}) ([]map[string]str } return nil, err } - columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - - for i := range values { - scanArgs[i] = &values[i] - } var list []map[string]string var index string var rowerr error info := make(map[string]string) - for rows.Next() { + columns, _ := rows.Columns() + + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) + + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } - rowerr = rows.Scan(scanArgs...) + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + info = make(map[string]string) - info = make(map[string]string) - if rowerr == nil { - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + if rowerr == nil { + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - info[index] = ToString(col) - } else { + info[index] = DmFormatDecimal(col) + } + } + } else { + log.Println("rows scan error", rowerr) + } + if len(info) > 0 { + list = append(list, info) + } + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + info = make(map[string]string) + + if rowerr == nil { + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) info[index] = helper.ToString(col) } } + } else { + log.Println("rows scan error", rowerr) + } + if len(info) > 0 { + list = append(list, info) } - } else { - log.Println("rows scan error", rowerr) - } - if len(info) > 0 { - list = append(list, info) } } @@ -155,37 +185,65 @@ func StmtForQueryRow(stmt *sql.Stmt, valuelist []interface{}) (map[string]string } return nil, err } - columns, _ := rows.Columns() - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - - for i := range values { - scanArgs[i] = &values[i] - } var index string var rowerr error info := make(map[string]string) - for rows.Next() { - rowerr = rows.Scan(scanArgs...) - if rowerr == nil { - for i, col := range values { - if col != nil { - if DB_PROVIDER == "DmSql" { - //达梦返回全大写字段,需先转小写 + + columns, _ := rows.Columns() + + if DB_PROVIDER == "DmSql" { + columnTypes, _ := rows.ColumnTypes() + scanArgs := make([]interface{}, len(columns)) + + for i, colType := range columnTypes { + //fmt.Printf("字段: %s, 数据库类型: %s, 扫描类型: %v\n", columns[i], colType.DatabaseTypeName(), colType.ScanType()) + // 为 DECIMAL 类型创建专用的扫描变量,达梦8中如果只少于1时,前面0会丢失 + if colType.DatabaseTypeName() == "DECIMAL" { + var decimalStr sql.RawBytes + scanArgs[i] = &decimalStr + } else { + var val interface{} + scanArgs[i] = &val + } + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range scanArgs { + if col != nil { index = helper.StrFirstToUpper(strings.ToLower(columns[i])) - //达梦返回的字段类型比较细,比如:int16、int32 - info[index] = ToString(col) - } else { + info[index] = DmFormatDecimal(col) + } + } + } else { + log.Println("rows scan error", rowerr) + } + } + } else { + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + rowerr = rows.Scan(scanArgs...) + if rowerr == nil { + for i, col := range values { + if col != nil { index = helper.StrFirstToUpper(columns[i]) info[index] = helper.ToString(col) } } + } else { + log.Println("rows scan error", rowerr) } - } else { - log.Println("rows scan error", rowerr) } } + if rowerr != nil { return info, errors.New("数据出错") }