diff --git a/kernel/api/ai.go b/kernel/api/ai.go index 2aa5b1005..13717e233 100644 --- a/kernel/api/ai.go +++ b/kernel/api/ai.go @@ -54,3 +54,38 @@ func chatGPTContinueWriteBlocks(c *gin.Context) { } ret.Data = model.ChatGPTContinueWriteBlocks(ids) } + +func chatGPTTranslate(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + idsArg := arg["ids"].([]interface{}) + var ids []string + for _, id := range idsArg { + ids = append(ids, id.(string)) + } + lang := arg["lang"].(string) + ret.Data = model.ChatGPTTranslate(ids, lang) +} + +func chatGPTSummary(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + idsArg := arg["ids"].([]interface{}) + var ids []string + for _, id := range idsArg { + ids = append(ids, id.(string)) + } + ret.Data = model.ChatGPTSummary(ids) +} diff --git a/kernel/api/router.go b/kernel/api/router.go index 2030ed49c..de415a515 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -330,4 +330,6 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/ai/chatGPT", model.CheckAuth, chatGPT) ginServer.Handle("POST", "/api/ai/chatGPTContinueWriteBlocks", model.CheckAuth, chatGPTContinueWriteBlocks) + ginServer.Handle("POST", "/api/ai/chatGPTTranslate", model.CheckAuth, chatGPTTranslate) + ginServer.Handle("POST", "/api/ai/chatGPTSummary", model.CheckAuth, chatGPTSummary) } diff --git a/kernel/model/ai.go b/kernel/model/ai.go index de3f1872a..401e69014 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -23,20 +23,32 @@ import ( "github.com/siyuan-note/siyuan/kernel/util" ) +func ChatGPTSummary(ids []string) (ret string) { + if !isOpenAIAPIEnabled() { + return + } + + msg := getBlocksContent(ids) + ret = util.ChatGPTSummary(msg, Conf.Lang) + return +} + +func ChatGPTTranslate(ids []string, lang string) (ret string) { + if !isOpenAIAPIEnabled() { + return + } + + msg := getBlocksContent(ids) + ret = util.ChatGPTTranslate(msg, lang) + return +} + func ChatGPTContinueWriteBlocks(ids []string) (ret string) { if !isOpenAIAPIEnabled() { return } - sqlBlocks := sql.GetBlocks(ids) - - buf := bytes.Buffer{} - for _, sqlBlock := range sqlBlocks { - buf.WriteString(sqlBlock.Content) - buf.WriteString("\n\n") - } - - msg := buf.String() + msg := getBlocksContent(ids) ret, _ = util.ChatGPTContinueWrite(msg, nil) return } @@ -56,3 +68,14 @@ func isOpenAIAPIEnabled() bool { } return true } + +func getBlocksContent(ids []string) string { + sqlBlocks := sql.GetBlocks(ids) + buf := bytes.Buffer{} + for _, sqlBlock := range sqlBlocks { + buf.WriteString(sqlBlock.Content) + buf.WriteString("\n\n") + } + + return buf.String() +} diff --git a/kernel/util/openai.go b/kernel/util/openai.go index db71654c4..ca98d6528 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -45,6 +45,18 @@ func ChatGPT(msg string) (ret string) { return } +func ChatGPTTranslate(msg string, lang string) (ret string) { + msg = "Translate to " + lang + ":\n" + msg + ret, _ = ChatGPTContinueWrite(msg, nil) + return +} + +func ChatGPTSummary(msg string, lang string) (ret string) { + msg = "Summarized as follows, the result is in {" + lang + "}:\n" + msg + ret, _ = ChatGPTContinueWrite(msg, nil) + return +} + func ChatGPTContinueWrite(msg string, contextMsgs []string) (ret string, retContextMsgs []string) { if "" == OpenAIAPIKey { return