From 4ac2e8c30d34d7383f60b8fd9813ef89c24934a7 Mon Sep 17 00:00:00 2001 From: zhenghaorong Date: Sat, 3 Sep 2022 10:09:43 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=89=B9=E9=87=8F=E6=9B=B4?= =?UTF-8?q?=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 125 +++++++++++++++++++++++++++++++++++++++++-- db.go | 10 +++- transaction_chain.go | 117 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 244 insertions(+), 8 deletions(-) diff --git a/chain.go b/chain.go index 7b42714..b3f5170 100644 --- a/chain.go +++ b/chain.go @@ -21,11 +21,13 @@ type Query struct { where []string where_or []string join [][]string //[["tablea as a","a.id=b.id","left"]] - save_data []map[string]interface{} //[["title":"a","num":1,],["title":"a","num":1,]] + save_data []map[string]interface{} //批量操作的数据[["title":"a","num":1,],["title":"a","num":1,]] + upd_field []string // 批量更新时需要更新的字段,为空时按除id外的字段进行更新 data []string value []interface{} orderby string groupby string + having string page int page_size int stmt *sql.Stmt @@ -89,7 +91,10 @@ func (this *Query) PageSize(page_num int) *Query { this.page_size = page_num return this } - +func (this *Query) Having(having string) *Query { + this.having = having + return this +} func (this *Query) Orderby(orderby string) *Query { this.orderby = orderby return this @@ -120,6 +125,14 @@ func (this *Query) SaveDatas(value []map[string]interface{}) *Query { this.save_data = append(this.save_data, value...) return this } +func (this *Query) UpdField(value string) *Query { + this.upd_field = append(this.upd_field, value) + return this +} +func (this *Query) UpdFields(value []string) *Query { + this.upd_field = append(this.upd_field, value...) + return this +} func (this *Query) Value(value interface{}) *Query { this.value = append(this.value, value) return this @@ -160,6 +173,9 @@ func (this *Query) Clean() *Query { this.groupby = "" this.page = 0 this.page_size = 0 + this.save_data = this.save_data[0:0] + this.upd_field = this.upd_field[0:0] + this.having = "" return this } @@ -220,6 +236,10 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { if this.groupby != "" { sql = helper.StringJoin(sql, " group by ", this.groupby) + } + if this.having != "" { + sql = helper.StringJoin(sql, " having ", this.having) + } if this.orderby != "" { sql = helper.StringJoin(sql, " order by ", this.orderby) @@ -327,6 +347,92 @@ func (this *Query) UpdateStmt() error { return nil } +// 拼批量存在更新不存在插入sql +func (this *Query) UpdateAllStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + + dbName := getTableName(this.dbname, this.table) + + var sql string + var dataSql = []string{} + var valSql = []string{} + var updSql = []string{} + var updFieldLen = len(this.upd_field) + if len(this.save_data) > 0 { + //批量操作 + this.data = []string{} + this.value = []interface{}{} + for i, datum := range this.save_data { + if i == 0 { + for k, _ := range datum { + this.data = append(this.data, k) + dataSql = append(dataSql, "?") + if updFieldLen == 0 && k != "id" { + updSql = append(updSql, k+"=values("+k+")") + } + } + if updFieldLen > 0 { + for _, k := range this.upd_field { + updSql = append(updSql, k+"=values("+k+")") + } + } + } + for _, k := range this.data { + this.value = append(this.value, datum[k]) + } + valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") + } + } else { + //添加一条 + for _, datum := range this.data { + dataSql = append(dataSql, "?") + if updFieldLen == 0 && datum != "id" { + updSql = append(updSql, datum+"=values("+datum+")") + } + } + if updFieldLen > 0 { + for _, k := range this.upd_field { + updSql = append(updSql, k+"=values("+k+")") + } + } + valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") + } + + if len(this.data) == 0 { + return errors.New("参数错误,没有数据表") + } + if len(this.value) == 0 { + return errors.New("参数错误,条件值错误") + } + + setText := " values " + if len(valSql) > 1 { + setText = " value " + } + sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , ")) + + if this.debug { + log.Println("insert on duplicate key update sql:", sql, this.value) + } + + if this.conn == nil { + this.conn = DB + } + + stmt, err = this.conn.Prepare(sql) + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + // 拼批量插入sql func (this *Query) CreateAllStmt() error { @@ -492,7 +598,7 @@ func (this *Query) DeleteStmt() error { func (this *Query) Select() ([]map[string]string, error) { _, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join, - this.where, this.where_or, this.value, this.orderby, this.groupby, this.page, this.page_size, this.debug) + this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug) return rows, err } @@ -523,7 +629,7 @@ func (this *Query) List() ([]map[string]string, error) { func (this *Query) Find() (map[string]string, error) { _, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join, - this.where, this.where_or, this.value, this.orderby, this.groupby, this.debug) + this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug) return row, err } @@ -561,6 +667,17 @@ func (this *Query) Update() (int64, error) { return StmtForUpdateExec(this.stmt, this.value) } +//批量更新 +func (this *Query) UpdateAll() (int64, error) { + + err := this.UpdateAllStmt() + if err != nil { + return 0, err + } + + return StmtForUpdateExec(this.stmt, this.value) +} + /** * 执行删除 * return is_delete error diff --git a/db.go b/db.go index 0a22955..145c49b 100644 --- a/db.go +++ b/db.go @@ -295,7 +295,7 @@ func GetData(dbName, table string, title string, where map[string]string, limit * @param dbName 数据表名 * @param title 查询字段名 */ -func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby string, debug bool) (int, map[string]string, error) { +func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) { var count int = 0 info := make(map[string]string) @@ -352,6 +352,9 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh if groupby != "" { sql_str = helper.StringJoin(sql_str, " group by ", groupby) } + if having != "" { + sql_str = helper.StringJoin(sql_str, " having ", having) + } if orderby != "" { sql_str = helper.StringJoin(sql_str, " order by ", orderby) } @@ -439,7 +442,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh * @param dbName 数据表名 * @param title 查询字段名 */ -func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby string, page int, page_size int, debug bool) (int, []map[string]string, error) { +func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) { var count int = 0 list := make([]map[string]string, 0) @@ -497,6 +500,9 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string, if groupby != "" { sql_str = helper.StringJoin(sql_str, " group by ", groupby) } + if having != "" { + sql_str = helper.StringJoin(sql_str, " HAVING ", having) + } if orderby != "" { sql_str = helper.StringJoin(sql_str, " order by ", orderby) } diff --git a/transaction_chain.go b/transaction_chain.go index 94e73c9..808626d 100644 --- a/transaction_chain.go +++ b/transaction_chain.go @@ -23,9 +23,11 @@ type TxQuery struct { join [][]string //[["tablea as a","a.id=b.id","left"]] data []string value []interface{} - save_data []map[string]interface{} + save_data []map[string]interface{} //批量操作的数据[["title":"a","num":1,],["title":"a","num":1,]] + upd_field []string // 批量更新时需要更新的字段,为空时按除id外的字段进行更新 orderby string groupby string + having string page int page_size int stmt *sql.Stmt @@ -99,7 +101,10 @@ func (this *TxQuery) Groupby(groupby string) *TxQuery { this.groupby = groupby return this } - +func (this *TxQuery) Having(having string) *TxQuery { + this.having = having + return this +} func (this *TxQuery) Where(where string) *TxQuery { this.where = append(this.where, where) return this @@ -126,6 +131,14 @@ func (this *TxQuery) SaveDatas(value []map[string]interface{}) *TxQuery { this.save_data = append(this.save_data, value...) return this } +func (this *TxQuery) UpdField(value string) *TxQuery { + this.upd_field = append(this.upd_field, value) + return this +} +func (this *TxQuery) UpdFields(value []string) *TxQuery { + this.upd_field = append(this.upd_field, value...) + return this +} func (this *TxQuery) Values(values []interface{}) *TxQuery { this.value = append(this.value, values...) return this @@ -161,6 +174,9 @@ func (this *TxQuery) Clean() *TxQuery { this.groupby = "" this.page = 0 this.page_size = 0 + this.save_data = this.save_data[0:0] + this.upd_field = this.upd_field[0:0] + this.having = "" return this } @@ -219,6 +235,10 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { if this.groupby != "" { sql = helper.StringJoin(sql, " group by ", this.groupby) + } + if this.having != "" { + sql = helper.StringJoin(sql, " having ", this.having) + } if this.orderby != "" { sql = helper.StringJoin(sql, " order by ", this.orderby) @@ -323,6 +343,88 @@ func (this *TxQuery) UpdateStmt() error { return nil } +// 拼批量存在更新不存在插入sql +func (this *TxQuery) UpdateAllStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + + dbName := getTableName(this.dbname, this.table) + + var sql string + var dataSql = []string{} + var valSql = []string{} + var updSql = []string{} + var updFieldLen = len(this.upd_field) + if len(this.save_data) > 0 { + this.data = []string{} + this.value = []interface{}{} + for i, datum := range this.save_data { + if i == 0 { + for k, _ := range datum { + this.data = append(this.data, k) + dataSql = append(dataSql, "?") + if updFieldLen == 0 && k != "id" { + updSql = append(updSql, k+"=values("+k+")") + } + } + if updFieldLen > 0 { + for _, k := range this.upd_field { + updSql = append(updSql, k+"=values("+k+")") + } + } + } + for _, k := range this.data { + this.value = append(this.value, datum[k]) + } + valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") + } + } else { + for _, datum := range this.data { + dataSql = append(dataSql, "?") + if updFieldLen == 0 && datum != "id" { + updSql = append(updSql, datum+"=values("+datum+")") + } + } + if updFieldLen > 0 { + for _, k := range this.upd_field { + updSql = append(updSql, k+"=values("+k+")") + } + } + valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")") + } + + if len(this.data) == 0 { + return errors.New("参数错误,没有数据表") + } + if len(this.value) == 0 { + return errors.New("参数错误,条件值错误") + } + setText := " values " + if len(valSql) > 1 { + setText = " value " + } + sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ",")) + if len(this.value) == 0 { + return errors.New("参数错误,条件值错误") + } + + if this.debug { + log.Println("insert on duplicate key update sql:", sql, this.value) + } + + stmt, err = this.tx.Prepare(sql) + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + // 拼插入sql func (this *TxQuery) CreateStmt() error { @@ -522,6 +624,17 @@ func (this *TxQuery) Update() (int64, error) { return StmtForUpdateExec(this.stmt, this.value) } +//批量更新 +func (this *TxQuery) UpdateAll() (int64, error) { + + err := this.UpdateAllStmt() + if err != nil { + return 0, err + } + + return StmtForUpdateExec(this.stmt, this.value) +} + /** * 执行删除 * return is_delete error