Compare commits
3 Commits
56356867ad
...
486032d2d3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
486032d2d3 | ||
|
|
0de42adaf8 | ||
|
|
c8b92cefb9 |
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,12 +24,13 @@ type ChatMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
type ChatCompletionRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Messages []ChatMessage `json:"messages"`
|
Messages []ChatMessage `json:"messages"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionResponse struct {
|
type ChatCompletionResponse struct {
|
||||||
@@ -89,6 +91,24 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||||
|
resp, err := c.doChatRequest(req)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldRetryWithMaxCompletionTokens(err, req) {
|
||||||
|
tokens := *req.MaxTokens
|
||||||
|
retryReq := *req
|
||||||
|
retryReq.MaxTokens = nil
|
||||||
|
retryReq.MaxCompletionTokens = &tokens
|
||||||
|
fmt.Printf("OpenAI: retrying with max_completion_tokens=%d\n", tokens)
|
||||||
|
return c.doChatRequest(&retryReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) doChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
|
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
|
||||||
@@ -192,7 +212,7 @@ func WithTemperature(temp float64) func(*ChatCompletionRequest) {
|
|||||||
|
|
||||||
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
|
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
|
||||||
return func(req *ChatCompletionRequest) {
|
return func(req *ChatCompletionRequest) {
|
||||||
req.MaxTokens = tokens
|
req.MaxTokens = &tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,3 +267,21 @@ func (c *OpenAIClient) TestConnection() error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldRetryWithMaxCompletionTokens(err error, req *ChatCompletionRequest) bool {
|
||||||
|
if err == nil || req == nil || req.MaxTokens == nil || req.MaxCompletionTokens != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := err.Error()
|
||||||
|
if strings.Contains(msg, "Unsupported parameter: 'max_tokens'") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.Contains(msg, "max_tokens is not supported") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.Contains(msg, "max_completion_tokens") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user