package aoiprxkit import ( "context" "fmt" "net" "sync" "time" ) type PCMFrame struct { SequenceNumber uint16 Timestamp uint32 SampleRateHz int Channels int Samples []int32 // interleaved ReceivedAt time.Time Source string } type FrameHandler func(frame PCMFrame) type Receiver struct { cfg Config onFrame FrameHandler mu sync.Mutex conn *net.UDPConn cancel context.CancelFunc done chan struct{} doneOnce sync.Once stats statsAtomic } func NewReceiver(cfg Config, onFrame FrameHandler) (*Receiver, error) { if err := cfg.Validate(); err != nil { return nil, err } if onFrame == nil { return nil, fmt.Errorf("onFrame must not be nil") } return &Receiver{ cfg: cfg, onFrame: onFrame, done: make(chan struct{}), }, nil } func (r *Receiver) Start(ctx context.Context) error { r.mu.Lock() defer r.mu.Unlock() if r.conn != nil { return fmt.Errorf("receiver already started") } group := net.ParseIP(r.cfg.MulticastGroup) ifi, err := resolveMulticastInterface(r.cfg.InterfaceName) if err != nil { return err } addr := &net.UDPAddr{IP: group, Port: r.cfg.Port} conn, err := net.ListenMulticastUDP("udp4", ifi, addr) if err != nil { return fmt.Errorf("listen multicast UDP: %w", err) } if r.cfg.ReadBufferBytes > 0 { _ = conn.SetReadBuffer(r.cfg.ReadBufferBytes) } cctx, cancel := context.WithCancel(ctx) r.conn = conn r.cancel = cancel r.done = make(chan struct{}) r.doneOnce = sync.Once{} go r.loop(cctx) return nil } func (r *Receiver) Stop() error { r.mu.Lock() if r.conn == nil { r.mu.Unlock() return nil } conn := r.conn cancel := r.cancel done := r.done r.conn = nil r.cancel = nil r.mu.Unlock() if cancel != nil { cancel() } _ = conn.Close() <-done return nil } func (r *Receiver) Stats() Stats { return r.stats.snapshot() } func (r *Receiver) loop(ctx context.Context) { defer r.doneOnce.Do(func() { close(r.done) }) jb := newJitterBuffer(r.cfg.JitterDepthPackets) buf := make([]byte, 64*1024) for { select { case <-ctx.Done(): return default: } _ = r.conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) n, _, err := r.conn.ReadFromUDP(buf) if err != nil { if ne, ok := err.(net.Error); ok && ne.Timeout() { continue } return } r.stats.packetsReceived.Add(1) if n < 12 { r.stats.packetsShort.Add(1) continue } pkt, err := ParseRTPPacket(buf[:n]) if err != nil { r.stats.packetsShort.Add(1) continue } r.stats.packetsParsed.Add(1) if pkt.PayloadType != r.cfg.PayloadType { r.stats.packetsWrongPT.Add(1) continue } ready, lateDrop, gapLoss, reorder := jb.push(pkt) if lateDrop { r.stats.packetsLateDrop.Add(1) continue } if gapLoss > 0 { r.stats.packetsGapLoss.Add(gapLoss) } if reorder { r.stats.jitterReorders.Add(1) } for _, rp := range ready { samples, err := DecodeL24BE(rp.Payload, r.cfg.Channels) if err != nil { r.stats.decodeErrors.Add(1) continue } frame := PCMFrame{ SequenceNumber: rp.SequenceNumber, Timestamp: rp.Timestamp, SampleRateHz: r.cfg.SampleRateHz, Channels: r.cfg.Channels, Samples: samples, ReceivedAt: time.Now(), Source: fmt.Sprintf("rtp://%s:%d", r.cfg.MulticastGroup, r.cfg.Port), } r.onFrame(frame) r.stats.packetsDelivered.Add(1) r.stats.samplesDelivered.Add(uint64(len(samples))) if r.cfg.Channels > 0 { r.stats.framesDelivered.Add(uint64(len(samples) / r.cfg.Channels)) } r.stats.lastSequence.Store(uint32(rp.SequenceNumber)) r.stats.sequenceValid.Store(1) } } } func resolveMulticastInterface(name string) (*net.Interface, error) { if name == "" { return nil, nil } ifi, err := net.InterfaceByName(name) if err != nil { return nil, fmt.Errorf("resolve interface %q: %w", name, err) } return ifi, nil }