19 Commits

8 changed files with 341 additions and 42 deletions
Split View
  1. +110
    -12
      chain.go
  2. +13
    -1
      conn.go
  3. +102
    -14
      db.go
  4. +0
    -0
      db_test.go
  5. +0
    -0
      prepare.go
  6. +7
    -2
      sqlserver.go
  7. +5
    -5
      transaction.go
  8. +104
    -8
      transaction_chain.go

+ 110
- 12
chain.go View File

@ -34,6 +34,7 @@ type Query struct {
conn *sql.DB
debug bool
dbtype string
with [][]string //[[临时表的sql语句,临时表的名称]]
}
func NewQuery(t ...string) *Query {
@ -103,6 +104,14 @@ func (this *Query) Groupby(groupby string) *Query {
this.groupby = groupby
return this
}
func (this *Query) With(with []string) *Query {
this.with = append(this.with, with)
return this
}
func (this *Query) Withs(withs [][]string) *Query {
this.with = append(this.with, withs...)
return this
}
func (this *Query) Where(where string) *Query {
this.where = append(this.where, where)
return this
@ -145,6 +154,27 @@ func (this *Query) Join(join []string) *Query {
this.join = append(this.join, join)
return this
}
/**
* 左连接
* 2023/08/10
* gz
*/
func (this *Query) LeftJoin(table_name string, condition string) *Query {
this.join = append(this.join, []string{table_name, condition, "left"})
return this
}
/**
* 右连接
* 2023/08/10
* gz
*/
func (this *Query) RightJoin(table_name string, condition string) *Query {
this.join = append(this.join, []string{table_name, condition, "right"})
return this
}
func (this *Query) Data(data string) *Query {
this.data = append(this.data, data)
return this
@ -176,10 +206,12 @@ func (this *Query) Clean() *Query {
this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0]
this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this
}
//获取表格信息
// 获取表格信息
func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
field := []string{
"COLUMN_NAME", //字段名
@ -245,13 +277,41 @@ func (this *Query) GetTableInfo(table string) (map[string]interface{}, error) {
}, nil
}
//构造子查询
// 返回表名
func (this *Query) GetTableName(table string) string {
return getTableName(this.dbname, table)
}
// 构造子查询
func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if this.dbname == "" && this.table == "" {
return nil, errors.New("参数错误,没有数据表")
}
var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table
} else {
table = getTableName(this.dbname, this.table, this.dbtype)
@ -266,7 +326,22 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
} else {
title = "*"
}
sql = helper.StringJoin("/*slave*/ select ", title)
if this.dbtype == "mssql" {
if this.page_size > 0 {
sql = helper.StringJoin(withSql, "select top ", helper.ToStr(this.page_size), " ")
} else {
sql = helper.StringJoin(withSql, "select ")
}
} else {
if DB_PROVIDER == "TencentDB" {
sql = helper.StringJoin("/*slave*/ ", withSql, " select ")
} else {
sql = helper.StringJoin(withSql, "select ")
}
}
sql = helper.StringJoin(sql, title)
if this.alias != "" {
table = helper.StringJoin(table, " as ", this.alias)
@ -275,15 +350,31 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 {
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
} else { //默认左连接
sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
}
}
if len(this.where) > 0 || len(this.where_or) > 0 {
@ -311,7 +402,7 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " order by ", this.orderby)
}
if this.page > 0 || this.page_size > 0 {
if this.dbtype == "mysql" && (this.page > 0 || this.page_size > 0) {
if this.page < 1 {
this.page = 1
@ -351,6 +442,13 @@ func (this *Query) QueryStmt() error {
return err
}
sql := helper.ToStr(res["sql"])
if SLAVER_DB != nil {
this.conn = SLAVER_DB
}
// else {
// this.conn = DB
// }
if this.conn == nil {
this.conn = DB
}
@ -721,7 +819,7 @@ func (this *Query) DeleteStmt() error {
*/
func (this *Query) Select() ([]map[string]string, error) {
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join,
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug)
return rows, err
@ -752,7 +850,7 @@ func (this *Query) List() ([]map[string]string, error) {
*/
func (this *Query) Find() (map[string]string, error) {
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join,
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug)
return row, err
@ -791,7 +889,7 @@ func (this *Query) Update() (int64, error) {
return StmtForUpdateExec(this.stmt, this.value)
}
//批量更新
// 批量更新
func (this *Query) UpdateAll() (int64, error) {
err := this.UpdateAllStmt()


+ 13
- 1
conn.go View File

@ -17,6 +17,9 @@ var DB *sql.DB
var SLAVER_DB *sql.DB
// db类型,默认空,如TencentDB(腾讯),
var DB_PROVIDER string
func Connect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
log.Println("database connectting...")
@ -120,7 +123,16 @@ func getTableName(dbName, table string, dbtype ...string) string {
ret = table
}
if dbName != "" {
ret = helper.StringJoin(dbName, ".", table)
if strings.Contains(table, ",") {
arr := strings.Split(table, ",")
arrStrs := make([]string, 0, len(arr))
for _, v := range arr {
arrStrs = append(arrStrs, helper.StringJoin(dbName, ".", v))
}
ret = strings.Join(arrStrs, ",")
} else {
ret = helper.StringJoin(dbName, ".", table)
}
} else {
ret = table
}


+ 102
- 14
db.go View File

@ -295,7 +295,7 @@ func GetData(dbName, table string, title string, where map[string]string, limit
* @param dbName 数据表名
* @param title 查询字段名
*/
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0
info := make(map[string]string)
@ -304,7 +304,30 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
}
table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
@ -316,8 +339,12 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} else {
title = "*"
}
sql_str = helper.StringJoin("/*slave*/ select ", title)
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
}
@ -325,15 +352,31 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}
if len(where) > 0 || len(where_or) > 0 {
@ -401,6 +444,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
}
if err != nil {
log.Println("DB error:", err)
rows.Close()
return count, info, err
}
@ -431,6 +475,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
}
rows.Close()
if rowerr != nil {
log.Println("DB row error:", rowerr)
return count, info, rowerr
}
return count, info, nil
@ -442,7 +487,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
* @param dbName 数据表名
* @param title 查询字段名
*/
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0
list := make([]map[string]string, 0)
@ -450,7 +495,30 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
return count, list, errors.New("没有数据表")
}
table := ""
if strings.Contains(table_name, "select ") {
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name
} else {
table = getTableName(dbName, table_name)
@ -463,8 +531,12 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
} else {
title = "*"
}
sql_str = helper.StringJoin("/*slave*/ select ", title)
if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else {
sql_str = helper.StringJoin(withSql, "select ", title)
}
if alias != "" {
table = helper.StringJoin(table, " as ", alias)
}
@ -472,15 +544,31 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 {
var builder strings.Builder
builder.WriteString(sql_str)
boo := false
for _, joinitem := range join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) >= 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql_str = builder.String()
}
}


