|
- package llmruntime
-
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "sort"
- "strings"
- "time"
- )
-
- type Request struct {
- Provider string
- Model string
- BaseURL string
- APIKey string
- Temperature *float64
- MaxTokens *int
- SystemPrompt string
- UserPrompt string
- }
-
- type Client interface {
- Generate(ctx context.Context, req Request) (string, error)
- }
-
- type Factory struct {
- httpClient *http.Client
- }
-
- func NewFactory(timeout time.Duration) *Factory {
- if timeout <= 0 {
- timeout = 45 * time.Second
- }
- return &Factory{
- httpClient: &http.Client{Timeout: timeout},
- }
- }
-
- func (f *Factory) ClientFor(provider string) (Client, error) {
- normalized := strings.ToLower(strings.TrimSpace(provider))
- switch normalized {
- case "openai", "xai", "ollama":
- return &openAICompatibleClient{httpClient: f.httpClient}, nil
- case "anthropic":
- return &anthropicClient{httpClient: f.httpClient}, nil
- case "google":
- return &googleClient{httpClient: f.httpClient}, nil
- default:
- return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
- }
- }
-
- type openAICompatibleClient struct {
- httpClient *http.Client
- }
-
- func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- switch strings.ToLower(strings.TrimSpace(req.Provider)) {
- case "xai":
- baseURL = "https://api.x.ai"
- case "ollama":
- baseURL = "http://localhost:11434"
- default:
- baseURL = "https://api.openai.com"
- }
- }
-
- payload := map[string]any{
- "model": strings.TrimSpace(req.Model),
- "temperature": optionalFloat64(req.Temperature, 0),
- "messages": []map[string]string{
- {"role": "system", "content": strings.TrimSpace(req.SystemPrompt)},
- {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
- },
- }
- payload[openAICompatibleMaxTokensField(req.Provider, req.Model)] = optionalInt(req.MaxTokens, 1200)
-
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
- if err != nil {
- return "", err
- }
-
- var response map[string]any
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode openai-compatible response: %w", err)
- }
- content := extractOpenAICompatibleContent(response)
- if content == "" {
- return "", fmt.Errorf("empty openai-compatible response content (%s)", describeOpenAICompatibleShape(response))
- }
- return content, nil
- }
-
- type anthropicClient struct {
- httpClient *http.Client
- }
-
- func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- baseURL = "https://api.anthropic.com"
- }
- payload := map[string]any{
- "model": strings.TrimSpace(req.Model),
- "max_tokens": optionalInt(req.MaxTokens, 1200),
- "temperature": optionalFloat64(req.Temperature, 0),
- "system": strings.TrimSpace(req.SystemPrompt),
- "messages": []map[string]any{
- {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
- },
- }
- headers := map[string]string{"anthropic-version": "2023-06-01"}
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
- if err != nil {
- return "", err
- }
-
- var response struct {
- Content []struct {
- Type string `json:"type"`
- Text string `json:"text"`
- } `json:"content"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode anthropic response: %w", err)
- }
- for _, item := range response.Content {
- if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
- return strings.TrimSpace(item.Text), nil
- }
- }
- return "", fmt.Errorf("empty anthropic response")
- }
-
- type googleClient struct {
- httpClient *http.Client
- }
-
- func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
- baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
- if baseURL == "" {
- baseURL = "https://generativelanguage.googleapis.com"
- }
- model := strings.TrimSpace(req.Model)
- if model == "" {
- return "", fmt.Errorf("google model is required")
- }
- apiKey := strings.TrimSpace(req.APIKey)
- if apiKey == "" {
- return "", fmt.Errorf("google api key is required")
- }
-
- endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
- payload := map[string]any{
- "contents": []map[string]any{
- {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
- },
- "generationConfig": map[string]any{
- "temperature": optionalFloat64(req.Temperature, 0),
- "maxOutputTokens": optionalInt(req.MaxTokens, 1200),
- },
- }
- if strings.TrimSpace(req.SystemPrompt) != "" {
- payload["systemInstruction"] = map[string]any{
- "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
- }
- }
-
- body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
- if err != nil {
- return "", err
- }
-
- var response struct {
- Candidates []struct {
- Content struct {
- Parts []struct {
- Text string `json:"text"`
- } `json:"parts"`
- } `json:"content"`
- } `json:"candidates"`
- }
- if err := json.Unmarshal(body, &response); err != nil {
- return "", fmt.Errorf("decode google response: %w", err)
- }
- if len(response.Candidates) == 0 {
- return "", fmt.Errorf("empty google response")
- }
- parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
- for _, part := range response.Candidates[0].Content.Parts {
- if text := strings.TrimSpace(part.Text); text != "" {
- parts = append(parts, text)
- }
- }
- if len(parts) == 0 {
- return "", fmt.Errorf("google response has no text parts")
- }
- return strings.Join(parts, "\n"), nil
- }
-
- func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
- body, err := json.Marshal(payload)
- if err != nil {
- return nil, fmt.Errorf("marshal request: %w", err)
- }
-
- req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
- if err != nil {
- return nil, fmt.Errorf("build request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- if strings.TrimSpace(apiKey) != "" {
- req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
- req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
- }
- for key, value := range headers {
- if strings.TrimSpace(key) == "" {
- continue
- }
- req.Header.Set(key, value)
- }
-
- resp, err := httpClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("do request: %w", err)
- }
- defer resp.Body.Close()
-
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, fmt.Errorf("read response: %w", err)
- }
- if resp.StatusCode >= 400 {
- message := trimProviderErrorMessage(respBody)
- return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
- }
- return respBody, nil
- }
-
- func optionalFloat64(value *float64, fallback float64) float64 {
- if value == nil {
- return fallback
- }
- return *value
- }
-
- func optionalInt(value *int, fallback int) int {
- if value == nil {
- return fallback
- }
- return *value
- }
-
- func openAICompatibleMaxTokensField(provider, model string) string {
- if isOpenAIGPT5Model(provider, model) {
- return "max_completion_tokens"
- }
- return "max_tokens"
- }
-
- func isOpenAIGPT5Model(provider, model string) bool {
- if !strings.EqualFold(strings.TrimSpace(provider), "openai") {
- return false
- }
- normalizedModel := strings.ToLower(strings.TrimSpace(model))
- return strings.HasPrefix(normalizedModel, "gpt-5")
- }
-
- func trimProviderErrorMessage(respBody []byte) string {
- message := extractProviderErrorMessage(respBody)
- if len(message) > 500 {
- return message[:500]
- }
- return message
- }
-
- func extractProviderErrorMessage(respBody []byte) string {
- raw := strings.TrimSpace(string(respBody))
- if raw == "" {
- return "empty error response"
- }
- var parsed map[string]any
- if err := json.Unmarshal(respBody, &parsed); err == nil {
- if value := nestedString(parsed, "error", "message"); value != "" {
- return value
- }
- if value := nestedString(parsed, "error"); value != "" {
- return value
- }
- if value := nestedString(parsed, "message"); value != "" {
- return value
- }
- }
- return raw
- }
-
- func nestedString(values map[string]any, path ...string) string {
- if len(path) == 0 || values == nil {
- return ""
- }
- current := any(values)
- for _, key := range path {
- nextMap, ok := current.(map[string]any)
- if !ok {
- return ""
- }
- current = nextMap[key]
- }
- switch value := current.(type) {
- case string:
- return strings.TrimSpace(value)
- case fmt.Stringer:
- return strings.TrimSpace(value.String())
- default:
- return ""
- }
- }
-
- func extractOpenAICompatibleContent(response map[string]any) string {
- if response == nil {
- return ""
- }
- if text := strings.TrimSpace(extractOpenAICompatibleChoicesContent(response["choices"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(response["output_text"])); text != "" {
- return text
- }
- return strings.TrimSpace(extractOpenAICompatibleOutputContent(response["output"]))
- }
-
- func extractOpenAICompatibleChoicesContent(raw any) string {
- choices, ok := raw.([]any)
- if !ok {
- return ""
- }
- for _, rawChoice := range choices {
- choice, ok := rawChoice.(map[string]any)
- if !ok {
- continue
- }
- if text := strings.TrimSpace(extractTextFromContentValue(choice["message"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(choice["delta"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(choice["text"])); text != "" {
- return text
- }
- }
- return ""
- }
-
- func extractOpenAICompatibleOutputContent(raw any) string {
- output, ok := raw.([]any)
- if !ok {
- return ""
- }
- for _, rawItem := range output {
- item, ok := rawItem.(map[string]any)
- if !ok {
- continue
- }
- if text := strings.TrimSpace(extractTextFromContentValue(item["content"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(item["text"])); text != "" {
- return text
- }
- }
- return ""
- }
-
- func extractTextFromContentValue(raw any) string {
- switch value := raw.(type) {
- case string:
- return strings.TrimSpace(value)
- case []any:
- parts := make([]string, 0, len(value))
- for _, item := range value {
- if text := strings.TrimSpace(extractTextFromContentValue(item)); text != "" {
- parts = append(parts, text)
- }
- }
- return strings.TrimSpace(strings.Join(parts, "\n"))
- case map[string]any:
- if text := strings.TrimSpace(extractTextFromContentValue(value["content"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(value["text"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(value["value"])); text != "" {
- return text
- }
- if text := strings.TrimSpace(extractTextFromContentValue(value["output_text"])); text != "" {
- return text
- }
- return ""
- default:
- return ""
- }
- }
-
- func describeOpenAICompatibleShape(response map[string]any) string {
- parts := make([]string, 0, 8)
- parts = append(parts, "top="+describeMapKeys(response))
-
- if choices, ok := response["choices"].([]any); ok {
- parts = append(parts, fmt.Sprintf("choices_len=%d", len(choices)))
- if len(choices) > 0 {
- if choice, ok := choices[0].(map[string]any); ok {
- parts = append(parts, "choices0="+describeMapKeys(choice))
- if message, ok := choice["message"].(map[string]any); ok {
- parts = append(parts, "message="+describeMapKeys(message))
- parts = append(parts, "message_content_type="+valueType(message["content"]))
- }
- }
- }
- } else if _, exists := response["choices"]; exists {
- parts = append(parts, "choices_type="+valueType(response["choices"]))
- }
-
- if _, exists := response["output_text"]; exists {
- parts = append(parts, "output_text_type="+valueType(response["output_text"]))
- }
- if output, ok := response["output"].([]any); ok {
- parts = append(parts, fmt.Sprintf("output_len=%d", len(output)))
- if len(output) > 0 {
- if first, ok := output[0].(map[string]any); ok {
- parts = append(parts, "output0="+describeMapKeys(first))
- parts = append(parts, "output0_content_type="+valueType(first["content"]))
- }
- }
- } else if _, exists := response["output"]; exists {
- parts = append(parts, "output_type="+valueType(response["output"]))
- }
-
- return strings.Join(parts, "; ")
- }
-
- func describeMapKeys(raw map[string]any) string {
- if len(raw) == 0 {
- return "{}"
- }
- keys := make([]string, 0, len(raw))
- for key := range raw {
- keys = append(keys, key)
- }
- sort.Strings(keys)
- described := make([]string, 0, len(keys))
- for _, key := range keys {
- described = append(described, fmt.Sprintf("%s:%s", key, valueType(raw[key])))
- }
- return "{" + strings.Join(described, ",") + "}"
- }
-
- func valueType(raw any) string {
- switch raw.(type) {
- case nil:
- return "null"
- case string:
- return "string"
- case bool:
- return "bool"
- case float64:
- return "number"
- case []any:
- return "array"
- case map[string]any:
- return "object"
- default:
- return fmt.Sprintf("%T", raw)
- }
- }
|