Parcourir la source

feat: improve WAV ingest robustness and add linear interpolation resampling

tags/v0.4.0-pre
Jan Svabenik il y a 1 mois
Parent
révision
656784097c
4 fichiers modifiés avec 346 ajouts et 131 suppressions
  1. +33
    -18
      internal/audio/resample.go
  2. +41
    -8
      internal/audio/resample_test.go
  3. +130
    -66
      internal/audio/wav.go
  4. +142
    -39
      internal/audio/wav_test.go

+ 33
- 18
internal/audio/resample.go Voir le fichier

@@ -1,28 +1,43 @@
package audio

// ResampledSource adapts a WAVSource to a different sample rate using
// linear interpolation between adjacent samples.
type ResampledSource struct {
src *WAVSource
ratio float64
position float64
src *WAVSource
ratio float64 // source rate / target rate
position float64 // fractional position in source frames
}

// NewResampledSource wraps a WAV source with rate conversion.
func NewResampledSource(src *WAVSource, targetSampleRate float64) *ResampledSource {
ratio := 1.0
if src != nil && src.SampleRate > 0 && targetSampleRate > 0 {
ratio = float64(src.SampleRate) / targetSampleRate
}
return &ResampledSource{src: src, ratio: ratio}
ratio := 1.0
if src != nil && src.SampleRate > 0 && targetSampleRate > 0 {
ratio = float64(src.SampleRate) / targetSampleRate
}
return &ResampledSource{src: src, ratio: ratio}
}

// NextFrame returns the next interpolated stereo frame.
func (s *ResampledSource) NextFrame() Frame {
if s.src == nil || len(s.src.frames) == 0 {
return NewFrame(0, 0)
}
idx := int(s.position) % len(s.src.frames)
frame := s.src.frames[idx]
s.position += s.ratio
for s.position >= float64(len(s.src.frames)) {
s.position -= float64(len(s.src.frames))
}
return frame
if s.src == nil || len(s.src.frames) == 0 {
return NewFrame(0, 0)
}

n := len(s.src.frames)
idx0 := int(s.position) % n
idx1 := (idx0 + 1) % n
frac := s.position - float64(int(s.position))

f0 := s.src.frames[idx0]
f1 := s.src.frames[idx1]

l := float64(f0.L)*(1-frac) + float64(f1.L)*frac
r := float64(f0.R)*(1-frac) + float64(f1.R)*frac

s.position += s.ratio
for s.position >= float64(n) {
s.position -= float64(n)
}

return NewFrame(Sample(l), Sample(r))
}

+ 41
- 8
internal/audio/resample_test.go Voir le fichier

@@ -1,13 +1,46 @@
package audio

import "testing"
import (
"math"
"testing"
)

func TestResampledSource(t *testing.T) {
src := &WAVSource{frames: []Frame{NewFrame(0.1, 0.1), NewFrame(0.2, 0.2)}, SampleRate: 48000}
rs := NewResampledSource(src, 96000)
a := rs.NextFrame()
b := rs.NextFrame()
if a == (Frame{}) || b == (Frame{}) {
t.Fatal("expected frames")
}
src := &WAVSource{frames: []Frame{NewFrame(0.1, 0.1), NewFrame(0.2, 0.2)}, SampleRate: 48000}
rs := NewResampledSource(src, 96000)
a := rs.NextFrame()
b := rs.NextFrame()
if a == (Frame{}) || b == (Frame{}) {
t.Fatal("expected non-zero frames")
}
}

func TestResampledSourceInterpolation(t *testing.T) {
// 2 samples at 48k, target at 96k -> ratio=0.5, so we should
// get interpolated values between the two source frames.
src := &WAVSource{
frames: []Frame{NewFrame(0, 0), NewFrame(1, 1)},
SampleRate: 48000,
}
rs := NewResampledSource(src, 96000)

// First sample: position=0.0, exact frame[0] -> (0,0)
f0 := rs.NextFrame()
if math.Abs(float64(f0.L)) > 1e-9 {
t.Fatalf("expected L=0 at pos 0, got %.6f", f0.L)
}

// Second sample: position=0.5, interpolated -> (0.5, 0.5)
f1 := rs.NextFrame()
if math.Abs(float64(f1.L)-0.5) > 1e-9 {
t.Fatalf("expected L=0.5 at pos 0.5, got %.6f", f1.L)
}
}

func TestResampledSourceNilSrc(t *testing.T) {
rs := NewResampledSource(nil, 48000)
f := rs.NextFrame()
if f.L != 0 || f.R != 0 {
t.Fatal("expected zero frame for nil source")
}
}

