diff --git a/cmd/sdrd/dsp_loop.go b/cmd/sdrd/dsp_loop.go index 5980dd0..2ec918c 100644 --- a/cmd/sdrd/dsp_loop.go +++ b/cmd/sdrd/dsp_loop.go @@ -183,7 +183,7 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det * if i < len(snips) { snip = snips[i] } - cls := classifier.Classify(classifier.SignalInput{FirstBin: signals[i].FirstBin, LastBin: signals[i].LastBin, SNRDb: signals[i].SNRDb}, spectrum, cfg.SampleRate, cfg.FFTSize, snip) + cls := classifier.Classify(classifier.SignalInput{FirstBin: signals[i].FirstBin, LastBin: signals[i].LastBin, SNRDb: signals[i].SNRDb, CenterHz: signals[i].CenterHz}, spectrum, cfg.SampleRate, cfg.FFTSize, snip) signals[i].Class = cls } det.UpdateClasses(signals) diff --git a/internal/classifier/classifier.go b/internal/classifier/classifier.go index 1c32fc9..4c751a7 100644 --- a/internal/classifier/classifier.go +++ b/internal/classifier/classifier.go @@ -13,6 +13,6 @@ func Classify(input SignalInput, spectrum []float64, sampleRate int, fftSize int feat.InstFreqStd = instStd feat.CrestFactor = crest } - cls := RuleClassify(feat) + cls := RuleClassify(feat, input.CenterHz, input.SNRDb) return &cls } diff --git a/internal/classifier/classifier_test.go b/internal/classifier/classifier_test.go index afe036b..679c3f4 100644 --- a/internal/classifier/classifier_test.go +++ b/internal/classifier/classifier_test.go @@ -10,11 +10,11 @@ func TestRuleClassifyWFM(t *testing.T) { spectrum[i] = -100 } start := 100 - end := 350 // ~244 bins -> ~238 kHz + end := 350 for i := start; i <= end; i++ { spectrum[i] = -10 } - cls := Classify(SignalInput{FirstBin: start, LastBin: end}, spectrum, sampleRate, fftSize, nil) + cls := Classify(SignalInput{FirstBin: start, LastBin: end, CenterHz: 100e6, SNRDb: 30}, spectrum, sampleRate, fftSize, nil) if cls == nil || cls.ModType != ClassWFM { t.Fatalf("expected WFM, got %+v", cls) } @@ -38,3 +38,72 @@ func TestSoftmaxConfidence(t *testing.T) { t.Fatalf("empty should return 0.1: %f", c3) } } + +func TestClassifierProfiles(t *testing.T) { + tests := []struct { + name string + feat Features + centerHz float64 + snrDb float64 + wantBest SignalClass + }{ + { + name: "FM Broadcast 100 MHz", + feat: Features{BW3dB: 120000, SpectralFlat: 0.3, PeakToAvg: 1.5, Symmetry: 0.05, + RolloffLeft: 20, RolloffRight: 22, EnvVariance: 0.01, InstFreqStd: 0.8}, + centerHz: 100.0e6, snrDb: 40, + wantBest: ClassWFM, + }, + { + name: "FT8 auf 7.074 MHz", + feat: Features{BW3dB: 2500, SpectralFlat: 0.6, PeakToAvg: 1.8, Symmetry: 0.1, + EnvVariance: 0.03, InstFreqStd: 0.4}, + centerHz: 7.074e6, snrDb: 15, + wantBest: ClassFT8, + }, + { + name: "USB Voice 14.230 MHz", + feat: Features{BW3dB: 2800, SpectralFlat: 0.35, PeakToAvg: 3.5, Symmetry: 0.4, + RolloffLeft: 5, RolloffRight: 18, EnvVariance: 0.25, InstFreqStd: 0.6}, + centerHz: 14.230e6, snrDb: 25, + wantBest: ClassSSBUSB, + }, + { + name: "DMR auf 438 MHz", + feat: Features{BW3dB: 12500, SpectralFlat: 0.7, PeakToAvg: 1.2, Symmetry: 0.02, + RolloffLeft: 25, RolloffRight: 24, EnvVariance: 0.01, InstFreqStd: 0.35}, + centerHz: 438.5e6, snrDb: 20, + wantBest: ClassDMR, + }, + { + name: "Airband AM 121.5 MHz", + feat: Features{BW3dB: 7000, SpectralFlat: 0.25, PeakToAvg: 4.0, Symmetry: 0.05, + RolloffLeft: 15, RolloffRight: 16, EnvVariance: 0.2, InstFreqStd: 0.7}, + centerHz: 121.5e6, snrDb: 30, + wantBest: ClassAM, + }, + { + name: "CW auf 7.020 MHz", + feat: Features{BW3dB: 80, SpectralFlat: 0.15, PeakToAvg: 8.0, Symmetry: 0.0, + EnvVariance: 0.9, InstFreqStd: 0.05, CrestFactor: 3.5}, + centerHz: 7.020e6, snrDb: 20, + wantBest: ClassCW, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cls := RuleClassify(tt.feat, tt.centerHz, tt.snrDb) + if cls.ModType != tt.wantBest { + t.Errorf("got %s (conf=%.2f), want %s. Scores: %v", cls.ModType, cls.Confidence, tt.wantBest, cls.Scores) + } + }) + } +} + +func TestLowSNRConfidence(t *testing.T) { + feat := Features{BW3dB: 3000, SpectralFlat: 0.5, PeakToAvg: 1.5} + cls := RuleClassify(feat, 14.2e6, 5) + if cls.Confidence > 0.5 { + t.Errorf("low SNR should have low confidence: got %.2f", cls.Confidence) + } +} diff --git a/internal/classifier/context.go b/internal/classifier/context.go new file mode 100644 index 0000000..ec872a0 --- /dev/null +++ b/internal/classifier/context.go @@ -0,0 +1,89 @@ +package classifier + +import ( + _ "embed" + "encoding/json" + "math" +) + +//go:embed frequency_context.json +var frequencyContextJSON []byte + +type frequencyRange struct { + Name string `json:"name"` + StartMHz float64 `json:"start_mhz"` + EndMHz float64 `json:"end_mhz"` +} + +type frequencyContextConfig struct { + FT8MHz []float64 `json:"ft8_mhz"` + WSPRMHz []float64 `json:"wspr_mhz"` + Ranges []frequencyRange `json:"ranges"` +} + +var frequencyContext = loadFrequencyContext() + +func loadFrequencyContext() frequencyContextConfig { + var cfg frequencyContextConfig + if err := json.Unmarshal(frequencyContextJSON, &cfg); err != nil { + return frequencyContextConfig{} + } + return cfg +} + +func addFrequencyContext(add func(SignalClass, float64), centerHz float64, bw float64) { + mhz := centerHz / 1e6 + for _, r := range frequencyContext.Ranges { + if mhz < r.StartMHz || mhz > r.EndMHz { + continue + } + switch r.Name { + case "hf": + for _, f := range frequencyContext.FT8MHz { + if math.Abs(mhz-f) < 0.003 && bw >= 1500 && bw <= 3500 { + add(ClassFT8, 2.0) + break + } + } + for _, f := range frequencyContext.WSPRMHz { + if math.Abs(mhz-f) < 0.001 && bw >= 100 && bw <= 500 { + add(ClassWSPR, 2.0) + break + } + } + if bw < 500 { + add(ClassCW, 0.5) + } + if bw >= 2000 && bw <= 4000 { + if mhz < 10 { + add(ClassSSBLSB, 0.8) + } else { + add(ClassSSBUSB, 0.8) + } + } + case "vhf_2m": + if bw >= 6000 && bw <= 16000 { + add(ClassNFM, 0.5) + } + if bw >= 2000 && bw <= 4000 { + add(ClassSSBUSB, 0.5) + } + case "uhf_70cm": + if bw >= 6000 && bw <= 16000 { + add(ClassNFM, 0.3) + add(ClassDMR, 0.5) + add(ClassDStar, 0.3) + } + case "pmr446": + add(ClassNFM, 1.0) + case "broadcast_fm": + if bw >= 50000 { + add(ClassWFM, 1.5) + } + case "airband": + if bw >= 5000 && bw <= 10000 { + add(ClassAM, 1.5) + } + } + } +} diff --git a/internal/classifier/features_iq.go b/internal/classifier/features_iq.go index 274bf4e..c271175 100644 --- a/internal/classifier/features_iq.go +++ b/internal/classifier/features_iq.go @@ -19,13 +19,17 @@ func ExtractTemporalFeatures(iq []complex64) (envVar float64, zeroCross float64, } mean /= float64(len(iq)) rms = math.Sqrt(rms / float64(len(iq))) - // env variance + // normalized env variance (coefficient of variation squared) var sumVar float64 for _, v := range env { d := v - mean sumVar += d * d } - envVar = sumVar / float64(len(iq)) + if mean > 1e-12 { + envVar = (sumVar / float64(len(iq))) / (mean * mean) + } else { + envVar = 0 + } if rms > 0 { crest = maxFloat(env) / rms } diff --git a/internal/classifier/frequency_context.json b/internal/classifier/frequency_context.json new file mode 100644 index 0000000..71af5b0 --- /dev/null +++ b/internal/classifier/frequency_context.json @@ -0,0 +1,12 @@ +{ + "ft8_mhz": [1.84, 3.573, 5.357, 7.074, 10.136, 14.074, 18.1, 21.074, 24.915, 28.074], + "wspr_mhz": [1.8366, 3.5926, 7.0386, 10.1387, 14.0956, 18.1046, 21.0946, 24.9246, 28.1246], + "ranges": [ + {"name": "hf", "start_mhz": 1.8, "end_mhz": 30.0}, + {"name": "vhf_2m", "start_mhz": 144.0, "end_mhz": 148.0}, + {"name": "broadcast_fm", "start_mhz": 87.5, "end_mhz": 108.0}, + {"name": "airband", "start_mhz": 118.0, "end_mhz": 137.0}, + {"name": "uhf_70cm", "start_mhz": 430.0, "end_mhz": 440.0}, + {"name": "pmr446", "start_mhz": 446.0, "end_mhz": 446.2} + ] +} diff --git a/internal/classifier/rules.go b/internal/classifier/rules.go index 7e9b660..7b3226b 100644 --- a/internal/classifier/rules.go +++ b/internal/classifier/rules.go @@ -2,7 +2,7 @@ package classifier import "math" -func RuleClassify(feat Features) Classification { +func RuleClassify(feat Features, centerHz float64, snrDb float64) Classification { bw := feat.BW3dB flat := feat.SpectralFlat sym := feat.Symmetry @@ -10,10 +10,13 @@ func RuleClassify(feat Features) Classification { scores := map[SignalClass]float64{} add := func(c SignalClass, w float64) { - if w <= 0 { + if w == 0 { return } scores[c] += w + if scores[c] < 0 { + scores[c] = 0 + } } switch { @@ -23,7 +26,7 @@ func RuleClassify(feat Features) Classification { add(ClassWFM, 1.4) add(ClassNFM, 0.8) case bw >= 6e3 && bw < 25e3: - add(ClassNFM, 2.0) + add(ClassNFM, 1.2) case bw >= 3e3 && bw < 6e3: add(ClassSSBUSB, 0.6) add(ClassSSBLSB, 0.6) @@ -48,10 +51,42 @@ func RuleClassify(feat Features) Classification { } else if sym < -0.2 { add(ClassSSBLSB, 1.2) } - if feat.EnvVariance < 0.6 && feat.InstFreqStd < 0.7 && bw >= 2000 && bw < 3000 { + + rollAvg := (feat.RolloffLeft + feat.RolloffRight) / 2.0 + rollAsym := math.Abs(feat.RolloffLeft - feat.RolloffRight) + if rollAvg > 15 && rollAsym < 5 { + if bw >= 6000 { + add(ClassNFM, 0.4) + } + if bw >= 80000 { + add(ClassWFM, 0.4) + } + if bw >= 3000 && bw <= 10000 { + add(ClassAM, 0.3) + } + } + if rollAsym > 10 && bw >= 2000 && bw <= 4000 { + if feat.RolloffLeft > feat.RolloffRight { + add(ClassSSBLSB, 0.6) + } else { + add(ClassSSBUSB, 0.6) + } + } + + if feat.EnvVariance < 0.08 && bw >= 10000 && bw <= 14000 && flat > 0.55 { + add(ClassDMR, 1.5) + } + if feat.EnvVariance < 0.08 && bw >= 5000 && bw <= 8000 && flat > 0.55 { + add(ClassDStar, 1.3) + } + if feat.EnvVariance < 0.03 && bw >= 5000 && bw <= 16000 { + add(ClassNFM, -0.5) + } + + if feat.EnvVariance < 0.08 && feat.InstFreqStd < 0.7 && bw >= 2000 && bw < 3000 { add(ClassFT8, 1.4) } - if feat.EnvVariance < 0.4 && feat.InstFreqStd < 0.5 && bw >= 150 && bw < 500 { + if feat.EnvVariance < 0.05 && feat.InstFreqStd < 0.5 && bw >= 150 && bw < 500 { add(ClassWSPR, 1.3) } if feat.InstFreqStd > 0.9 { @@ -65,10 +100,12 @@ func RuleClassify(feat Features) Classification { if flat > 0.85 && bw > 2e3 { add(ClassNoise, 1.0) } - if feat.InstFreqStd < 0.5 && feat.EnvVariance < 0.3 && bw >= 6e3 && bw < 25e3 { + if feat.EnvVariance < 0.08 && feat.InstFreqStd < 0.5 && bw >= 6e3 && bw < 25e3 { add(ClassDMR, 0.7) } + addFrequencyContext(add, centerHz, bw) + best, _, second, _ := top2(scores) if best == "" { best = ClassUnknown @@ -79,10 +116,14 @@ func RuleClassify(feat Features) Classification { conf := softmaxConfidence(scores, best) if best == ClassNFM || best == ClassWFM { - conf = conf * (0.8 + 0.2*clamp01(1-flat)) + conf *= 0.8 + 0.2*clamp01(1-flat) } if best == ClassAM { - conf = conf * (0.7 + 0.3*clamp01(p2a/6.0)) + conf *= 0.7 + 0.3*clamp01(p2a/6.0) + } + if snrDb < 20 { + snrFactor := clamp01((snrDb - 3) / 17.0) + conf *= 0.3 + 0.7*snrFactor } if math.IsNaN(conf) || conf <= 0 { conf = 0.1 diff --git a/internal/classifier/types.go b/internal/classifier/types.go index 379e3a6..5cb0b72 100644 --- a/internal/classifier/types.go +++ b/internal/classifier/types.go @@ -52,4 +52,5 @@ type SignalInput struct { FirstBin int LastBin int SNRDb float64 + CenterHz float64 }