+ 0
- 0
db_test.go View File


+ 0
- 0
prepare.go View File


+ 7
- 2
sqlserver.go View File

@ -15,7 +15,7 @@ import (
var MSDB_CONN *sql.DB
func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error {
func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT, encrypt string, conns ...int) error {
log.Println("msdb connectting...")
@ -29,7 +29,12 @@ func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error
db_port, _ := strconv.Atoi(DBPORT)
connString := fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s", DBHOST, db_port, DBNAME, DBUSER, DBPWD)
params := "server=%s;port=%d;database=%s;user id=%s;password=%s"
if encrypt != "" {
params = params + ";encrypt=" + encrypt
}
connString := fmt.Sprintf(params, DBHOST, db_port, DBNAME, DBUSER, DBPWD)
log.Println(connString)


+ 5
- 5
transaction.go View File

@ -25,7 +25,7 @@ func TxInsert(tx *sql.Tx, dbname, table string, data map[string]string) (int64,
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
}
if len(data) < 1 {
return 0, errors.New("参数错误,没有要写入的数据")
@ -71,7 +71,7 @@ func TxPreInsert(tx *sql.Tx, dbname, table string, data map[string]interface{})
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
}
if len(data) < 1 {
@ -122,7 +122,7 @@ func TxUpdate(tx *sql.Tx, dbname, table string, data map[string]string, where ma
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
}
if len(data) < 1 {
return rowsAffected, errors.New("参数错误,没有要写入的数据")
@ -186,7 +186,7 @@ func TxPreUpdate(tx *sql.Tx, dbname, table string, data []string, where []string
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
}
if len(where) < 1 {
@ -228,7 +228,7 @@ func TxDelete(tx *sql.Tx, dbname, table string, where map[string]string, del_cou
if strings.Contains(table, "select ") {
dbName = table
} else {
dbName = getTableName(dbName, table)
dbName = getTableName(dbname, table)
}
if len(where) < 1 {
return count, errors.New("参数错误,没有删除条件")


+ 104
- 8
transaction_chain.go View File

@ -34,6 +34,7 @@ type TxQuery struct {
conn *sql.DB
tx *sql.Tx
debug bool
with [][]string //[[临时表的sql语句,临时表的名称]]
}
func NewTxQuery(t ...string) *TxQuery {
@ -109,6 +110,14 @@ func (this *TxQuery) Where(where string) *TxQuery {
this.where = append(this.where, where)
return this
}
func (this *TxQuery) With(with []string) *TxQuery {
this.with = append(this.with, with)
return this
}
func (this *TxQuery) Withs(withs [][]string) *TxQuery {
this.with = append(this.with, withs...)
return this
}
func (this *TxQuery) Wheres(wheres []string) *TxQuery {
if len(wheres) > 0 {
this.where = append(this.where, wheres...)
@ -147,6 +156,26 @@ func (this *TxQuery) Join(join []string) *TxQuery {
this.join = append(this.join, join)
return this
}
/**
* 左连接
* 2023/08/10
* gz
*/
func (this *TxQuery) LeftJoin(table_name string, condition string) *TxQuery {
this.join = append(this.join, []string{table_name, condition, "left"})
return this
}
/**
* 右连接
* 2023/08/10
* gz
*/
func (this *TxQuery) RightJoin(table_name string, condition string) *TxQuery {
this.join = append(this.join, []string{table_name, condition, "right"})
return this
}
func (this *TxQuery) Data(data string) *TxQuery {
this.data = append(this.data, data)
return this
@ -177,16 +206,46 @@ func (this *TxQuery) Clean() *TxQuery {
this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0]
this.having = ""
this.alias = ""
this.with = this.with[0:0]
return this
}
//构造子查询
// 返回表名
func (this *TxQuery) GetTableName(table string) string {
return getTableName(this.dbname, table)
}
// 构造子查询
func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
if this.dbname == "" && this.table == "" {
return nil, errors.New("参数错误,没有数据表")
}
var table = ""
if strings.Contains(this.table, "select ") {
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if withSql != "" || strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table
} else {
table = getTableName(this.dbname, this.table)
@ -199,7 +258,8 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
} else {
title = "*"
}
sql = helper.StringJoin("select ", title)
sql = helper.StringJoin(withSql, "select ", title)
if this.alias != "" {
table = helper.StringJoin(table, " as ", this.alias)
@ -208,15 +268,31 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 {
var builder strings.Builder
builder.WriteString(sql)
boo := false
for _, joinitem := range this.join {
if len(joinitem) < 2 {
continue
}
if len(joinitem) == 3 {
sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1])
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if withSql != "" || strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
boo = true
}
if boo {
sql = builder.String()
}
}
if len(this.where) > 0 || len(this.where_or) > 0 {
@ -278,7 +354,7 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
}, nil
}
//获取表格信息
// 获取表格信息
func (this *TxQuery) GetTableInfo(table string) (map[string]interface{}, error) {
field := []string{
"COLUMN_NAME", //字段名
@ -801,6 +877,26 @@ func (this *TxQuery) CreateAll() (int64, error) {
return StmtForInsertExec(this.stmt, this.value)
}
/**
* 执行原生sql
* return error
*/
func (this *TxQuery) ExecSql(sql string) (int64, error) {
if this.debug {
log.Println("ExecSql sql:", sql)
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
return 0, err
}
res, err := stmt.Exec()
if err != nil {
return 0, errors.New("执行失败:" + err.Error())
}
return res.RowsAffected()
}
/**
* 提交
*/


Loading…
Cancel
Save