選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

229 行
7.0KB

  1. package mapping
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "qctextbuilder/internal/domain"
  8. "qctextbuilder/internal/llmruntime"
  9. )
  10. type SettingsReader interface {
  11. GetSettings(ctx context.Context) (*domain.AppSettings, error)
  12. }
  13. type ProviderAwareSuggestionGenerator struct {
  14. settings SettingsReader
  15. runtimeFactory *llmruntime.Factory
  16. }
  17. func NewProviderAwareSuggestionGenerator(settings SettingsReader, runtimeFactory *llmruntime.Factory) *ProviderAwareSuggestionGenerator {
  18. return &ProviderAwareSuggestionGenerator{
  19. settings: settings,
  20. runtimeFactory: runtimeFactory,
  21. }
  22. }
  23. func (g *ProviderAwareSuggestionGenerator) Generate(ctx context.Context, req SuggestionRequest) (SuggestionResult, error) {
  24. if g == nil || g.settings == nil || g.runtimeFactory == nil {
  25. return SuggestionResult{}, fmt.Errorf("provider-aware generator is not configured")
  26. }
  27. settings, err := g.settings.GetSettings(ctx)
  28. if err != nil || settings == nil {
  29. return SuggestionResult{}, fmt.Errorf("llm settings are not available")
  30. }
  31. provider := domain.NormalizeLLMProvider(settings.LLMActiveProvider)
  32. model := domain.NormalizeLLMModel(provider, settings.LLMActiveModel)
  33. if strings.TrimSpace(model) == "" {
  34. return SuggestionResult{}, fmt.Errorf("no active model configured")
  35. }
  36. apiKey := apiKeyForProvider(provider, *settings)
  37. if provider != domain.LLMProviderOllama && strings.TrimSpace(apiKey) == "" {
  38. return SuggestionResult{}, fmt.Errorf("api key for provider %s is not configured", provider)
  39. }
  40. targets := collectSuggestionTargets(req.Fields, req.Existing, req.IncludeFilled)
  41. if len(targets) == 0 {
  42. return SuggestionResult{Suggestions: []Suggestion{}, ByFieldPath: map[string]Suggestion{}}, nil
  43. }
  44. allowed := make(map[string]SemanticSlotTarget, len(targets))
  45. for _, target := range targets {
  46. allowed[target.FieldPath] = target
  47. }
  48. providerClient, err := g.runtimeFactory.ClientFor(provider)
  49. if err != nil {
  50. return SuggestionResult{}, err
  51. }
  52. systemPrompt, userPrompt := buildProviderPrompts(req, targets)
  53. raw, err := providerClient.Generate(ctx, llmruntime.Request{
  54. Provider: provider,
  55. Model: model,
  56. BaseURL: strings.TrimSpace(settings.LLMBaseURL),
  57. APIKey: strings.TrimSpace(apiKey),
  58. SystemPrompt: systemPrompt,
  59. UserPrompt: userPrompt,
  60. })
  61. if err != nil {
  62. return SuggestionResult{}, err
  63. }
  64. parsed, err := parseProviderSuggestions(raw)
  65. if err != nil {
  66. return SuggestionResult{}, err
  67. }
  68. out := SuggestionResult{
  69. Suggestions: make([]Suggestion, 0, len(parsed)),
  70. ByFieldPath: map[string]Suggestion{},
  71. }
  72. for _, item := range parsed {
  73. fieldPath := strings.TrimSpace(item.FieldPath)
  74. target, ok := allowed[fieldPath]
  75. if !ok {
  76. continue
  77. }
  78. value := strings.TrimSpace(item.Value)
  79. if value == "" {
  80. continue
  81. }
  82. suggestion := Suggestion{
  83. FieldPath: fieldPath,
  84. Slot: firstNonEmpty(strings.TrimSpace(item.Slot), target.Slot),
  85. Value: value,
  86. Reason: firstNonEmpty(strings.TrimSpace(item.Reason), "provider suggestion"),
  87. Source: provider,
  88. }
  89. if _, exists := out.ByFieldPath[fieldPath]; exists {
  90. continue
  91. }
  92. out.Suggestions = append(out.Suggestions, suggestion)
  93. out.ByFieldPath[fieldPath] = suggestion
  94. }
  95. return out, nil
  96. }
  97. type providerSuggestion struct {
  98. FieldPath string `json:"fieldPath"`
  99. Slot string `json:"slot,omitempty"`
  100. Value string `json:"value"`
  101. Reason string `json:"reason,omitempty"`
  102. }
  103. func parseProviderSuggestions(raw string) ([]providerSuggestion, error) {
  104. content := strings.TrimSpace(raw)
  105. if content == "" {
  106. return nil, fmt.Errorf("empty provider response")
  107. }
  108. candidates := []string{content}
  109. if fence := extractFencedJSON(content); fence != "" {
  110. candidates = append([]string{fence}, candidates...)
  111. }
  112. if object := extractJSONObject(content); object != "" {
  113. candidates = append(candidates, object)
  114. }
  115. for _, candidate := range candidates {
  116. items, ok := parseSuggestionsCandidate(candidate)
  117. if ok {
  118. return items, nil
  119. }
  120. }
  121. return nil, fmt.Errorf("provider response is not valid suggestions json")
  122. }
  123. func parseSuggestionsCandidate(raw string) ([]providerSuggestion, bool) {
  124. var objectPayload struct {
  125. Suggestions []providerSuggestion `json:"suggestions"`
  126. }
  127. if err := json.Unmarshal([]byte(raw), &objectPayload); err == nil && len(objectPayload.Suggestions) > 0 {
  128. return objectPayload.Suggestions, true
  129. }
  130. var listPayload []providerSuggestion
  131. if err := json.Unmarshal([]byte(raw), &listPayload); err == nil && len(listPayload) > 0 {
  132. return listPayload, true
  133. }
  134. return nil, false
  135. }
  136. func extractFencedJSON(value string) string {
  137. const fence = "```"
  138. start := strings.Index(value, fence)
  139. for start >= 0 {
  140. rest := value[start+len(fence):]
  141. end := strings.Index(rest, fence)
  142. if end < 0 {
  143. return ""
  144. }
  145. block := strings.TrimSpace(rest[:end])
  146. block = strings.TrimPrefix(block, "json")
  147. block = strings.TrimPrefix(block, "JSON")
  148. block = strings.TrimSpace(block)
  149. if strings.HasPrefix(block, "{") || strings.HasPrefix(block, "[") {
  150. return block
  151. }
  152. nextOffset := start + len(fence) + end + len(fence)
  153. nextStart := strings.Index(value[nextOffset:], fence)
  154. if nextStart < 0 {
  155. break
  156. }
  157. start = nextOffset + nextStart
  158. }
  159. return ""
  160. }
  161. func extractJSONObject(value string) string {
  162. start := strings.IndexAny(value, "{[")
  163. if start < 0 {
  164. return ""
  165. }
  166. end := strings.LastIndexAny(value, "}]")
  167. if end <= start {
  168. return ""
  169. }
  170. return strings.TrimSpace(value[start : end+1])
  171. }
  172. func buildProviderPrompts(req SuggestionRequest, targets []SemanticSlotTarget) (string, string) {
  173. targetPayload := make([]map[string]string, 0, len(targets))
  174. for _, target := range targets {
  175. targetPayload = append(targetPayload, map[string]string{
  176. "fieldPath": strings.TrimSpace(target.FieldPath),
  177. "slot": strings.TrimSpace(target.Slot),
  178. })
  179. }
  180. contextPayload := map[string]any{
  181. "globalData": req.GlobalData,
  182. "draftContext": llmDraftContextMap(req.DraftContext),
  183. "masterPrompt": strings.TrimSpace(req.MasterPrompt),
  184. "promptBlocks": enabledPromptBlocks(req.PromptBlocks),
  185. "targets": targetPayload,
  186. }
  187. contextJSON, _ := json.MarshalIndent(contextPayload, "", " ")
  188. system := "You generate website text suggestions. Return JSON only. Format: {\"suggestions\":[{\"fieldPath\":\"...\",\"slot\":\"...\",\"value\":\"...\",\"reason\":\"...\"}]}. Use only provided field paths. Keep values concise and in input language."
  189. user := "Generate suggestions for each target field using the provided context. Do not include markdown.\n\n" + string(contextJSON)
  190. return system, user
  191. }
  192. func apiKeyForProvider(provider string, settings domain.AppSettings) string {
  193. switch provider {
  194. case domain.LLMProviderOpenAI:
  195. return strings.TrimSpace(settings.OpenAIAPIKeyEncrypted)
  196. case domain.LLMProviderAnthropic:
  197. return strings.TrimSpace(settings.AnthropicAPIKeyEncrypted)
  198. case domain.LLMProviderGoogle:
  199. return strings.TrimSpace(settings.GoogleAPIKeyEncrypted)
  200. case domain.LLMProviderXAI:
  201. return strings.TrimSpace(settings.XAIAPIKeyEncrypted)
  202. case domain.LLMProviderOllama:
  203. return strings.TrimSpace(settings.OllamaAPIKeyEncrypted)
  204. default:
  205. return ""
  206. }
  207. }