diff --git a/package-lock.json b/package-lock.json index 142a22f..22d745e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "droid2api", - "version": "1.0.0", + "version": "1.3.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "droid2api", - "version": "1.0.0", + "version": "1.3.1", "license": "MIT", "dependencies": { "express": "^4.18.2", diff --git a/routes.js b/routes.js index bb88f7b..921f092 100644 --- a/routes.js +++ b/routes.js @@ -501,9 +501,93 @@ async function handleDirectMessages(req, res) { } catch (error) { logError('Error in /v1/messages', error); - res.status(500).json({ + res.status(500).json({ error: 'Internal server error', - message: error.message + message: error.message + }); + } +} + +// 处理 Anthropic count_tokens 请求 +async function handleCountTokens(req, res) { + logInfo('POST /v1/messages/count_tokens'); + + try { + const anthropicRequest = req.body; + const modelId = anthropicRequest.model; + + if (!modelId) { + return res.status(400).json({ error: 'model is required' }); + } + + const model = getModelById(modelId); + if (!model) { + return res.status(404).json({ error: `Model ${modelId} not found` }); + } + + // 只允许 anthropic 类型端点 + if (model.type !== 'anthropic') { + return res.status(400).json({ + error: 'Invalid endpoint type', + message: `/v1/messages/count_tokens 接口只支持 anthropic 类型端点,当前模型 ${modelId} 是 ${model.type} 类型` + }); + } + + const endpoint = getEndpointByType('anthropic'); + if (!endpoint) { + return res.status(500).json({ error: 'Endpoint type anthropic not found' }); + } + + // Get API key + let authHeader; + try { + const clientAuthFromXApiKey = req.headers['x-api-key'] + ? `Bearer ${req.headers['x-api-key']}` + : null; + authHeader = await getApiKey(req.headers.authorization || clientAuthFromXApiKey); + } catch (error) { + logError('Failed to get API key', error); + return res.status(500).json({ + error: 'API key not available', + message: 'Failed to get or refresh API key. Please check server logs.' + }); + } + + const clientHeaders = req.headers; + const headers = getAnthropicHeaders(authHeader, clientHeaders, false, modelId); + + // 构建 count_tokens 端点 URL + const countTokensUrl = endpoint.base_url.replace('/v1/messages', '/v1/messages/count_tokens'); + + logInfo(`Forwarding to count_tokens endpoint: ${countTokensUrl}`); + logRequest('POST', countTokensUrl, headers, anthropicRequest); + + const response = await fetch(countTokensUrl, { + method: 'POST', + headers, + body: JSON.stringify(anthropicRequest) + }); + + logInfo(`Response status: ${response.status}`); + + if (!response.ok) { + const errorText = await response.text(); + logError(`Count tokens error: ${response.status}`, new Error(errorText)); + return res.status(response.status).json({ + error: `Endpoint returned ${response.status}`, + details: errorText + }); + } + + const data = await response.json(); + logResponse(200, null, data); + res.json(data); + + } catch (error) { + logError('Error in /v1/messages/count_tokens', error); + res.status(500).json({ + error: 'Internal server error', + message: error.message }); } } @@ -512,5 +596,6 @@ async function handleDirectMessages(req, res) { router.post('/v1/chat/completions', handleChatCompletions); router.post('/v1/responses', handleDirectResponses); router.post('/v1/messages', handleDirectMessages); +router.post('/v1/messages/count_tokens', handleCountTokens); export default router; diff --git a/server.js b/server.js index 2152110..a0e52a8 100644 --- a/server.js +++ b/server.js @@ -31,7 +31,8 @@ app.get('/', (req, res) => { 'GET /v1/models', 'POST /v1/chat/completions', 'POST /v1/responses', - 'POST /v1/messages' + 'POST /v1/messages', + 'POST /v1/messages/count_tokens' ] }); }); @@ -90,7 +91,8 @@ app.use((req, res, next) => { 'GET /v1/models', 'POST /v1/chat/completions', 'POST /v1/responses', - 'POST /v1/messages' + 'POST /v1/messages', + 'POST /v1/messages/count_tokens' ] }); }); @@ -125,6 +127,7 @@ app.use((err, req, res, next) => { logInfo(' POST /v1/chat/completions'); logInfo(' POST /v1/responses'); logInfo(' POST /v1/messages'); + logInfo(' POST /v1/messages/count_tokens'); }) .on('error', (err) => { if (err.code === 'EADDRINUSE') {