From 87001a68d48ceafd76b3ee7f1265814716a3c709 Mon Sep 17 00:00:00 2001 From: guzeng Date: Thu, 11 Mar 2021 13:43:46 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=BA=8B=E5=8A=A1=E9=93=BE?= =?UTF-8?q?=E5=BC=8F=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 54 +++-- transaction_chain.go | 463 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 497 insertions(+), 20 deletions(-) create mode 100644 transaction_chain.go diff --git a/chain.go b/chain.go index bb96d62..7198141 100644 --- a/chain.go +++ b/chain.go @@ -3,6 +3,7 @@ package dbquery import ( "database/sql" "errors" + "log" "strconv" "strings" @@ -27,6 +28,7 @@ type Query struct { page_size int stmt *sql.Stmt conn *sql.DB + debug bool } func NewQuery(t ...string) *Query { @@ -119,10 +121,26 @@ func (this *Query) Datas(datas []string) *Query { return this } -// func (this *Query) Insert(where string) *Query { -// this.insert = append(this.insert, where) -// return this -// } +func (this *Query) Debug(debug bool) *Query { + this.debug = debug + return this +} + +/* + * 清理上次查询 + */ +func (this *Query) Clean() *Query { + this.title = "" + this.where = this.where[0:0] + this.where_or = this.where_or[0:0] + this.join = this.join[0:0] + this.data = this.data[0:0] + this.value = this.value[0:0] + this.orderby = "" + this.page = 0 + this.page_size = 0 + return this +} // 拼查询sql func (this *Query) QueryStmt() error { @@ -130,9 +148,6 @@ func (this *Query) QueryStmt() error { if this.dbname == "" && this.table == "" { return errors.New("参数错误,没有数据表") } - // if len(this.where)+len(this.where_or) < len(this.value) { - // return errors.New("参数错误,条件值错误") - // } table := getTableName(this.dbname, this.table) @@ -197,7 +212,9 @@ func (this *Query) QueryStmt() error { sql = helper.StringJoin(sql, " limit ", from, " , ", offset) } } - // log.Println(sql) + if this.debug { + log.Println("query sql:", sql, this.value) + } condition_len := 0 //所有条件数 for _, ch2 := range sql { if string(ch2) == "?" { @@ -235,15 +252,15 @@ func (this *Query) UpdateStmt() error { dbName := getTableName(this.dbname, this.table) - // var stmt *sql.Stmt - // var err error - var sql string sql = helper.StringJoin("update ", dbName, " set ", strings.Join(this.data, " , ")) sql = helper.StringJoin(sql, " where ", strings.Join(this.where, " and ")) + if this.debug { + log.Println("update sql:", sql, this.value) + } condition_len := 0 //所有条件数 for _, ch2 := range sql { @@ -279,13 +296,13 @@ func (this *Query) CreateStmt() error { dbName := getTableName(this.dbname, this.table) - // var stmt *sql.Stmt - // var err error - var sql string sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) + if this.debug { + log.Println("insert sql:", sql, this.value) + } condition_len := 0 //所有条件数 for _, ch2 := range sql { @@ -321,15 +338,9 @@ func (this *Query) DeleteStmt() error { if len(this.where) < 1 { return errors.New("参数错误,缺少条件") } - // if len(this.where) != len(this.value) { - // return errors.New("参数错误,条件值错误") - // } dbName := getTableName(this.dbname, this.table) - // var stmt *sql.Stmt - // var err error - var sql string sql = helper.StringJoin("delete from ", dbName, " where ", strings.Join(this.where, " and ")) @@ -338,6 +349,9 @@ func (this *Query) DeleteStmt() error { sql = helper.StringJoin(sql, " limit ", strconv.Itoa(this.page_size)) } + if this.debug { + log.Println("delete sql:", sql, this.value) + } condition_len := 0 //所有条件数 for _, ch2 := range sql { diff --git a/transaction_chain.go b/transaction_chain.go new file mode 100644 index 0000000..6830455 --- /dev/null +++ b/transaction_chain.go @@ -0,0 +1,463 @@ +package dbquery + +/** + * 事务操作 + */ +import ( + "database/sql" + "errors" + "log" + "strconv" + "strings" + + "git.tetele.net/tgo/helper" +) + +type TxQuery struct { + dbname string + table string + alias string + title string + where []string + where_or []string + join [][]string //[["tablea as a","a.id=b.id","left"]] + data []string + value []interface{} + orderby string + page int + page_size int + stmt *sql.Stmt + conn *sql.DB + tx *sql.Tx + debug bool +} + +func NewTxQuery(t ...string) *TxQuery { + + var conn_type *sql.DB = DB + + if len(t) > 0 { + switch t[0] { + case "mysql": + conn_type = DB + + case "mssql": //sql server + conn_type = MSDB_CONN + + } + } + + tx, err := conn_type.Begin() + if err != nil { + log.Println("start tx begin error", err) + } + + return &TxQuery{ + conn: conn_type, + tx: tx, + } +} + +func (this *TxQuery) Conn(conn *sql.DB) *TxQuery { + this.conn = conn + return this +} +func (this *TxQuery) Db(dbname string) *TxQuery { + this.dbname = dbname + return this +} + +func (this *TxQuery) Table(tablename string) *TxQuery { + this.table = tablename + return this +} +func (this *TxQuery) Alias(tablename string) *TxQuery { + this.alias = tablename + return this +} + +func (this *TxQuery) Title(title string) *TxQuery { + this.title = title + return this +} +func (this *TxQuery) Page(page int) *TxQuery { + this.page = page + return this +} +func (this *TxQuery) PageSize(page_num int) *TxQuery { + this.page_size = page_num + return this +} + +func (this *TxQuery) Orderby(orderby string) *TxQuery { + this.orderby = orderby + return this +} +func (this *TxQuery) Where(where string) *TxQuery { + this.where = append(this.where, where) + return this +} +func (this *TxQuery) Wheres(wheres []string) *TxQuery { + if len(wheres) > 0 { + this.where = append(this.where, wheres...) + } + return this +} +func (this *TxQuery) WhereOr(where string) *TxQuery { + this.where_or = append(this.where_or, where) + return this +} +func (this *TxQuery) Value(value interface{}) *TxQuery { + this.value = append(this.value, value) + return this +} +func (this *TxQuery) Values(values []interface{}) *TxQuery { + this.value = append(this.value, values...) + return this +} +func (this *TxQuery) Join(join []string) *TxQuery { + this.join = append(this.join, join) + return this +} +func (this *TxQuery) Data(data string) *TxQuery { + this.data = append(this.data, data) + return this +} +func (this *TxQuery) Datas(datas []string) *TxQuery { + this.data = append(this.data, datas...) + return this +} +func (this *TxQuery) Debug(debug bool) *TxQuery { + this.debug = debug + return this +} + +/* + * 清理上次查询 + */ +func (this *TxQuery) Clean() *TxQuery { + this.title = "" + this.where = this.where[0:0] + this.where_or = this.where_or[0:0] + this.join = this.join[0:0] + this.data = this.data[0:0] + this.value = this.value[0:0] + this.orderby = "" + this.page = 0 + this.page_size = 0 + return this +} + +// 拼查询sql +func (this *TxQuery) QueryStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + + table := getTableName(this.dbname, this.table) + + var sql, title string + + if this.title != "" { + title = this.title + } else { + title = "*" + } + sql = helper.StringJoin("select ", title) + + if this.alias != "" { + table = helper.StringJoin(table, " as ", this.alias) + } + + sql = helper.StringJoin(sql, " from ", table) + + if len(this.join) > 0 { + for _, joinitem := range this.join { + if len(joinitem) < 2 { + continue + } + if len(joinitem) == 3 { + sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1]) + } else { //默认左连接 + sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1]) + } + } + } + if len(this.where) > 0 || len(this.where_or) > 0 { + sql = helper.StringJoin(sql, " where ") + } + if len(this.where) > 0 { + sql = helper.StringJoin(sql, " (", strings.Join(this.where, " and "), " ) ") + } + if len(this.where_or) > 0 { + if len(this.where) > 0 { + sql = helper.StringJoin(sql, " or ", strings.Join(this.where_or, " or ")) + } else { + sql = helper.StringJoin(sql, strings.Join(this.where_or, " or ")) + } + } + + if this.orderby != "" { + sql = helper.StringJoin(sql, " order by ", this.orderby) + } + + if this.page > 0 || this.page_size > 0 { + + if this.page < 1 { + this.page = 1 + } + if this.page_size < 1 { + this.page_size = 10 + } + from := strconv.Itoa((this.page - 1) * this.page_size) + offset := strconv.Itoa(this.page_size) + if from != "" && offset != "" { + sql = helper.StringJoin(sql, " limit ", from, " , ", offset) + } + } + + if this.debug { + log.Println("query sql:", sql, this.value) + } + + condition_len := 0 //所有条件数 + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + + stmt, err = this.tx.Prepare(sql + " FOR UPDATE") + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + +// 拼更新sql +func (this *TxQuery) UpdateStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + if len(this.where) < 1 { + return errors.New("参数错误,缺少条件") + } + + dbName := getTableName(this.dbname, this.table) + + var sql string + + sql = helper.StringJoin("update ", dbName, " set ", strings.Join(this.data, " , ")) + + sql = helper.StringJoin(sql, " where ", strings.Join(this.where, " and ")) + + if this.debug { + log.Println("update sql:", sql, this.value) + } + + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + + stmt, err = this.tx.Prepare(sql) + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + +// 拼插入sql +func (this *TxQuery) CreateStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + + dbName := getTableName(this.dbname, this.table) + + var sql string + + sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) + + if this.debug { + log.Println("insert sql:", sql, this.value) + } + + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + + stmt, err = this.tx.Prepare(sql) + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + +// 拼删除sql +func (this *TxQuery) DeleteStmt() error { + + if this.dbname == "" && this.table == "" { + return errors.New("参数错误,没有数据表") + } + if len(this.where) < 1 { + return errors.New("参数错误,缺少条件") + } + + dbName := getTableName(this.dbname, this.table) + + var sql string + + sql = helper.StringJoin("delete from ", dbName, " where ", strings.Join(this.where, " and ")) + + if this.page_size > 0 { + sql = helper.StringJoin(sql, " limit ", strconv.Itoa(this.page_size)) + } + + if this.debug { + log.Println("delete sql:", sql, this.value) + } + + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + + stmt, err = this.tx.Prepare(sql) + + if err != nil { + return err + } + + this.stmt = stmt + + return nil +} + +/** + * 执行查询列表 + * return list error + */ +func (this *TxQuery) Select() ([]map[string]string, error) { + + err := this.QueryStmt() + if err != nil { + return []map[string]string{}, err + } + + if this.stmt == nil { + return []map[string]string{}, errors.New("缺少必要参数") + } + + return StmtForQueryList(this.stmt, this.value) +} + +/** + * 执行查询一条数据 + * return row error + */ +func (this *TxQuery) Find() (map[string]string, error) { + + err := this.QueryStmt() + if err != nil { + return map[string]string{}, err + } + + if this.stmt == nil { + return nil, errors.New("缺少必要参数") + } + return StmtForQueryRow(this.stmt, this.value) +} + +/** + * 执行更新 + * return is_updated error + */ +func (this *TxQuery) Update() (int64, error) { + + err := this.UpdateStmt() + if err != nil { + return 0, err + } + + return StmtForUpdateExec(this.stmt, this.value) +} + +/** + * 执行删除 + * return is_delete error + */ +func (this *TxQuery) Delete() (int64, error) { + + err := this.DeleteStmt() + if err != nil { + return 0, err + } + + return StmtForUpdateExec(this.stmt, this.value) +} + +/** + * 执行写入 + * return is_insert error + */ +func (this *TxQuery) Create() (int64, error) { + + err := this.CreateStmt() + if err != nil { + return 0, err + } + + return StmtForInsertExec(this.stmt, this.value) +} + +/** + * 提交 + */ +func (this *TxQuery) Commit() error { + return this.tx.Commit() +} + +/** + * 回滚 + */ +func (this *TxQuery) Rollback() error { + return this.tx.Rollback() +}