Browse Source

增加with语句

master
zhenghaorong 1 year ago
parent
commit
6d7835939e
3 changed files with 183 additions and 36 deletions
  1. +58
    -14
      chain.go
  2. +90
    -20
      db.go
  3. +35
    -2
      transaction_chain.go

+ 58
- 14
chain.go View File

@ -34,6 +34,7 @@ type Query struct {
conn *sql.DB conn *sql.DB
debug bool debug bool
dbtype string dbtype string
with [][]string //[[临时表的sql语句,临时表的名称]]
} }
func NewQuery(t ...string) *Query { func NewQuery(t ...string) *Query {
@ -103,6 +104,14 @@ func (this *Query) Groupby(groupby string) *Query {
this.groupby = groupby this.groupby = groupby
return this return this
} }
func (this *Query) With(with []string) *Query {
this.with = append(this.with, with)
return this
}
func (this *Query) Withs(withs [][]string) *Query {
this.with = append(this.with, withs...)
return this
}
func (this *Query) Where(where string) *Query { func (this *Query) Where(where string) *Query {
this.where = append(this.where, where) this.where = append(this.where, where)
return this return this
@ -198,6 +207,7 @@ func (this *Query) Clean() *Query {
this.upd_field = this.upd_field[0:0] this.upd_field = this.upd_field[0:0]
this.having = "" this.having = ""
this.alias = "" this.alias = ""
this.with = this.with[0:0]
return this return this
} }
@ -278,7 +288,7 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
return nil, errors.New("参数错误,没有数据表") return nil, errors.New("参数错误,没有数据表")
} }
var table = "" var table = ""
if strings.Contains(this.table, "select ") {
if strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table table = this.table
} else { } else {
table = getTableName(this.dbname, this.table, this.dbtype) table = getTableName(this.dbname, this.table, this.dbtype)
@ -293,18 +303,40 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
} else { } else {
title = "*" title = "*"
} }
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if this.dbtype == "mssql" { if this.dbtype == "mssql" {
if this.page_size > 0 { if this.page_size > 0 {
sql = helper.StringJoin("select top ", helper.ToStr(this.page_size), " ")
sql = helper.StringJoin(withSql, "select top ", helper.ToStr(this.page_size), " ")
} else { } else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
} }
} else { } else {
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql = "/*slave*/ select "
sql = helper.StringJoin("/*slave*/ ", withSql, " select ")
} else { } else {
sql = "select "
sql = helper.StringJoin(withSql, "select ")
} }
} }
@ -317,17 +349,29 @@ func (this *Query) BuildSelectSql() (map[string]interface{}, error) {
sql = helper.StringJoin(sql, " from ", table) sql = helper.StringJoin(sql, " from ", table)
if len(this.join) > 0 { if len(this.join) > 0 {
join_type := "left"
var builder strings.Builder
for _, joinitem := range this.join { for _, joinitem := range this.join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 3 {
join_type = joinitem[2]
} else { //默认左连接
join_type = "left"
builder.WriteString(sql)
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
}
builder.WriteString(" join ")
if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(this.dbname, joinitem[0]))
} }
sql = helper.StringJoin(sql, " ", join_type, " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1])
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
}
if builder.Len() > 0 {
sql = builder.String()
} }
} }
if len(this.where) > 0 || len(this.where_or) > 0 { if len(this.where) > 0 || len(this.where_or) > 0 {
@ -772,7 +816,7 @@ func (this *Query) DeleteStmt() error {
*/ */
func (this *Query) Select() ([]map[string]string, error) { func (this *Query) Select() ([]map[string]string, error) {
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.join,
_, rows, err := FetchRows(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug) this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.page, this.page_size, this.debug)
return rows, err return rows, err
@ -803,7 +847,7 @@ func (this *Query) List() ([]map[string]string, error) {
*/ */
func (this *Query) Find() (map[string]string, error) { func (this *Query) Find() (map[string]string, error) {
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.join,
_, row, err := GetRow(this.dbname, this.table, this.alias, this.title, this.with, this.join,
this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug) this.where, this.where_or, this.value, this.orderby, this.groupby, this.having, this.debug)
return row, err return row, err


+ 90
- 20
db.go View File

@ -295,7 +295,7 @@ func GetData(dbName, table string, title string, where map[string]string, limit
* @param dbName 数据表名 * @param dbName 数据表名
* @param title 查询字段名 * @param title 查询字段名
*/ */
func GetRow(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
func GetRow(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, debug bool) (int, map[string]string, error) {
var count int = 0 var count int = 0
info := make(map[string]string) info := make(map[string]string)
@ -304,7 +304,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} }
table := "" table := ""
if strings.Contains(table_name, "select ") {
if strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name table = table_name
} else { } else {
table = getTableName(dbName, table_name) table = getTableName(dbName, table_name)
@ -316,10 +316,33 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
} else { } else {
title = "*" title = "*"
} }
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else { } else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
} }
if alias != "" { if alias != "" {
table = helper.StringJoin(table, " as ", alias) table = helper.StringJoin(table, " as ", alias)
@ -328,17 +351,29 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
sql_str = helper.StringJoin(sql_str, " from ", table) sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 { if len(join) > 0 {
var builder strings.Builder
for _, joinitem := range join { for _, joinitem := range join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(sql_str)
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
} }
builder.WriteString(" join ")
if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
}
if builder.Len() > 0 {
sql_str = builder.String()
} }
} }
if len(where) > 0 || len(where_or) > 0 { if len(where) > 0 || len(where_or) > 0 {
@ -449,7 +484,7 @@ func GetRow(dbName, table_name, alias string, titles string, join [][]string, wh
* @param dbName 数据表名 * @param dbName 数据表名
* @param title 查询字段名 * @param title 查询字段名
*/ */
func FetchRows(dbName, table_name, alias string, titles string, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
func FetchRows(dbName, table_name, alias string, titles string, with, join [][]string, where, where_or []string, valueList []interface{}, orderby, groupby, having string, page int, page_size int, debug bool) (int, []map[string]string, error) {
var count int = 0 var count int = 0
list := make([]map[string]string, 0) list := make([]map[string]string, 0)
@ -457,7 +492,7 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
return count, list, errors.New("没有数据表") return count, list, errors.New("没有数据表")
} }
table := "" table := ""
if strings.Contains(table_name, "select ") {
if strings.Contains(table_name, "select ") || strings.HasPrefix(table, "(") {
table = table_name table = table_name
} else { } else {
table = getTableName(dbName, table_name) table = getTableName(dbName, table_name)
@ -470,10 +505,33 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
} else { } else {
title = "*" title = "*"
} }
withSql := ""
if len(with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
if DB_PROVIDER == "TencentDB" { if DB_PROVIDER == "TencentDB" {
sql_str = helper.StringJoin("/*slave*/ select ", title)
sql_str = helper.StringJoin("/*slave*/ ", withSql, " select ", title)
} else { } else {
sql_str = helper.StringJoin("select ", title)
sql_str = helper.StringJoin(withSql, "select ", title)
} }
if alias != "" { if alias != "" {
table = helper.StringJoin(table, " as ", alias) table = helper.StringJoin(table, " as ", alias)
@ -482,17 +540,29 @@ func FetchRows(dbName, table_name, alias string, titles string, join [][]string,
sql_str = helper.StringJoin(sql_str, " from ", table) sql_str = helper.StringJoin(sql_str, " from ", table)
if len(join) > 0 { if len(join) > 0 {
var builder strings.Builder
for _, joinitem := range join { for _, joinitem := range join {
if len(joinitem) < 2 { if len(joinitem) < 2 {
continue continue
} }
if len(joinitem) == 4 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", joinitem[0], " on ", joinitem[1])
} else if len(joinitem) == 3 {
sql_str = helper.StringJoin(sql_str, " ", joinitem[2], " join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
} else { //默认左连接
sql_str = helper.StringJoin(sql_str, " left join ", getTableName(dbName, joinitem[0]), " on ", joinitem[1])
builder.WriteString(sql_str)
builder.WriteString(" ")
if len(joinitem) >= 3 {
builder.WriteString(joinitem[2])
} else {
builder.WriteString("left")
} }
builder.WriteString(" join ")
if strings.Contains(joinitem[0], "select ") || strings.HasPrefix(joinitem[0], "(") || len(joinitem) > 4 {
builder.WriteString(joinitem[0])
} else {
builder.WriteString(getTableName(dbName, joinitem[0]))
}
builder.WriteString(" on ")
builder.WriteString(joinitem[1])
}
if builder.Len() > 0 {
sql_str = builder.String()
} }
} }


+ 35
- 2
transaction_chain.go View File

@ -34,6 +34,7 @@ type TxQuery struct {
conn *sql.DB conn *sql.DB
tx *sql.Tx tx *sql.Tx
debug bool debug bool
with [][]string //[[临时表的sql语句,临时表的名称]]
} }
func NewTxQuery(t ...string) *TxQuery { func NewTxQuery(t ...string) *TxQuery {
@ -109,6 +110,14 @@ func (this *TxQuery) Where(where string) *TxQuery {
this.where = append(this.where, where) this.where = append(this.where, where)
return this return this
} }
func (this *TxQuery) With(with []string) *TxQuery {
this.with = append(this.with, with)
return this
}
func (this *TxQuery) Withs(withs [][]string) *TxQuery {
this.with = append(this.with, withs...)
return this
}
func (this *TxQuery) Wheres(wheres []string) *TxQuery { func (this *TxQuery) Wheres(wheres []string) *TxQuery {
if len(wheres) > 0 { if len(wheres) > 0 {
this.where = append(this.where, wheres...) this.where = append(this.where, wheres...)
@ -198,6 +207,7 @@ func (this *TxQuery) Clean() *TxQuery {
this.upd_field = this.upd_field[0:0] this.upd_field = this.upd_field[0:0]
this.having = "" this.having = ""
this.alias = "" this.alias = ""
this.with = this.with[0:0]
return this return this
} }
@ -212,7 +222,7 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
return nil, errors.New("参数错误,没有数据表") return nil, errors.New("参数错误,没有数据表")
} }
var table = "" var table = ""
if strings.Contains(this.table, "select ") {
if strings.Contains(this.table, "select ") || strings.HasPrefix(this.table, "(") {
table = this.table table = this.table
} else { } else {
table = getTableName(this.dbname, this.table) table = getTableName(this.dbname, this.table)
@ -225,7 +235,30 @@ func (this *TxQuery) BuildSelectSql() (map[string]interface{}, error) {
} else { } else {
title = "*" title = "*"
} }
sql = helper.StringJoin("select ", title)
withSql := ""
if len(this.with) > 0 {
var builder strings.Builder
builder.WriteString("WITH ")
boo := false
for k, v := range this.with {
if len(v) < 2 {
continue
}
if k != 0 {
builder.WriteString(", ")
}
builder.WriteString(v[1])
builder.WriteString(" as (")
builder.WriteString(v[0])
builder.WriteString(")")
boo = true
}
if boo {
builder.WriteString(" ")
withSql = builder.String()
}
}
sql = helper.StringJoin(withSql, "select ", title)
if this.alias != "" { if this.alias != "" {
table = helper.StringJoin(table, " as ", this.alias) table = helper.StringJoin(table, " as ", this.alias)


Loading…
Cancel
Save