From bd754d1507b155c5711d34316515c4b4d0e185d0 Mon Sep 17 00:00:00 2001 From: loshiqi <553578653@qq.com> Date: Tue, 15 Jul 2025 17:18:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9=E9=AB=98=E6=96=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 7 +++++-- prepare.go | 20 ++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/chain.go b/chain.go index c424770..65720ad 100644 --- a/chain.go +++ b/chain.go @@ -715,7 +715,10 @@ func (this *Query) CreateAllStmt() error { if this.conn == nil { this.conn = DB } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + sql = helper.StringJoin(sql, " RETURNING id") + } stmt, err = this.conn.Prepare(sql) if err != nil { @@ -754,7 +757,7 @@ func (this *Query) CreateStmt() error { } } - sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") } else { sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) } diff --git a/prepare.go b/prepare.go index 37dde8b..26955ed 100644 --- a/prepare.go +++ b/prepare.go @@ -234,11 +234,23 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) { * @return lastId error */ func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) { - res, err := stmt.Exec(valuelist...) - if err != nil { - return 0, errors.New("创建失败:" + err.Error()) + if DB_PROVIDER == "PgsqlDb" { + row := stmt.QueryRow(valuelist...) + var id int64 + err = row.Scan(&id) // 扫描 RETURNING 返回的 ID + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return id, nil + + } else { + res, err := stmt.Exec(valuelist...) + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return res.LastInsertId() } - return res.LastInsertId() + } /**