Wideband autonomous SDR analysis engine forked from sdr-visual-suite
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

214 lines
7.9KB

  1. package gpudemod
  2. import (
  3. "math"
  4. "testing"
  5. )
  6. type streamingValidationStep struct {
  7. name string
  8. iq []complex64
  9. jobs []StreamingExtractJob
  10. }
  11. type streamingPreparedExecutor func(*BatchRunner, []StreamingGPUInvocation) ([]StreamingGPUExecutionResult, error)
  12. func makeToneNoiseIQ(n int, phaseInc float64) []complex64 {
  13. out := make([]complex64, n)
  14. phase := 0.0
  15. for i := 0; i < n; i++ {
  16. tone := complex(math.Cos(phase), math.Sin(phase))
  17. noiseI := 0.17*math.Cos(0.113*float64(i)+0.31) + 0.07*math.Sin(0.071*float64(i))
  18. noiseQ := 0.13*math.Sin(0.097*float64(i)+0.11) - 0.05*math.Cos(0.043*float64(i))
  19. out[i] = complex64(0.85*tone + 0.15*complex(noiseI, noiseQ))
  20. phase += phaseInc
  21. }
  22. return out
  23. }
  24. func makeStreamingValidationSteps(iq []complex64, chunkSizes []int, jobs []StreamingExtractJob) []streamingValidationStep {
  25. steps := make([]streamingValidationStep, 0, len(chunkSizes)+1)
  26. pos := 0
  27. for idx, n := range chunkSizes {
  28. if n < 0 {
  29. n = 0
  30. }
  31. end := pos + n
  32. if end > len(iq) {
  33. end = len(iq)
  34. }
  35. steps = append(steps, streamingValidationStep{
  36. name: "chunk",
  37. iq: append([]complex64(nil), iq[pos:end]...),
  38. jobs: append([]StreamingExtractJob(nil), jobs...),
  39. })
  40. _ = idx
  41. pos = end
  42. }
  43. if pos < len(iq) {
  44. steps = append(steps, streamingValidationStep{
  45. name: "remainder",
  46. iq: append([]complex64(nil), iq[pos:]...),
  47. jobs: append([]StreamingExtractJob(nil), jobs...),
  48. })
  49. }
  50. return steps
  51. }
  52. func requirePhaseClose(t *testing.T, got float64, want float64, tol float64) {
  53. t.Helper()
  54. diff := got - want
  55. for diff > math.Pi {
  56. diff -= 2 * math.Pi
  57. }
  58. for diff < -math.Pi {
  59. diff += 2 * math.Pi
  60. }
  61. if math.Abs(diff) > tol {
  62. t.Fatalf("phase mismatch: got=%0.12f want=%0.12f diff=%0.12f tol=%0.12f", got, want, diff, tol)
  63. }
  64. }
  65. func requireStreamingExtractResultMatchesOracle(t *testing.T, got StreamingExtractResult, want StreamingExtractResult) {
  66. t.Helper()
  67. if got.SignalID != want.SignalID {
  68. t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID)
  69. }
  70. if got.Rate != want.Rate {
  71. t.Fatalf("rate mismatch for signal %d: got=%d want=%d", got.SignalID, got.Rate, want.Rate)
  72. }
  73. if got.NOut != want.NOut {
  74. t.Fatalf("n_out mismatch for signal %d: got=%d want=%d", got.SignalID, got.NOut, want.NOut)
  75. }
  76. if got.PhaseCount != want.PhaseCount {
  77. t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCount, want.PhaseCount)
  78. }
  79. if got.HistoryLen != want.HistoryLen {
  80. t.Fatalf("history len mismatch for signal %d: got=%d want=%d", got.SignalID, got.HistoryLen, want.HistoryLen)
  81. }
  82. }
  83. func requirePreparedExecutionResultMatchesOracle(t *testing.T, got StreamingGPUExecutionResult, want StreamingExtractResult, oracleState *CPUOracleState, sampleTol float64, phaseTol float64) {
  84. t.Helper()
  85. if oracleState == nil {
  86. t.Fatalf("missing oracle state for signal %d", got.SignalID)
  87. }
  88. if got.SignalID != want.SignalID {
  89. t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID)
  90. }
  91. if got.Rate != want.Rate {
  92. t.Fatalf("rate mismatch for signal %d: got=%d want=%d", got.SignalID, got.Rate, want.Rate)
  93. }
  94. if got.NOut != want.NOut {
  95. t.Fatalf("n_out mismatch for signal %d: got=%d want=%d", got.SignalID, got.NOut, want.NOut)
  96. }
  97. if got.PhaseCountOut != oracleState.PhaseCount {
  98. t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCountOut, oracleState.PhaseCount)
  99. }
  100. requirePhaseClose(t, got.NCOPhaseOut, oracleState.NCOPhase, phaseTol)
  101. if got.HistoryLenOut != len(oracleState.ShiftedHistory) {
  102. t.Fatalf("history len mismatch for signal %d: got=%d want=%d", got.SignalID, got.HistoryLenOut, len(oracleState.ShiftedHistory))
  103. }
  104. requireComplexSlicesClose(t, got.IQ, want.IQ, sampleTol)
  105. requireComplexSlicesClose(t, got.HistoryOut, oracleState.ShiftedHistory, sampleTol)
  106. }
  107. func requireExtractStateMatchesOracle(t *testing.T, got *ExtractStreamState, want *CPUOracleState, phaseTol float64, sampleTol float64) {
  108. t.Helper()
  109. if got == nil || want == nil {
  110. t.Fatalf("state mismatch: got nil=%t want nil=%t", got == nil, want == nil)
  111. }
  112. if got.SignalID != want.SignalID {
  113. t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID)
  114. }
  115. if got.ConfigHash != want.ConfigHash {
  116. t.Fatalf("config hash mismatch for signal %d: got=%d want=%d", got.SignalID, got.ConfigHash, want.ConfigHash)
  117. }
  118. if got.Decim != want.Decim {
  119. t.Fatalf("decim mismatch for signal %d: got=%d want=%d", got.SignalID, got.Decim, want.Decim)
  120. }
  121. if got.NumTaps != want.NumTaps {
  122. t.Fatalf("num taps mismatch for signal %d: got=%d want=%d", got.SignalID, got.NumTaps, want.NumTaps)
  123. }
  124. if got.PhaseCount != want.PhaseCount {
  125. t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCount, want.PhaseCount)
  126. }
  127. requirePhaseClose(t, got.NCOPhase, want.NCOPhase, phaseTol)
  128. requireComplexSlicesClose(t, got.ShiftedHistory, want.ShiftedHistory, sampleTol)
  129. }
  130. func requireStateKeysMatchOracle(t *testing.T, got map[int64]*ExtractStreamState, want map[int64]*CPUOracleState) {
  131. t.Helper()
  132. if len(got) != len(want) {
  133. t.Fatalf("active state count mismatch: got=%d want=%d", len(got), len(want))
  134. }
  135. for signalID := range want {
  136. if got[signalID] == nil {
  137. t.Fatalf("missing active state for signal %d", signalID)
  138. }
  139. }
  140. for signalID := range got {
  141. if want[signalID] == nil {
  142. t.Fatalf("unexpected active state for signal %d", signalID)
  143. }
  144. }
  145. }
  146. func runStreamingExecSequenceAgainstOracle(t *testing.T, runner *BatchRunner, steps []streamingValidationStep, sampleTol float64, phaseTol float64) {
  147. t.Helper()
  148. oracle := NewCPUOracleRunner(runner.eng.sampleRate)
  149. for idx, step := range steps {
  150. got, err := runner.StreamingExtractGPUExec(step.iq, step.jobs)
  151. if err != nil {
  152. t.Fatalf("step %d (%s): exec failed: %v", idx, step.name, err)
  153. }
  154. want, err := oracle.StreamingExtract(step.iq, step.jobs)
  155. if err != nil {
  156. t.Fatalf("step %d (%s): oracle failed: %v", idx, step.name, err)
  157. }
  158. if len(got) != len(want) {
  159. t.Fatalf("step %d (%s): result count mismatch: got=%d want=%d", idx, step.name, len(got), len(want))
  160. }
  161. for i, job := range step.jobs {
  162. requireStreamingExtractResultMatchesOracle(t, got[i], want[i])
  163. requireComplexSlicesClose(t, got[i].IQ, want[i].IQ, sampleTol)
  164. requireExtractStateMatchesOracle(t, runner.streamState[job.SignalID], oracle.States[job.SignalID], phaseTol, sampleTol)
  165. }
  166. requireStateKeysMatchOracle(t, runner.streamState, oracle.States)
  167. }
  168. }
  169. func runPreparedSequenceAgainstOracle(t *testing.T, runner *BatchRunner, exec streamingPreparedExecutor, steps []streamingValidationStep, sampleTol float64, phaseTol float64) {
  170. t.Helper()
  171. oracle := NewCPUOracleRunner(runner.eng.sampleRate)
  172. for idx, step := range steps {
  173. invocations, err := runner.buildStreamingGPUInvocations(step.iq, step.jobs)
  174. if err != nil {
  175. t.Fatalf("step %d (%s): build invocations failed: %v", idx, step.name, err)
  176. }
  177. got, err := exec(runner, invocations)
  178. if err != nil {
  179. t.Fatalf("step %d (%s): prepared exec failed: %v", idx, step.name, err)
  180. }
  181. want, err := oracle.StreamingExtract(step.iq, step.jobs)
  182. if err != nil {
  183. t.Fatalf("step %d (%s): oracle failed: %v", idx, step.name, err)
  184. }
  185. if len(got) != len(want) {
  186. t.Fatalf("step %d (%s): result count mismatch: got=%d want=%d", idx, step.name, len(got), len(want))
  187. }
  188. applied := runner.applyStreamingGPUExecutionResults(got)
  189. if len(applied) != len(want) {
  190. t.Fatalf("step %d (%s): applied result count mismatch: got=%d want=%d", idx, step.name, len(applied), len(want))
  191. }
  192. for i, job := range step.jobs {
  193. oracleState := oracle.States[job.SignalID]
  194. requirePreparedExecutionResultMatchesOracle(t, got[i], want[i], oracleState, sampleTol, phaseTol)
  195. requireStreamingExtractResultMatchesOracle(t, applied[i], want[i])
  196. requireComplexSlicesClose(t, applied[i].IQ, want[i].IQ, sampleTol)
  197. requireExtractStateMatchesOracle(t, runner.streamState[job.SignalID], oracleState, phaseTol, sampleTol)
  198. }
  199. requireStateKeysMatchOracle(t, runner.streamState, oracle.States)
  200. }
  201. }