From 656784097c8f4b66efd0fe0c5de0f42c3b29edd9 Mon Sep 17 00:00:00 2001 From: Jan Svabenik Date: Fri, 3 Apr 2026 09:13:09 +0200 Subject: [PATCH] feat: improve WAV ingest robustness and add linear interpolation resampling --- internal/audio/resample.go | 51 ++++++--- internal/audio/resample_test.go | 49 ++++++-- internal/audio/wav.go | 196 +++++++++++++++++++++----------- internal/audio/wav_test.go | 181 ++++++++++++++++++++++------- 4 files changed, 346 insertions(+), 131 deletions(-) diff --git a/internal/audio/resample.go b/internal/audio/resample.go index 7bc4b91..665cbc5 100644 --- a/internal/audio/resample.go +++ b/internal/audio/resample.go @@ -1,28 +1,43 @@ package audio +// ResampledSource adapts a WAVSource to a different sample rate using +// linear interpolation between adjacent samples. type ResampledSource struct { - src *WAVSource - ratio float64 - position float64 + src *WAVSource + ratio float64 // source rate / target rate + position float64 // fractional position in source frames } +// NewResampledSource wraps a WAV source with rate conversion. func NewResampledSource(src *WAVSource, targetSampleRate float64) *ResampledSource { - ratio := 1.0 - if src != nil && src.SampleRate > 0 && targetSampleRate > 0 { - ratio = float64(src.SampleRate) / targetSampleRate - } - return &ResampledSource{src: src, ratio: ratio} + ratio := 1.0 + if src != nil && src.SampleRate > 0 && targetSampleRate > 0 { + ratio = float64(src.SampleRate) / targetSampleRate + } + return &ResampledSource{src: src, ratio: ratio} } +// NextFrame returns the next interpolated stereo frame. func (s *ResampledSource) NextFrame() Frame { - if s.src == nil || len(s.src.frames) == 0 { - return NewFrame(0, 0) - } - idx := int(s.position) % len(s.src.frames) - frame := s.src.frames[idx] - s.position += s.ratio - for s.position >= float64(len(s.src.frames)) { - s.position -= float64(len(s.src.frames)) - } - return frame + if s.src == nil || len(s.src.frames) == 0 { + return NewFrame(0, 0) + } + + n := len(s.src.frames) + idx0 := int(s.position) % n + idx1 := (idx0 + 1) % n + frac := s.position - float64(int(s.position)) + + f0 := s.src.frames[idx0] + f1 := s.src.frames[idx1] + + l := float64(f0.L)*(1-frac) + float64(f1.L)*frac + r := float64(f0.R)*(1-frac) + float64(f1.R)*frac + + s.position += s.ratio + for s.position >= float64(n) { + s.position -= float64(n) + } + + return NewFrame(Sample(l), Sample(r)) } diff --git a/internal/audio/resample_test.go b/internal/audio/resample_test.go index c9798e1..6b95a70 100644 --- a/internal/audio/resample_test.go +++ b/internal/audio/resample_test.go @@ -1,13 +1,46 @@ package audio -import "testing" +import ( + "math" + "testing" +) func TestResampledSource(t *testing.T) { - src := &WAVSource{frames: []Frame{NewFrame(0.1, 0.1), NewFrame(0.2, 0.2)}, SampleRate: 48000} - rs := NewResampledSource(src, 96000) - a := rs.NextFrame() - b := rs.NextFrame() - if a == (Frame{}) || b == (Frame{}) { - t.Fatal("expected frames") - } + src := &WAVSource{frames: []Frame{NewFrame(0.1, 0.1), NewFrame(0.2, 0.2)}, SampleRate: 48000} + rs := NewResampledSource(src, 96000) + a := rs.NextFrame() + b := rs.NextFrame() + if a == (Frame{}) || b == (Frame{}) { + t.Fatal("expected non-zero frames") + } +} + +func TestResampledSourceInterpolation(t *testing.T) { + // 2 samples at 48k, target at 96k -> ratio=0.5, so we should + // get interpolated values between the two source frames. + src := &WAVSource{ + frames: []Frame{NewFrame(0, 0), NewFrame(1, 1)}, + SampleRate: 48000, + } + rs := NewResampledSource(src, 96000) + + // First sample: position=0.0, exact frame[0] -> (0,0) + f0 := rs.NextFrame() + if math.Abs(float64(f0.L)) > 1e-9 { + t.Fatalf("expected L=0 at pos 0, got %.6f", f0.L) + } + + // Second sample: position=0.5, interpolated -> (0.5, 0.5) + f1 := rs.NextFrame() + if math.Abs(float64(f1.L)-0.5) > 1e-9 { + t.Fatalf("expected L=0.5 at pos 0.5, got %.6f", f1.L) + } +} + +func TestResampledSourceNilSrc(t *testing.T) { + rs := NewResampledSource(nil, 48000) + f := rs.NextFrame() + if f.L != 0 || f.R != 0 { + t.Fatal("expected zero frame for nil source") + } } diff --git a/internal/audio/wav.go b/internal/audio/wav.go index 8725caf..7d8aecd 100644 --- a/internal/audio/wav.go +++ b/internal/audio/wav.go @@ -1,88 +1,152 @@ package audio import ( - "encoding/binary" - "fmt" - "io" - "os" + "encoding/binary" + "fmt" + "io" + "os" ) +// WAVSource loads a PCM WAV file into memory and provides frame-by-frame access. type WAVSource struct { - frames []Frame - index int - SampleRate int - Channels int + frames []Frame + index int + SampleRate int + Channels int } +// LoadWAVSource reads and decodes a WAV file. It properly scans for the "fmt " +// and "data" chunks, handling files with extra metadata chunks (LIST, INFO, +// bext, etc.) that appear between headers. func LoadWAVSource(path string) (*WAVSource, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() - header := make([]byte, 44) - if _, err := io.ReadFull(f, header); err != nil { - return nil, fmt.Errorf("read wav header: %w", err) - } - if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" { - return nil, fmt.Errorf("unsupported wav header") - } + // Read RIFF header (12 bytes) + riffHeader := make([]byte, 12) + if _, err := io.ReadFull(f, riffHeader); err != nil { + return nil, fmt.Errorf("read riff header: %w", err) + } + if string(riffHeader[0:4]) != "RIFF" || string(riffHeader[8:12]) != "WAVE" { + return nil, fmt.Errorf("not a RIFF/WAVE file") + } - audioFormat := binary.LittleEndian.Uint16(header[20:22]) - channels := binary.LittleEndian.Uint16(header[22:24]) - sampleRate := binary.LittleEndian.Uint32(header[24:28]) - bitsPerSample := binary.LittleEndian.Uint16(header[34:36]) - dataSize := binary.LittleEndian.Uint32(header[40:44]) + var ( + audioFormat uint16 + channels uint16 + sampleRate uint32 + bitsPerSample uint16 + dataBytes []byte + fmtFound bool + dataFound bool + ) - if audioFormat != 1 { - return nil, fmt.Errorf("only PCM wav supported") - } - if bitsPerSample != 16 { - return nil, fmt.Errorf("only 16-bit PCM wav supported") - } - if channels != 1 && channels != 2 { - return nil, fmt.Errorf("only mono/stereo wav supported") - } - if sampleRate == 0 { - return nil, fmt.Errorf("invalid wav sample rate") - } + // Scan chunks + for { + var chunkID [4]byte + var chunkSize uint32 + if _, err := io.ReadFull(f, chunkID[:]); err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } + return nil, fmt.Errorf("read chunk id: %w", err) + } + if err := binary.Read(f, binary.LittleEndian, &chunkSize); err != nil { + return nil, fmt.Errorf("read chunk size: %w", err) + } - raw := make([]byte, dataSize) - if _, err := io.ReadFull(f, raw); err != nil { - return nil, fmt.Errorf("read wav data: %w", err) - } + switch string(chunkID[:]) { + case "fmt ": + if chunkSize < 16 { + return nil, fmt.Errorf("fmt chunk too small: %d", chunkSize) + } + fmtData := make([]byte, chunkSize) + if _, err := io.ReadFull(f, fmtData); err != nil { + return nil, fmt.Errorf("read fmt chunk: %w", err) + } + audioFormat = binary.LittleEndian.Uint16(fmtData[0:2]) + channels = binary.LittleEndian.Uint16(fmtData[2:4]) + sampleRate = binary.LittleEndian.Uint32(fmtData[4:8]) + bitsPerSample = binary.LittleEndian.Uint16(fmtData[14:16]) + fmtFound = true - step := int(channels) * 2 - frames := make([]Frame, 0, len(raw)/step) - for i := 0; i+step <= len(raw); i += step { - l := pcm16ToSample(int16(binary.LittleEndian.Uint16(raw[i : i+2]))) - r := l - if channels == 2 { - r = pcm16ToSample(int16(binary.LittleEndian.Uint16(raw[i+2 : i+4]))) - } - frames = append(frames, NewFrame(l, r)) - } + case "data": + dataBytes = make([]byte, chunkSize) + if _, err := io.ReadFull(f, dataBytes); err != nil { + return nil, fmt.Errorf("read data chunk: %w", err) + } + dataFound = true - return &WAVSource{ - frames: frames, - SampleRate: int(sampleRate), - Channels: int(channels), - }, nil + default: + // Skip unknown chunks, respecting RIFF padding (chunks are word-aligned) + skip := int64(chunkSize) + if chunkSize%2 != 0 { + skip++ + } + if _, err := io.CopyN(io.Discard, f, skip); err != nil { + // Could be EOF if this is the last chunk + break + } + } + + if fmtFound && dataFound { + break + } + } + + if !fmtFound { + return nil, fmt.Errorf("no fmt chunk found") + } + if !dataFound { + return nil, fmt.Errorf("no data chunk found") + } + if audioFormat != 1 { + return nil, fmt.Errorf("only PCM wav supported (format=%d)", audioFormat) + } + if bitsPerSample != 16 { + return nil, fmt.Errorf("only 16-bit PCM wav supported (bits=%d)", bitsPerSample) + } + if channels != 1 && channels != 2 { + return nil, fmt.Errorf("only mono/stereo wav supported (channels=%d)", channels) + } + if sampleRate == 0 { + return nil, fmt.Errorf("invalid wav sample rate") + } + + step := int(channels) * 2 + frames := make([]Frame, 0, len(dataBytes)/step) + for i := 0; i+step <= len(dataBytes); i += step { + l := pcm16ToSample(int16(binary.LittleEndian.Uint16(dataBytes[i : i+2]))) + r := l + if channels == 2 { + r = pcm16ToSample(int16(binary.LittleEndian.Uint16(dataBytes[i+2 : i+4]))) + } + frames = append(frames, NewFrame(l, r)) + } + + return &WAVSource{ + frames: frames, + SampleRate: int(sampleRate), + Channels: int(channels), + }, nil } +// NextFrame returns the next audio frame, looping at the end. func (s *WAVSource) NextFrame() Frame { - if len(s.frames) == 0 { - return NewFrame(0, 0) - } - frame := s.frames[s.index] - s.index++ - if s.index >= len(s.frames) { - s.index = 0 - } - return frame + if len(s.frames) == 0 { + return NewFrame(0, 0) + } + frame := s.frames[s.index] + s.index++ + if s.index >= len(s.frames) { + s.index = 0 + } + return frame } func pcm16ToSample(v int16) Sample { - return Sample(float64(v) / 32768.0).Clamp() + return Sample(float64(v) / 32768.0).Clamp() } diff --git a/internal/audio/wav_test.go b/internal/audio/wav_test.go index 63e45fb..41bbbf0 100644 --- a/internal/audio/wav_test.go +++ b/internal/audio/wav_test.go @@ -1,52 +1,155 @@ package audio import ( - "os" - "path/filepath" - "testing" + "encoding/binary" + "os" + "path/filepath" + "testing" ) func TestPCM16ToSample(t *testing.T) { - if pcm16ToSample(32767) <= 0 { - t.Fatal("expected positive sample") - } - if pcm16ToSample(-32768) < -1.0 { - t.Fatal("expected clamped lower bound") - } + if pcm16ToSample(32767) <= 0 { + t.Fatal("expected positive sample") + } + if pcm16ToSample(-32768) < -1.0 { + t.Fatal("expected clamped lower bound") + } } func TestLoadWAVSource(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test.wav") - wav := []byte{ - 'R','I','F','F', 52,0,0,0, 'W','A','V','E', - 'f','m','t',' ', 16,0,0,0, 1,0, 1,0, 0x80,0xbb,0x00,0x00, 0x00,0x77,0x01,0x00, 2,0, 16,0, - 'd','a','t','a', 8,0,0,0, - 0,0, 255,127, 0,128, 0,0, - } - if err := os.WriteFile(path, wav, 0o644); err != nil { - t.Fatalf("write wav: %v", err) - } - src, err := LoadWAVSource(path) - if err != nil { - t.Fatalf("LoadWAVSource failed: %v", err) - } - if src.SampleRate != 48000 { - t.Fatalf("unexpected sample rate: %d", src.SampleRate) - } - if src.Channels != 1 { - t.Fatalf("unexpected channels: %d", src.Channels) - } - _ = src.NextFrame() + dir := t.TempDir() + path := filepath.Join(dir, "test.wav") + wav := buildMinimalWAV(48000, 1, []int16{0, 32767, -32768, 0}) + if err := os.WriteFile(path, wav, 0o644); err != nil { + t.Fatalf("write wav: %v", err) + } + src, err := LoadWAVSource(path) + if err != nil { + t.Fatalf("LoadWAVSource failed: %v", err) + } + if src.SampleRate != 48000 { + t.Fatalf("unexpected sample rate: %d", src.SampleRate) + } + if src.Channels != 1 { + t.Fatalf("unexpected channels: %d", src.Channels) + } + _ = src.NextFrame() +} + +func TestLoadWAVWithExtraChunks(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "extra.wav") + + // Build a WAV with a LIST chunk between fmt and data + wav := buildWAVWithExtraChunks(44100, 1, []int16{100, -100}) + if err := os.WriteFile(path, wav, 0o644); err != nil { + t.Fatalf("write wav: %v", err) + } + src, err := LoadWAVSource(path) + if err != nil { + t.Fatalf("LoadWAVSource with extra chunks failed: %v", err) + } + if src.SampleRate != 44100 { + t.Fatalf("unexpected sample rate: %d", src.SampleRate) + } + if len(src.frames) != 2 { + t.Fatalf("expected 2 frames, got %d", len(src.frames)) + } } func TestRejectInvalidWAV(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "bad.wav") - if err := os.WriteFile(path, []byte("nope"), 0o644); err != nil { - t.Fatalf("write wav: %v", err) - } - if _, err := LoadWAVSource(path); err == nil { - t.Fatal("expected wav load error") - } + dir := t.TempDir() + path := filepath.Join(dir, "bad.wav") + if err := os.WriteFile(path, []byte("nope"), 0o644); err != nil { + t.Fatalf("write wav: %v", err) + } + if _, err := LoadWAVSource(path); err == nil { + t.Fatal("expected wav load error") + } +} + +func TestStereoWAV(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "stereo.wav") + wav := buildMinimalWAV(48000, 2, []int16{1000, -1000, 2000, -2000}) + if err := os.WriteFile(path, wav, 0o644); err != nil { + t.Fatalf("write wav: %v", err) + } + src, err := LoadWAVSource(path) + if err != nil { + t.Fatalf("LoadWAVSource stereo failed: %v", err) + } + if src.Channels != 2 { + t.Fatalf("expected 2 channels, got %d", src.Channels) + } + if len(src.frames) != 2 { + t.Fatalf("expected 2 frames, got %d", len(src.frames)) + } +} + +// -- helpers -- + +func buildMinimalWAV(sampleRate int, channels int, samples []int16) []byte { + dataSize := len(samples) * 2 + fileSize := 36 + dataSize + + buf := make([]byte, 0, 44+dataSize) + buf = append(buf, []byte("RIFF")...) + buf = binary.LittleEndian.AppendUint32(buf, uint32(fileSize)) + buf = append(buf, []byte("WAVE")...) + + // fmt chunk + buf = append(buf, []byte("fmt ")...) + buf = binary.LittleEndian.AppendUint32(buf, 16) + buf = binary.LittleEndian.AppendUint16(buf, 1) // PCM + buf = binary.LittleEndian.AppendUint16(buf, uint16(channels)) + buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate)) + buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate*channels*2)) // byte rate + buf = binary.LittleEndian.AppendUint16(buf, uint16(channels*2)) // block align + buf = binary.LittleEndian.AppendUint16(buf, 16) // bits per sample + + // data chunk + buf = append(buf, []byte("data")...) + buf = binary.LittleEndian.AppendUint32(buf, uint32(dataSize)) + for _, s := range samples { + buf = binary.LittleEndian.AppendUint16(buf, uint16(s)) + } + return buf +} + +func buildWAVWithExtraChunks(sampleRate int, channels int, samples []int16) []byte { + dataSize := len(samples) * 2 + // Add a fake LIST chunk of 12 bytes between fmt and data + listChunkData := []byte("INFOtest") // 8 bytes + listChunkSize := uint32(len(listChunkData)) + totalExtraChunk := 8 + len(listChunkData) // "LIST" + size + data + fileSize := 36 + totalExtraChunk + dataSize + + buf := make([]byte, 0, 44+totalExtraChunk+dataSize) + buf = append(buf, []byte("RIFF")...) + buf = binary.LittleEndian.AppendUint32(buf, uint32(fileSize)) + buf = append(buf, []byte("WAVE")...) + + // fmt chunk + buf = append(buf, []byte("fmt ")...) + buf = binary.LittleEndian.AppendUint32(buf, 16) + buf = binary.LittleEndian.AppendUint16(buf, 1) + buf = binary.LittleEndian.AppendUint16(buf, uint16(channels)) + buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate)) + buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate*channels*2)) + buf = binary.LittleEndian.AppendUint16(buf, uint16(channels*2)) + buf = binary.LittleEndian.AppendUint16(buf, 16) + + // LIST chunk (extra, should be skipped) + buf = append(buf, []byte("LIST")...) + buf = binary.LittleEndian.AppendUint32(buf, listChunkSize) + buf = append(buf, listChunkData...) + + // data chunk + buf = append(buf, []byte("data")...) + buf = binary.LittleEndian.AppendUint32(buf, uint32(dataSize)) + for _, s := range samples { + buf = binary.LittleEndian.AppendUint16(buf, uint16(s)) + } + return buf }