|
- package llmruntime
-
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "strings"
- "time"
- )
-
- type Request struct {
- Provider string
- Model string
- BaseURL string
- APIKey string
- Temperature *float64
- MaxTokens *int
- SystemPrompt string
- UserPrompt string
- }
-
- type Client interface {
- Generate(ctx context.Context, req Request) (string, error)
- }
-
- type Factory struct {
- httpClient *http.Client
- }
-
- func NewFactory(timeout time.Duration) *Factory {
- if timeout <= 0 {
- timeout = 45 * time.Second
- }
- return &Factory{
- httpClient: &http.Client{Timeout: timeout},
- }
- }
-
- func (f *Factory) ClientFor(provider string) (Client, error) {
- normalized := strings.ToLower(strings.TrimSpace(provider))
- switch normalized {
- case "openai", "xai", "ollama":
- return &openAICompatibleClient{httpClient: f.httpClient}, nil
- case "anthropic":
- return &anthropicClient{httpClient: f.httpClient}, nil
- case "google":
- return &googleClient{httpClient: f.httpClient}, nil
- default:
- return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
- }
- }
-
- type openAICompatibleClient struct {
- httpClient *http.Client
- }
-
- func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- switch strings.ToLower(strings.TrimSpace(req.Provider)) {
- case "xai":
- baseURL = "https://api.x.ai"
- case "ollama":
- baseURL = "http://localhost:11434"
- default:
- baseURL = "https://api.openai.com"
- }
- }
-
- payload := map[string]any{
- "model": strings.TrimSpace(req.Model),
- "temperature": optionalFloat64(req.Temperature, 0),
- "messages": []map[string]string{
- {"role": "system", "content": strings.TrimSpace(req.SystemPrompt)},
- {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
- },
- }
- payload[openAICompatibleMaxTokensField(req.Provider, req.Model)] = optionalInt(req.MaxTokens, 1200)
-
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
- if err != nil {
- return "", err
- }
-
- var response struct {
- Choices []struct {
- Message struct {
- Content string `json:"content"`
- } `json:"message"`
- } `json:"choices"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode openai-compatible response: %w", err)
- }
- if len(response.Choices) == 0 {
- return "", fmt.Errorf("empty openai-compatible response")
- }
- return strings.TrimSpace(response.Choices[0].Message.Content), nil
- }
-
- type anthropicClient struct {
- httpClient *http.Client
- }
-
- func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- baseURL = "https://api.anthropic.com"
- }
- payload := map[string]any{
- "model": strings.TrimSpace(req.Model),
- "max_tokens": optionalInt(req.MaxTokens, 1200),
- "temperature": optionalFloat64(req.Temperature, 0),
- "system": strings.TrimSpace(req.SystemPrompt),
- "messages": []map[string]any{
- {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
- },
- }
- headers := map[string]string{"anthropic-version": "2023-06-01"}
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
- if err != nil {
- return "", err
- }
-
- var response struct {
- Content []struct {
- Type string `json:"type"`
- Text string `json:"text"`
- } `json:"content"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode anthropic response: %w", err)
- }
- for _, item := range response.Content {
- if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
- return strings.TrimSpace(item.Text), nil
- }
- }
- return "", fmt.Errorf("empty anthropic response")
- }
-
- type googleClient struct {
- httpClient *http.Client
- }
-
- func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- baseURL = "https://generativelanguage.googleapis.com"
- }
- model := strings.TrimSpace(req.Model)
- if model == "" {
- return "", fmt.Errorf("google model is required")
- }
- apiKey := strings.TrimSpace(req.APIKey)
- if apiKey == "" {
- return "", fmt.Errorf("google api key is required")
- }
-
- endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
- payload := map[string]any{
- "contents": []map[string]any{
- {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
- },
- "generationConfig": map[string]any{
- "temperature": optionalFloat64(req.Temperature, 0),
- "maxOutputTokens": optionalInt(req.MaxTokens, 1200),
- },
- }
- if strings.TrimSpace(req.SystemPrompt) != "" {
- payload["systemInstruction"] = map[string]any{
- "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
- }
- }
-
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
- if err != nil {
- return "", err
- }
-
- var response struct {
- Candidates []struct {
- Content struct {
- Parts []struct {
- Text string `json:"text"`
- } `json:"parts"`
- } `json:"content"`
- } `json:"candidates"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode google response: %w", err)
- }
- if len(response.Candidates) == 0 {
- return "", fmt.Errorf("empty google response")
- }
- parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
- for _, part := range response.Candidates[0].Content.Parts {
- if text := strings.TrimSpace(part.Text); text != "" {
- parts = append(parts, text)
- }
- }
- if len(parts) == 0 {
- return "", fmt.Errorf("google response has no text parts")
- }
- return strings.Join(parts, "\n"), nil
- }
-
- func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
- body, err := json.Marshal(payload)
- if err != nil {
- return nil, fmt.Errorf("marshal request: %w", err)
- }
-
- req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("build request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- if strings.TrimSpace(apiKey) != "" {
- req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
- req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
- }
- for key, value := range headers {
- if strings.TrimSpace(key) == "" {
- continue
- }
- req.Header.Set(key, value)
- }
-
- resp, err := httpClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("do request: %w", err)
- }
- defer resp.Body.Close()
-
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("read response: %w", err)
- }
- if resp.StatusCode >= 400 {
- message := trimProviderErrorMessage(respBody)
- return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
- }
- return respBody, nil
- }
-
- func optionalFloat64(value *float64, fallback float64) float64 {
- if value == nil {
- return fallback
- }
- return *value
- }
-
- func optionalInt(value *int, fallback int) int {
- if value == nil {
- return fallback
- }
- return *value
- }
-
- func openAICompatibleMaxTokensField(provider, model string) string {
- if isOpenAIGPT5Model(provider, model) {
- return "max_completion_tokens"
- }
- return "max_tokens"
- }
-
- func isOpenAIGPT5Model(provider, model string) bool {
- if !strings.EqualFold(strings.TrimSpace(provider), "openai") {
- return false
- }
- normalizedModel := strings.ToLower(strings.TrimSpace(model))
- return strings.HasPrefix(normalizedModel, "gpt-5")
- }
-
- func trimProviderErrorMessage(respBody []byte) string {
- message := extractProviderErrorMessage(respBody)
- if len(message) > 500 {
- return message[:500]
- }
- return message
- }
-
- func extractProviderErrorMessage(respBody []byte) string {
- raw := strings.TrimSpace(string(respBody))
- if raw == "" {
- return "empty error response"
- }
- var parsed map[string]any
- if err := json.Unmarshal(respBody, &parsed); err == nil {
- if value := nestedString(parsed, "error", "message"); value != "" {
- return value
- }
- if value := nestedString(parsed, "error"); value != "" {
- return value
- }
- if value := nestedString(parsed, "message"); value != "" {
- return value
- }
- }
- return raw
- }
-
- func nestedString(values map[string]any, path ...string) string {
- if len(path) == 0 || values == nil {
- return ""
- }
- current := any(values)
- for _, key := range path {
- nextMap, ok := current.(map[string]any)
- if !ok {
- return ""
- }
- current = nextMap[key]
- }
- switch value := current.(type) {
- case string:
- return strings.TrimSpace(value)
- case fmt.Stringer:
- return strings.TrimSpace(value.String())
- default:
- return ""
- }
- }
|