Você não pode selecionar mais de 25 tópicos Os tópicos devem começar com uma letra ou um número, podem incluir traços ('-') e podem ter até 35 caracteres.

600 linhas
17KB

  1. package llmruntime
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "net/http"
  10. "net/url"
  11. "sort"
  12. "strings"
  13. "time"
  14. )
  15. type Request struct {
  16. Provider string
  17. Model string
  18. BaseURL string
  19. APIKey string
  20. Temperature *float64
  21. MaxTokens *int
  22. SystemPrompt string
  23. UserPrompt string
  24. }
  25. const (
  26. runtimeSnippetLimit = 4000
  27. runtimePromptSnippetLimit = 1500
  28. runtimePayloadSnippetLimit = 5000
  29. runtimeShapeSnippetLimit = 1200
  30. )
  31. type Client interface {
  32. Generate(ctx context.Context, req Request) (string, error)
  33. }
  34. type Factory struct {
  35. httpClient *http.Client
  36. }
  37. func NewFactory(timeout time.Duration) *Factory {
  38. if timeout <= 0 {
  39. timeout = 45 * time.Second
  40. }
  41. return &Factory{
  42. httpClient: &http.Client{Timeout: timeout},
  43. }
  44. }
  45. func (f *Factory) ClientFor(provider string) (Client, error) {
  46. normalized := strings.ToLower(strings.TrimSpace(provider))
  47. switch normalized {
  48. case "openai", "xai", "ollama":
  49. return &openAICompatibleClient{httpClient: f.httpClient}, nil
  50. case "anthropic":
  51. return &anthropicClient{httpClient: f.httpClient}, nil
  52. case "google":
  53. return &googleClient{httpClient: f.httpClient}, nil
  54. default:
  55. return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
  56. }
  57. }
  58. type openAICompatibleClient struct {
  59. httpClient *http.Client
  60. }
  61. func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
  62. provider := strings.ToLower(strings.TrimSpace(req.Provider))
  63. model := strings.TrimSpace(req.Model)
  64. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  65. if baseURL == "" {
  66. switch provider {
  67. case "xai":
  68. baseURL = "https://api.x.ai"
  69. case "ollama":
  70. baseURL = "http://localhost:11434"
  71. default:
  72. baseURL = "https://api.openai.com"
  73. }
  74. }
  75. systemPrompt := strings.TrimSpace(req.SystemPrompt)
  76. userPrompt := strings.TrimSpace(req.UserPrompt)
  77. payload := map[string]any{
  78. "model": model,
  79. "temperature": optionalFloat64(req.Temperature, 0),
  80. "messages": []map[string]string{
  81. {"role": "system", "content": systemPrompt},
  82. {"role": "user", "content": userPrompt},
  83. },
  84. }
  85. maxTokensField := openAICompatibleMaxTokensField(provider, model)
  86. payload[maxTokensField] = optionalInt(req.MaxTokens, 1200)
  87. payloadRaw, _ := json.Marshal(payload)
  88. payloadSnippet := redactSecrets(snippet(string(payloadRaw), runtimePayloadSnippetLimit), req.APIKey)
  89. runtimeLogger().InfoContext(ctx, "llm runtime",
  90. "component", "autofill",
  91. "step", "provider_http_request",
  92. "provider", provider,
  93. "model", model,
  94. "base_url", safeBaseURL(baseURL),
  95. "max_tokens_field", maxTokensField,
  96. "system_prompt_chars", len(systemPrompt),
  97. "system_prompt_snippet", snippet(systemPrompt, runtimePromptSnippetLimit),
  98. "user_prompt_chars", len(userPrompt),
  99. "user_prompt_snippet", snippet(userPrompt, runtimePromptSnippetLimit),
  100. "request_payload_chars", len(payloadRaw),
  101. "request_payload_snippet", payloadSnippet,
  102. )
  103. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
  104. if err != nil {
  105. return "", err
  106. }
  107. rawResponse := strings.TrimSpace(string(body))
  108. runtimeLogger().InfoContext(ctx, "llm runtime",
  109. "component", "autofill",
  110. "step", "provider_http_response",
  111. "provider", provider,
  112. "model", model,
  113. "raw_response_chars", len(rawResponse),
  114. "raw_response_snippet", redactSecrets(snippet(rawResponse, runtimeSnippetLimit), req.APIKey),
  115. )
  116. var response map[string]any
  117. if err := json.Unmarshal(body, &response); err != nil {
  118. return "", fmt.Errorf("decode openai-compatible response: %w", err)
  119. }
  120. shape := describeOpenAICompatibleShape(response)
  121. runtimeLogger().InfoContext(ctx, "llm runtime",
  122. "component", "autofill",
  123. "step", "provider_http_response_shape",
  124. "provider", provider,
  125. "model", model,
  126. "response_shape_hint", snippet(shape, runtimeShapeSnippetLimit),
  127. )
  128. content := extractOpenAICompatibleContent(response)
  129. runtimeLogger().InfoContext(ctx, "llm runtime",
  130. "component", "autofill",
  131. "step", "provider_extract",
  132. "provider", provider,
  133. "model", model,
  134. "extracted_content_chars", len(content),
  135. "extracted_content_snippet", redactSecrets(snippet(content, runtimeSnippetLimit), req.APIKey),
  136. )
  137. if content == "" {
  138. return "", fmt.Errorf("empty openai-compatible response content (%s)", describeOpenAICompatibleShape(response))
  139. }
  140. return content, nil
  141. }
  142. type anthropicClient struct {
  143. httpClient *http.Client
  144. }
  145. func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
  146. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  147. if baseURL == "" {
  148. baseURL = "https://api.anthropic.com"
  149. }
  150. payload := map[string]any{
  151. "model": strings.TrimSpace(req.Model),
  152. "max_tokens": optionalInt(req.MaxTokens, 1200),
  153. "temperature": optionalFloat64(req.Temperature, 0),
  154. "system": strings.TrimSpace(req.SystemPrompt),
  155. "messages": []map[string]any{
  156. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  157. },
  158. }
  159. headers := map[string]string{"anthropic-version": "2023-06-01"}
  160. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
  161. if err != nil {
  162. return "", err
  163. }
  164. var response struct {
  165. Content []struct {
  166. Type string `json:"type"`
  167. Text string `json:"text"`
  168. } `json:"content"`
  169. }
  170. if err := json.Unmarshal(body, &response); err != nil {
  171. return "", fmt.Errorf("decode anthropic response: %w", err)
  172. }
  173. for _, item := range response.Content {
  174. if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
  175. return strings.TrimSpace(item.Text), nil
  176. }
  177. }
  178. return "", fmt.Errorf("empty anthropic response")
  179. }
  180. type googleClient struct {
  181. httpClient *http.Client
  182. }
  183. func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
  184. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  185. if baseURL == "" {
  186. baseURL = "https://generativelanguage.googleapis.com"
  187. }
  188. model := strings.TrimSpace(req.Model)
  189. if model == "" {
  190. return "", fmt.Errorf("google model is required")
  191. }
  192. apiKey := strings.TrimSpace(req.APIKey)
  193. if apiKey == "" {
  194. return "", fmt.Errorf("google api key is required")
  195. }
  196. endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
  197. payload := map[string]any{
  198. "contents": []map[string]any{
  199. {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
  200. },
  201. "generationConfig": map[string]any{
  202. "temperature": optionalFloat64(req.Temperature, 0),
  203. "maxOutputTokens": optionalInt(req.MaxTokens, 1200),
  204. },
  205. }
  206. if strings.TrimSpace(req.SystemPrompt) != "" {
  207. payload["systemInstruction"] = map[string]any{
  208. "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
  209. }
  210. }
  211. body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
  212. if err != nil {
  213. return "", err
  214. }
  215. var response struct {
  216. Candidates []struct {
  217. Content struct {
  218. Parts []struct {
  219. Text string `json:"text"`
  220. } `json:"parts"`
  221. } `json:"content"`
  222. } `json:"candidates"`
  223. }
  224. if err := json.Unmarshal(body, &response); err != nil {
  225. return "", fmt.Errorf("decode google response: %w", err)
  226. }
  227. if len(response.Candidates) == 0 {
  228. return "", fmt.Errorf("empty google response")
  229. }
  230. parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
  231. for _, part := range response.Candidates[0].Content.Parts {
  232. if text := strings.TrimSpace(part.Text); text != "" {
  233. parts = append(parts, text)
  234. }
  235. }
  236. if len(parts) == 0 {
  237. return "", fmt.Errorf("google response has no text parts")
  238. }
  239. return strings.Join(parts, "\n"), nil
  240. }
  241. func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
  242. body, err := json.Marshal(payload)
  243. if err != nil {
  244. return nil, fmt.Errorf("marshal request: %w", err)
  245. }
  246. req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
  247. if err != nil {
  248. return nil, fmt.Errorf("build request: %w", err)
  249. }
  250. req.Header.Set("Content-Type", "application/json")
  251. req.Header.Set("Accept", "application/json")
  252. if strings.TrimSpace(apiKey) != "" {
  253. req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
  254. req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
  255. }
  256. for key, value := range headers {
  257. if strings.TrimSpace(key) == "" {
  258. continue
  259. }
  260. req.Header.Set(key, value)
  261. }
  262. resp, err := httpClient.Do(req)
  263. if err != nil {
  264. return nil, fmt.Errorf("do request: %w", err)
  265. }
  266. defer resp.Body.Close()
  267. respBody, err := io.ReadAll(resp.Body)
  268. if err != nil {
  269. return nil, fmt.Errorf("read response: %w", err)
  270. }
  271. if resp.StatusCode >= 400 {
  272. message := trimProviderErrorMessage(respBody)
  273. return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
  274. }
  275. return respBody, nil
  276. }
  277. func optionalFloat64(value *float64, fallback float64) float64 {
  278. if value == nil {
  279. return fallback
  280. }
  281. return *value
  282. }
  283. func optionalInt(value *int, fallback int) int {
  284. if value == nil {
  285. return fallback
  286. }
  287. return *value
  288. }
  289. func openAICompatibleMaxTokensField(provider, model string) string {
  290. if isOpenAIGPT5Model(provider, model) {
  291. return "max_completion_tokens"
  292. }
  293. return "max_tokens"
  294. }
  295. func isOpenAIGPT5Model(provider, model string) bool {
  296. if !strings.EqualFold(strings.TrimSpace(provider), "openai") {
  297. return false
  298. }
  299. normalizedModel := strings.ToLower(strings.TrimSpace(model))
  300. return strings.HasPrefix(normalizedModel, "gpt-5")
  301. }
  302. func trimProviderErrorMessage(respBody []byte) string {
  303. message := extractProviderErrorMessage(respBody)
  304. if len(message) > 500 {
  305. return message[:500]
  306. }
  307. return message
  308. }
  309. func extractProviderErrorMessage(respBody []byte) string {
  310. raw := strings.TrimSpace(string(respBody))
  311. if raw == "" {
  312. return "empty error response"
  313. }
  314. var parsed map[string]any
  315. if err := json.Unmarshal(respBody, &parsed); err == nil {
  316. if value := nestedString(parsed, "error", "message"); value != "" {
  317. return value
  318. }
  319. if value := nestedString(parsed, "error"); value != "" {
  320. return value
  321. }
  322. if value := nestedString(parsed, "message"); value != "" {
  323. return value
  324. }
  325. }
  326. return raw
  327. }
  328. func nestedString(values map[string]any, path ...string) string {
  329. if len(path) == 0 || values == nil {
  330. return ""
  331. }
  332. current := any(values)
  333. for _, key := range path {
  334. nextMap, ok := current.(map[string]any)
  335. if !ok {
  336. return ""
  337. }
  338. current = nextMap[key]
  339. }
  340. switch value := current.(type) {
  341. case string:
  342. return strings.TrimSpace(value)
  343. case fmt.Stringer:
  344. return strings.TrimSpace(value.String())
  345. default:
  346. return ""
  347. }
  348. }
  349. func extractOpenAICompatibleContent(response map[string]any) string {
  350. if response == nil {
  351. return ""
  352. }
  353. if text := strings.TrimSpace(extractOpenAICompatibleChoicesContent(response["choices"])); text != "" {
  354. return text
  355. }
  356. if text := strings.TrimSpace(extractTextFromContentValue(response["output_text"])); text != "" {
  357. return text
  358. }
  359. return strings.TrimSpace(extractOpenAICompatibleOutputContent(response["output"]))
  360. }
  361. func extractOpenAICompatibleChoicesContent(raw any) string {
  362. choices, ok := raw.([]any)
  363. if !ok {
  364. return ""
  365. }
  366. for _, rawChoice := range choices {
  367. choice, ok := rawChoice.(map[string]any)
  368. if !ok {
  369. continue
  370. }
  371. if text := strings.TrimSpace(extractTextFromContentValue(choice["message"])); text != "" {
  372. return text
  373. }
  374. if text := strings.TrimSpace(extractTextFromContentValue(choice["delta"])); text != "" {
  375. return text
  376. }
  377. if text := strings.TrimSpace(extractTextFromContentValue(choice["text"])); text != "" {
  378. return text
  379. }
  380. }
  381. return ""
  382. }
  383. func extractOpenAICompatibleOutputContent(raw any) string {
  384. output, ok := raw.([]any)
  385. if !ok {
  386. return ""
  387. }
  388. for _, rawItem := range output {
  389. item, ok := rawItem.(map[string]any)
  390. if !ok {
  391. continue
  392. }
  393. if text := strings.TrimSpace(extractTextFromContentValue(item["content"])); text != "" {
  394. return text
  395. }
  396. if text := strings.TrimSpace(extractTextFromContentValue(item["text"])); text != "" {
  397. return text
  398. }
  399. }
  400. return ""
  401. }
  402. func extractTextFromContentValue(raw any) string {
  403. switch value := raw.(type) {
  404. case string:
  405. return strings.TrimSpace(value)
  406. case []any:
  407. parts := make([]string, 0, len(value))
  408. for _, item := range value {
  409. if text := strings.TrimSpace(extractTextFromContentValue(item)); text != "" {
  410. parts = append(parts, text)
  411. }
  412. }
  413. return strings.TrimSpace(strings.Join(parts, "\n"))
  414. case map[string]any:
  415. if text := strings.TrimSpace(extractTextFromContentValue(value["content"])); text != "" {
  416. return text
  417. }
  418. if text := strings.TrimSpace(extractTextFromContentValue(value["text"])); text != "" {
  419. return text
  420. }
  421. if text := strings.TrimSpace(extractTextFromContentValue(value["value"])); text != "" {
  422. return text
  423. }
  424. if text := strings.TrimSpace(extractTextFromContentValue(value["output_text"])); text != "" {
  425. return text
  426. }
  427. return ""
  428. default:
  429. return ""
  430. }
  431. }
  432. func describeOpenAICompatibleShape(response map[string]any) string {
  433. parts := make([]string, 0, 14)
  434. if choices, ok := response["choices"].([]any); ok {
  435. parts = append(parts, fmt.Sprintf("choices_len=%d", len(choices)))
  436. if len(choices) > 0 {
  437. if choice, ok := choices[0].(map[string]any); ok {
  438. parts = append(parts, "choices0="+describeMapKeys(choice))
  439. if finishReason, exists := choice["finish_reason"]; exists {
  440. if reason, ok := finishReason.(string); ok {
  441. parts = append(parts, "choices0_finish_reason="+strings.TrimSpace(reason))
  442. } else {
  443. parts = append(parts, "choices0_finish_reason_type="+valueType(finishReason))
  444. }
  445. }
  446. if message, ok := choice["message"].(map[string]any); ok {
  447. parts = append(parts, "message="+describeMapKeys(message))
  448. parts = append(parts, "message_content_type="+valueType(message["content"]))
  449. if content, ok := message["content"].([]any); ok {
  450. parts = append(parts, fmt.Sprintf("message_content_len=%d", len(content)))
  451. if len(content) > 0 {
  452. parts = append(parts, "message_content0_type="+valueType(content[0]))
  453. if first, ok := content[0].(map[string]any); ok {
  454. parts = append(parts, "message_content0="+describeMapKeys(first))
  455. }
  456. }
  457. }
  458. } else if _, exists := choice["message"]; exists {
  459. parts = append(parts, "message_type="+valueType(choice["message"]))
  460. }
  461. }
  462. }
  463. } else if _, exists := response["choices"]; exists {
  464. parts = append(parts, "choices_type="+valueType(response["choices"]))
  465. }
  466. if _, exists := response["output_text"]; exists {
  467. parts = append(parts, "output_text_type="+valueType(response["output_text"]))
  468. }
  469. if output, ok := response["output"].([]any); ok {
  470. parts = append(parts, fmt.Sprintf("output_len=%d", len(output)))
  471. if len(output) > 0 {
  472. if first, ok := output[0].(map[string]any); ok {
  473. parts = append(parts, "output0="+describeMapKeys(first))
  474. parts = append(parts, "output0_content_type="+valueType(first["content"]))
  475. }
  476. }
  477. } else if _, exists := response["output"]; exists {
  478. parts = append(parts, "output_type="+valueType(response["output"]))
  479. }
  480. parts = append(parts, "top="+describeMapKeys(response))
  481. return strings.Join(parts, "; ")
  482. }
  483. func describeMapKeys(raw map[string]any) string {
  484. if len(raw) == 0 {
  485. return "{}"
  486. }
  487. keys := make([]string, 0, len(raw))
  488. for key := range raw {
  489. keys = append(keys, key)
  490. }
  491. sort.Strings(keys)
  492. described := make([]string, 0, len(keys))
  493. for _, key := range keys {
  494. described = append(described, fmt.Sprintf("%s:%s", key, valueType(raw[key])))
  495. }
  496. return "{" + strings.Join(described, ",") + "}"
  497. }
  498. func valueType(raw any) string {
  499. switch raw.(type) {
  500. case nil:
  501. return "null"
  502. case string:
  503. return "string"
  504. case bool:
  505. return "bool"
  506. case float64:
  507. return "number"
  508. case []any:
  509. return "array"
  510. case map[string]any:
  511. return "object"
  512. default:
  513. return fmt.Sprintf("%T", raw)
  514. }
  515. }
  516. func runtimeLogger() *slog.Logger {
  517. return slog.Default()
  518. }
  519. func snippet(value string, limit int) string {
  520. trimmed := strings.TrimSpace(value)
  521. if trimmed == "" || limit <= 0 {
  522. return ""
  523. }
  524. runes := []rune(trimmed)
  525. if len(runes) <= limit {
  526. return trimmed
  527. }
  528. return strings.TrimSpace(string(runes[:limit])) + "...(truncated)"
  529. }
  530. func redactSecrets(value string, secrets ...string) string {
  531. out := value
  532. for _, secret := range secrets {
  533. trimmed := strings.TrimSpace(secret)
  534. if trimmed == "" {
  535. continue
  536. }
  537. out = strings.ReplaceAll(out, trimmed, "[REDACTED]")
  538. }
  539. return out
  540. }
  541. func safeBaseURL(value string) string {
  542. trimmed := strings.TrimSpace(value)
  543. if trimmed == "" {
  544. return ""
  545. }
  546. parsed, err := url.Parse(trimmed)
  547. if err != nil || parsed.Scheme == "" || parsed.Host == "" {
  548. return trimmed
  549. }
  550. parsed.User = nil
  551. parsed.RawQuery = ""
  552. parsed.Fragment = ""
  553. return strings.TrimRight(parsed.String(), "/")
  554. }