You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
8.7KB

  1. package llmruntime
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. )
  13. type Request struct {
  14. Provider string
  15. Model string
  16. BaseURL string
  17. APIKey string
  18. Temperature *float64
  19. MaxTokens *int
  20. SystemPrompt string
  21. UserPrompt string
  22. }
  23. type Client interface {
  24. Generate(ctx context.Context, req Request) (string, error)
  25. }
  26. type Factory struct {
  27. httpClient *http.Client
  28. }
  29. func NewFactory(timeout time.Duration) *Factory {
  30. if timeout <= 0 {
  31. timeout = 45 * time.Second
  32. }
  33. return &Factory{
  34. httpClient: &http.Client{Timeout: timeout},
  35. }
  36. }
  37. func (f *Factory) ClientFor(provider string) (Client, error) {
  38. normalized := strings.ToLower(strings.TrimSpace(provider))
  39. switch normalized {
  40. case "openai", "xai", "ollama":
  41. return &openAICompatibleClient{httpClient: f.httpClient}, nil
  42. case "anthropic":
  43. return &anthropicClient{httpClient: f.httpClient}, nil
  44. case "google":
  45. return &googleClient{httpClient: f.httpClient}, nil
  46. default:
  47. return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
  48. }
  49. }
  50. type openAICompatibleClient struct {
  51. httpClient *http.Client
  52. }
  53. func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
  54. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  55. if baseURL == "" {
  56. switch strings.ToLower(strings.TrimSpace(req.Provider)) {
  57. case "xai":
  58. baseURL = "https://api.x.ai"
  59. case "ollama":
  60. baseURL = "http://localhost:11434"
  61. default:
  62. baseURL = "https://api.openai.com"
  63. }
  64. }
  65. payload := map[string]any{
  66. "model": strings.TrimSpace(req.Model),
  67. "temperature": optionalFloat64(req.Temperature, 0),
  68. "messages": []map[string]string{
  69. {"role": "system", "content": strings.TrimSpace(req.SystemPrompt)},
  70. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  71. },
  72. }
  73. payload[openAICompatibleMaxTokensField(req.Provider, req.Model)] = optionalInt(req.MaxTokens, 1200)
  74. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
  75. if err != nil {
  76. return "", err
  77. }
  78. var response struct {
  79. Choices []struct {
  80. Message struct {
  81. Content string `json:"content"`
  82. } `json:"message"`
  83. } `json:"choices"`
  84. }
  85. if err := json.Unmarshal(body, &response); err != nil {
  86. return "", fmt.Errorf("decode openai-compatible response: %w", err)
  87. }
  88. if len(response.Choices) == 0 {
  89. return "", fmt.Errorf("empty openai-compatible response")
  90. }
  91. return strings.TrimSpace(response.Choices[0].Message.Content), nil
  92. }
  93. type anthropicClient struct {
  94. httpClient *http.Client
  95. }
  96. func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
  97. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  98. if baseURL == "" {
  99. baseURL = "https://api.anthropic.com"
  100. }
  101. payload := map[string]any{
  102. "model": strings.TrimSpace(req.Model),
  103. "max_tokens": optionalInt(req.MaxTokens, 1200),
  104. "temperature": optionalFloat64(req.Temperature, 0),
  105. "system": strings.TrimSpace(req.SystemPrompt),
  106. "messages": []map[string]any{
  107. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  108. },
  109. }
  110. headers := map[string]string{"anthropic-version": "2023-06-01"}
  111. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
  112. if err != nil {
  113. return "", err
  114. }
  115. var response struct {
  116. Content []struct {
  117. Type string `json:"type"`
  118. Text string `json:"text"`
  119. } `json:"content"`
  120. }
  121. if err := json.Unmarshal(body, &response); err != nil {
  122. return "", fmt.Errorf("decode anthropic response: %w", err)
  123. }
  124. for _, item := range response.Content {
  125. if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
  126. return strings.TrimSpace(item.Text), nil
  127. }
  128. }
  129. return "", fmt.Errorf("empty anthropic response")
  130. }
  131. type googleClient struct {
  132. httpClient *http.Client
  133. }
  134. func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
  135. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  136. if baseURL == "" {
  137. baseURL = "https://generativelanguage.googleapis.com"
  138. }
  139. model := strings.TrimSpace(req.Model)
  140. if model == "" {
  141. return "", fmt.Errorf("google model is required")
  142. }
  143. apiKey := strings.TrimSpace(req.APIKey)
  144. if apiKey == "" {
  145. return "", fmt.Errorf("google api key is required")
  146. }
  147. endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
  148. payload := map[string]any{
  149. "contents": []map[string]any{
  150. {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
  151. },
  152. "generationConfig": map[string]any{
  153. "temperature": optionalFloat64(req.Temperature, 0),
  154. "maxOutputTokens": optionalInt(req.MaxTokens, 1200),
  155. },
  156. }
  157. if strings.TrimSpace(req.SystemPrompt) != "" {
  158. payload["systemInstruction"] = map[string]any{
  159. "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
  160. }
  161. }
  162. body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
  163. if err != nil {
  164. return "", err
  165. }
  166. var response struct {
  167. Candidates []struct {
  168. Content struct {
  169. Parts []struct {
  170. Text string `json:"text"`
  171. } `json:"parts"`
  172. } `json:"content"`
  173. } `json:"candidates"`
  174. }
  175. if err := json.Unmarshal(body, &response); err != nil {
  176. return "", fmt.Errorf("decode google response: %w", err)
  177. }
  178. if len(response.Candidates) == 0 {
  179. return "", fmt.Errorf("empty google response")
  180. }
  181. parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
  182. for _, part := range response.Candidates[0].Content.Parts {
  183. if text := strings.TrimSpace(part.Text); text != "" {
  184. parts = append(parts, text)
  185. }
  186. }
  187. if len(parts) == 0 {
  188. return "", fmt.Errorf("google response has no text parts")
  189. }
  190. return strings.Join(parts, "\n"), nil
  191. }
  192. func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
  193. body, err := json.Marshal(payload)
  194. if err != nil {
  195. return nil, fmt.Errorf("marshal request: %w", err)
  196. }
  197. req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
  198. if err != nil {
  199. return nil, fmt.Errorf("build request: %w", err)
  200. }
  201. req.Header.Set("Content-Type", "application/json")
  202. req.Header.Set("Accept", "application/json")
  203. if strings.TrimSpace(apiKey) != "" {
  204. req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
  205. req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
  206. }
  207. for key, value := range headers {
  208. if strings.TrimSpace(key) == "" {
  209. continue
  210. }
  211. req.Header.Set(key, value)
  212. }
  213. resp, err := httpClient.Do(req)
  214. if err != nil {
  215. return nil, fmt.Errorf("do request: %w", err)
  216. }
  217. defer resp.Body.Close()
  218. respBody, err := io.ReadAll(resp.Body)
  219. if err != nil {
  220. return nil, fmt.Errorf("read response: %w", err)
  221. }
  222. if resp.StatusCode >= 400 {
  223. message := trimProviderErrorMessage(respBody)
  224. return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
  225. }
  226. return respBody, nil
  227. }
  228. func optionalFloat64(value *float64, fallback float64) float64 {
  229. if value == nil {
  230. return fallback
  231. }
  232. return *value
  233. }
  234. func optionalInt(value *int, fallback int) int {
  235. if value == nil {
  236. return fallback
  237. }
  238. return *value
  239. }
  240. func openAICompatibleMaxTokensField(provider, model string) string {
  241. if isOpenAIGPT5Model(provider, model) {
  242. return "max_completion_tokens"
  243. }
  244. return "max_tokens"
  245. }
  246. func isOpenAIGPT5Model(provider, model string) bool {
  247. if !strings.EqualFold(strings.TrimSpace(provider), "openai") {
  248. return false
  249. }
  250. normalizedModel := strings.ToLower(strings.TrimSpace(model))
  251. return strings.HasPrefix(normalizedModel, "gpt-5")
  252. }
  253. func trimProviderErrorMessage(respBody []byte) string {
  254. message := extractProviderErrorMessage(respBody)
  255. if len(message) > 500 {
  256. return message[:500]
  257. }
  258. return message
  259. }
  260. func extractProviderErrorMessage(respBody []byte) string {
  261. raw := strings.TrimSpace(string(respBody))
  262. if raw == "" {
  263. return "empty error response"
  264. }
  265. var parsed map[string]any
  266. if err := json.Unmarshal(respBody, &parsed); err == nil {
  267. if value := nestedString(parsed, "error", "message"); value != "" {
  268. return value
  269. }
  270. if value := nestedString(parsed, "error"); value != "" {
  271. return value
  272. }
  273. if value := nestedString(parsed, "message"); value != "" {
  274. return value
  275. }
  276. }
  277. return raw
  278. }
  279. func nestedString(values map[string]any, path ...string) string {
  280. if len(path) == 0 || values == nil {
  281. return ""
  282. }
  283. current := any(values)
  284. for _, key := range path {
  285. nextMap, ok := current.(map[string]any)
  286. if !ok {
  287. return ""
  288. }
  289. current = nextMap[key]
  290. }
  291. switch value := current.(type) {
  292. case string:
  293. return strings.TrimSpace(value)
  294. case fmt.Stringer:
  295. return strings.TrimSpace(value.String())
  296. default:
  297. return ""
  298. }
  299. }