|
- // cmd/wmdecode — STFT-domain spread-spectrum watermark decoder.
- //
- // Decodes watermark from FM broadcast recordings following
- // Kirovski & Malvar (IEEE TSP 2003) architecture.
- //
- // Usage:
- //
- // wmdecode <file.wav> [key ...]
- package main
-
- import (
- "encoding/binary"
- "fmt"
- "math"
- "math/cmplx"
- "os"
- "sort"
- "time"
-
- "github.com/jan/fm-rds-tx/internal/dsp"
- "github.com/jan/fm-rds-tx/internal/watermark"
- )
-
- func main() {
- if len(os.Args) < 2 {
- fmt.Fprintln(os.Stderr, "usage: wmdecode <file.wav> [key ...]")
- os.Exit(1)
- }
-
- t0 := time.Now()
-
- samples, recRate, err := readMonoWAV(os.Args[1])
- if err != nil {
- fmt.Fprintf(os.Stderr, "read WAV: %v\n", err)
- os.Exit(1)
- }
- rms := rmsLevel(samples)
- fmt.Printf("WAV: %d samples @ %.0f Hz = %.2fs, RMS %.1f dBFS\n",
- len(samples), recRate, float64(len(samples))/recRate, 20*math.Log10(rms+1e-9))
-
- // Step 1: Decimate to WMRate (12 kHz)
- wmRate := float64(watermark.WMRate)
- decimFactor := int(recRate / wmRate)
- if decimFactor < 1 {
- decimFactor = 1
- }
- actualRate := recRate / float64(decimFactor)
- fmt.Printf("Downsample: %d:1 (%.0f Hz → %.0f Hz)\n", decimFactor, recRate, actualRate)
-
- lpfCoeffs := designLPF8(5500, recRate)
- filtered := applyIIR(samples, lpfCoeffs)
-
- nDown := len(filtered) / decimFactor
- down := make([]float64, nDown)
- for i := 0; i < nDown; i++ {
- down[i] = filtered[i*decimFactor]
- }
- fmt.Printf("Downsampled: %d samples, %.1fs\n", nDown, float64(nDown)/wmRate)
-
- // Step 2: Compute ALL STFT frames with cepstrum filtering
- fftSize := watermark.FFTSize
- hop := watermark.FFTHop
- nFrames := (nDown - fftSize) / hop
- if nFrames <= 0 {
- fmt.Fprintln(os.Stderr, "Recording too short")
- os.Exit(1)
- }
-
- var window [watermark.FFTSize]float64
- dsp.HannWindow(window[:])
- fmt.Printf("STFT: %d frames (%d-point, hop=%d)\n", nFrames, fftSize, hop)
-
- type stftMag [watermark.FFTSize / 2]float64
- frameMags := make([]stftMag, nFrames)
- for f := 0; f < nFrames; f++ {
- offset := f * hop
- var buf [watermark.FFTSize]complex128
- for i := 0; i < fftSize; i++ {
- buf[i] = complex(down[offset+i]*window[i], 0)
- }
- dsp.FFT(buf[:])
- for bin := 0; bin < fftSize/2; bin++ {
- mag := cmplx.Abs(buf[bin])
- if mag < 1e-12 {
- mag = 1e-12
- }
- frameMags[f][bin] = 20 * math.Log10(mag)
- }
- cepstrumFilter(frameMags[f][:], 8)
- }
-
- // Step 3: For each key, search cycle offset + rep offset
- keys := os.Args[2:]
- if len(keys) == 0 {
- fmt.Println("No keys supplied.")
- os.Exit(1)
- }
-
- for _, key := range keys {
- fmt.Printf("\nKey: %q\n", key)
- det := watermark.NewSTFTDetector(key)
-
- totalGroups := watermark.TotalGroups
- timeRep := watermark.TimeRep
- framesPerWM := watermark.FramesPerWM
- numBins := watermark.NumBins
- binLow := watermark.BinLow
- centerRep := timeRep / 2
-
- bestMetric := -1.0
- var bestCorrs [watermark.PayloadBits]float64
- bestCycleOff := 0
- bestRepOff := 0
-
- nCandidates := 0
- for cycleOff := 0; cycleOff < framesPerWM; cycleOff += timeRep {
- for repOff := 0; repOff < timeRep; repOff++ {
- var testCorrs [watermark.PayloadBits]float64
-
- for f := 0; f < nFrames; f++ {
- wmFrame := ((f - cycleOff - repOff) % framesPerWM + framesPerWM) % framesPerWM
- if wmFrame%timeRep != centerRep {
- continue
- }
- g := wmFrame / timeRep
- if g >= totalGroups {
- continue
- }
-
- var corr float64
- for b := 0; b < numBins; b++ {
- corr += frameMags[f][binLow+b] * float64(det.PNChipAt(g, b))
- }
- testCorrs[det.GroupBit(g)] += corr
- }
-
- var metric float64
- for _, c := range testCorrs {
- metric += c * c
- }
-
- if metric > bestMetric {
- bestMetric = metric
- bestCorrs = testCorrs
- bestCycleOff = cycleOff
- bestRepOff = repOff
- }
- nCandidates++
- }
- }
-
- fmt.Printf("Searched %d candidates in %v\n", nCandidates, time.Since(t0).Round(time.Millisecond))
- fmt.Printf("Best: cycleOff=%d, repOff=%d, metric=%.0f\n", bestCycleOff, bestRepOff, bestMetric)
-
- var sumAbs float64
- for _, c := range bestCorrs {
- sumAbs += math.Abs(c)
- }
- fmt.Printf("Corrs: avg|c|=%.1f\n", sumAbs/128)
-
- // BER diagnostic against known key
- knownPayload := watermark.KeyToPayload(key)
- knownCW := watermark.RSEncode(knownPayload)
- var knownBits [watermark.PayloadBits]int
- for i := 0; i < watermark.PayloadBits; i++ {
- knownBits[i] = int((knownCW[i/8] >> uint(7-(i%8))) & 1)
- }
- nerr := 0
- for i := 0; i < watermark.PayloadBits; i++ {
- hard := 0
- if bestCorrs[i] < 0 {
- hard = 1
- }
- if hard != knownBits[i] {
- nerr++
- }
- }
- fmt.Printf("BER: %d/128 (%.1f%%)\n", nerr, 100*float64(nerr)/128)
-
- // Show recv vs expected
- var recv [watermark.RsTotalBytes]byte
- confs := make([]float64, watermark.PayloadBits)
- for i := 0; i < watermark.PayloadBits; i++ {
- confs[i] = math.Abs(bestCorrs[i])
- if bestCorrs[i] < 0 {
- recv[i/8] |= 1 << uint(7-(i%8))
- }
- }
- fmt.Printf("recv: %x\nwant: %x\n", recv, knownCW)
-
- // Confidence-based erasure (MIN bit confidence per byte)
- type bc struct{ idx int; conf float64 }
- byteConfs := make([]bc, watermark.RsTotalBytes)
- for b := 0; b < watermark.RsTotalBytes; b++ {
- minC := confs[b*8]
- for bit := 1; bit < 8; bit++ {
- if confs[b*8+bit] < minC {
- minC = confs[b*8+bit]
- }
- }
- byteConfs[b] = bc{b, minC}
- }
- sort.Slice(byteConfs, func(a, b int) bool { return byteConfs[a].conf < byteConfs[b].conf })
-
- decoded := false
- for nErase := 0; nErase <= watermark.RsCheckBytes; nErase++ {
- if nErase == 0 {
- p, ok := watermark.RSDecode(recv, nil)
- if ok && watermark.KeyMatchesPayload(key, p) {
- fmt.Printf(" ✓ MATCH (0 erasures), payload=%x\n", p)
- decoded = true
- break
- }
- continue
- }
- erasePos := make([]int, nErase)
- for i := 0; i < nErase; i++ {
- erasePos[i] = byteConfs[i].idx
- }
- sort.Ints(erasePos)
- p, ok := watermark.RSDecode(recv, erasePos)
- if ok && watermark.KeyMatchesPayload(key, p) {
- fmt.Printf(" ✓ MATCH (%d erasures), payload=%x\n", nErase, p)
- decoded = true
- break
- }
- }
-
- if !decoded {
- fmt.Println(" ✗ NOT FOUND")
- }
- }
-
- fmt.Printf("\nDone in %v\n", time.Since(t0).Round(time.Millisecond))
- }
-
- func cepstrumFilter(magDB []float64, nCeps int) {
- n := len(magDB)
- if n < nCeps*2 {
- return
- }
- ceps := make([]float64, n)
- for k := 0; k < n; k++ {
- var sum float64
- for i := 0; i < n; i++ {
- sum += magDB[i] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
- }
- ceps[k] = sum
- }
- for k := 0; k < nCeps; k++ {
- ceps[k] = 0
- }
- for i := 0; i < n; i++ {
- var sum float64
- for k := 0; k < n; k++ {
- w := 1.0
- if k == 0 {
- w = 0.5
- }
- sum += w * ceps[k] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
- }
- magDB[i] = sum * 2.0 / float64(n)
- }
- }
-
- type biquad struct{ b0, b1, b2, a1, a2 float64 }
- type iirCoeffs []biquad
-
- func designLPF8(cutoffHz, sampleRate float64) iirCoeffs {
- angles := []float64{math.Pi / 16, 3 * math.Pi / 16, 5 * math.Pi / 16, 7 * math.Pi / 16}
- coeffs := make(iirCoeffs, 4)
- for i, angle := range angles {
- q := 1.0 / (2 * math.Cos(angle))
- omega := 2 * math.Pi * cutoffHz / sampleRate
- cosW := math.Cos(omega)
- sinW := math.Sin(omega)
- alpha := sinW / (2 * q)
- a0 := 1 + alpha
- coeffs[i] = biquad{
- b0: (1 - cosW) / 2 / a0, b1: (1 - cosW) / a0, b2: (1 - cosW) / 2 / a0,
- a1: (-2 * cosW) / a0, a2: (1 - alpha) / a0,
- }
- }
- return coeffs
- }
-
- func applyIIR(samples []float64, coeffs iirCoeffs) []float64 {
- out := make([]float64, len(samples))
- copy(out, samples)
- for _, bq := range coeffs {
- var z1, z2 float64
- for i, x := range out {
- y := bq.b0*x + z1
- z1 = bq.b1*x - bq.a1*y + z2
- z2 = bq.b2*x - bq.a2*y
- out[i] = y
- }
- }
- return out
- }
-
- func rmsLevel(s []float64) float64 {
- var acc float64
- for _, v := range s {
- acc += v * v
- }
- return math.Sqrt(acc / float64(len(s)))
- }
-
- func readMonoWAV(path string) ([]float64, float64, error) {
- data, err := os.ReadFile(path)
- if err != nil {
- return nil, 0, err
- }
- if len(data) < 44 || string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
- return nil, 0, fmt.Errorf("not a RIFF/WAVE file")
- }
- var channels, bitsPerSample uint16
- var sampleRate uint32
- var dataStart, dataLen int
- i := 12
- for i+8 <= len(data) {
- id := string(data[i : i+4])
- sz := int(binary.LittleEndian.Uint32(data[i+4 : i+8]))
- i += 8
- switch id {
- case "fmt ":
- if sz >= 16 {
- channels = binary.LittleEndian.Uint16(data[i+2 : i+4])
- sampleRate = binary.LittleEndian.Uint32(data[i+4 : i+8])
- bitsPerSample = binary.LittleEndian.Uint16(data[i+14 : i+16])
- }
- case "data":
- dataStart, dataLen = i, sz
- }
- i += sz
- if sz%2 != 0 {
- i++
- }
- if dataStart > 0 && channels > 0 {
- break
- }
- }
- if dataStart == 0 || bitsPerSample != 16 || channels == 0 {
- return nil, 0, fmt.Errorf("unsupported WAV")
- }
- if dataStart+dataLen > len(data) {
- dataLen = len(data) - dataStart
- }
- step := int(channels) * 2
- nFrames := dataLen / step
- out := make([]float64, nFrames)
- for j := 0; j < nFrames; j++ {
- off := dataStart + j*step
- l := float64(int16(binary.LittleEndian.Uint16(data[off : off+2])))
- r := l
- if channels >= 2 {
- r = float64(int16(binary.LittleEndian.Uint16(data[off+2 : off+4])))
- }
- out[j] = (l + r) / 2.0 / 32768.0
- }
- return out, float64(sampleRate), nil
- }
|