+ 130
- 66
internal/audio/wav.go Voir le fichier

@@ -1,88 +1,152 @@
package audio

import (
"encoding/binary"
"fmt"
"io"
"os"
"encoding/binary"
"fmt"
"io"
"os"
)

// WAVSource loads a PCM WAV file into memory and provides frame-by-frame access.
type WAVSource struct {
frames []Frame
index int
SampleRate int
Channels int
frames []Frame
index int
SampleRate int
Channels int
}

// LoadWAVSource reads and decodes a WAV file. It properly scans for the "fmt "
// and "data" chunks, handling files with extra metadata chunks (LIST, INFO,
// bext, etc.) that appear between headers.
func LoadWAVSource(path string) (*WAVSource, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()

header := make([]byte, 44)
if _, err := io.ReadFull(f, header); err != nil {
return nil, fmt.Errorf("read wav header: %w", err)
}
if string(header[0:4]) != "RIFF" || string(header[8:12]) != "WAVE" {
return nil, fmt.Errorf("unsupported wav header")
}
// Read RIFF header (12 bytes)
riffHeader := make([]byte, 12)
if _, err := io.ReadFull(f, riffHeader); err != nil {
return nil, fmt.Errorf("read riff header: %w", err)
}
if string(riffHeader[0:4]) != "RIFF" || string(riffHeader[8:12]) != "WAVE" {
return nil, fmt.Errorf("not a RIFF/WAVE file")
}

audioFormat := binary.LittleEndian.Uint16(header[20:22])
channels := binary.LittleEndian.Uint16(header[22:24])
sampleRate := binary.LittleEndian.Uint32(header[24:28])
bitsPerSample := binary.LittleEndian.Uint16(header[34:36])
dataSize := binary.LittleEndian.Uint32(header[40:44])
var (
audioFormat uint16
channels uint16
sampleRate uint32
bitsPerSample uint16
dataBytes []byte
fmtFound bool
dataFound bool
)

if audioFormat != 1 {
return nil, fmt.Errorf("only PCM wav supported")
}
if bitsPerSample != 16 {
return nil, fmt.Errorf("only 16-bit PCM wav supported")
}
if channels != 1 && channels != 2 {
return nil, fmt.Errorf("only mono/stereo wav supported")
}
if sampleRate == 0 {
return nil, fmt.Errorf("invalid wav sample rate")
}
// Scan chunks
for {
var chunkID [4]byte
var chunkSize uint32
if _, err := io.ReadFull(f, chunkID[:]); err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
}
return nil, fmt.Errorf("read chunk id: %w", err)
}
if err := binary.Read(f, binary.LittleEndian, &chunkSize); err != nil {
return nil, fmt.Errorf("read chunk size: %w", err)
}

raw := make([]byte, dataSize)
if _, err := io.ReadFull(f, raw); err != nil {
return nil, fmt.Errorf("read wav data: %w", err)
}
switch string(chunkID[:]) {
case "fmt ":
if chunkSize < 16 {
return nil, fmt.Errorf("fmt chunk too small: %d", chunkSize)
}
fmtData := make([]byte, chunkSize)
if _, err := io.ReadFull(f, fmtData); err != nil {
return nil, fmt.Errorf("read fmt chunk: %w", err)
}
audioFormat = binary.LittleEndian.Uint16(fmtData[0:2])
channels = binary.LittleEndian.Uint16(fmtData[2:4])
sampleRate = binary.LittleEndian.Uint32(fmtData[4:8])
bitsPerSample = binary.LittleEndian.Uint16(fmtData[14:16])
fmtFound = true

step := int(channels) * 2
frames := make([]Frame, 0, len(raw)/step)
for i := 0; i+step <= len(raw); i += step {
l := pcm16ToSample(int16(binary.LittleEndian.Uint16(raw[i : i+2])))
r := l
if channels == 2 {
r = pcm16ToSample(int16(binary.LittleEndian.Uint16(raw[i+2 : i+4])))
}
frames = append(frames, NewFrame(l, r))
}
case "data":
dataBytes = make([]byte, chunkSize)
if _, err := io.ReadFull(f, dataBytes); err != nil {
return nil, fmt.Errorf("read data chunk: %w", err)
}
dataFound = true

return &WAVSource{
frames: frames,
SampleRate: int(sampleRate),
Channels: int(channels),
}, nil
default:
// Skip unknown chunks, respecting RIFF padding (chunks are word-aligned)
skip := int64(chunkSize)
if chunkSize%2 != 0 {
skip++
}
if _, err := io.CopyN(io.Discard, f, skip); err != nil {
// Could be EOF if this is the last chunk
break
}
}

