Ви не можете вибрати більше 25 тем Теми мають розпочинатися з літери або цифри, можуть містити дефіси (-) і не повинні перевищувати 35 символів.

485 рядки
13KB

  1. package llmruntime
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "sort"
  11. "strings"
  12. "time"
  13. )
  14. type Request struct {
  15. Provider string
  16. Model string
  17. BaseURL string
  18. APIKey string
  19. Temperature *float64
  20. MaxTokens *int
  21. SystemPrompt string
  22. UserPrompt string
  23. }
  24. type Client interface {
  25. Generate(ctx context.Context, req Request) (string, error)
  26. }
  27. type Factory struct {
  28. httpClient *http.Client
  29. }
  30. func NewFactory(timeout time.Duration) *Factory {
  31. if timeout <= 0 {
  32. timeout = 45 * time.Second
  33. }
  34. return &Factory{
  35. httpClient: &http.Client{Timeout: timeout},
  36. }
  37. }
  38. func (f *Factory) ClientFor(provider string) (Client, error) {
  39. normalized := strings.ToLower(strings.TrimSpace(provider))
  40. switch normalized {
  41. case "openai", "xai", "ollama":
  42. return &openAICompatibleClient{httpClient: f.httpClient}, nil
  43. case "anthropic":
  44. return &anthropicClient{httpClient: f.httpClient}, nil
  45. case "google":
  46. return &googleClient{httpClient: f.httpClient}, nil
  47. default:
  48. return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
  49. }
  50. }
  51. type openAICompatibleClient struct {
  52. httpClient *http.Client
  53. }
  54. func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
  55. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  56. if baseURL == "" {
  57. switch strings.ToLower(strings.TrimSpace(req.Provider)) {
  58. case "xai":
  59. baseURL = "https://api.x.ai"
  60. case "ollama":
  61. baseURL = "http://localhost:11434"
  62. default:
  63. baseURL = "https://api.openai.com"
  64. }
  65. }
  66. payload := map[string]any{
  67. "model": strings.TrimSpace(req.Model),
  68. "temperature": optionalFloat64(req.Temperature, 0),
  69. "messages": []map[string]string{
  70. {"role": "system", "content": strings.TrimSpace(req.SystemPrompt)},
  71. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  72. },
  73. }
  74. payload[openAICompatibleMaxTokensField(req.Provider, req.Model)] = optionalInt(req.MaxTokens, 1200)
  75. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
  76. if err != nil {
  77. return "", err
  78. }
  79. var response map[string]any
  80. if err := json.Unmarshal(body, &response); err != nil {
  81. return "", fmt.Errorf("decode openai-compatible response: %w", err)
  82. }
  83. content := extractOpenAICompatibleContent(response)
  84. if content == "" {
  85. return "", fmt.Errorf("empty openai-compatible response content (%s)", describeOpenAICompatibleShape(response))
  86. }
  87. return content, nil
  88. }
  89. type anthropicClient struct {
  90. httpClient *http.Client
  91. }
  92. func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
  93. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  94. if baseURL == "" {
  95. baseURL = "https://api.anthropic.com"
  96. }
  97. payload := map[string]any{
  98. "model": strings.TrimSpace(req.Model),
  99. "max_tokens": optionalInt(req.MaxTokens, 1200),
  100. "temperature": optionalFloat64(req.Temperature, 0),
  101. "system": strings.TrimSpace(req.SystemPrompt),
  102. "messages": []map[string]any{
  103. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  104. },
  105. }
  106. headers := map[string]string{"anthropic-version": "2023-06-01"}
  107. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
  108. if err != nil {
  109. return "", err
  110. }
  111. var response struct {
  112. Content []struct {
  113. Type string `json:"type"`
  114. Text string `json:"text"`
  115. } `json:"content"`
  116. }
  117. if err := json.Unmarshal(body, &response); err != nil {
  118. return "", fmt.Errorf("decode anthropic response: %w", err)
  119. }
  120. for _, item := range response.Content {
  121. if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
  122. return strings.TrimSpace(item.Text), nil
  123. }
  124. }
  125. return "", fmt.Errorf("empty anthropic response")
  126. }
  127. type googleClient struct {
  128. httpClient *http.Client
  129. }
  130. func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
  131. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  132. if baseURL == "" {
  133. baseURL = "https://generativelanguage.googleapis.com"
  134. }
  135. model := strings.TrimSpace(req.Model)
  136. if model == "" {
  137. return "", fmt.Errorf("google model is required")
  138. }
  139. apiKey := strings.TrimSpace(req.APIKey)
  140. if apiKey == "" {
  141. return "", fmt.Errorf("google api key is required")
  142. }
  143. endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
  144. payload := map[string]any{
  145. "contents": []map[string]any{
  146. {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
  147. },
  148. "generationConfig": map[string]any{
  149. "temperature": optionalFloat64(req.Temperature, 0),
  150. "maxOutputTokens": optionalInt(req.MaxTokens, 1200),
  151. },
  152. }
  153. if strings.TrimSpace(req.SystemPrompt) != "" {
  154. payload["systemInstruction"] = map[string]any{
  155. "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
  156. }
  157. }
  158. body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
  159. if err != nil {
  160. return "", err
  161. }
  162. var response struct {
  163. Candidates []struct {
  164. Content struct {
  165. Parts []struct {
  166. Text string `json:"text"`
  167. } `json:"parts"`
  168. } `json:"content"`
  169. } `json:"candidates"`
  170. }
  171. if err := json.Unmarshal(body, &response); err != nil {
  172. return "", fmt.Errorf("decode google response: %w", err)
  173. }
  174. if len(response.Candidates) == 0 {
  175. return "", fmt.Errorf("empty google response")
  176. }
  177. parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
  178. for _, part := range response.Candidates[0].Content.Parts {
  179. if text := strings.TrimSpace(part.Text); text != "" {
  180. parts = append(parts, text)
  181. }
  182. }
  183. if len(parts) == 0 {
  184. return "", fmt.Errorf("google response has no text parts")
  185. }
  186. return strings.Join(parts, "\n"), nil
  187. }
  188. func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
  189. body, err := json.Marshal(payload)
  190. if err != nil {
  191. return nil, fmt.Errorf("marshal request: %w", err)
  192. }
  193. req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
  194. if err != nil {
  195. return nil, fmt.Errorf("build request: %w", err)
  196. }
  197. req.Header.Set("Content-Type", "application/json")
  198. req.Header.Set("Accept", "application/json")
  199. if strings.TrimSpace(apiKey) != "" {
  200. req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
  201. req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
  202. }
  203. for key, value := range headers {
  204. if strings.TrimSpace(key) == "" {
  205. continue
  206. }
  207. req.Header.Set(key, value)
  208. }
  209. resp, err := httpClient.Do(req)
  210. if err != nil {
  211. return nil, fmt.Errorf("do request: %w", err)
  212. }
  213. defer resp.Body.Close()
  214. respBody, err := io.ReadAll(resp.Body)
  215. if err != nil {
  216. return nil, fmt.Errorf("read response: %w", err)
  217. }
  218. if resp.StatusCode >= 400 {
  219. message := trimProviderErrorMessage(respBody)
  220. return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
  221. }
  222. return respBody, nil
  223. }
  224. func optionalFloat64(value *float64, fallback float64) float64 {
  225. if value == nil {
  226. return fallback
  227. }
  228. return *value
  229. }
  230. func optionalInt(value *int, fallback int) int {
  231. if value == nil {
  232. return fallback
  233. }
  234. return *value
  235. }
  236. func openAICompatibleMaxTokensField(provider, model string) string {
  237. if isOpenAIGPT5Model(provider, model) {
  238. return "max_completion_tokens"
  239. }
  240. return "max_tokens"
  241. }
  242. func isOpenAIGPT5Model(provider, model string) bool {
  243. if !strings.EqualFold(strings.TrimSpace(provider), "openai") {
  244. return false
  245. }
  246. normalizedModel := strings.ToLower(strings.TrimSpace(model))
  247. return strings.HasPrefix(normalizedModel, "gpt-5")
  248. }
  249. func trimProviderErrorMessage(respBody []byte) string {
  250. message := extractProviderErrorMessage(respBody)
  251. if len(message) > 500 {
  252. return message[:500]
  253. }
  254. return message
  255. }
  256. func extractProviderErrorMessage(respBody []byte) string {
  257. raw := strings.TrimSpace(string(respBody))
  258. if raw == "" {
  259. return "empty error response"
  260. }
  261. var parsed map[string]any
  262. if err := json.Unmarshal(respBody, &parsed); err == nil {
  263. if value := nestedString(parsed, "error", "message"); value != "" {
  264. return value
  265. }
  266. if value := nestedString(parsed, "error"); value != "" {
  267. return value
  268. }
  269. if value := nestedString(parsed, "message"); value != "" {
  270. return value
  271. }
  272. }
  273. return raw
  274. }
  275. func nestedString(values map[string]any, path ...string) string {
  276. if len(path) == 0 || values == nil {
  277. return ""
  278. }
  279. current := any(values)
  280. for _, key := range path {
  281. nextMap, ok := current.(map[string]any)
  282. if !ok {
  283. return ""
  284. }
  285. current = nextMap[key]
  286. }
  287. switch value := current.(type) {
  288. case string:
  289. return strings.TrimSpace(value)
  290. case fmt.Stringer:
  291. return strings.TrimSpace(value.String())
  292. default:
  293. return ""
  294. }
  295. }
  296. func extractOpenAICompatibleContent(response map[string]any) string {
  297. if response == nil {
  298. return ""
  299. }
  300. if text := strings.TrimSpace(extractOpenAICompatibleChoicesContent(response["choices"])); text != "" {
  301. return text
  302. }
  303. if text := strings.TrimSpace(extractTextFromContentValue(response["output_text"])); text != "" {
  304. return text
  305. }
  306. return strings.TrimSpace(extractOpenAICompatibleOutputContent(response["output"]))
  307. }
  308. func extractOpenAICompatibleChoicesContent(raw any) string {
  309. choices, ok := raw.([]any)
  310. if !ok {
  311. return ""
  312. }
  313. for _, rawChoice := range choices {
  314. choice, ok := rawChoice.(map[string]any)
  315. if !ok {
  316. continue
  317. }
  318. if text := strings.TrimSpace(extractTextFromContentValue(choice["message"])); text != "" {
  319. return text
  320. }
  321. if text := strings.TrimSpace(extractTextFromContentValue(choice["delta"])); text != "" {
  322. return text
  323. }
  324. if text := strings.TrimSpace(extractTextFromContentValue(choice["text"])); text != "" {
  325. return text
  326. }
  327. }
  328. return ""
  329. }
  330. func extractOpenAICompatibleOutputContent(raw any) string {
  331. output, ok := raw.([]any)
  332. if !ok {
  333. return ""
  334. }
  335. for _, rawItem := range output {
  336. item, ok := rawItem.(map[string]any)
  337. if !ok {
  338. continue
  339. }
  340. if text := strings.TrimSpace(extractTextFromContentValue(item["content"])); text != "" {
  341. return text
  342. }
  343. if text := strings.TrimSpace(extractTextFromContentValue(item["text"])); text != "" {
  344. return text
  345. }
  346. }
  347. return ""
  348. }
  349. func extractTextFromContentValue(raw any) string {
  350. switch value := raw.(type) {
  351. case string:
  352. return strings.TrimSpace(value)
  353. case []any:
  354. parts := make([]string, 0, len(value))
  355. for _, item := range value {
  356. if text := strings.TrimSpace(extractTextFromContentValue(item)); text != "" {
  357. parts = append(parts, text)
  358. }
  359. }
  360. return strings.TrimSpace(strings.Join(parts, "\n"))
  361. case map[string]any:
  362. if text := strings.TrimSpace(extractTextFromContentValue(value["content"])); text != "" {
  363. return text
  364. }
  365. if text := strings.TrimSpace(extractTextFromContentValue(value["text"])); text != "" {
  366. return text
  367. }
  368. if text := strings.TrimSpace(extractTextFromContentValue(value["value"])); text != "" {
  369. return text
  370. }
  371. if text := strings.TrimSpace(extractTextFromContentValue(value["output_text"])); text != "" {
  372. return text
  373. }
  374. return ""
  375. default:
  376. return ""
  377. }
  378. }
  379. func describeOpenAICompatibleShape(response map[string]any) string {
  380. parts := make([]string, 0, 8)
  381. parts = append(parts, "top="+describeMapKeys(response))
  382. if choices, ok := response["choices"].([]any); ok {
  383. parts = append(parts, fmt.Sprintf("choices_len=%d", len(choices)))
  384. if len(choices) > 0 {
  385. if choice, ok := choices[0].(map[string]any); ok {
  386. parts = append(parts, "choices0="+describeMapKeys(choice))
  387. if message, ok := choice["message"].(map[string]any); ok {
  388. parts = append(parts, "message="+describeMapKeys(message))
  389. parts = append(parts, "message_content_type="+valueType(message["content"]))
  390. }
  391. }
  392. }
  393. } else if _, exists := response["choices"]; exists {
  394. parts = append(parts, "choices_type="+valueType(response["choices"]))
  395. }
  396. if _, exists := response["output_text"]; exists {
  397. parts = append(parts, "output_text_type="+valueType(response["output_text"]))
  398. }
  399. if output, ok := response["output"].([]any); ok {
  400. parts = append(parts, fmt.Sprintf("output_len=%d", len(output)))
  401. if len(output) > 0 {
  402. if first, ok := output[0].(map[string]any); ok {
  403. parts = append(parts, "output0="+describeMapKeys(first))
  404. parts = append(parts, "output0_content_type="+valueType(first["content"]))
  405. }
  406. }
  407. } else if _, exists := response["output"]; exists {
  408. parts = append(parts, "output_type="+valueType(response["output"]))
  409. }
  410. return strings.Join(parts, "; ")
  411. }
  412. func describeMapKeys(raw map[string]any) string {
  413. if len(raw) == 0 {
  414. return "{}"
  415. }
  416. keys := make([]string, 0, len(raw))
  417. for key := range raw {
  418. keys = append(keys, key)
  419. }
  420. sort.Strings(keys)
  421. described := make([]string, 0, len(keys))
  422. for _, key := range keys {
  423. described = append(described, fmt.Sprintf("%s:%s", key, valueType(raw[key])))
  424. }
  425. return "{" + strings.Join(described, ",") + "}"
  426. }
  427. func valueType(raw any) string {
  428. switch raw.(type) {
  429. case nil:
  430. return "null"
  431. case string:
  432. return "string"
  433. case bool:
  434. return "bool"
  435. case float64:
  436. return "number"
  437. case []any:
  438. return "array"
  439. case map[string]any:
  440. return "object"
  441. default:
  442. return fmt.Sprintf("%T", raw)
  443. }
  444. }