Browse Source

Harden runtime validation and event/websocket handling

master
Jan Svabenik 4 days ago
parent
commit
9a47140b89
6 changed files with 78 additions and 21 deletions
  1. +29
    -5
      cmd/sdrd/main.go
  2. +12
    -2
      internal/detector/detector.go
  3. +6
    -3
      internal/fft/gpufft/gpu.go
  4. +20
    -6
      internal/mock/source.go
  5. +3
    -0
      internal/runtime/runtime.go
  6. +8
    -5
      internal/sdrplay/sdrplay.go

+ 29
- 5
cmd/sdrd/main.go View File

@@ -265,6 +265,7 @@ func main() {
log.Fatalf("open events: %v", err) log.Fatalf("open events: %v", err)
} }
defer eventFile.Close() defer eventFile.Close()
eventMu := &sync.Mutex{}


det := detector.New(cfg.Detector.ThresholdDb, cfg.SampleRate, cfg.FFTSize, det := detector.New(cfg.Detector.ThresholdDb, cfg.SampleRate, cfg.FFTSize,
time.Duration(cfg.Detector.MinDurationMs)*time.Millisecond, time.Duration(cfg.Detector.MinDurationMs)*time.Millisecond,
@@ -277,9 +278,12 @@ func main() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()


go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, dspUpdates, gpuState)
go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, eventMu, dspUpdates, gpuState)


upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return origin == "" || strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1")
}}
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@@ -290,6 +294,21 @@ func main() {
h.remove(c) h.remove(c)
_ = c.Close() _ = c.Close()
}() }()
c.SetReadDeadline(time.Now().Add(60 * time.Second))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := c.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}()
for { for {
_, _, err := c.ReadMessage() _, _, err := c.ReadMessage()
if err != nil { if err != nil {
@@ -414,7 +433,10 @@ func main() {
return return
} }
} }
evs, err := events.ReadRecent(cfg.EventPath, limit, since)
snap := cfgManager.Snapshot()
eventMu.Lock()
evs, err := events.ReadRecent(snap.EventPath, limit, since)
eventMu.Unlock()
if err != nil { if err != nil {
http.Error(w, "failed to read events", http.StatusInternalServerError) http.Error(w, "failed to read events", http.StatusInternalServerError)
return return
@@ -440,7 +462,7 @@ func main() {
_ = server.Shutdown(ctxTimeout) _ = server.Shutdown(ctxTimeout)
} }


func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, updates <-chan dspUpdate, gpuState *gpuStatus) {
func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, eventMu *sync.Mutex, updates <-chan dspUpdate, gpuState *gpuStatus) {
ticker := time.NewTicker(cfg.FrameInterval()) ticker := time.NewTicker(cfg.FrameInterval())
defer ticker.Stop() defer ticker.Stop()
logTicker := time.NewTicker(5 * time.Second) logTicker := time.NewTicker(5 * time.Second)
@@ -543,7 +565,7 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *
gpuState.set(false, err) gpuState.set(false, err)
} }
useGPU = false useGPU = false
spectrum = fftutil.Spectrum(iq, window)
spectrum = fftutil.SpectrumWithPlan(iq, nil, plan)
} else { } else {
spectrum = fftutil.SpectrumFromFFT(out) spectrum = fftutil.SpectrumFromFFT(out)
} }
@@ -552,9 +574,11 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *
} }
now := time.Now() now := time.Now()
finished, signals := det.Process(now, spectrum, cfg.CenterHz) finished, signals := det.Process(now, spectrum, cfg.CenterHz)
eventMu.Lock()
for _, ev := range finished { for _, ev := range finished {
_ = enc.Encode(ev) _ = enc.Encode(ev)
} }
eventMu.Unlock()
h.broadcast(SpectrumFrame{ h.broadcast(SpectrumFrame{
Timestamp: now.UnixMilli(), Timestamp: now.UnixMilli(),
CenterHz: cfg.CenterHz, CenterHz: cfg.CenterHz,


+ 12
- 2
internal/detector/detector.go View File

@@ -129,12 +129,22 @@ func (d *Detector) matchSignals(now time.Time, signals []Signal) []Event {
used := make(map[int64]bool, len(d.active)) used := make(map[int64]bool, len(d.active))
for _, s := range signals { for _, s := range signals {
var best *activeEvent var best *activeEvent
var candidates []struct {
ev *activeEvent
dist float64
}
for _, ev := range d.active { for _, ev := range d.active {
if overlapHz(s.CenterHz, s.BWHz, ev.centerHz, ev.bwHz) && math.Abs(s.CenterHz-ev.centerHz) < (s.BWHz+ev.bwHz)/2.0 { if overlapHz(s.CenterHz, s.BWHz, ev.centerHz, ev.bwHz) && math.Abs(s.CenterHz-ev.centerHz) < (s.BWHz+ev.bwHz)/2.0 {
best = ev
break
candidates = append(candidates, struct {
ev *activeEvent
dist float64
}{ev: ev, dist: math.Abs(s.CenterHz - ev.centerHz)})
} }
} }
if len(candidates) > 0 {
sort.Slice(candidates, func(i, j int) bool { return candidates[i].dist < candidates[j].dist })
best = candidates[0].ev
}
if best == nil { if best == nil {
id := d.nextID id := d.nextID
d.nextID++ d.nextID++


+ 6
- 3
internal/fft/gpufft/gpu.go View File

@@ -80,9 +80,12 @@ func (e *Engine) Exec(in []complex64) ([]complex64, error) {
if C.cufftExecC2C(e.plan, e.data, e.data, C.CUFFT_FORWARD) != C.CUFFT_SUCCESS { if C.cufftExecC2C(e.plan, e.data, e.data, C.CUFFT_FORWARD) != C.CUFFT_SUCCESS {
return nil, errors.New("cufftExecC2C failed") return nil, errors.New("cufftExecC2C failed")
} }
if C.cudaMemcpy(unsafe.Pointer(&in[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess {
if C.cudaDeviceSynchronize() != C.cudaSuccess {
return nil, errors.New("cudaDeviceSynchronize failed")
}
out := make([]complex64, e.n)
if C.cudaMemcpy(unsafe.Pointer(&out[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess {
return nil, errors.New("cudaMemcpy D2H failed") return nil, errors.New("cudaMemcpy D2H failed")
} }
_ = C.cudaDeviceSynchronize()
return in, nil
return out, nil
} }

+ 20
- 6
internal/mock/source.go View File

@@ -16,13 +16,14 @@ type Source struct {
phase3 float64 phase3 float64
sampleRate float64 sampleRate float64
noise float64 noise float64
rng *rand.Rand
} }


func New(sampleRate int) *Source { func New(sampleRate int) *Source {
rand.Seed(time.Now().UnixNano())
return &Source{ return &Source{
sampleRate: float64(sampleRate), sampleRate: float64(sampleRate),
noise: 0.02, noise: 0.02,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
} }
} }


@@ -45,14 +46,27 @@ func (s *Source) ReadIQ(n int) ([]complex64, error) {
f1 := 50e3 f1 := 50e3
f2 := -120e3 f2 := -120e3
f3 := 300e3 f3 := 300e3
const twoPi = 2 * math.Pi
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
s.phase += 2 * math.Pi * f1 / s.sampleRate
s.phase2 += 2 * math.Pi * f2 / s.sampleRate
s.phase3 += 2 * math.Pi * f3 / s.sampleRate
s.phase += twoPi * f1 / s.sampleRate
s.phase2 += twoPi * f2 / s.sampleRate
s.phase3 += twoPi * f3 / s.sampleRate
if s.phase > twoPi {
s.phase -= twoPi
}
if s.phase2 > twoPi {
s.phase2 -= twoPi
}
if s.phase2 < 0 {
s.phase2 += twoPi
}
if s.phase3 > twoPi {
s.phase3 -= twoPi
}
re := math.Cos(s.phase) + 0.7*math.Cos(s.phase2) + 0.4*math.Cos(s.phase3) re := math.Cos(s.phase) + 0.7*math.Cos(s.phase2) + 0.4*math.Cos(s.phase3)
im := math.Sin(s.phase) + 0.7*math.Sin(s.phase2) + 0.4*math.Sin(s.phase3) im := math.Sin(s.phase) + 0.7*math.Sin(s.phase2) + 0.4*math.Sin(s.phase3)
re += s.noise * rand.NormFloat64()
im += s.noise * rand.NormFloat64()
re += s.noise * s.rng.NormFloat64()
im += s.noise * s.rng.NormFloat64()
out[i] = complex(float32(re), float32(im)) out[i] = complex(float32(re), float32(im))
} }
return out, nil return out, nil


+ 3
- 0
internal/runtime/runtime.go View File

@@ -56,6 +56,9 @@ func (m *Manager) ApplyConfig(update ConfigUpdate) (config.Config, error) {


next := m.cfg next := m.cfg
if update.CenterHz != nil { if update.CenterHz != nil {
if *update.CenterHz < 1e3 || *update.CenterHz > 2e9 {
return m.cfg, errors.New("center_hz out of range")
}
next.CenterHz = *update.CenterHz next.CenterHz = *update.CenterHz
} }
if update.SampleRate != nil { if update.SampleRate != nil {


+ 8
- 5
internal/sdrplay/sdrplay.go View File

@@ -113,8 +113,7 @@ func New(sampleRate int, centerHz float64, gainDb float64, bwKHz int) (sdr.Sourc
s.resizeBuffer(sampleRate, 0) s.resizeBuffer(sampleRate, 0)
s.handle = cgo.NewHandle(s) s.handle = cgo.NewHandle(s)
if err := s.configure(sampleRate, centerHz, gainDb, bwKHz); err != nil { if err := s.configure(sampleRate, centerHz, gainDb, bwKHz); err != nil {
s.handle.Delete()
s.handle = 0
_ = s.Stop()
return nil, err return nil, err
} }
return s, nil return s, nil
@@ -329,11 +328,15 @@ func max(a, b int) int {


func (s *Source) Stop() error { func (s *Source) Stop() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock()
if s.params != nil {
params := s.params
s.params = nil
s.mu.Unlock()
if params != nil {
_ = cErr(C.sdrplay_api_Uninit(s.dev.dev)) _ = cErr(C.sdrplay_api_Uninit(s.dev.dev))
s.params = nil
} }

s.mu.Lock()
defer s.mu.Unlock()
if s.open { if s.open {
_ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev)) _ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev))
_ = cErr(C.sdrplay_api_Close()) _ = cErr(C.sdrplay_api_Close())


Loading…
Cancel
Save