if fmtFound && dataFound {
break
}
}

if !fmtFound {
return nil, fmt.Errorf("no fmt chunk found")
}
if !dataFound {
return nil, fmt.Errorf("no data chunk found")
}
if audioFormat != 1 {
return nil, fmt.Errorf("only PCM wav supported (format=%d)", audioFormat)
}
if bitsPerSample != 16 {
return nil, fmt.Errorf("only 16-bit PCM wav supported (bits=%d)", bitsPerSample)
}
if channels != 1 && channels != 2 {
return nil, fmt.Errorf("only mono/stereo wav supported (channels=%d)", channels)
}
if sampleRate == 0 {
return nil, fmt.Errorf("invalid wav sample rate")
}

step := int(channels) * 2
frames := make([]Frame, 0, len(dataBytes)/step)
for i := 0; i+step <= len(dataBytes); i += step {
l := pcm16ToSample(int16(binary.LittleEndian.Uint16(dataBytes[i : i+2])))
r := l
if channels == 2 {
r = pcm16ToSample(int16(binary.LittleEndian.Uint16(dataBytes[i+2 : i+4])))
}
frames = append(frames, NewFrame(l, r))
}

return &WAVSource{
frames: frames,
SampleRate: int(sampleRate),
Channels: int(channels),
}, nil
}

// NextFrame returns the next audio frame, looping at the end.
func (s *WAVSource) NextFrame() Frame {
if len(s.frames) == 0 {
return NewFrame(0, 0)
}
frame := s.frames[s.index]
s.index++
if s.index >= len(s.frames) {
s.index = 0
}
return frame
if len(s.frames) == 0 {
return NewFrame(0, 0)
}
frame := s.frames[s.index]
s.index++
if s.index >= len(s.frames) {
s.index = 0
}
return frame
}

func pcm16ToSample(v int16) Sample {
return Sample(float64(v) / 32768.0).Clamp()
return Sample(float64(v) / 32768.0).Clamp()
}

+ 142
- 39
internal/audio/wav_test.go Voir le fichier

@@ -1,52 +1,155 @@
package audio

import (
"os"
"path/filepath"
"testing"
"encoding/binary"
"os"
"path/filepath"
"testing"
)

func TestPCM16ToSample(t *testing.T) {
if pcm16ToSample(32767) <= 0 {
t.Fatal("expected positive sample")
}
if pcm16ToSample(-32768) < -1.0 {
t.Fatal("expected clamped lower bound")
}
if pcm16ToSample(32767) <= 0 {
t.Fatal("expected positive sample")
}
if pcm16ToSample(-32768) < -1.0 {
t.Fatal("expected clamped lower bound")
}
}

