Parcourir la source

wmdecode: 6x faster decoding via cosine table + parallel search

The decoder took ~70s for a 20-minute recording. Profiling revealed the
bottleneck was not the 6400-candidate cycle-offset search, but the
cepstrum filter's naive O(N²) DCT calling math.Cos() in the inner loop:

 55458 STFT frames × 2 passes × 256² × math.Cos() = 7.27 billion calls
 At ~20ns per call: ~145 seconds (dominated total runtime)

Fixes:

 1. Precomputed cosine table: compute 256×256 = 65536 cosine values
 once, then use table lookups in the inner loop. Eliminates all
 math.Cos() calls from the per-frame processing.

 2. Parallel cycle-offset search: 5 goroutines (one per rep offset),
 each searching 1280 cycle offsets independently. The rep offsets
 are fully independent — no shared state, no synchronization needed
 until the final result merge.

 3. Precomputed center-frame lists: instead of checking f%timeRep for
 every frame in every candidate test, precompute which frames are
 center frames for each rep offset. Eliminates per-frame branching.

 4. Float64 PN chip arrays: convert int8 PN chips to float64 once at
 startup. Eliminates int8→float64 conversion in the hot inner loop
 (204 conversions × 11000 frames × 6400 candidates = 14.4 billion
 avoided conversions).

Performance (20-minute recording, 55458 STFT frames):

 Before: 70s (math.Cos dominated)
 After: 11.5s (6x faster)
 Unit test (round-trip): 20s → 1.4s (14x faster)

Note: attempted coarse/fine search (testing every 10th group offset,
then refining) but abandoned — the chi-squared metric peak is too
narrow and the coarse step missed the true peak, causing false
positives. The full 6400-candidate brute-force search is kept for
correctness; the speedup comes entirely from eliminating per-operation
overhead, not from reducing the number of operations.
main
Jan il y a 1 mois
Parent
révision
96fdb2e7e1
3 fichiers modifiés avec 129 ajouts et 33 suppressions
  1. +81
    -27
      cmd/wmdecode/main.go
  2. +32
    -6
      internal/watermark/stft_watermark.go
  3. +16
    -0
      internal/watermark/watermark.go

+ 81
- 27
cmd/wmdecode/main.go Voir le fichier

