Compare commits
3 Commits
56356867ad
...
486032d2d3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
486032d2d3 | ||
|
|
0de42adaf8 | ||
|
|
c8b92cefb9 |
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -23,12 +24,13 @@ type ChatMessage struct {
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
@@ -89,6 +91,24 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||
resp, err := c.doChatRequest(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if shouldRetryWithMaxCompletionTokens(err, req) {
|
||||
tokens := *req.MaxTokens
|
||||
retryReq := *req
|
||||
retryReq.MaxTokens = nil
|
||||
retryReq.MaxCompletionTokens = &tokens
|
||||
fmt.Printf("OpenAI: retrying with max_completion_tokens=%d\n", tokens)
|
||||
return c.doChatRequest(&retryReq)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) doChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
|
||||
@@ -192,7 +212,7 @@ func WithTemperature(temp float64) func(*ChatCompletionRequest) {
|
||||
|
||||
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.MaxTokens = tokens
|
||||
req.MaxTokens = &tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,3 +267,21 @@ func (c *OpenAIClient) TestConnection() error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func shouldRetryWithMaxCompletionTokens(err error, req *ChatCompletionRequest) bool {
|
||||
if err == nil || req == nil || req.MaxTokens == nil || req.MaxCompletionTokens != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "Unsupported parameter: 'max_tokens'") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "max_tokens is not supported") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(msg, "max_completion_tokens") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user