diff --git a/chain.go b/chain.go index c424770..65720ad 100644 --- a/chain.go +++ b/chain.go @@ -715,7 +715,10 @@ func (this *Query) CreateAllStmt() error { if this.conn == nil { this.conn = DB } - + if DB_PROVIDER == "PgsqlDb" { + sql = sqlx.Rebind(sqlx.DOLLAR, sql) + sql = helper.StringJoin(sql, " RETURNING id") + } stmt, err = this.conn.Prepare(sql) if err != nil { @@ -754,7 +757,7 @@ func (this *Query) CreateStmt() error { } } - sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")") + sql = helper.StringJoin("insert into ", dbName, " ("+strings.Join(insert_data, " , ")+")", " VALUES ", "("+strings.Join(value_data, " , ")+")", " RETURNING id") } else { sql = helper.StringJoin("insert into ", dbName, " set ", strings.Join(this.data, " , ")) } diff --git a/prepare.go b/prepare.go index 37dde8b..26955ed 100644 --- a/prepare.go +++ b/prepare.go @@ -234,11 +234,23 @@ func StmtForInsert(dbName, table string, data []string) (*sql.Stmt, error) { * @return lastId error */ func StmtForInsertExec(stmt *sql.Stmt, valuelist []interface{}) (int64, error) { - res, err := stmt.Exec(valuelist...) - if err != nil { - return 0, errors.New("创建失败:" + err.Error()) + if DB_PROVIDER == "PgsqlDb" { + row := stmt.QueryRow(valuelist...) + var id int64 + err = row.Scan(&id) // 扫描 RETURNING 返回的 ID + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return id, nil + + } else { + res, err := stmt.Exec(valuelist...) + if err != nil { + return 0, errors.New("创建失败:" + err.Error()) + } + return res.LastInsertId() } - return res.LastInsertId() + } /**