diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 401e69014..bc48dc5a9 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -19,7 +19,9 @@ package model import ( "bytes" - "github.com/siyuan-note/siyuan/kernel/sql" + "github.com/88250/lute/ast" + "github.com/88250/lute/parse" + "github.com/siyuan-note/siyuan/kernel/treenode" "github.com/siyuan-note/siyuan/kernel/util" ) @@ -70,12 +72,40 @@ func isOpenAIAPIEnabled() bool { } func getBlocksContent(ids []string) string { - sqlBlocks := sql.GetBlocks(ids) - buf := bytes.Buffer{} - for _, sqlBlock := range sqlBlocks { - buf.WriteString(sqlBlock.Content) - buf.WriteString("\n\n") + var nodes []*ast.Node + trees := map[string]*parse.Tree{} + for _, id := range ids { + bt := treenode.GetBlockTree(id) + if nil == bt { + continue + } + + var tree *parse.Tree + if tree = trees[bt.RootID]; nil == tree { + tree, _ = loadTreeByBlockID(bt.RootID) + if nil == tree { + continue + } + + trees[bt.RootID] = tree + } + + if node := treenode.GetNodeInTree(tree, id); nil != node { + if ast.NodeDocument == node.Type { + for child := node.FirstChild; nil != child; child = child.Next { + nodes = append(nodes, child) + } + } else { + nodes = append(nodes, node) + } + } } + buf := bytes.Buffer{} + for _, node := range nodes { + content := treenode.NodeStaticContent(node, nil, true) + buf.WriteString(content) + buf.WriteString("\n\n") + } return buf.String() }