func TestLoadWAVSource(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.wav")
wav := []byte{
'R','I','F','F', 52,0,0,0, 'W','A','V','E',
'f','m','t',' ', 16,0,0,0, 1,0, 1,0, 0x80,0xbb,0x00,0x00, 0x00,0x77,0x01,0x00, 2,0, 16,0,
'd','a','t','a', 8,0,0,0,
0,0, 255,127, 0,128, 0,0,
}
if err := os.WriteFile(path, wav, 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
src, err := LoadWAVSource(path)
if err != nil {
t.Fatalf("LoadWAVSource failed: %v", err)
}
if src.SampleRate != 48000 {
t.Fatalf("unexpected sample rate: %d", src.SampleRate)
}
if src.Channels != 1 {
t.Fatalf("unexpected channels: %d", src.Channels)
}
_ = src.NextFrame()
dir := t.TempDir()
path := filepath.Join(dir, "test.wav")
wav := buildMinimalWAV(48000, 1, []int16{0, 32767, -32768, 0})
if err := os.WriteFile(path, wav, 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
src, err := LoadWAVSource(path)
if err != nil {
t.Fatalf("LoadWAVSource failed: %v", err)
}
if src.SampleRate != 48000 {
t.Fatalf("unexpected sample rate: %d", src.SampleRate)
}
if src.Channels != 1 {
t.Fatalf("unexpected channels: %d", src.Channels)
}
_ = src.NextFrame()
}

func TestLoadWAVWithExtraChunks(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "extra.wav")

// Build a WAV with a LIST chunk between fmt and data
wav := buildWAVWithExtraChunks(44100, 1, []int16{100, -100})
if err := os.WriteFile(path, wav, 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
src, err := LoadWAVSource(path)
if err != nil {
t.Fatalf("LoadWAVSource with extra chunks failed: %v", err)
}
if src.SampleRate != 44100 {
t.Fatalf("unexpected sample rate: %d", src.SampleRate)
}
if len(src.frames) != 2 {
t.Fatalf("expected 2 frames, got %d", len(src.frames))
}
}

func TestRejectInvalidWAV(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad.wav")
if err := os.WriteFile(path, []byte("nope"), 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
if _, err := LoadWAVSource(path); err == nil {
t.Fatal("expected wav load error")
}
dir := t.TempDir()
path := filepath.Join(dir, "bad.wav")
if err := os.WriteFile(path, []byte("nope"), 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
if _, err := LoadWAVSource(path); err == nil {
t.Fatal("expected wav load error")
}
}

func TestStereoWAV(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "stereo.wav")
wav := buildMinimalWAV(48000, 2, []int16{1000, -1000, 2000, -2000})
if err := os.WriteFile(path, wav, 0o644); err != nil {
t.Fatalf("write wav: %v", err)
}
src, err := LoadWAVSource(path)
if err != nil {
t.Fatalf("LoadWAVSource stereo failed: %v", err)
}
if src.Channels != 2 {
t.Fatalf("expected 2 channels, got %d", src.Channels)
}
if len(src.frames) != 2 {
t.Fatalf("expected 2 frames, got %d", len(src.frames))
}
}

// -- helpers --

func buildMinimalWAV(sampleRate int, channels int, samples []int16) []byte {
dataSize := len(samples) * 2
fileSize := 36 + dataSize

buf := make([]byte, 0, 44+dataSize)
buf = append(buf, []byte("RIFF")...)
buf = binary.LittleEndian.AppendUint32(buf, uint32(fileSize))
buf = append(buf, []byte("WAVE")...)

// fmt chunk
buf = append(buf, []byte("fmt ")...)
buf = binary.LittleEndian.AppendUint32(buf, 16)
buf = binary.LittleEndian.AppendUint16(buf, 1) // PCM
buf = binary.LittleEndian.AppendUint16(buf, uint16(channels))
buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate))
buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate*channels*2)) // byte rate
buf = binary.LittleEndian.AppendUint16(buf, uint16(channels*2)) // block align
buf = binary.LittleEndian.AppendUint16(buf, 16) // bits per sample

// data chunk
buf = append(buf, []byte("data")...)
buf = binary.LittleEndian.AppendUint32(buf, uint32(dataSize))
for _, s := range samples {
buf = binary.LittleEndian.AppendUint16(buf, uint16(s))
}
return buf
}

func buildWAVWithExtraChunks(sampleRate int, channels int, samples []int16) []byte {
dataSize := len(samples) * 2
// Add a fake LIST chunk of 12 bytes between fmt and data
listChunkData := []byte("INFOtest") // 8 bytes
listChunkSize := uint32(len(listChunkData))
totalExtraChunk := 8 + len(listChunkData) // "LIST" + size + data
fileSize := 36 + totalExtraChunk + dataSize

buf := make([]byte, 0, 44+totalExtraChunk+dataSize)
buf = append(buf, []byte("RIFF")...)
buf = binary.LittleEndian.AppendUint32(buf, uint32(fileSize))
buf = append(buf, []byte("WAVE")...)

// fmt chunk
buf = append(buf, []byte("fmt ")...)
buf = binary.LittleEndian.AppendUint32(buf, 16)
buf = binary.LittleEndian.AppendUint16(buf, 1)
buf = binary.LittleEndian.AppendUint16(buf, uint16(channels))
buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate))
buf = binary.LittleEndian.AppendUint32(buf, uint32(sampleRate*channels*2))
buf = binary.LittleEndian.AppendUint16(buf, uint16(channels*2))
buf = binary.LittleEndian.AppendUint16(buf, 16)

// LIST chunk (extra, should be skipped)
buf = append(buf, []byte("LIST")...)
buf = binary.LittleEndian.AppendUint32(buf, listChunkSize)
buf = append(buf, listChunkData...)

// data chunk
buf = append(buf, []byte("data")...)
buf = binary.LittleEndian.AppendUint32(buf, uint32(dataSize))
for _, s := range samples {
buf = binary.LittleEndian.AppendUint16(buf, uint16(s))
}
return buf
}

Chargement…
Annuler
Enregistrer