You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
5.9KB

  1. package llmruntime
  2. import (
  3. "context"
  4. "encoding/json"
  5. "net/http"
  6. "net/http/httptest"
  7. "strings"
  8. "testing"
  9. "time"
  10. )
  11. func TestOpenAICompatibleClient_ForwardsTemperatureAndMaxTokens(t *testing.T) {
  12. t.Parallel()
  13. var got map[string]any
  14. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  15. _ = json.NewDecoder(r.Body).Decode(&got)
  16. _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`))
  17. }))
  18. defer server.Close()
  19. factory := NewFactory(2 * time.Second)
  20. client, err := factory.ClientFor("openai")
  21. if err != nil {
  22. t.Fatalf("client creation failed: %v", err)
  23. }
  24. temperature := 0.77
  25. maxTokens := 777
  26. _, err = client.Generate(context.Background(), Request{
  27. Provider: "openai",
  28. BaseURL: server.URL,
  29. Model: "gpt-5.4",
  30. APIKey: "key",
  31. Temperature: &temperature,
  32. MaxTokens: &maxTokens,
  33. SystemPrompt: "system",
  34. UserPrompt: "user",
  35. })
  36. if err != nil {
  37. t.Fatalf("generate failed: %v", err)
  38. }
  39. gotTemperature, _ := got["temperature"].(float64)
  40. if gotTemperature != 0.77 {
  41. t.Fatalf("unexpected temperature: %v", gotTemperature)
  42. }
  43. if _, exists := got["max_tokens"]; exists {
  44. t.Fatalf("did not expect max_tokens for openai gpt-5 models")
  45. }
  46. gotMaxCompletionTokens, _ := got["max_completion_tokens"].(float64)
  47. if gotMaxCompletionTokens != 777 {
  48. t.Fatalf("unexpected max_completion_tokens: %v", gotMaxCompletionTokens)
  49. }
  50. }
  51. func TestOpenAICompatibleClient_UsesMaxTokensForOlderOpenAIModels(t *testing.T) {
  52. t.Parallel()
  53. var got map[string]any
  54. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  55. _ = json.NewDecoder(r.Body).Decode(&got)
  56. _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`))
  57. }))
  58. defer server.Close()
  59. factory := NewFactory(2 * time.Second)
  60. client, err := factory.ClientFor("openai")
  61. if err != nil {
  62. t.Fatalf("client creation failed: %v", err)
  63. }
  64. maxTokens := 512
  65. _, err = client.Generate(context.Background(), Request{
  66. Provider: "openai",
  67. BaseURL: server.URL,
  68. Model: "gpt-4.1",
  69. APIKey: "key",
  70. MaxTokens: &maxTokens,
  71. SystemPrompt: "system",
  72. UserPrompt: "user",
  73. })
  74. if err != nil {
  75. t.Fatalf("generate failed: %v", err)
  76. }
  77. if _, exists := got["max_completion_tokens"]; exists {
  78. t.Fatalf("did not expect max_completion_tokens for non-gpt-5 model")
  79. }
  80. gotMaxTokens, _ := got["max_tokens"].(float64)
  81. if gotMaxTokens != 512 {
  82. t.Fatalf("unexpected max_tokens: %v", gotMaxTokens)
  83. }
  84. }
  85. func TestOpenAICompatibleClient_ExtractsMessageContentParts(t *testing.T) {
  86. t.Parallel()
  87. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  88. _, _ = w.Write([]byte(`{"choices":[{"message":{"content":[{"type":"text","text":"{\"suggestions\":["},{"type":"output_text","text":"{\"fieldPath\":\"hero.title\",\"value\":\"Hello\"}"}]}}]}`))
  89. }))
  90. defer server.Close()
  91. factory := NewFactory(2 * time.Second)
  92. client, err := factory.ClientFor("openai")
  93. if err != nil {
  94. t.Fatalf("client creation failed: %v", err)
  95. }
  96. got, err := client.Generate(context.Background(), Request{
  97. Provider: "openai",
  98. BaseURL: server.URL,
  99. Model: "gpt-5.4-mini",
  100. APIKey: "key",
  101. SystemPrompt: "system",
  102. UserPrompt: "user",
  103. })
  104. if err != nil {
  105. t.Fatalf("generate failed: %v", err)
  106. }
  107. if !strings.Contains(got, "hero.title") {
  108. t.Fatalf("unexpected extracted content: %q", got)
  109. }
  110. }
  111. func TestOpenAICompatibleClient_ExtractsResponsesOutputShape(t *testing.T) {
  112. t.Parallel()
  113. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  114. _, _ = w.Write([]byte(`{"id":"resp_123","object":"response","output":[{"type":"message","content":[{"type":"output_text","text":"{\"suggestions\":[{\"fieldPath\":\"hero.subtitle\",\"value\":\"World\"}]}"}]}]}`))
  115. }))
  116. defer server.Close()
  117. factory := NewFactory(2 * time.Second)
  118. client, err := factory.ClientFor("openai")
  119. if err != nil {
  120. t.Fatalf("client creation failed: %v", err)
  121. }
  122. got, err := client.Generate(context.Background(), Request{
  123. Provider: "openai",
  124. BaseURL: server.URL,
  125. Model: "gpt-5.4-mini",
  126. APIKey: "key",
  127. SystemPrompt: "system",
  128. UserPrompt: "user",
  129. })
  130. if err != nil {
  131. t.Fatalf("generate failed: %v", err)
  132. }
  133. if !strings.Contains(got, "hero.subtitle") {
  134. t.Fatalf("unexpected extracted content: %q", got)
  135. }
  136. }
  137. func TestOpenAICompatibleClient_EmptyContentIncludesShapeDiagnostics(t *testing.T) {
  138. t.Parallel()
  139. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  140. _, _ = w.Write([]byte(`{"id":"chatcmpl_x","choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":[]}}]}`))
  141. }))
  142. defer server.Close()
  143. factory := NewFactory(2 * time.Second)
  144. client, err := factory.ClientFor("openai")
  145. if err != nil {
  146. t.Fatalf("client creation failed: %v", err)
  147. }
  148. _, err = client.Generate(context.Background(), Request{
  149. Provider: "openai",
  150. BaseURL: server.URL,
  151. Model: "gpt-5.4-mini",
  152. APIKey: "key",
  153. SystemPrompt: "system",
  154. UserPrompt: "user",
  155. })
  156. if err == nil {
  157. t.Fatalf("expected generate error")
  158. }
  159. if !strings.Contains(err.Error(), "empty openai-compatible response content") {
  160. t.Fatalf("unexpected error: %v", err)
  161. }
  162. if !strings.Contains(err.Error(), "message_content_type=array") {
  163. t.Fatalf("expected shape diagnostics in error: %v", err)
  164. }
  165. if !strings.Contains(err.Error(), "message_content_len=0") {
  166. t.Fatalf("expected message content length diagnostics in error: %v", err)
  167. }
  168. if !strings.Contains(err.Error(), "choices0_finish_reason=stop") {
  169. t.Fatalf("expected finish reason diagnostics in error: %v", err)
  170. }
  171. }
  172. func TestExtractProviderErrorMessage(t *testing.T) {
  173. t.Parallel()
  174. msg := extractProviderErrorMessage([]byte(`{"error":{"message":"invalid key"}}`))
  175. if !strings.Contains(msg, "invalid key") {
  176. t.Fatalf("unexpected message: %q", msg)
  177. }
  178. }