From 2a06225a0d8177082232b6acd8b04c417b8ecac5 Mon Sep 17 00:00:00 2001 From: Jan Svabenik Date: Tue, 17 Mar 2026 22:13:21 +0100 Subject: [PATCH] Fix websocket writes, validation, and GPU FFT safety --- cmd/sdrd/main.go | 64 +++++++++++++++++++++++++------------ internal/sdrplay/sdrplay.go | 12 +++++-- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/cmd/sdrd/main.go b/cmd/sdrd/main.go index 1e0d575..68912b8 100644 --- a/cmd/sdrd/main.go +++ b/cmd/sdrd/main.go @@ -38,9 +38,14 @@ type SpectrumFrame struct { Signals []detector.Signal `json:"signals"` } +type client struct { + conn *websocket.Conn + send chan []byte +} + type hub struct { mu sync.Mutex - clients map[*websocket.Conn]struct{} + clients map[*client]struct{} } type gpuStatus struct { @@ -68,33 +73,40 @@ func (g *gpuStatus) snapshot() gpuStatus { } func newHub() *hub { - return &hub{clients: map[*websocket.Conn]struct{}{}} + return &hub{clients: map[*client]struct{}{}} } -func (h *hub) add(c *websocket.Conn) { +func (h *hub) add(c *client) { h.mu.Lock() defer h.mu.Unlock() h.clients[c] = struct{}{} } -func (h *hub) remove(c *websocket.Conn) { +func (h *hub) remove(c *client) { h.mu.Lock() defer h.mu.Unlock() delete(h.clients, c) + close(c.send) } func (h *hub) broadcast(frame SpectrumFrame) { + b, err := json.Marshal(frame) + if err != nil { + log.Printf("marshal frame: %v", err) + return + } + h.mu.Lock() - clients := make([]*websocket.Conn, 0, len(h.clients)) + clients := make([]*client, 0, len(h.clients)) for c := range h.clients { clients = append(clients, c) } h.mu.Unlock() - b, _ := json.Marshal(frame) for _, c := range clients { - _ = c.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) - if err := c.WriteMessage(websocket.TextMessage, b); err != nil { + select { + case c.send <- b: + default: h.remove(c) } } @@ -285,32 +297,44 @@ func main() { return origin == "" || strings.HasPrefix(origin, "http://localhost") || strings.HasPrefix(origin, "http://127.0.0.1") }} http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) + conn, err := upgrader.Upgrade(w, r, nil) if err != nil { return } + c := &client{conn: conn, send: make(chan []byte, 32)} h.add(c) defer func() { h.remove(c) - _ = c.Close() + _ = conn.Close() }() - c.SetReadDeadline(time.Now().Add(60 * time.Second)) - c.SetPongHandler(func(string) error { - c.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) go func() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - for range ticker.C { - _ = c.SetWriteDeadline(time.Now().Add(5 * time.Second)) - if err := c.WriteMessage(websocket.PingMessage, nil); err != nil { - return + ping := time.NewTicker(30 * time.Second) + defer ping.Stop() + for { + select { + case msg, ok := <-c.send: + if !ok { + return + } + _ = conn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) + if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil { + return + } + case <-ping.C: + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } } } }() for { - _, _, err := c.ReadMessage() + _, _, err := conn.ReadMessage() if err != nil { return } diff --git a/internal/sdrplay/sdrplay.go b/internal/sdrplay/sdrplay.go index d787b3e..1bf09dd 100644 --- a/internal/sdrplay/sdrplay.go +++ b/internal/sdrplay/sdrplay.go @@ -83,6 +83,7 @@ type Source struct { mu sync.Mutex dev C.sdrplay_api_DeviceT params *C.sdrplay_api_DeviceParamsT + devSelected bool ch chan []complex64 handle cgo.Handle open bool @@ -127,20 +128,24 @@ func (s *Source) configure(sampleRate int, centerHz float64, gainDb float64, bwK if err := cErr(C.sdrplay_api_LockDeviceApi()); err != nil { return fmt.Errorf("sdrplay_api_LockDeviceApi: %w", err) } - defer func() { _ = cErr(C.sdrplay_api_UnlockDeviceApi()) }() var numDevs C.uint var devices [8]C.sdrplay_api_DeviceT if err := cErr(C.sdrplay_api_GetDevices(&devices[0], &numDevs, C.uint(len(devices)))); err != nil { + _ = cErr(C.sdrplay_api_UnlockDeviceApi()) return fmt.Errorf("sdrplay_api_GetDevices: %w", err) } if numDevs == 0 { + _ = cErr(C.sdrplay_api_UnlockDeviceApi()) return errors.New("no SDRplay devices found") } s.dev = devices[0] if err := cErr(C.sdrplay_api_SelectDevice(&s.dev)); err != nil { + _ = cErr(C.sdrplay_api_UnlockDeviceApi()) return fmt.Errorf("sdrplay_api_SelectDevice: %w", err) } + s.devSelected = true + _ = cErr(C.sdrplay_api_UnlockDeviceApi()) var params *C.sdrplay_api_DeviceParamsT if err := cErr(C.sdrplay_api_GetDeviceParams(s.dev.dev, ¶ms)); err != nil { @@ -338,7 +343,10 @@ func (s *Source) Stop() error { s.mu.Lock() defer s.mu.Unlock() if s.open { - _ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev)) + if s.devSelected { + _ = cErr(C.sdrplay_api_ReleaseDevice(&s.dev)) + s.devSelected = false + } _ = cErr(C.sdrplay_api_Close()) s.open = false }