package gpudemod import ( "math" "testing" ) type streamingValidationStep struct { name string iq []complex64 jobs []StreamingExtractJob } type streamingPreparedExecutor func(*BatchRunner, []StreamingGPUInvocation) ([]StreamingGPUExecutionResult, error) func makeToneNoiseIQ(n int, phaseInc float64) []complex64 { out := make([]complex64, n) phase := 0.0 for i := 0; i < n; i++ { tone := complex(math.Cos(phase), math.Sin(phase)) noiseI := 0.17*math.Cos(0.113*float64(i)+0.31) + 0.07*math.Sin(0.071*float64(i)) noiseQ := 0.13*math.Sin(0.097*float64(i)+0.11) - 0.05*math.Cos(0.043*float64(i)) out[i] = complex64(0.85*tone + 0.15*complex(noiseI, noiseQ)) phase += phaseInc } return out } func makeStreamingValidationSteps(iq []complex64, chunkSizes []int, jobs []StreamingExtractJob) []streamingValidationStep { steps := make([]streamingValidationStep, 0, len(chunkSizes)+1) pos := 0 for idx, n := range chunkSizes { if n < 0 { n = 0 } end := pos + n if end > len(iq) { end = len(iq) } steps = append(steps, streamingValidationStep{ name: "chunk", iq: append([]complex64(nil), iq[pos:end]...), jobs: append([]StreamingExtractJob(nil), jobs...), }) _ = idx pos = end } if pos < len(iq) { steps = append(steps, streamingValidationStep{ name: "remainder", iq: append([]complex64(nil), iq[pos:]...), jobs: append([]StreamingExtractJob(nil), jobs...), }) } return steps } func requirePhaseClose(t *testing.T, got float64, want float64, tol float64) { t.Helper() diff := got - want for diff > math.Pi { diff -= 2 * math.Pi } for diff < -math.Pi { diff += 2 * math.Pi } if math.Abs(diff) > tol { t.Fatalf("phase mismatch: got=%0.12f want=%0.12f diff=%0.12f tol=%0.12f", got, want, diff, tol) } } func requireStreamingExtractResultMatchesOracle(t *testing.T, got StreamingExtractResult, want StreamingExtractResult) { t.Helper() if got.SignalID != want.SignalID { t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID) } if got.Rate != want.Rate { t.Fatalf("rate mismatch for signal %d: got=%d want=%d", got.SignalID, got.Rate, want.Rate) } if got.NOut != want.NOut { t.Fatalf("n_out mismatch for signal %d: got=%d want=%d", got.SignalID, got.NOut, want.NOut) } if got.PhaseCount != want.PhaseCount { t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCount, want.PhaseCount) } if got.HistoryLen != want.HistoryLen { t.Fatalf("history len mismatch for signal %d: got=%d want=%d", got.SignalID, got.HistoryLen, want.HistoryLen) } } func requirePreparedExecutionResultMatchesOracle(t *testing.T, got StreamingGPUExecutionResult, want StreamingExtractResult, oracleState *CPUOracleState, sampleTol float64, phaseTol float64) { t.Helper() if oracleState == nil { t.Fatalf("missing oracle state for signal %d", got.SignalID) } if got.SignalID != want.SignalID { t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID) } if got.Rate != want.Rate { t.Fatalf("rate mismatch for signal %d: got=%d want=%d", got.SignalID, got.Rate, want.Rate) } if got.NOut != want.NOut { t.Fatalf("n_out mismatch for signal %d: got=%d want=%d", got.SignalID, got.NOut, want.NOut) } if got.PhaseCountOut != oracleState.PhaseCount { t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCountOut, oracleState.PhaseCount) } requirePhaseClose(t, got.NCOPhaseOut, oracleState.NCOPhase, phaseTol) if got.HistoryLenOut != len(oracleState.ShiftedHistory) { t.Fatalf("history len mismatch for signal %d: got=%d want=%d", got.SignalID, got.HistoryLenOut, len(oracleState.ShiftedHistory)) } requireComplexSlicesClose(t, got.IQ, want.IQ, sampleTol) requireComplexSlicesClose(t, got.HistoryOut, oracleState.ShiftedHistory, sampleTol) } func requireExtractStateMatchesOracle(t *testing.T, got *ExtractStreamState, want *CPUOracleState, phaseTol float64, sampleTol float64) { t.Helper() if got == nil || want == nil { t.Fatalf("state mismatch: got nil=%t want nil=%t", got == nil, want == nil) } if got.SignalID != want.SignalID { t.Fatalf("signal id mismatch: got=%d want=%d", got.SignalID, want.SignalID) } if got.ConfigHash != want.ConfigHash { t.Fatalf("config hash mismatch for signal %d: got=%d want=%d", got.SignalID, got.ConfigHash, want.ConfigHash) } if got.Decim != want.Decim { t.Fatalf("decim mismatch for signal %d: got=%d want=%d", got.SignalID, got.Decim, want.Decim) } if got.NumTaps != want.NumTaps { t.Fatalf("num taps mismatch for signal %d: got=%d want=%d", got.SignalID, got.NumTaps, want.NumTaps) } if got.PhaseCount != want.PhaseCount { t.Fatalf("phase count mismatch for signal %d: got=%d want=%d", got.SignalID, got.PhaseCount, want.PhaseCount) } requirePhaseClose(t, got.NCOPhase, want.NCOPhase, phaseTol) requireComplexSlicesClose(t, got.ShiftedHistory, want.ShiftedHistory, sampleTol) } func requireStateKeysMatchOracle(t *testing.T, got map[int64]*ExtractStreamState, want map[int64]*CPUOracleState) { t.Helper() if len(got) != len(want) { t.Fatalf("active state count mismatch: got=%d want=%d", len(got), len(want)) } for signalID := range want { if got[signalID] == nil { t.Fatalf("missing active state for signal %d", signalID) } } for signalID := range got { if want[signalID] == nil { t.Fatalf("unexpected active state for signal %d", signalID) } } } func runStreamingExecSequenceAgainstOracle(t *testing.T, runner *BatchRunner, steps []streamingValidationStep, sampleTol float64, phaseTol float64) { t.Helper() oracle := NewCPUOracleRunner(runner.eng.sampleRate) for idx, step := range steps { got, err := runner.StreamingExtractGPUExec(step.iq, step.jobs) if err != nil { t.Fatalf("step %d (%s): exec failed: %v", idx, step.name, err) } want, err := oracle.StreamingExtract(step.iq, step.jobs) if err != nil { t.Fatalf("step %d (%s): oracle failed: %v", idx, step.name, err) } if len(got) != len(want) { t.Fatalf("step %d (%s): result count mismatch: got=%d want=%d", idx, step.name, len(got), len(want)) } for i, job := range step.jobs { requireStreamingExtractResultMatchesOracle(t, got[i], want[i]) requireComplexSlicesClose(t, got[i].IQ, want[i].IQ, sampleTol) requireExtractStateMatchesOracle(t, runner.streamState[job.SignalID], oracle.States[job.SignalID], phaseTol, sampleTol) } requireStateKeysMatchOracle(t, runner.streamState, oracle.States) } } func runPreparedSequenceAgainstOracle(t *testing.T, runner *BatchRunner, exec streamingPreparedExecutor, steps []streamingValidationStep, sampleTol float64, phaseTol float64) { t.Helper() oracle := NewCPUOracleRunner(runner.eng.sampleRate) for idx, step := range steps { invocations, err := runner.buildStreamingGPUInvocations(step.iq, step.jobs) if err != nil { t.Fatalf("step %d (%s): build invocations failed: %v", idx, step.name, err) } got, err := exec(runner, invocations) if err != nil { t.Fatalf("step %d (%s): prepared exec failed: %v", idx, step.name, err) } want, err := oracle.StreamingExtract(step.iq, step.jobs) if err != nil { t.Fatalf("step %d (%s): oracle failed: %v", idx, step.name, err) } if len(got) != len(want) { t.Fatalf("step %d (%s): result count mismatch: got=%d want=%d", idx, step.name, len(got), len(want)) } applied := runner.applyStreamingGPUExecutionResults(got) if len(applied) != len(want) { t.Fatalf("step %d (%s): applied result count mismatch: got=%d want=%d", idx, step.name, len(applied), len(want)) } for i, job := range step.jobs { oracleState := oracle.States[job.SignalID] requirePreparedExecutionResultMatchesOracle(t, got[i], want[i], oracleState, sampleTol, phaseTol) requireStreamingExtractResultMatchesOracle(t, applied[i], want[i]) requireComplexSlicesClose(t, applied[i].IQ, want[i].IQ, sampleTol) requireExtractStateMatchesOracle(t, runner.streamState[job.SignalID], oracleState, phaseTol, sampleTol) } requireStateKeysMatchOracle(t, runner.streamState, oracle.States) } }