From 9a47140b89875d56e611025372b216076338b195 Mon Sep 17 00:00:00 2001 From: Jan Svabenik Date: Tue, 17 Mar 2026 21:45:54 +0100 Subject: [PATCH] Harden runtime validation and event/websocket handling --- cmd/sdrd/main.go | 34 +++++++++++++++++++++++++++++----- internal/detector/detector.go | 14 ++++++++++++-- internal/fft/gpufft/gpu.go | 9 ++++++--- internal/mock/source.go | 26 ++++++++++++++++++++------ internal/runtime/runtime.go | 3 +++ internal/sdrplay/sdrplay.go | 13 ++++++++----- 6 files changed, 78 insertions(+), 21 deletions(-) diff --git a/cmd/sdrd/main.go b/cmd/sdrd/main.go index 40fb7f9..1e0d575 100644 --- a/cmd/sdrd/main.go +++ b/cmd/sdrd/main.go @@ -265,6 +265,7 @@ func main() { log.Fatalf("open events: %v", err) } defer eventFile.Close() + eventMu := &sync.Mutex{} det := detector.New(cfg.Detector.ThresholdDb, cfg.SampleRate, cfg.FFTSize, time.Duration(cfg.Detector.MinDurationMs)*time.Millisecond, @@ -277,9 +278,12 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, dspUpdates, gpuState) + go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, eventMu, dspUpdates, gpuState) - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { + origin := r.Header.Get("Origin") + return origin == "" || strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1") +}} http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { c, err := upgrader.Upgrade(w, r, nil) if err != nil { @@ -290,6 +294,21 @@ func main() { h.remove(c) _ = c.Close() }() + c.SetReadDeadline(time.Now().Add(60 * time.Second)) + c.SetPongHandler(func(string) error { + c.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for range ticker.C { + _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + }() for { _, _, err := c.ReadMessage() if err != nil { @@ -414,7 +433,10 @@ func main() { return } } - evs, err := events.ReadRecent(cfg.EventPath, limit, since) + snap := cfgManager.Snapshot() + eventMu.Lock() + evs, err := events.ReadRecent(snap.EventPath, limit, since) + eventMu.Unlock() if err != nil { http.Error(w, "failed to read events", http.StatusInternalServerError) return @@ -440,7 +462,7 @@ func main() { _ = server.Shutdown(ctxTimeout) } -func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, updates <-chan dspUpdate, gpuState *gpuStatus) { +func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, eventMu *sync.Mutex, updates <-chan dspUpdate, gpuState *gpuStatus) { ticker := time.NewTicker(cfg.FrameInterval()) defer ticker.Stop() logTicker := time.NewTicker(5 * time.Second) @@ -543,7 +565,7 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det * gpuState.set(false, err) } useGPU = false - spectrum = fftutil.Spectrum(iq, window) + spectrum = fftutil.SpectrumWithPlan(iq, nil, plan) } else { spectrum = fftutil.SpectrumFromFFT(out) } @@ -552,9 +574,11 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det * } now := time.Now() finished, signals := det.Process(now, spectrum, cfg.CenterHz) + eventMu.Lock() for _, ev := range finished { _ = enc.Encode(ev) } + eventMu.Unlock() h.broadcast(SpectrumFrame{ Timestamp: now.UnixMilli(), CenterHz: cfg.CenterHz, diff --git a/internal/detector/detector.go b/internal/detector/detector.go index 500725f..b0bd2e2 100644 --- a/internal/detector/detector.go +++ b/internal/detector/detector.go @@ -129,12 +129,22 @@ func (d *Detector) matchSignals(now time.Time, signals []Signal) []Event { used := make(map[int64]bool, len(d.active)) for _, s := range signals { var best *activeEvent + var candidates []struct { + ev *activeEvent + dist float64 + } for _, ev := range d.active { if overlapHz(s.CenterHz, s.BWHz, ev.centerHz, ev.bwHz) && math.Abs(s.CenterHz-ev.centerHz) < (s.BWHz+ev.bwHz)/2.0 { - best = ev - break + candidates = append(candidates, struct { + ev *activeEvent + dist float64 + }{ev: ev, dist: math.Abs(s.CenterHz - ev.centerHz)}) } } + if len(candidates) > 0 { + sort.Slice(candidates, func(i, j int) bool { return candidates[i].dist < candidates[j].dist }) + best = candidates[0].ev + } if best == nil { id := d.nextID d.nextID++ diff --git a/internal/fft/gpufft/gpu.go b/internal/fft/gpufft/gpu.go index 5ba0341..00eca2f 100644 --- a/internal/fft/gpufft/gpu.go +++ b/internal/fft/gpufft/gpu.go @@ -80,9 +80,12 @@ func (e *Engine) Exec(in []complex64) ([]complex64, error) { if C.cufftExecC2C(e.plan, e.data, e.data, C.CUFFT_FORWARD) != C.CUFFT_SUCCESS { return nil, errors.New("cufftExecC2C failed") } - if C.cudaMemcpy(unsafe.Pointer(&in[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess { + if C.cudaDeviceSynchronize() != C.cudaSuccess { + return nil, errors.New("cudaDeviceSynchronize failed") + } + out := make([]complex64, e.n) + if C.cudaMemcpy(unsafe.Pointer(&out[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess { return nil, errors.New("cudaMemcpy D2H failed") } - _ = C.cudaDeviceSynchronize() - return in, nil + return out, nil } diff --git a/internal/mock/source.go b/internal/mock/source.go index 377559e..a257205 100644 --- a/internal/mock/source.go +++ b/internal/mock/source.go @@ -16,13 +16,14 @@ type Source struct { phase3 float64 sampleRate float64 noise float64 + rng *rand.Rand } func New(sampleRate int) *Source { - rand.Seed(time.Now().UnixNano()) return &Source{ sampleRate: float64(sampleRate), noise: 0.02, + rng: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -45,14 +46,27 @@ func (s *Source) ReadIQ(n int) ([]complex64, error) { f1 := 50e3 f2 := -120e3 f3 := 300e3 + const twoPi = 2 * math.Pi for i := 0; i < n; i++ { - s.phase += 2 * math.Pi * f1 / s.sampleRate - s.phase2 += 2 * math.Pi * f2 / s.sampleRate - s.phase3 += 2 * math.Pi * f3 / s.sampleRate + s.phase += twoPi * f1 / s.sampleRate + s.phase2 += twoPi * f2 / s.sampleRate + s.phase3 += twoPi * f3 / s.sampleRate + if s.phase > twoPi { + s.phase -= twoPi + } + if s.phase2 > twoPi { + s.phase2 -= twoPi + } + if s.phase2 < 0 { + s.phase2 += twoPi + } + if s.phase3 > twoPi { + s.phase3 -= twoPi + } re := math.Cos(s.phase) + 0.7*math.Cos(s.phase2) + 0.4*math.Cos(s.phase3) im := math.Sin(s.phase) + 0.7*math.Sin(s.phase2) + 0.4*math.Sin(s.phase3) - re += s.noise * rand.NormFloat64() - im += s.noise * rand.NormFloat64() + re += s.noise * s.rng.NormFloat64() + im += s.noise * s.rng.NormFloat64() out[i] = complex(float32(re), float32(im)) } return out, nil diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 85b902c..6c8c7cb 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -56,6 +56,9 @@ func (m *Manager) ApplyConfig(update ConfigUpdate) (config.Config, error) { next := m.cfg if update.CenterHz != nil { + if *update.CenterHz < 1e3 || *update.CenterHz > 2e9 { + return m.cfg, errors.New("center_hz out of range") + } next.CenterHz = *update.CenterHz } if update.SampleRate != nil { diff --git a/internal/sdrplay/sdrplay.go b/internal/sdrplay/sdrplay.go index dcccc2d..d787b3e 100644 --- a/internal/sdrplay/sdrplay.go +++ b/internal/sdrplay/sdrplay.go @@ -113,8 +113,7 @@ func New(sampleRate int, centerHz float64, gainDb float64, bwKHz int) (sdr.Sourc s.resizeBuffer(sampleRate, 0) s.handle = cgo.NewHandle(s) if err := s.configure(sampleRate, centerHz, gainDb, bwKHz); err != nil { - s.handle.Delete() - s.handle = 0 + _ = s.Stop() return nil, err } return s, nil @@ -329,11 +328,15 @@ func max(a, b int) int { func (s *Source) Stop() error { s.mu.Lock() - defer s.mu.Unlock() - if s.params != nil { + params := s.params + s.params = nil + s.mu.Unlock() + if params != nil { _ = cErr(C.sdrplay_api_Uninit(s.dev.dev)) - s.params = nil } + + s.mu.Lock() + defer s.mu.Unlock() if s.open { _ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev)) _ = cErr(C.sdrplay_api_Close())