Merge remote-tracking branch 'upstream/master'

This commit is contained in:
empty
2026-01-24 01:11:32 +08:00

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
) )
@@ -23,12 +24,13 @@ type ChatMessage struct {
} }
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []ChatMessage `json:"messages"` Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"` MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
Stream bool `json:"stream,omitempty"` TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
} }
type ChatCompletionResponse struct { type ChatCompletionResponse struct {
@@ -89,6 +91,24 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
} }
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
resp, err := c.doChatRequest(req)
if err == nil {
return resp, nil
}
if shouldRetryWithMaxCompletionTokens(err, req) {
tokens := *req.MaxTokens
retryReq := *req
retryReq.MaxTokens = nil
retryReq.MaxCompletionTokens = &tokens
fmt.Printf("OpenAI: retrying with max_completion_tokens=%d\n", tokens)
return c.doChatRequest(&retryReq)
}
return nil, err
}
func (c *OpenAIClient) doChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err) fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
@@ -192,7 +212,7 @@ func WithTemperature(temp float64) func(*ChatCompletionRequest) {
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) { func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) { return func(req *ChatCompletionRequest) {
req.MaxTokens = tokens req.MaxTokens = &tokens
} }
} }
@@ -247,3 +267,21 @@ func (c *OpenAIClient) TestConnection() error {
} }
return err return err
} }
func shouldRetryWithMaxCompletionTokens(err error, req *ChatCompletionRequest) bool {
if err == nil || req == nil || req.MaxTokens == nil || req.MaxCompletionTokens != nil {
return false
}
msg := err.Error()
if strings.Contains(msg, "Unsupported parameter: 'max_tokens'") {
return true
}
if strings.Contains(msg, "max_tokens is not supported") {
return true
}
if strings.Contains(msg, "max_completion_tokens") {
return true
}
return false
}