diff --git a/app/appearance/langs/en_US.json b/app/appearance/langs/en_US.json index cec4b3e76..2ea378014 100644 --- a/app/appearance/langs/en_US.json +++ b/app/appearance/langs/en_US.json @@ -273,6 +273,7 @@ "apiMaxTokensTip": "The max_tokens parameter passed in when requesting the API is used to control the length of the generated text", "apiBaseURL": "API Base URL", "apiBaseURLTip": "The base address of the request, such as https://api.openai.com/v1", + "apiUserAgentTip": "The user agent that initiated the request, that is, the HTTP header User-Agent", "skip": "Skip", "nextRound": "Next round", "save": "Save", diff --git a/app/appearance/langs/es_ES.json b/app/appearance/langs/es_ES.json index 798039a4e..927a24ada 100644 --- a/app/appearance/langs/es_ES.json +++ b/app/appearance/langs/es_ES.json @@ -273,6 +273,7 @@ "apiMaxTokensTip": "El parámetro max_tokens que se pasa al solicitar la API se usa para controlar la longitud del texto generado", "apiBaseURL": "URL base de la API", "apiBaseURLTip": "La dirección base de la solicitud, como https://api.openai.com/v1", + "apiUserAgentTip": "El agente de usuario que inició la solicitud, es decir, el encabezado HTTP User-Agent", "skip": "barco", "nextRound": "Siguiente ronda", "save": "Ahorrar", diff --git a/app/appearance/langs/fr_FR.json b/app/appearance/langs/fr_FR.json index 5f52fc70c..8eee1e94a 100644 --- a/app/appearance/langs/fr_FR.json +++ b/app/appearance/langs/fr_FR.json @@ -273,6 +273,7 @@ "apiMaxTokensTip": "Le paramètre max_tokens transmis lors de la demande de l'API est utilisé pour contrôler la longueur du texte généré", "apiBaseURL": "URL de base de l'API", "apiBaseURLTip": "L'adresse de base de la requête, telle que https://api.openai.com/v1", + "apiUserAgentTip": "L'agent utilisateur qui a initié la requête, c'est-à-dire l'en-tête HTTP User-Agent", "skip": "Navire", "nextRound": "Prochain tour", "save": "Sauvegarder", diff --git a/app/appearance/langs/zh_CHT.json b/app/appearance/langs/zh_CHT.json index b8fd5de11..adb08de6b 100644 --- a/app/appearance/langs/zh_CHT.json +++ b/app/appearance/langs/zh_CHT.json @@ -273,6 +273,7 @@ "apiMaxTokensTip": "請求 API 時傳入的 max_tokens 參數,用於控制生成的文字長度", "apiBaseURL": "API 基礎地址", "apiBaseURLTip": "發起請求的基礎地址,如 https://api.openai.com/v1", + "apiUserAgentTip": "發起請求的使用者代理,即 HTTP 標頭 User-Agent", "skip": "跳過", "nextRound": "下一輪", "save": "保存", diff --git a/app/appearance/langs/zh_CN.json b/app/appearance/langs/zh_CN.json index 1dab2295e..486e2c737 100644 --- a/app/appearance/langs/zh_CN.json +++ b/app/appearance/langs/zh_CN.json @@ -273,6 +273,7 @@ "apiMaxTokensTip": "请求 API 时传入的 max_tokens 参数,用于控制生成的文本长度", "apiBaseURL": "API 基础地址", "apiBaseURLTip": "发起请求的基础地址,如 https://api.openai.com/v1", + "apiUserAgentTip": "发起请求的用户代理,即 HTTP 标头 User-Agent", "skip": "跳过", "nextRound": "下一轮", "save": "保存", diff --git a/app/src/config/ai.ts b/app/src/config/ai.ts index cd5cbc904..19d19d112 100644 --- a/app/src/config/ai.ts +++ b/app/src/config/ai.ts @@ -50,6 +50,12 @@ export const ai = {
${window.siyuan.languages.apiBaseURLTip}
+ +
+ User-Agent +
+ +
${window.siyuan.languages.apiUserAgentTip}
`; /// #else responsiveHTML = `
@@ -106,6 +112,14 @@ export const ai = {
+ +
+
+ User-Agent +
${window.siyuan.languages.apiUserAgentTip}
+ + +
`; /// #endif return `
@@ -124,6 +138,7 @@ export const ai = { item.addEventListener("change", () => { fetchPost("/api/setting/setAI", { openAI: { + apiUserAgent: (ai.element.querySelector("#apiUserAgent") as HTMLInputElement).value, apiBaseURL: (ai.element.querySelector("#apiBaseURL") as HTMLInputElement).value, apiKey: (ai.element.querySelector("#apiKey") as HTMLInputElement).value, apiModel: (ai.element.querySelector("#apiModel") as HTMLSelectElement).value, diff --git a/app/src/types/index.d.ts b/app/src/types/index.d.ts index 409069342..ebd85a02e 100644 --- a/app/src/types/index.d.ts +++ b/app/src/types/index.d.ts @@ -722,6 +722,7 @@ interface IConfig { } ai: { openAI: { + apiUserAgent: string apiBaseURL: string apiKey: string apiModel: string diff --git a/kernel/conf/ai.go b/kernel/conf/ai.go index 9d9ee5f85..4a2abbb67 100644 --- a/kernel/conf/ai.go +++ b/kernel/conf/ai.go @@ -17,6 +17,7 @@ package conf import ( + "github.com/siyuan-note/siyuan/kernel/util" "os" "strconv" @@ -34,13 +35,15 @@ type OpenAI struct { APIModel string `json:"apiModel"` APIMaxTokens int `json:"apiMaxTokens"` APIBaseURL string `json:"apiBaseURL"` + APIUserAgent string `json:"apiUserAgent"` } func NewAI() *AI { openAI := &OpenAI{ - APITimeout: 30, - APIModel: openai.GPT3Dot5Turbo, - APIBaseURL: "https://api.openai.com/v1", + APITimeout: 30, + APIModel: openai.GPT3Dot5Turbo, + APIBaseURL: "https://api.openai.com/v1", + APIUserAgent: util.UserAgent, } openAI.APIKey = os.Getenv("SIYUAN_OPENAI_API_KEY") @@ -67,5 +70,8 @@ func NewAI() *AI { openAI.APIBaseURL = baseURL } + if userAgent := os.Getenv("SIYUAN_OPENAI_API_USER_AGENT"); "" != userAgent { + openAI.APIUserAgent = userAgent + } return &AI{OpenAI: openAI} } diff --git a/kernel/model/ai.go b/kernel/model/ai.go index 97308bf9d..5c5035839 100644 --- a/kernel/model/ai.go +++ b/kernel/model/ai.go @@ -92,7 +92,7 @@ func chatGPTContinueWrite(msg string, contextMsgs []string, cloud bool) (ret str if cloud { gpt = &CloudGPT{} } else { - gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL)} + gpt = &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APIUserAgent)} } buf := &bytes.Buffer{} diff --git a/kernel/model/conf.go b/kernel/model/conf.go index aed770dfd..024f7446a 100644 --- a/kernel/model/conf.go +++ b/kernel/model/conf.go @@ -405,15 +405,24 @@ func InitConf() { if "" == Conf.AI.OpenAI.APIModel { Conf.AI.OpenAI.APIModel = openai.GPT3Dot5Turbo } + if "" == Conf.AI.OpenAI.APIUserAgent { + Conf.AI.OpenAI.APIUserAgent = util.UserAgent + } if "" != Conf.AI.OpenAI.APIKey { logging.LogInfof("OpenAI API enabled\n"+ + " userAgent=%s\n"+ " baseURL=%s\n"+ " timeout=%ds\n"+ " proxy=%s\n"+ " model=%s\n"+ " maxTokens=%d", - Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APITimeout, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens) + Conf.AI.OpenAI.APIUserAgent, + Conf.AI.OpenAI.APIBaseURL, + Conf.AI.OpenAI.APITimeout, + Conf.AI.OpenAI.APIProxy, + Conf.AI.OpenAI.APIModel, + Conf.AI.OpenAI.APIMaxTokens) } Conf.ReadOnly = util.ReadOnly diff --git a/kernel/util/openai.go b/kernel/util/openai.go index 388404c89..e7c90f629 100644 --- a/kernel/util/openai.go +++ b/kernel/util/openai.go @@ -75,17 +75,32 @@ func ChatGPT(msg string, contextMsgs []string, c *openai.Client, model string, m return } -func NewOpenAIClient(apiKey, apiProxy, apiBaseURL string) *openai.Client { +func NewOpenAIClient(apiKey, apiProxy, apiBaseURL, apiUserAgent string) *openai.Client { config := openai.DefaultConfig(apiKey) + transport := &http.Transport{} if "" != apiProxy { proxyUrl, err := url.Parse(apiProxy) if nil != err { logging.LogErrorf("OpenAI API proxy failed: %v", err) } else { - config.HTTPClient = &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} + transport.Proxy = http.ProxyURL(proxyUrl) } } - + config.HTTPClient = &http.Client{Transport: newAddHeaderTransport(transport, apiUserAgent)} config.BaseURL = apiBaseURL return openai.NewClientWithConfig(config) } + +type AddHeaderTransport struct { + RoundTripper http.RoundTripper + UserAgent string +} + +func (adt *AddHeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Add("User-Agent", adt.UserAgent) + return adt.RoundTripper.RoundTrip(req) +} + +func newAddHeaderTransport(transport *http.Transport, userAgent string) *AddHeaderTransport { + return &AddHeaderTransport{RoundTripper: transport, UserAgent: userAgent} +}