From 2e2e83844169d875246a1a6c199a26c04aedf4fe Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 7 Apr 2026 16:08:56 +0200 Subject: [PATCH] ingest: harden icecast decoder fallback handling --- internal/ingest/adapters/icecast/source.go | 24 ++- .../ingest/adapters/icecast/source_test.go | 142 ++++++++++++++++++ internal/ingest/decoder/decoder_test.go | 12 ++ 3 files changed, 177 insertions(+), 1 deletion(-) diff --git a/internal/ingest/adapters/icecast/source.go b/internal/ingest/adapters/icecast/source.go index 0106eff..3601d80 100644 --- a/internal/ingest/adapters/icecast/source.go +++ b/internal/ingest/adapters/icecast/source.go @@ -1,6 +1,7 @@ package icecast import ( + "bytes" "context" "errors" "fmt" @@ -245,11 +246,15 @@ func (s *Source) decodeWithPreference(ctx context.Context, stream io.Reader, met // only when native selection/decode reports "unsupported". native, err := s.decReg.SelectByContentType(meta.ContentType) if err == nil { - if err := native.DecodeStream(ctx, stream, meta, s.emitChunk); err == nil { + captured := &capturingReader{r: stream} + if err := native.DecodeStream(ctx, captured, meta, s.emitChunk); err == nil { return nil } else if !errors.Is(err, decoder.ErrUnsupported) { return err } + // Native decode can consume stream bytes before returning "unsupported". + // Reconstruct a full reader for fallback: consumed prefix + remaining stream. + stream = io.MultiReader(bytes.NewReader(captured.Bytes()), stream) } else if !errors.Is(err, decoder.ErrUnsupported) { return fmt.Errorf("icecast decoder select: %w", err) } @@ -259,6 +264,23 @@ func (s *Source) decodeWithPreference(ctx context.Context, stream io.Reader, met } } +type capturingReader struct { + r io.Reader + buf bytes.Buffer +} + +func (r *capturingReader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + if n > 0 { + _, _ = r.buf.Write(p[:n]) + } + return n, err +} + +func (r *capturingReader) Bytes() []byte { + return r.buf.Bytes() +} + func (s *Source) decodeNamed(ctx context.Context, name string, stream io.Reader, meta decoder.StreamMeta) error { dec, err := s.decReg.Create(name) if err != nil { diff --git a/internal/ingest/adapters/icecast/source_test.go b/internal/ingest/adapters/icecast/source_test.go index 84b4572..8171ebe 100644 --- a/internal/ingest/adapters/icecast/source_test.go +++ b/internal/ingest/adapters/icecast/source_test.go @@ -24,6 +24,38 @@ func (d *testDecoder) DecodeStream(_ context.Context, _ io.Reader, _ decoder.Str return d.err } +type consumingUnsupportedDecoder struct { + n int + called int +} + +func (d *consumingUnsupportedDecoder) Name() string { return "native-consuming-unsupported" } + +func (d *consumingUnsupportedDecoder) DecodeStream(_ context.Context, r io.Reader, _ decoder.StreamMeta, _ func(ingest.PCMChunk) error) error { + d.called++ + buf := make([]byte, d.n) + _, _ = io.ReadFull(r, buf) + return decoder.ErrUnsupported +} + +type captureStreamDecoder struct { + name string + called int + payload []byte +} + +func (d *captureStreamDecoder) Name() string { return d.name } + +func (d *captureStreamDecoder) DecodeStream(_ context.Context, r io.Reader, _ decoder.StreamMeta, _ func(ingest.PCMChunk) error) error { + d.called++ + data, err := io.ReadAll(r) + if err != nil { + return err + } + d.payload = data + return nil +} + func TestDecodeWithPreferenceAutoFallsBackFromNativeUnsupported(t *testing.T) { native := &testDecoder{name: "native", err: decoder.ErrUnsupported} fallback := &testDecoder{name: "ffmpeg"} @@ -156,6 +188,116 @@ func TestDecodeWithPreferenceAutoUsesOggNativeForOggContentType(t *testing.T) { } } +func TestDecodeWithPreferenceAutoUsesMP3NativeForMPEGContentType(t *testing.T) { + mp3Native := &testDecoder{name: "mp3"} + fallback := &testDecoder{name: "ffmpeg"} + + reg := decoder.NewRegistry() + reg.Register("mp3", func() decoder.Decoder { return mp3Native }) + reg.Register("ffmpeg", func() decoder.Decoder { return fallback }) + + src := New("ice-test", "http://example", nil, ReconnectConfig{}, + WithDecoderRegistry(reg), + WithDecoderPreference("auto"), + ) + + err := src.decodeWithPreference(context.Background(), bytes.NewReader(nil), decoder.StreamMeta{ + ContentType: "audio/mpeg; charset=utf-8", + SourceID: "ice-test", + }) + if err != nil { + t.Fatalf("decode: %v", err) + } + if mp3Native.called != 1 { + t.Fatalf("mp3 native decoder called %d times", mp3Native.called) + } + if fallback.called != 0 { + t.Fatalf("fallback should not be called, got %d", fallback.called) + } +} + +func TestDecodeWithPreferenceAutoNativeErrorDoesNotFallback(t *testing.T) { + nativeErr := errors.New("native hard failure") + mp3Native := &testDecoder{name: "mp3", err: nativeErr} + fallback := &testDecoder{name: "ffmpeg"} + + reg := decoder.NewRegistry() + reg.Register("mp3", func() decoder.Decoder { return mp3Native }) + reg.Register("ffmpeg", func() decoder.Decoder { return fallback }) + + src := New("ice-test", "http://example", nil, ReconnectConfig{}, + WithDecoderRegistry(reg), + WithDecoderPreference("auto"), + ) + + err := src.decodeWithPreference(context.Background(), bytes.NewReader(nil), decoder.StreamMeta{ + ContentType: "audio/mpeg", + SourceID: "ice-test", + }) + if !errors.Is(err, nativeErr) { + t.Fatalf("expected native error, got %v", err) + } + if fallback.called != 0 { + t.Fatalf("fallback should not be called on native hard error, got %d", fallback.called) + } +} + +func TestDecodeWithPreferenceAutoFallbackSeesFullStreamAfterNativeConsumesPrefix(t *testing.T) { + const consumed = 4 + input := []byte("0123456789abcdef") + + native := &consumingUnsupportedDecoder{n: consumed} + fallback := &captureStreamDecoder{name: "ffmpeg"} + + reg := decoder.NewRegistry() + reg.Register("mp3", func() decoder.Decoder { return native }) + reg.Register("ffmpeg", func() decoder.Decoder { return fallback }) + + src := New("ice-test", "http://example", nil, ReconnectConfig{}, + WithDecoderRegistry(reg), + WithDecoderPreference("auto"), + ) + + err := src.decodeWithPreference(context.Background(), bytes.NewReader(input), decoder.StreamMeta{ + ContentType: "audio/mpeg", + SourceID: "ice-test", + }) + if err != nil { + t.Fatalf("decode: %v", err) + } + if native.called != 1 { + t.Fatalf("native called %d times", native.called) + } + if fallback.called != 1 { + t.Fatalf("fallback called %d times", fallback.called) + } + if !bytes.Equal(fallback.payload, input) { + t.Fatalf("fallback payload mismatch: got %q want %q", string(fallback.payload), string(input)) + } +} + +func TestDecodeWithPreferenceNativeUnsupportedContentTypeFailsWithoutFallback(t *testing.T) { + fallback := &testDecoder{name: "ffmpeg"} + reg := decoder.NewRegistry() + reg.Register("ffmpeg", func() decoder.Decoder { return fallback }) + + src := New("ice-test", "http://example", nil, ReconnectConfig{}, + WithDecoderRegistry(reg), + WithDecoderPreference("native"), + ) + + err := src.decodeWithPreference(context.Background(), bytes.NewReader(nil), decoder.StreamMeta{ + ContentType: "application/octet-stream", + SourceID: "ice-test", + }) + if err == nil { + t.Fatal("expected native-mode select error for unsupported content-type") + } + if fallback.called != 0 { + t.Fatalf("fallback should not be called in native mode, got %d", fallback.called) + } +} + func TestWithDecoderPreferenceFallbackAliasNormalizesToFFmpeg(t *testing.T) { src := New("ice-test", "http://example", nil, ReconnectConfig{}, WithDecoderPreference("fallback")) if got := src.Descriptor().Codec; got != "ffmpeg" { diff --git a/internal/ingest/decoder/decoder_test.go b/internal/ingest/decoder/decoder_test.go index b304d79..7e724bc 100644 --- a/internal/ingest/decoder/decoder_test.go +++ b/internal/ingest/decoder/decoder_test.go @@ -2,6 +2,7 @@ package decoder import ( "context" + "errors" "io" "testing" @@ -27,8 +28,11 @@ func TestRegistrySelectByContentType(t *testing.T) { want string }{ {"audio/mpeg", "mp3"}, + {"audio/mpeg; charset=utf-8", "mp3"}, {"application/ogg", "ogg"}, + {"audio/ogg;codecs=vorbis", "ogg"}, {"audio/aac", "aac"}, + {"audio/aacp", "aac"}, } for _, tt := range tests { dec, err := r.SelectByContentType(tt.ct) @@ -40,3 +44,11 @@ func TestRegistrySelectByContentType(t *testing.T) { } } } + +func TestRegistrySelectByContentTypeUnsupported(t *testing.T) { + r := NewRegistry() + _, err := r.SelectByContentType("application/octet-stream") + if !errors.Is(err, ErrUnsupported) { + t.Fatalf("expected ErrUnsupported, got %v", err) + } +}