Merge remote-tracking branch 'upstream/master'

This commit is contained in:
empty
2026-01-24 01:11:32 +08:00

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
)
@@ -26,7 +27,8 @@ type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
}
@@ -89,6 +91,24 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
}
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
resp, err := c.doChatRequest(req)
if err == nil {
return resp, nil
}
if shouldRetryWithMaxCompletionTokens(err, req) {
tokens := *req.MaxTokens
retryReq := *req
retryReq.MaxTokens = nil
retryReq.MaxCompletionTokens = &tokens
fmt.Printf("OpenAI: retrying with max_completion_tokens=%d\n", tokens)
return c.doChatRequest(&retryReq)
}
return nil, err
}
func (c *OpenAIClient) doChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
@@ -192,7 +212,7 @@ func WithTemperature(temp float64) func(*ChatCompletionRequest) {
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.MaxTokens = tokens
req.MaxTokens = &tokens
}
}
@@ -247,3 +267,21 @@ func (c *OpenAIClient) TestConnection() error {
}
return err
}
func shouldRetryWithMaxCompletionTokens(err error, req *ChatCompletionRequest) bool {
if err == nil || req == nil || req.MaxTokens == nil || req.MaxCompletionTokens != nil {
return false
}
msg := err.Error()
if strings.Contains(msg, "Unsupported parameter: 'max_tokens'") {
return true
}
if strings.Contains(msg, "max_tokens is not supported") {
return true
}
if strings.Contains(msg, "max_completion_tokens") {
return true
}
return false
}