From 6d7835939e76e894d9eaefd78ac38be611978cd5 Mon Sep 17 00:00:00 2001 From: zhenghaorong Date: Tue, 24 Sep 2024 10:51:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0with=E8=AF=AD=E5=8F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chain.go | 72 ++++++++++++++++++++++------ db.go | 110 +++++++++++++++++++++++++++++++++++-------- transaction_chain.go | 37 ++++++++++++++- 3 files changed, 183 insertions(+), 36 deletions(-) diff --git a/chain.go b/chain.go index 66f99dc..5f96812 100644 --- a/chain.go +++ b/chain.go @@ -34,6 +34,7 @@ type Query struct { conn *sql.DB debug bool dbtype string + with [][]string //[[临时表的sql语句,临时表的名称]] } func NewQuery(t ...string) *Query { @@ -103,6 +104,14 @@ func (this *Query) Groupby(groupby string) *Query { this.groupby = groupby return this } +func (this *Query) With(with []string) *Query { + this.with = append(this.with, with) + return this +} +func (this *Query) Withs(withs [][]string) *Query { + this.with = append(this.with, withs...) + return this +} func (this *Query) Where(where string) *Query { this.where = append(this.where, where) return this @@ -198,6 +207,7 @@ func (this *Query) Clean() *Query { this.upd_field = this.upd_field[0:0] this.having = "" this.alias = "" + this.with = this.with[0:0] return this } @@ -278,7 +288,7 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { return nil, errors.New("参数错误,没有数据表") } var table = "" - if strings.Contains(this.table, "select ") { + if strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") { table = this.table } else { table = getTableName(this.dbname, this.table, this.dbtype) @@ -293,18 +303,40 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { } else { title = "*" } - + withSql := "" + if len(this.with) > 0 { + var builder strings.Builder + builder.WriteString("WITH ") + boo := false + for k, v := range this.with { + if len(v) < 2 { + continue + } + if k != 0 { + builder.WriteString(", ") + } + builder.WriteString(v[1]) + builder.WriteString(" as (") + builder.WriteString(v[0]) + builder.WriteString(")") + boo = true + } + if boo { + builder.WriteString(" ") + withSql = builder.String() + } + } if this.dbtype == "mssql" { if this.page_size > 0 { - sql = helper.StringJoin("select top ", helper.ToStr(this.page_size), " ") + sql = helper.StringJoin(withSql, "select top ", helper.ToStr(this.page_size), " ") } else { - sql = "select " + sql = helper.StringJoin(withSql, "select ") } } else { if DB_PROVIDER == "TencentDB" { - sql = "/*slave*/ select " + sql = helper.StringJoin("/*slave*/ ", withSql, " select ") } else { - sql = "select " + sql = helper.StringJoin(withSql, "select ") } } @@ -317,17 +349,29 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) { sql = helper.StringJoin(sql, " from ", table) if len(this.join) > 0 { - join_type := "left" + var builder strings.Builder for _, joinitem := range this.join { if len(joinitem) < 2 { continue } - if len(joinitem) == 3 { - join_type = joinitem[2] - } else { //默认左连接 - join_type = "left" + builder.WriteString(sql) + builder.WriteString(" ") + if len(joinitem) >= 3 { + builder.WriteString(joinitem[2]) + } else { + builder.WriteString("left") + } + builder.WriteString(" join ") + if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") { + builder.WriteString(joinitem[0]) + } else { + builder.WriteString(getTableName(this.dbname, joinitem[0])) } - sql = helper.StringJoin(sql, " ", join_type, " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1]) + builder.WriteString(" on ") + builder.WriteString(joinitem[1]) + } + if builder.Len() > 0 { + sql = builder.String() } } if len(this.where) > 0 || len(this.where_or) > 0 { @@ -772,7 +816,7 @@ func (this *Query) DeleteStmt() error { */ func (this *Query) Select() ([]map[string]string, error) { - _, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join, + _, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.with, this.join, this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug) return rows, err @@ -803,7 +847,7 @@ func (this *Query) List() ([]map[string]string, error) { */ func (this *Query) Find() (map[string]string, error) { - _, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join, + _, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.with, this.join, this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug) return row, err diff --git a/db.go b/db.go index f636724..e2acdd2 100644 --- a/db.go +++ b/db.go @@ -295,7 +295,7 @@ func GetData(dbName, table string, title string, where map[string]string, limit * @param dbName 数据表名 * @param title 查询字段名 */ -func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) { +func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) { var count int = 0 info := make(map[string]string) @@ -304,7 +304,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh } table := "" - if strings.Contains(table_name, "select ") { + if strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") { table = table_name } else { table = getTableName(dbName, table_name) @@ -316,10 +316,33 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh } else { title = "*" } + withSql := "" + if len(with) > 0 { + var builder strings.Builder + builder.WriteString("WITH ") + boo := false + for k, v := range with { + if len(v) < 2 { + continue + } + if k != 0 { + builder.WriteString(", ") + } + builder.WriteString(v[1]) + builder.WriteString(" as (") + builder.WriteString(v[0]) + builder.WriteString(")") + boo = true + } + if boo { + builder.WriteString(" ") + withSql = builder.String() + } + } if DB_PROVIDER == "TencentDB" { - sql_str = helper.StringJoin("/*slave*/ select ", title) + sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title) } else { - sql_str = helper.StringJoin("select ", title) + sql_str = helper.StringJoin(withSql, "select ", title) } if alias != "" { table = helper.StringJoin(table, " as ", alias) @@ -328,17 +351,29 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh sql_str = helper.StringJoin(sql_str, " from ", table) if len(join) > 0 { + var builder strings.Builder for _, joinitem := range join { if len(joinitem) < 2 { continue } - if len(joinitem) == 4 { - sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1]) - } else if len(joinitem) == 3 { - sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1]) - } else { //默认左连接 - sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1]) + builder.WriteString(sql_str) + builder.WriteString(" ") + if len(joinitem) >= 3 { + builder.WriteString(joinitem[2]) + } else { + builder.WriteString("left") } + builder.WriteString(" join ") + if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 { + builder.WriteString(joinitem[0]) + } else { + builder.WriteString(getTableName(dbName, joinitem[0])) + } + builder.WriteString(" on ") + builder.WriteString(joinitem[1]) + } + if builder.Len() > 0 { + sql_str = builder.String() } } if len(where) > 0 || len(where_or) > 0 { @@ -449,7 +484,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh * @param dbName 数据表名 * @param title 查询字段名 */ -func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) { +func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) { var count int = 0 list := make([]map[string]string, 0) @@ -457,7 +492,7 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string, return count, list, errors.New("没有数据表") } table := "" - if strings.Contains(table_name, "select ") { + if strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") { table = table_name } else { table = getTableName(dbName, table_name) @@ -470,10 +505,33 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string, } else { title = "*" } + withSql := "" + if len(with) > 0 { + var builder strings.Builder + builder.WriteString("WITH ") + boo := false + for k, v := range with { + if len(v) < 2 { + continue + } + if k != 0 { + builder.WriteString(", ") + } + builder.WriteString(v[1]) + builder.WriteString(" as (") + builder.WriteString(v[0]) + builder.WriteString(")") + boo = true + } + if boo { + builder.WriteString(" ") + withSql = builder.String() + } + } if DB_PROVIDER == "TencentDB" { - sql_str = helper.StringJoin("/*slave*/ select ", title) + sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title) } else { - sql_str = helper.StringJoin("select ", title) + sql_str = helper.StringJoin(withSql, "select ", title) } if alias != "" { table = helper.StringJoin(table, " as ", alias) @@ -482,17 +540,29 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string, sql_str = helper.StringJoin(sql_str, " from ", table) if len(join) > 0 { + var builder strings.Builder for _, joinitem := range join { if len(joinitem) < 2 { continue } - if len(joinitem) == 4 { - sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1]) - } else if len(joinitem) == 3 { - sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1]) - } else { //默认左连接 - sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1]) + builder.WriteString(sql_str) + builder.WriteString(" ") + if len(joinitem) >= 3 { + builder.WriteString(joinitem[2]) + } else { + builder.WriteString("left") } + builder.WriteString(" join ") + if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 { + builder.WriteString(joinitem[0]) + } else { + builder.WriteString(getTableName(dbName, joinitem[0])) + } + builder.WriteString(" on ") + builder.WriteString(joinitem[1]) + } + if builder.Len() > 0 { + sql_str = builder.String() } } diff --git a/transaction_chain.go b/transaction_chain.go index 11d273a..2cc6934 100644 --- a/transaction_chain.go +++ b/transaction_chain.go @@ -34,6 +34,7 @@ type TxQuery struct { conn *sql.DB tx *sql.Tx debug bool + with [][]string //[[临时表的sql语句,临时表的名称]] } func NewTxQuery(t ...string) *TxQuery { @@ -109,6 +110,14 @@ func (this *TxQuery) Where(where string) *TxQuery { this.where = append(this.where, where) return this } +func (this *TxQuery) With(with []string) *TxQuery { + this.with = append(this.with, with) + return this +} +func (this *TxQuery) Withs(withs [][]string) *TxQuery { + this.with = append(this.with, withs...) + return this +} func (this *TxQuery) Wheres(wheres []string) *TxQuery { if len(wheres) > 0 { this.where = append(this.where, wheres...) @@ -198,6 +207,7 @@ func (this *TxQuery) Clean() *TxQuery { this.upd_field = this.upd_field[0:0] this.having = "" this.alias = "" + this.with = this.with[0:0] return this } @@ -212,7 +222,7 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { return nil, errors.New("参数错误,没有数据表") } var table = "" - if strings.Contains(this.table, "select ") { + if strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") { table = this.table } else { table = getTableName(this.dbname, this.table) @@ -225,7 +235,30 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) { } else { title = "*" } - sql = helper.StringJoin("select ", title) + withSql := "" + if len(this.with) > 0 { + var builder strings.Builder + builder.WriteString("WITH ") + boo := false + for k, v := range this.with { + if len(v) < 2 { + continue + } + if k != 0 { + builder.WriteString(", ") + } + builder.WriteString(v[1]) + builder.WriteString(" as (") + builder.WriteString(v[0]) + builder.WriteString(")") + boo = true + } + if boo { + builder.WriteString(" ") + withSql = builder.String() + } + } + sql = helper.StringJoin(withSql, "select ", title) if this.alias != "" { table = helper.StringJoin(table, " as ", this.alias)