選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

250 行
6.7KB

  1. package llmruntime
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "net/url"
  10. "strings"
  11. "time"
  12. )
  13. type Request struct {
  14. Provider string
  15. Model string
  16. BaseURL string
  17. APIKey string
  18. SystemPrompt string
  19. UserPrompt string
  20. }
  21. type Client interface {
  22. Generate(ctx context.Context, req Request) (string, error)
  23. }
  24. type Factory struct {
  25. httpClient *http.Client
  26. }
  27. func NewFactory(timeout time.Duration) *Factory {
  28. if timeout <= 0 {
  29. timeout = 45 * time.Second
  30. }
  31. return &Factory{
  32. httpClient: &http.Client{Timeout: timeout},
  33. }
  34. }
  35. func (f *Factory) ClientFor(provider string) (Client, error) {
  36. normalized := strings.ToLower(strings.TrimSpace(provider))
  37. switch normalized {
  38. case "openai", "xai", "ollama":
  39. return &openAICompatibleClient{httpClient: f.httpClient}, nil
  40. case "anthropic":
  41. return &anthropicClient{httpClient: f.httpClient}, nil
  42. case "google":
  43. return &googleClient{httpClient: f.httpClient}, nil
  44. default:
  45. return nil, fmt.Errorf("unsupported llm provider: %s", normalized)
  46. }
  47. }
  48. type openAICompatibleClient struct {
  49. httpClient *http.Client
  50. }
  51. func (c *openAICompatibleClient) Generate(ctx context.Context, req Request) (string, error) {
  52. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  53. if baseURL == "" {
  54. switch strings.ToLower(strings.TrimSpace(req.Provider)) {
  55. case "xai":
  56. baseURL = "https://api.x.ai"
  57. case "ollama":
  58. baseURL = "http://localhost:11434"
  59. default:
  60. baseURL = "https://api.openai.com"
  61. }
  62. }
  63. payload := map[string]any{
  64. "model": strings.TrimSpace(req.Model),
  65. "temperature": 0,
  66. "messages": []map[string]string{
  67. {"role": "system", "content": strings.TrimSpace(req.SystemPrompt)},
  68. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  69. },
  70. }
  71. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/chat/completions", req.APIKey, nil, payload)
  72. if err != nil {
  73. return "", err
  74. }
  75. var response struct {
  76. Choices []struct {
  77. Message struct {
  78. Content string `json:"content"`
  79. } `json:"message"`
  80. } `json:"choices"`
  81. }
  82. if err := json.Unmarshal(body, &response); err != nil {
  83. return "", fmt.Errorf("decode openai-compatible response: %w", err)
  84. }
  85. if len(response.Choices) == 0 {
  86. return "", fmt.Errorf("empty openai-compatible response")
  87. }
  88. return strings.TrimSpace(response.Choices[0].Message.Content), nil
  89. }
  90. type anthropicClient struct {
  91. httpClient *http.Client
  92. }
  93. func (c *anthropicClient) Generate(ctx context.Context, req Request) (string, error) {
  94. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  95. if baseURL == "" {
  96. baseURL = "https://api.anthropic.com"
  97. }
  98. payload := map[string]any{
  99. "model": strings.TrimSpace(req.Model),
  100. "max_tokens": 1200,
  101. "temperature": 0,
  102. "system": strings.TrimSpace(req.SystemPrompt),
  103. "messages": []map[string]any{
  104. {"role": "user", "content": strings.TrimSpace(req.UserPrompt)},
  105. },
  106. }
  107. headers := map[string]string{"anthropic-version": "2023-06-01"}
  108. body, err := doJSON(ctx, c.httpClient, http.MethodPost, baseURL+"/v1/messages", req.APIKey, headers, payload)
  109. if err != nil {
  110. return "", err
  111. }
  112. var response struct {
  113. Content []struct {
  114. Type string `json:"type"`
  115. Text string `json:"text"`
  116. } `json:"content"`
  117. }
  118. if err := json.Unmarshal(body, &response); err != nil {
  119. return "", fmt.Errorf("decode anthropic response: %w", err)
  120. }
  121. for _, item := range response.Content {
  122. if strings.EqualFold(strings.TrimSpace(item.Type), "text") && strings.TrimSpace(item.Text) != "" {
  123. return strings.TrimSpace(item.Text), nil
  124. }
  125. }
  126. return "", fmt.Errorf("empty anthropic response")
  127. }
  128. type googleClient struct {
  129. httpClient *http.Client
  130. }
  131. func (c *googleClient) Generate(ctx context.Context, req Request) (string, error) {
  132. baseURL := strings.TrimRight(strings.TrimSpace(req.BaseURL), "/")
  133. if baseURL == "" {
  134. baseURL = "https://generativelanguage.googleapis.com"
  135. }
  136. model := strings.TrimSpace(req.Model)
  137. if model == "" {
  138. return "", fmt.Errorf("google model is required")
  139. }
  140. apiKey := strings.TrimSpace(req.APIKey)
  141. if apiKey == "" {
  142. return "", fmt.Errorf("google api key is required")
  143. }
  144. endpoint := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", baseURL, url.PathEscape(model), url.QueryEscape(apiKey))
  145. payload := map[string]any{
  146. "contents": []map[string]any{
  147. {"parts": []map[string]string{{"text": strings.TrimSpace(req.UserPrompt)}}},
  148. },
  149. "generationConfig": map[string]any{
  150. "temperature": 0,
  151. },
  152. }
  153. if strings.TrimSpace(req.SystemPrompt) != "" {
  154. payload["systemInstruction"] = map[string]any{
  155. "parts": []map[string]string{{"text": strings.TrimSpace(req.SystemPrompt)}},
  156. }
  157. }
  158. body, err := doJSON(ctx, c.httpClient, http.MethodPost, endpoint, "", nil, payload)
  159. if err != nil {
  160. return "", err
  161. }
  162. var response struct {
  163. Candidates []struct {
  164. Content struct {
  165. Parts []struct {
  166. Text string `json:"text"`
  167. } `json:"parts"`
  168. } `json:"content"`
  169. } `json:"candidates"`
  170. }
  171. if err := json.Unmarshal(body, &response); err != nil {
  172. return "", fmt.Errorf("decode google response: %w", err)
  173. }
  174. if len(response.Candidates) == 0 {
  175. return "", fmt.Errorf("empty google response")
  176. }
  177. parts := make([]string, 0, len(response.Candidates[0].Content.Parts))
  178. for _, part := range response.Candidates[0].Content.Parts {
  179. if text := strings.TrimSpace(part.Text); text != "" {
  180. parts = append(parts, text)
  181. }
  182. }
  183. if len(parts) == 0 {
  184. return "", fmt.Errorf("google response has no text parts")
  185. }
  186. return strings.Join(parts, "\n"), nil
  187. }
  188. func doJSON(ctx context.Context, httpClient *http.Client, method, endpoint, apiKey string, headers map[string]string, payload any) ([]byte, error) {
  189. body, err := json.Marshal(payload)
  190. if err != nil {
  191. return nil, fmt.Errorf("marshal request: %w", err)
  192. }
  193. req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewReader(body))
  194. if err != nil {
  195. return nil, fmt.Errorf("build request: %w", err)
  196. }
  197. req.Header.Set("Content-Type", "application/json")
  198. req.Header.Set("Accept", "application/json")
  199. if strings.TrimSpace(apiKey) != "" {
  200. req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey))
  201. req.Header.Set("x-api-key", strings.TrimSpace(apiKey))
  202. }
  203. for key, value := range headers {
  204. if strings.TrimSpace(key) == "" {
  205. continue
  206. }
  207. req.Header.Set(key, value)
  208. }
  209. resp, err := httpClient.Do(req)
  210. if err != nil {
  211. return nil, fmt.Errorf("do request: %w", err)
  212. }
  213. defer resp.Body.Close()
  214. respBody, err := io.ReadAll(resp.Body)
  215. if err != nil {
  216. return nil, fmt.Errorf("read response: %w", err)
  217. }
  218. if resp.StatusCode >= 400 {
  219. message := strings.TrimSpace(string(respBody))
  220. if len(message) > 500 {
  221. message = message[:500]
  222. }
  223. return nil, fmt.Errorf("provider http %d: %s", resp.StatusCode, message)
  224. }
  225. return respBody, nil
  226. }