This commit is contained in:
Liang Ding 2023-05-04 10:11:29 +08:00
parent 8ae3b354ab
commit a48154ba84
No known key found for this signature in database
GPG Key ID: 136F30F901A2231D
3 changed files with 91 additions and 8 deletions

View File

@ -21,6 +21,7 @@ import (
"github.com/88250/gulu" "github.com/88250/gulu"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/siyuan-note/siyuan/kernel/model"
"github.com/siyuan-note/siyuan/kernel/sql" "github.com/siyuan-note/siyuan/kernel/sql"
"github.com/siyuan-note/siyuan/kernel/util" "github.com/siyuan-note/siyuan/kernel/util"
) )
@ -35,7 +36,7 @@ func SQL(c *gin.Context) {
} }
stmt := arg["stmt"].(string) stmt := arg["stmt"].(string)
result, err := sql.Query(stmt) result, err := sql.Query(stmt, model.Conf.Search.Limit)
if nil != err { if nil != err {
ret.Code = 1 ret.Code = 1
ret.Msg = err.Error() ret.Msg = err.Error()

View File

@ -624,7 +624,7 @@ func searchBySQL(stmt string, beforeLen, page int) (ret []*Block, matchedBlockCo
stmt = strings.ReplaceAll(stmt, "select * ", "select COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` ") stmt = strings.ReplaceAll(stmt, "select * ", "select COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` ")
} }
stmt = removeLimitClause(stmt) stmt = removeLimitClause(stmt)
result, _ := sql.Query(stmt) result, _ := sql.QueryNoLimit(stmt)
if 1 > len(ret) { if 1 > len(ret) {
return return
} }
@ -745,7 +745,7 @@ func fullTextSearchCountByRegexp(exp, boxFilter, pathFilter, typeFilter string)
fieldFilter := fieldRegexp(exp) fieldFilter := fieldRegexp(exp)
stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE " + fieldFilter + " AND type IN " + typeFilter stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE " + fieldFilter + " AND type IN " + typeFilter
stmt += boxFilter + pathFilter stmt += boxFilter + pathFilter
result, _ := sql.Query(stmt) result, _ := sql.QueryNoLimit(stmt)
if 1 > len(result) { if 1 > len(result) {
return return
} }
@ -785,7 +785,7 @@ func fullTextSearchByFTS(query, boxFilter, pathFilter, typeFilter, orderBy strin
func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (matchedBlockCount, matchedRootCount int) { func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (matchedBlockCount, matchedRootCount int) {
query = gulu.Str.RemoveInvisible(query) query = gulu.Str.RemoveInvisible(query)
if ast.IsNodeIDPattern(query) { if ast.IsNodeIDPattern(query) {
ret, _ := sql.Query("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'") ret, _ := sql.QueryNoLimit("SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `blocks` WHERE `id` = '" + query + "'")
if 1 > len(ret) { if 1 > len(ret) {
return return
} }
@ -802,7 +802,7 @@ func fullTextSearchCount(query, boxFilter, pathFilter, typeFilter string) (match
stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `" + table + "` WHERE (`" + table + "` MATCH '" + columnFilter() + ":(" + query + ")'" stmt := "SELECT COUNT(id) AS `matches`, COUNT(DISTINCT(root_id)) AS `docs` FROM `" + table + "` WHERE (`" + table + "` MATCH '" + columnFilter() + ":(" + query + ")'"
stmt += ") AND type IN " + typeFilter stmt += ") AND type IN " + typeFilter
stmt += boxFilter + pathFilter stmt += boxFilter + pathFilter
result, _ := sql.Query(stmt) result, _ := sql.QueryNoLimit(stmt)
if 1 > len(result) { if 1 > len(result) {
return return
} }

View File

@ -19,6 +19,7 @@ package sql
import ( import (
"bytes" "bytes"
"database/sql" "database/sql"
"math"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -378,7 +379,45 @@ func QueryBookmarkLabels() (ret []string) {
return return
} }
func Query(stmt string) (ret []map[string]interface{}, err error) { func QueryNoLimit(stmt string) (ret []map[string]interface{}, err error) {
return queryRawStmt(stmt, math.MaxInt)
}
func Query(stmt string, limit int) (ret []map[string]interface{}, err error) {
parsedStmt, err := sqlparser.Parse(stmt)
if nil != err {
return queryRawStmt(stmt, limit)
}
switch parsedStmt.(type) {
case *sqlparser.Select:
slct := parsedStmt.(*sqlparser.Select)
if nil == slct.Limit {
slct.Limit = &sqlparser.Limit{
Rowcount: &sqlparser.SQLVal{
Type: sqlparser.IntVal,
Val: []byte(strconv.Itoa(limit)),
},
}
} else {
if nil != slct.Limit.Rowcount && 0 < len(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val) {
limit, _ = strconv.Atoi(string(slct.Limit.Rowcount.(*sqlparser.SQLVal).Val))
if 0 >= limit {
limit = 32
}
}
slct.Limit.Rowcount = &sqlparser.SQLVal{
Type: sqlparser.IntVal,
Val: []byte(strconv.Itoa(limit)),
}
}
stmt = sqlparser.String(slct)
default:
return
}
ret = []map[string]interface{}{} ret = []map[string]interface{}{}
rows, err := query(stmt) rows, err := query(stmt)
if nil != err { if nil != err {
@ -413,6 +452,49 @@ func Query(stmt string) (ret []map[string]interface{}, err error) {
return return
} }
func queryRawStmt(stmt string, limit int) (ret []map[string]interface{}, err error) {
rows, err := query(stmt)
if nil != err {
if strings.Contains(err.Error(), "syntax error") {
return
}
return
}
defer rows.Close()
cols, err := rows.Columns()
if nil != err || nil == cols {
return
}
noLimit := !strings.Contains(strings.ToLower(stmt), " limit ")
var count, errCount int
for rows.Next() {
columns := make([]interface{}, len(cols))
columnPointers := make([]interface{}, len(cols))
for i := range columns {
columnPointers[i] = &columns[i]
}
if err = rows.Scan(columnPointers...); nil != err {
return
}
m := make(map[string]interface{})
for i, colName := range cols {
val := columnPointers[i].(*interface{})
m[colName] = *val
}
ret = append(ret, m)
count++
if (noLimit && limit < count) || 0 < errCount {
break
}
}
return
}
func SelectBlocksRawStmtNoParse(stmt string, limit int) (ret []*Block) { func SelectBlocksRawStmtNoParse(stmt string, limit int) (ret []*Block) {
return selectBlocksRawStmt(stmt, limit) return selectBlocksRawStmt(stmt, limit)
} }
@ -491,7 +573,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) {
} }
defer rows.Close() defer rows.Close()
confLimit := !strings.Contains(strings.ToLower(stmt), " limit ") noLimit := !strings.Contains(strings.ToLower(stmt), " limit ")
var count, errCount int var count, errCount int
for rows.Next() { for rows.Next() {
count++ count++
@ -502,7 +584,7 @@ func selectBlocksRawStmt(stmt string, limit int) (ret []*Block) {
errCount++ errCount++
} }
if (confLimit && limit < count) || 0 < errCount { if (noLimit && limit < count) || 0 < errCount {
break break
} }
} }