Sfoglia il codice sorgente

Harden runtime validation and event/websocket handling

master
Jan Svabenik 4 giorni fa
parent
commit
9a47140b89
6 ha cambiato i file con 78 aggiunte e 21 eliminazioni
  1. +29
    -5
      cmd/sdrd/main.go
  2. +12
    -2
      internal/detector/detector.go
  3. +6
    -3
      internal/fft/gpufft/gpu.go
  4. +20
    -6
      internal/mock/source.go
  5. +3
    -0
      internal/runtime/runtime.go
  6. +8
    -5
      internal/sdrplay/sdrplay.go

+ 29
- 5
cmd/sdrd/main.go Vedi File

@@ -265,6 +265,7 @@ func main() {
log.Fatalf("open events: %v", err)
}
defer eventFile.Close()
eventMu := &sync.Mutex{}

det := detector.New(cfg.Detector.ThresholdDb, cfg.SampleRate, cfg.FFTSize,
time.Duration(cfg.Detector.MinDurationMs)*time.Millisecond,
@@ -277,9 +278,12 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, dspUpdates, gpuState)
go runDSP(ctx, srcMgr, cfg, det, window, h, eventFile, eventMu, dspUpdates, gpuState)

upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
origin := r.Header.Get("Origin")
return origin == "" || strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1")
}}
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
@@ -290,6 +294,21 @@ func main() {
h.remove(c)
_ = c.Close()
}()
c.SetReadDeadline(time.Now().Add(60 * time.Second))
c.SetPongHandler(func(string) error {
c.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
_ = c.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err := c.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}()
for {
_, _, err := c.ReadMessage()
if err != nil {
@@ -414,7 +433,10 @@ func main() {
return
}
}
evs, err := events.ReadRecent(cfg.EventPath, limit, since)
snap := cfgManager.Snapshot()
eventMu.Lock()
evs, err := events.ReadRecent(snap.EventPath, limit, since)
eventMu.Unlock()
if err != nil {
http.Error(w, "failed to read events", http.StatusInternalServerError)
return
@@ -440,7 +462,7 @@ func main() {
_ = server.Shutdown(ctxTimeout)
}

func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, updates <-chan dspUpdate, gpuState *gpuStatus) {
func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *detector.Detector, window []float64, h *hub, eventFile *os.File, eventMu *sync.Mutex, updates <-chan dspUpdate, gpuState *gpuStatus) {
ticker := time.NewTicker(cfg.FrameInterval())
defer ticker.Stop()
logTicker := time.NewTicker(5 * time.Second)
@@ -543,7 +565,7 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *
gpuState.set(false, err)
}
useGPU = false
spectrum = fftutil.Spectrum(iq, window)
spectrum = fftutil.SpectrumWithPlan(iq, nil, plan)
} else {
spectrum = fftutil.SpectrumFromFFT(out)
}
@@ -552,9 +574,11 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det *
}
now := time.Now()
finished, signals := det.Process(now, spectrum, cfg.CenterHz)
eventMu.Lock()
for _, ev := range finished {
_ = enc.Encode(ev)
}
eventMu.Unlock()
h.broadcast(SpectrumFrame{
Timestamp: now.UnixMilli(),
CenterHz: cfg.CenterHz,


+ 12
- 2
internal/detector/detector.go Vedi File

@@ -129,12 +129,22 @@ func (d *Detector) matchSignals(now time.Time, signals []Signal) []Event {
used := make(map[int64]bool, len(d.active))
for _, s := range signals {
var best *activeEvent
var candidates []struct {
ev *activeEvent
dist float64
}
for _, ev := range d.active {
if overlapHz(s.CenterHz, s.BWHz, ev.centerHz, ev.bwHz) && math.Abs(s.CenterHz-ev.centerHz) < (s.BWHz+ev.bwHz)/2.0 {
best = ev
break
candidates = append(candidates, struct {
ev *activeEvent
dist float64
}{ev: ev, dist: math.Abs(s.CenterHz - ev.centerHz)})
}
}
if len(candidates) > 0 {
sort.Slice(candidates, func(i, j int) bool { return candidates[i].dist < candidates[j].dist })
best = candidates[0].ev
}
if best == nil {
id := d.nextID
d.nextID++


+ 6
- 3
internal/fft/gpufft/gpu.go Vedi File

@@ -80,9 +80,12 @@ func (e *Engine) Exec(in []complex64) ([]complex64, error) {
if C.cufftExecC2C(e.plan, e.data, e.data, C.CUFFT_FORWARD) != C.CUFFT_SUCCESS {
return nil, errors.New("cufftExecC2C failed")
}
if C.cudaMemcpy(unsafe.Pointer(&in[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess {
if C.cudaDeviceSynchronize() != C.cudaSuccess {
return nil, errors.New("cudaDeviceSynchronize failed")
}
out := make([]complex64, e.n)
if C.cudaMemcpy(unsafe.Pointer(&out[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess {
return nil, errors.New("cudaMemcpy D2H failed")
}
_ = C.cudaDeviceSynchronize()
return in, nil
return out, nil
}

+ 20
- 6
internal/mock/source.go Vedi File

@@ -16,13 +16,14 @@ type Source struct {
phase3 float64
sampleRate float64
noise float64
rng *rand.Rand
}

func New(sampleRate int) *Source {
rand.Seed(time.Now().UnixNano())
return &Source{
sampleRate: float64(sampleRate),
noise: 0.02,
rng: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

@@ -45,14 +46,27 @@ func (s *Source) ReadIQ(n int) ([]complex64, error) {
f1 := 50e3
f2 := -120e3
f3 := 300e3
const twoPi = 2 * math.Pi
for i := 0; i < n; i++ {
s.phase += 2 * math.Pi * f1 / s.sampleRate
s.phase2 += 2 * math.Pi * f2 / s.sampleRate
s.phase3 += 2 * math.Pi * f3 / s.sampleRate
s.phase += twoPi * f1 / s.sampleRate
s.phase2 += twoPi * f2 / s.sampleRate
s.phase3 += twoPi * f3 / s.sampleRate
if s.phase > twoPi {
s.phase -= twoPi
}
if s.phase2 > twoPi {
s.phase2 -= twoPi
}
if s.phase2 < 0 {
s.phase2 += twoPi
}
if s.phase3 > twoPi {
s.phase3 -= twoPi
}
re := math.Cos(s.phase) + 0.7*math.Cos(s.phase2) + 0.4*math.Cos(s.phase3)
im := math.Sin(s.phase) + 0.7*math.Sin(s.phase2) + 0.4*math.Sin(s.phase3)
re += s.noise * rand.NormFloat64()
im += s.noise * rand.NormFloat64()
re += s.noise * s.rng.NormFloat64()
im += s.noise * s.rng.NormFloat64()
out[i] = complex(float32(re), float32(im))
}
return out, nil


+ 3
- 0
internal/runtime/runtime.go Vedi File

@@ -56,6 +56,9 @@ func (m *Manager) ApplyConfig(update ConfigUpdate) (config.Config, error) {

next := m.cfg
if update.CenterHz != nil {
if *update.CenterHz < 1e3 || *update.CenterHz > 2e9 {
return m.cfg, errors.New("center_hz out of range")
}
next.CenterHz = *update.CenterHz
}
if update.SampleRate != nil {


+ 8
- 5
internal/sdrplay/sdrplay.go Vedi File

@@ -113,8 +113,7 @@ func New(sampleRate int, centerHz float64, gainDb float64, bwKHz int) (sdr.Sourc
s.resizeBuffer(sampleRate, 0)
s.handle = cgo.NewHandle(s)
if err := s.configure(sampleRate, centerHz, gainDb, bwKHz); err != nil {
s.handle.Delete()
s.handle = 0
_ = s.Stop()
return nil, err
}
return s, nil
@@ -329,11 +328,15 @@ func max(a, b int) int {

func (s *Source) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.params != nil {
params := s.params
s.params = nil
s.mu.Unlock()
if params != nil {
_ = cErr(C.sdrplay_api_Uninit(s.dev.dev))
s.params = nil
}

s.mu.Lock()
defer s.mu.Unlock()
if s.open {
_ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev))
_ = cErr(C.sdrplay_api_Close())


Loading…
Annulla
Salva