diff --git a/chain.go b/chain.go index 7198141..0cfb3e6 100644 --- a/chain.go +++ b/chain.go @@ -29,25 +29,31 @@ type Query struct { stmt *sql.Stmt conn *sql.DB debug bool + dbtype string } func NewQuery(t ...string) *Query { var conn_type *sql.DB = DB + var db_type string = "mysql" + if len(t) > 0 { switch t[0] { case "mysql": conn_type = DB + db_type = "mysql" case "mssql": //sql server conn_type = MSDB_CONN + db_type = "mssql" } } return &Query{ - conn: conn_type, + conn: conn_type, + dbtype: db_type, } } @@ -149,7 +155,7 @@ func (this *Query) QueryStmt() error { return errors.New("参数错误,没有数据表") } - table := getTableName(this.dbname, this.table) + table := getTableName(this.dbname, this.table, this.dbtype) // var err error @@ -174,9 +180,9 @@ func (this *Query) QueryStmt() error { continue } if len(joinitem) == 3 { - sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1]) + sql = helper.StringJoin(sql, " ", joinitem[2], " join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1]) } else { //默认左连接 - sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0]), " on ", joinitem[1]) + sql = helper.StringJoin(sql, " left join ", getTableName(this.dbname, joinitem[0], this.dbtype), " on ", joinitem[1]) } } } @@ -250,7 +256,7 @@ func (this *Query) UpdateStmt() error { return errors.New("参数错误,缺少条件") } - dbName := getTableName(this.dbname, this.table) + dbName := getTableName(this.dbname, this.table, this.dbtype) var sql string @@ -294,7 +300,7 @@ func (this *Query) CreateStmt() error { return errors.New("参数错误,没有数据表") } - dbName := getTableName(this.dbname, this.table) + dbName := getTableName(this.dbname, this.table, this.dbtype) var sql string @@ -339,7 +345,7 @@ func (this *Query) DeleteStmt() error { return errors.New("参数错误,缺少条件") } - dbName := getTableName(this.dbname, this.table) + dbName := getTableName(this.dbname, this.table, this.dbtype) var sql string diff --git a/conn.go b/conn.go index 54882b2..b010bd1 100755 --- a/conn.go +++ b/conn.go @@ -58,16 +58,36 @@ func CloseConn() error { /** * 检测表名 */ -func getTableName(dbName, table string) string { +func getTableName(dbName, table string, dbtype ...string) string { - if strings.Contains(table, ".") { - return table + var db_type string = "mysql" + + if len(dbtype) > 0 { + if dbtype[0] != "" { + db_type = dbtype[0] + } } - if dbName != "" { - return helper.StringJoin(dbName, ".", table) - } else { - return table + + var ret string + + switch db_type { + + case "mysql": + + if strings.Contains(table, ".") { + ret = table + } + if dbName != "" { + ret = helper.StringJoin(dbName, ".", table) + } else { + ret = table + } + + case "mssql": + ret = helper.StringJoin(dbName, ".", table) + } + return ret } func GetDbTableName(dbName, table string) string { diff --git a/sqlserver.go b/sqlserver.go index 8675d26..c8cc099 100644 --- a/sqlserver.go +++ b/sqlserver.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "strconv" "log" @@ -26,7 +27,9 @@ func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error //连接字符串 - connString := fmt.Sprintf("server=%s;port%d;database=%s;user id=%s;password=%s", DBHOST, DBPORT, DBNAME, DBUSER, DBPWD) + db_port, _ := strconv.Atoi(DBPORT) + + connString := fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s", DBHOST, db_port, DBNAME, DBUSER, DBPWD) log.Println(connString) @@ -39,7 +42,9 @@ func MSConnect(DBHOST, DBUSER, DBPWD, DBNAME, DBPORT string, conns ...int) error time.Sleep(time.Second * 5) } else { - log.Println("msdb connected") + err = MSDB_CONN.Ping() + + log.Println("msdb connected", err) break }