Browse Source

增加批量更新

master
zhenghaorong 1 year ago
parent
commit
4ac2e8c30d
3 changed files with 244 additions and 8 deletions
  1. +121
    -4
      chain.go
  2. +8
    -2
      db.go
  3. +115
    -2
      transaction_chain.go

+ 121
- 4
chain.go View File

@ -21,11 +21,13 @@ type Query struct {
where []string
where_or []string
join [][]string //[["tablea as a","a.id=b.id","left"]]
save_data []map[string]interface{} //[["title":"a","num":1,],["title":"a","num":1,]]
save_data []map[string]interface{} //批量操作的数据[["title":"a","num":1,],["title":"a","num":1,]]
upd_field []string // 批量更新时需要更新的字段,为空时按除id外的字段进行更新
data []string
value []interface{}
orderby string
groupby string
having string
page int
page_size int
stmt *sql.Stmt
@ -89,7 +91,10 @@ func (this *Query) PageSize(page_num int) *Query {
this.page_size = page_num
return this
}
func (this *Query) Having(having string) *Query {
this.having = having
return this
}
func (this *Query) Orderby(orderby string) *Query {
this.orderby = orderby
return this
@ -120,6 +125,14 @@ func (this *Query) SaveDatas(value []map[string]interface{}) *Query {
this.save_data = append(this.save_data, value...)
return this
}
func (this *Query) UpdField(value string) *Query {
this.upd_field = append(this.upd_field, value)
return this
}
func (this *Query) UpdFields(value []string) *Query {
this.upd_field = append(this.upd_field, value...)
return this
}
func (this *Query) Value(value interface{}) *Query {
this.value = append(this.value, value)
return this
@ -160,6 +173,9 @@ func (this *Query) Clean() *Query {
this.groupby = ""
this.page = 0
this.page_size = 0
this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0]
this.having = ""
return this
}
@ -220,6 +236,10 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
if this.groupby != "" {
sql = helper.StringJoin(sql, " group by ", this.groupby)
}
if this.having != "" {
sql = helper.StringJoin(sql, " having ", this.having)
}
if this.orderby != "" {
sql = helper.StringJoin(sql, " order by ", this.orderby)
@ -327,6 +347,92 @@ func (this *Query) UpdateStmt() error {
return nil
}
// 拼批量存在更新不存在插入sql
func (this *Query) UpdateAllStmt() error {
if this.dbname == "" && this.table == "" {
return errors.New("参数错误,没有数据表")
}
dbName := getTableName(this.dbname, this.table)
var sql string
var dataSql = []string{}
var valSql = []string{}
var updSql = []string{}
var updFieldLen = len(this.upd_field)
if len(this.save_data) > 0 {
//批量操作
this.data = []string{}
this.value = []interface{}{}
for i, datum := range this.save_data {
if i == 0 {
for k, _ := range datum {
this.data = append(this.data, k)
dataSql = append(dataSql, "?")
if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
}
}
}
for _, k := range this.data {
this.value = append(this.value, datum[k])
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
}
} else {
//添加一条
for _, datum := range this.data {
dataSql = append(dataSql, "?")
if updFieldLen == 0 && datum != "id" {
updSql = append(updSql, datum+"=values("+datum+")")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
}
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
}
if len(this.data) == 0 {
return errors.New("参数错误,没有数据表")
}
if len(this.value) == 0 {
return errors.New("参数错误,条件值错误")
}
setText := " values "
if len(valSql) > 1 {
setText = " value "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","), " ON DUPLICATE KEY UPDATE ", strings.Join(updSql, " , "))
if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value)
}
if this.conn == nil {
this.conn = DB
}
stmt, err = this.conn.Prepare(sql)
if err != nil {
return err
}
this.stmt = stmt
return nil
}
// 拼批量插入sql
func (this *Query) CreateAllStmt() error {
@ -492,7 +598,7 @@ func (this *Query) DeleteStmt() error {
func (this *Query) Select() ([]map[string]string, error) {
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.page, this.page_size, this.debug)
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug)
return rows, err
}
@ -523,7 +629,7 @@ func (this *Query) List() ([]map[string]string, error) {
func (this *Query) Find() (map[string]string, error) {
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.debug)
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug)
return row, err
}
@ -561,6 +667,17 @@ func (this *Query) Update() (int64, error) {
return StmtForUpdateExec(this.stmt, this.value)
}
//批量更新
func (this *Query) UpdateAll() (int64, error) {
err := this.UpdateAllStmt()
if err != nil {
return 0, err
}
return StmtForUpdateExec(this.stmt, this.value)
}
/**
* 执行删除
* return is_delete error


+ 8
- 2
db.go View File

@ -295,7 +295,7 @@ func GetData(dbName, table string, title string, where map[string]string, limit
* @param dbName 数据表名
* @param title 查询字段名
*/
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby string, debug bool) (int, map[string]string, error) {
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0
info := make(map[string]string)
@ -352,6 +352,9 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
if groupby != "" {
sql_str = helper.StringJoin(sql_str, " group by ", groupby)
}
if having != "" {
sql_str = helper.StringJoin(sql_str, " having ", having)
}
if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby)
}
@ -439,7 +442,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
* @param dbName 数据表名
* @param title 查询字段名
*/
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby string, page int, page_size int, debug bool) (int, []map[string]string, error) {
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0
list := make([]map[string]string, 0)
@ -497,6 +500,9 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
if groupby != "" {
sql_str = helper.StringJoin(sql_str, " group by ", groupby)
}
if having != "" {
sql_str = helper.StringJoin(sql_str, " HAVING ", having)
}
if orderby != "" {
sql_str = helper.StringJoin(sql_str, " order by ", orderby)
}


+ 115
- 2
transaction_chain.go View File

@ -23,9 +23,11 @@ type TxQuery struct {
join [][]string //[["tablea as a","a.id=b.id","left"]]
data []string
value []interface{}
save_data []map[string]interface{}
save_data []map[string]interface{} //批量操作的数据[["title":"a","num":1,],["title":"a","num":1,]]
upd_field []string // 批量更新时需要更新的字段,为空时按除id外的字段进行更新
orderby string
groupby string
having string
page int
page_size int
stmt *sql.Stmt
@ -99,7 +101,10 @@ func (this *TxQuery) Groupby(groupby string) *TxQuery {
this.groupby = groupby
return this
}
func (this *TxQuery) Having(having string) *TxQuery {
this.having = having
return this
}
func (this *TxQuery) Where(where string) *TxQuery {
this.where = append(this.where, where)
return this
@ -126,6 +131,14 @@ func (this *TxQuery) SaveDatas(value []map[string]interface{}) *TxQuery {
this.save_data = append(this.save_data, value...)
return this
}
func (this *TxQuery) UpdField(value string) *TxQuery {
this.upd_field = append(this.upd_field, value)
return this
}
func (this *TxQuery) UpdFields(value []string) *TxQuery {
this.upd_field = append(this.upd_field, value...)
return this
}
func (this *TxQuery) Values(values []interface{}) *TxQuery {
this.value = append(this.value, values...)
return this
@ -161,6 +174,9 @@ func (this *TxQuery) Clean() *TxQuery {
this.groupby = ""
this.page = 0
this.page_size = 0
this.save_data = this.save_data[0:0]
this.upd_field = this.upd_field[0:0]
this.having = ""
return this
}
@ -219,6 +235,10 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
if this.groupby != "" {
sql = helper.StringJoin(sql, " group by ", this.groupby)
}
if this.having != "" {
sql = helper.StringJoin(sql, " having ", this.having)
}
if this.orderby != "" {
sql = helper.StringJoin(sql, " order by ", this.orderby)
@ -323,6 +343,88 @@ func (this *TxQuery) UpdateStmt() error {
return nil
}
// 拼批量存在更新不存在插入sql
func (this *TxQuery) UpdateAllStmt() error {
if this.dbname == "" && this.table == "" {
return errors.New("参数错误,没有数据表")
}
dbName := getTableName(this.dbname, this.table)
var sql string
var dataSql = []string{}
var valSql = []string{}
var updSql = []string{}
var updFieldLen = len(this.upd_field)
if len(this.save_data) > 0 {
this.data = []string{}
this.value = []interface{}{}
for i, datum := range this.save_data {
if i == 0 {
for k, _ := range datum {
this.data = append(this.data, k)
dataSql = append(dataSql, "?")
if updFieldLen == 0 && k != "id" {
updSql = append(updSql, k+"=values("+k+")")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
}
}
}
for _, k := range this.data {
this.value = append(this.value, datum[k])
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
}
} else {
for _, datum := range this.data {
dataSql = append(dataSql, "?")
if updFieldLen == 0 && datum != "id" {
updSql = append(updSql, datum+"=values("+datum+")")
}
}
if updFieldLen > 0 {
for _, k := range this.upd_field {
updSql = append(updSql, k+"=values("+k+")")
}
}
valSql = append(valSql, "("+strings.Join(dataSql, " , ")+")")
}
if len(this.data) == 0 {
return errors.New("参数错误,没有数据表")
}
if len(this.value) == 0 {
return errors.New("参数错误,条件值错误")
}
setText := " values "
if len(valSql) > 1 {
setText = " value "
}
sql = helper.StringJoin("insert into ", dbName, " (", strings.Join(this.data, " , "), ")", setText, strings.Join(valSql, ","))
if len(this.value) == 0 {
return errors.New("参数错误,条件值错误")
}
if this.debug {
log.Println("insert on duplicate key update sql:", sql, this.value)
}
stmt, err = this.tx.Prepare(sql)
if err != nil {
return err
}
this.stmt = stmt
return nil
}
// 拼插入sql
func (this *TxQuery) CreateStmt() error {
@ -522,6 +624,17 @@ func (this *TxQuery) Update() (int64, error) {
return StmtForUpdateExec(this.stmt, this.value)
}
//批量更新
func (this *TxQuery) UpdateAll() (int64, error) {
err := this.UpdateAllStmt()
if err != nil {
return 0, err
}
return StmtForUpdateExec(this.stmt, this.value)
}
/**
* 执行删除
* return is_delete error


Loading…
Cancel
Save