@@ -101,39 +101,76 @@ func main() {
binLow := watermark.BinLow
centerRep := timeRep / 2

// Precompute for speed: float64 PN chips + group-to-bit + center frame lists
pnF := det.PNChipsFloat()
g2b := det.GroupToBit()

// Center frame indices per repOff (avoids per-frame modulo + branch)
var centerFrames [5][]int
for f := 0; f < nFrames; f++ {
r := f % timeRep
centerFrames[r] = append(centerFrames[r], f)
}

bestMetric := -1.0
var bestCorrs [watermark.PayloadBits]float64
bestCycleOff := 0
bestRepOff := 0

for cycleOff := 0; cycleOff < framesPerWM; cycleOff += timeRep {
for repOff := 0; repOff < timeRep; repOff++ {
var testCorrs [watermark.PayloadBits]float64
for f := 0; f < nFrames; f++ {
wmFrame := ((f - cycleOff - repOff) % framesPerWM + framesPerWM) % framesPerWM
if wmFrame%timeRep != centerRep {
continue
// Parallel search: 5 goroutines (one per rep offset), each searches 1280 cycle offsets.
type searchResult struct {
corrs [watermark.PayloadBits]float64
metric float64
cycleOff int
repOff int
}
results := make(chan searchResult, timeRep)

for repOff := 0; repOff < timeRep; repOff++ {
go func(repOff int) {
cfIdx := (repOff + centerRep) % timeRep
cfs := centerFrames[cfIdx]

var best searchResult
best.repOff = repOff

for cycleOff := 0; cycleOff < framesPerWM; cycleOff += timeRep {
var testCorrs [watermark.PayloadBits]float64

for _, f := range cfs {
wmFrame := ((f - cycleOff - repOff) % framesPerWM + framesPerWM) % framesPerWM
g := wmFrame / timeRep
if g >= totalGroups {
continue
}
var corr float64
for b := 0; b < numBins; b++ {
corr += frameMags[f][binLow+b] * pnF[g][b]
}
testCorrs[g2b[g]] += corr
}
g := wmFrame / timeRep
if g >= totalGroups {
continue

var metric float64
for _, c := range testCorrs {
metric += c * c
}
var corr float64
for b := 0; b < numBins; b++ {
corr += frameMags[f][binLow+b] * float64(det.PNChipAt(g, b))
if metric > best.metric {
best.metric = metric
best.corrs = testCorrs
best.cycleOff = cycleOff
}
testCorrs[det.GroupBit(g)] += corr
}
var metric float64
for _, c := range testCorrs {
metric += c * c
}
if metric > bestMetric {
bestMetric = metric
bestCorrs = testCorrs
bestCycleOff = cycleOff
bestRepOff = repOff
}
results <- best
}(repOff)
}

for i := 0; i < timeRep; i++ {
r := <-results
if r.metric > bestMetric {
bestMetric = r.metric
bestCorrs = r.corrs
bestCycleOff = r.cycleOff
bestRepOff = r.repOff
}
}

@@ -222,22 +259,39 @@ func main() {

// --- Cepstrum filter ---

var cosTable [][]float64

func initCosTable(n int) {
cosTable = make([][]float64, n)
for k := 0; k < n; k++ {
cosTable[k] = make([]float64, n)
for i := 0; i < n; i++ {
cosTable[k][i] = math.Cos(math.Pi * float64(k) * (float64(i) + 0.5) / float64(n))
}
}
}

func cepstrumFilter(magDB []float64, nCeps int) {
n := len(magDB)
if n < nCeps*2 {
return
}
if len(cosTable) != n {
initCosTable(n)
}
ceps := make([]float64, n)
for k := 0; k < n; k++ {
var sum float64
row := cosTable[k]
for i := 0; i < n; i++ {
sum += magDB[i] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
sum += magDB[i] * row[i]
}
ceps[k] = sum
}
for k := 0; k < nCeps; k++ {
ceps[k] = 0
}
scale := 2.0 / float64(n)
for i := 0; i < n; i++ {
var sum float64
for k := 0; k < n; k++ {
@@ -245,9 +299,9 @@ func cepstrumFilter(magDB []float64, nCeps int) {
if k == 0 {
w = 0.5
}
sum += w * ceps[k] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
sum += w * ceps[k] * cosTable[k][i]
}
magDB[i] = sum * 2.0 / float64(n)
magDB[i] = sum * scale
}
}



+ 32
- 6
internal/watermark/stft_watermark.go Voir le fichier

@@ -358,28 +358,34 @@ func (d *STFTDetector) Detect(audio []float64) (corrs [payloadBits]float64, best
// cepstrumFilter removes the spectral envelope from dB magnitudes.
// It zeros the first nCeps DCT coefficients (the smooth spectral shape).
// This is Kirovski's "CF" technique: reduces carrier noise by ~6 dB.
//
// Uses precomputed cosine table for O(N²) DCT without math.Cos calls.
func cepstrumFilter(magDB []float64, nCeps int) {
n := len(magDB)
if n < nCeps*2 {
return
}

// DCT-II (simplified, not optimized)
cosTable := getCosTable(n)

// DCT-II
ceps := make([]float64, n)
for k := 0; k < n; k++ {
var sum float64
row := cosTable[k]
for i := 0; i < n; i++ {
sum += magDB[i] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
sum += magDB[i] * row[i]
}
ceps[k] = sum
}

// Zero low-order cepstral coefficients (spectral envelope)
// Zero low-order coefficients
for k := 0; k < nCeps; k++ {
ceps[k] = 0
}

// IDCT (inverse DCT-II)
// IDCT
scale := 2.0 / float64(n)
for i := 0; i < n; i++ {
var sum float64
for k := 0; k < n; k++ {
@@ -387,10 +393,30 @@ func cepstrumFilter(magDB []float64, nCeps int) {
if k == 0 {
w = 0.5
}
sum += w * ceps[k] * math.Cos(math.Pi*float64(k)*(float64(i)+0.5)/float64(n))
sum += w * ceps[k] * cosTable[k][i]
}
magDB[i] = sum * scale
}
}

// Cached cosine table for DCT. cosTable[k][i] = cos(π·k·(i+0.5)/N).
var cachedCosTable [][]float64
var cachedCosN int

func getCosTable(n int) [][]float64 {
if cachedCosN == n {
return cachedCosTable
}
table := make([][]float64, n)
for k := 0; k < n; k++ {
table[k] = make([]float64, n)
for i := 0; i < n; i++ {
table[k][i] = math.Cos(math.Pi * float64(k) * (float64(i) + 0.5) / float64(n))
}
magDB[i] = sum * 2.0 / float64(n)
}
cachedCosTable = table
cachedCosN = n
return table
}

// Simple xorshift32 PRNG for deterministic chip generation.


+ 16
- 0
internal/watermark/watermark.go Voir le fichier

@@ -228,6 +228,22 @@ func (d *STFTDetector) GroupBit(g int) int {
return d.groupToBit[g]
}

// PNChipsFloat returns all PN chips as float64 for fast inner-loop access.
func (d *STFTDetector) PNChipsFloat() [TotalGroups][NumBins]float64 {
var out [TotalGroups][NumBins]float64
for g := 0; g < TotalGroups; g++ {
for b := 0; b < NumBins; b++ {
out[g][b] = float64(d.pnChips[g][b])
}
}
return out
}

// GroupToBit returns the full group-to-bit mapping.
func (d *STFTDetector) GroupToBit() [TotalGroups]int {
return d.groupToBit
}

// --- GF tables ---

var gfExp = [512]byte{1, 2, 4, 8, 16, 32, 64, 128, 29, 58, 116, 232, 205, 135, 19, 38, 76, 152, 45, 90, 180, 117, 234, 201, 143, 3, 6, 12, 24, 48, 96, 192, 157, 39, 78, 156, 37, 74, 148, 53, 106, 212, 181, 119, 238, 193, 159, 35, 70, 140, 5, 10, 20, 40, 80, 160, 93, 186, 105, 210, 185, 111, 222, 161, 95, 190, 97, 194, 153, 47, 94, 188, 101, 202, 137, 15, 30, 60, 120, 240, 253, 231, 211, 187, 107, 214, 177, 127, 254, 225, 223, 163, 91, 182, 113, 226, 217, 175, 67, 134, 17, 34, 68, 136, 13, 26, 52, 104, 208, 189, 103, 206, 129, 31, 62, 124, 248, 237, 199, 147, 59, 118, 236, 197, 151, 51, 102, 204, 133, 23, 46, 92, 184, 109, 218, 169, 79, 158, 33, 66, 132, 21, 42, 84, 168, 77, 154, 41, 82, 164, 85, 170, 73, 146, 57, 114, 228, 213, 183, 115, 230, 209, 191, 99, 198, 145, 63, 126, 252, 229, 215, 179, 123, 246, 241, 255, 227, 219, 171, 75, 150, 49, 98, 196, 149, 55, 110, 220, 165, 87, 174, 65, 130, 25, 50, 100, 200, 141, 7, 14, 28, 56, 112, 224, 221, 167, 83, 166, 81, 162, 89, 178, 121, 242, 249, 239, 195, 155, 43, 86, 172, 69, 138, 9, 18, 36, 72, 144, 61, 122, 244, 245, 247, 243, 251, 235, 203, 139, 11, 22, 44, 88, 176, 125, 250, 233, 207, 131, 27, 54, 108, 216, 173, 71, 142, 1, 2, 4, 8, 16, 32, 64, 128, 29, 58, 116, 232, 205, 135, 19, 38, 76, 152, 45, 90, 180, 117, 234, 201, 143, 3, 6, 12, 24, 48, 96, 192, 157, 39, 78, 156, 37, 74, 148, 53, 106, 212, 181, 119, 238, 193, 159, 35, 70, 140, 5, 10, 20, 40, 80, 160, 93, 186, 105, 210, 185, 111, 222, 161, 95, 190, 97, 194, 153, 47, 94, 188, 101, 202, 137, 15, 30, 60, 120, 240, 253, 231, 211, 187, 107, 214, 177, 127, 254, 225, 223, 163, 91, 182, 113, 226, 217, 175, 67, 134, 17, 34, 68, 136, 13, 26, 52, 104, 208, 189, 103, 206, 129, 31, 62, 124, 248, 237, 199, 147, 59, 118, 236, 197, 151, 51, 102, 204, 133, 23, 46, 92, 184, 109, 218, 169, 79, 158, 33, 66, 132, 21, 42, 84, 168, 77, 154, 41, 82, 164, 85, 170, 73, 146, 57, 114, 228, 213, 183, 115, 230, 209, 191, 99, 198, 145, 63, 126, 252, 229, 215, 179, 123, 246, 241, 255, 227, 219, 171, 75, 150, 49, 98, 196, 149, 55, 110, 220, 165, 87, 174, 65, 130, 25, 50, 100, 200, 141, 7, 14, 28, 56, 112, 224, 221, 167, 83, 166, 81, 162, 89, 178, 121, 242, 249, 239, 195, 155, 43, 86, 172, 69, 138, 9, 18, 36, 72, 144, 61, 122, 244, 245, 247, 243, 251, 235, 203, 139, 11, 22, 44, 88, 176, 125, 250, 233, 207, 131, 27, 54, 108, 216, 173, 71, 142, 1, 2}


Chargement…
Annuler
Enregistrer