|
- package watermark
-
- import (
- "math"
- "testing"
- )
-
- func TestSTFTRoundTrip(t *testing.T) {
- const key = "test-stft-key"
- const duration = 150.0 // seconds — need > 136.5s for one full WM cycle
-
- nSamples := int(duration * WMRate)
- t.Logf("Generating %d samples @ %d Hz (%.1fs)", nSamples, WMRate, duration)
- t.Logf("WM cycle: %d STFT frames, %.1fs", FramesPerWM, float64(SamplesPerWM)/WMRate)
-
- // Generate test signal: broadband noise (the multiplicative watermark
- // needs energy in all frequency bins to work — a pure tone only has
- // energy in one bin and the watermark has no effect on silent bins)
- audio := make([]float64, nSamples)
- // Simple LCG pseudo-random for reproducibility
- var lcg uint64 = 12345
- for i := range audio {
- lcg = lcg*6364136223846793005 + 1442695040888963407
- audio[i] = 0.3 * (float64(int32(lcg>>33))/float64(1<<31))
- }
-
- rmsIn := rmsF64(audio)
- t.Logf("Input RMS: %.1f dBFS", 20*math.Log10(rmsIn+1e-12))
-
- // Embed watermark
- embedder := NewSTFTEmbedder(key)
- watermarked := embedder.ProcessBlock(audio)
-
- rmsOut := rmsF64(watermarked)
- t.Logf("Output RMS: %.1f dBFS", 20*math.Log10(rmsOut+1e-12))
- t.Logf("RMS change: %.2f dB", 20*math.Log10(rmsOut/rmsIn))
-
- // Detect watermark
- detector := NewSTFTDetector(key)
- corrs, offset := detector.Detect(watermarked)
-
- t.Logf("Detection offset: %d", offset)
-
- // Check correlations
- var nPositive, nNegative int
- var sumAbs float64
- for _, c := range corrs {
- sumAbs += math.Abs(c)
- if c > 0 {
- nPositive++
- } else {
- nNegative++
- }
- }
- avgAbs := sumAbs / float64(payloadBits)
- t.Logf("Correlations: avg|c|=%.1f, positive=%d, negative=%d", avgAbs, nPositive, nNegative)
-
- if avgAbs < 1.0 {
- t.Errorf("avg|c| too low: %.1f (expected >> 1.0)", avgAbs)
- }
-
- // Check against known payload
- payload := KeyToPayload(key)
- codeword := RSEncode(payload)
- var expectedBits [payloadBits]int
- for i := 0; i < payloadBits; i++ {
- expectedBits[i] = int((codeword[i/8] >> uint(7-(i%8))) & 1)
- }
-
- nerr := 0
- for i := 0; i < payloadBits; i++ {
- hard := 0
- if corrs[i] < 0 {
- hard = 1
- }
- if hard != expectedBits[i] {
- nerr++
- }
- }
- t.Logf("BER: %d/%d (%.1f%%)", nerr, payloadBits, 100*float64(nerr)/float64(payloadBits))
-
- if nerr > 20 {
- t.Errorf("BER too high: %d/%d", nerr, payloadBits)
- }
-
- // Try RS decode
- var recv [rsTotalBytes]byte
- for i := 0; i < payloadBits; i++ {
- if corrs[i] < 0 {
- recv[i/8] |= 1 << uint(7-(i%8))
- }
- }
-
- // Try with erasures if needed
- decoded := false
- for nErase := 0; nErase <= rsCheckBytes; nErase++ {
- if nErase == 0 {
- // Try zero erasures (valid if BER=0)
- p, ok := RSDecode(recv, nil)
- if ok {
- if KeyMatchesPayload(key, p) {
- t.Logf("Decoded with 0 erasures: MATCH ✓")
- decoded = true
- break
- }
- }
- continue
- }
- // Erase weakest bytes by |correlation|
- type bc struct{ idx int; conf float64 }
- byteConfs := make([]bc, rsTotalBytes)
- for b := 0; b < rsTotalBytes; b++ {
- minC := math.Abs(corrs[b*8])
- for bit := 1; bit < 8; bit++ {
- c := math.Abs(corrs[b*8+bit])
- if c < minC {
- minC = c
- }
- }
- byteConfs[b] = bc{b, minC}
- }
- // Sort by confidence (weakest first)
- for i := 0; i < len(byteConfs); i++ {
- for j := i + 1; j < len(byteConfs); j++ {
- if byteConfs[j].conf < byteConfs[i].conf {
- byteConfs[i], byteConfs[j] = byteConfs[j], byteConfs[i]
- }
- }
- }
- erasePos := make([]int, nErase)
- for i := 0; i < nErase; i++ {
- erasePos[i] = byteConfs[i].idx
- }
- // Sort positions
- for i := 0; i < len(erasePos); i++ {
- for j := i + 1; j < len(erasePos); j++ {
- if erasePos[j] < erasePos[i] {
- erasePos[i], erasePos[j] = erasePos[j], erasePos[i]
- }
- }
- }
- p, ok := RSDecode(recv, erasePos)
- if ok {
- if KeyMatchesPayload(key, p) {
- t.Logf("Decoded with %d erasures: MATCH ✓", nErase)
- decoded = true
- break
- }
- }
- }
-
- if !decoded {
- t.Errorf("RS decode FAILED")
- }
- }
-
- func rmsF64(s []float64) float64 {
- var acc float64
- for _, v := range s {
- acc += v * v
- }
- return math.Sqrt(acc / float64(len(s)))
- }
|