diff --git a/chain.go b/chain.go index 0d7c8bd..9f05bdc 100644 --- a/chain.go +++ b/chain.go @@ -3,15 +3,15 @@ package dbquery import ( "database/sql" "errors" - "log" - - // "log" "strconv" "strings" "git.tetele.net/tgo/helper" ) +var stmt *sql.Stmt +var err error + type Query struct { dbname string table string @@ -99,14 +99,13 @@ func (this *Query) QueryStmt() error { if this.dbname == "" && this.table == "" { return errors.New("参数错误,没有数据表") } - if len(this.where)+len(this.where_or) < len(this.value) { - return errors.New("参数错误,条件值错误") - } + // if len(this.where)+len(this.where_or) < len(this.value) { + // return errors.New("参数错误,条件值错误") + // } table := getTableName(this.dbname, this.table) - var stmt *sql.Stmt - var err error + // var err error var sql, title string @@ -167,7 +166,17 @@ func (this *Query) QueryStmt() error { sql = helper.StringJoin(sql, " limit ", from, " , ", offset) } } - log.Println(sql) + // log.Println(sql) + condition_len := 0 //所有条件数 + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + stmt, err = DB.Prepare(sql) if err != nil { @@ -191,8 +200,8 @@ func (this *Query) UpdateStmt() error { dbName := getTableName(this.dbname, this.table) - var stmt *sql.Stmt - var err error + // var stmt *sql.Stmt + // var err error var sql string @@ -200,6 +209,17 @@ func (this *Query) UpdateStmt() error { sql = helper.StringJoin(sql, " where ", strings.Join(this.where, " and ")) + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + stmt, err = DB.Prepare(sql) if err != nil { @@ -220,13 +240,24 @@ func (this *Query) CreateStmt() error { dbName := getTableName(this.dbname, this.table) - var stmt *sql.Stmt - var err error + // var stmt *sql.Stmt + // var err error var sql string sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + stmt, err = DB.Prepare(sql) if err != nil { @@ -247,14 +278,14 @@ func (this *Query) DeleteStmt() error { if len(this.where) < 1 { return errors.New("参数错误,缺少条件") } - if len(this.where) != len(this.value) { - return errors.New("参数错误,条件值错误") - } + // if len(this.where) != len(this.value) { + // return errors.New("参数错误,条件值错误") + // } dbName := getTableName(this.dbname, this.table) - var stmt *sql.Stmt - var err error + // var stmt *sql.Stmt + // var err error var sql string @@ -264,6 +295,17 @@ func (this *Query) DeleteStmt() error { sql = helper.StringJoin(sql, " limit ", strconv.Itoa(this.page_size)) } + condition_len := 0 //所有条件数 + + for _, ch2 := range sql { + if string(ch2) == "?" { + condition_len++ + } + } + if condition_len != len(this.value) { + return errors.New("参数错误,条件值错误") + } + stmt, err = DB.Prepare(sql) if err != nil { diff --git a/chain_test.go b/chain_test.go index ecf3af2..aef4f56 100644 --- a/chain_test.go +++ b/chain_test.go @@ -7,7 +7,7 @@ import ( func Test_Chain(t *testing.T) { Connect("127.0.0.1", "root", "123456", "test1_tetele_com", "3306") - ret, err := new(Query).Db("test1_tetele_com").Table("ttl_user").Title("id,username").WhereOr("id =?").WhereOr("id = ?").Value(2).Value(4).PageSize(4).Select() + ret, err := new(Query).Db("test1_tetele_com").Table("ttl_user").Title("id,username").WhereOr("id =?").WhereOr("id = ?").Value(2).Value(4).Value(4).PageSize(4).Select() t.Log(len(ret)) t.Log(ret) diff --git a/go.mod b/go.mod index cfc9706..be40e1b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module git.tetele.net/tgo/dbquery go 1.14 require ( - git.tetele.net/tgo/helper v0.1.0 // indirect - github.com/go-sql-driver/mysql v1.5.0 // indirect + git.tetele.net/tgo/helper v0.1.0 + github.com/go-sql-driver/mysql v1.5.0 )