diff --git a/cmd/sdrd/dsp_loop.go b/cmd/sdrd/dsp_loop.go index 2ec918c..7756795 100644 --- a/cmd/sdrd/dsp_loop.go +++ b/cmd/sdrd/dsp_loop.go @@ -183,7 +183,7 @@ func runDSP(ctx context.Context, srcMgr *sourceManager, cfg config.Config, det * if i < len(snips) { snip = snips[i] } - cls := classifier.Classify(classifier.SignalInput{FirstBin: signals[i].FirstBin, LastBin: signals[i].LastBin, SNRDb: signals[i].SNRDb, CenterHz: signals[i].CenterHz}, spectrum, cfg.SampleRate, cfg.FFTSize, snip) + cls := classifier.Classify(classifier.SignalInput{FirstBin: signals[i].FirstBin, LastBin: signals[i].LastBin, SNRDb: signals[i].SNRDb, CenterHz: signals[i].CenterHz}, spectrum, cfg.SampleRate, cfg.FFTSize, snip, classifier.ClassifierMode(cfg.ClassifierMode)) signals[i].Class = cls } det.UpdateClasses(signals) diff --git a/internal/classifier/classifier.go b/internal/classifier/classifier.go index 4c751a7..e2c3148 100644 --- a/internal/classifier/classifier.go +++ b/internal/classifier/classifier.go @@ -1,7 +1,14 @@ package classifier -// Classify builds features and applies the rule-based classifier. -func Classify(input SignalInput, spectrum []float64, sampleRate int, fftSize int, iq []complex64) *Classification { +type ClassifierMode string + +const ( + ModeRule ClassifierMode = "rule" + ModeMath ClassifierMode = "math" + ModeCombined ClassifierMode = "combined" +) + +func Classify(input SignalInput, spectrum []float64, sampleRate int, fftSize int, iq []complex64, mode ClassifierMode) *Classification { if len(spectrum) == 0 || input.FirstBin < 0 || input.LastBin < 0 { return nil } @@ -13,6 +20,25 @@ func Classify(input SignalInput, spectrum []float64, sampleRate int, fftSize int feat.InstFreqStd = instStd feat.CrestFactor = crest } - cls := RuleClassify(feat, input.CenterHz, input.SNRDb) + var cls Classification + switch mode { + case ModeMath: + if len(iq) > 0 { + mf := ExtractMathFeatures(iq) + cls = MathClassify(mf, feat.BW3dB, input.CenterHz, input.SNRDb) + cls.Features = feat + } else { + cls = RuleClassify(feat, input.CenterHz, input.SNRDb) + } + case ModeCombined: + if len(iq) > 0 { + mf := ExtractMathFeatures(iq) + cls = CombinedClassify(feat, mf, input.CenterHz, input.SNRDb) + } else { + cls = RuleClassify(feat, input.CenterHz, input.SNRDb) + } + default: + cls = RuleClassify(feat, input.CenterHz, input.SNRDb) + } return &cls } diff --git a/internal/classifier/combined.go b/internal/classifier/combined.go new file mode 100644 index 0000000..f425b00 --- /dev/null +++ b/internal/classifier/combined.go @@ -0,0 +1,37 @@ +package classifier + +func CombinedClassify(feat Features, mf MathFeatures, centerHz float64, snrDb float64) Classification { + ruleCls := RuleClassify(feat, centerHz, snrDb) + mathCls := MathClassify(mf, feat.BW3dB, centerHz, snrDb) + combined := map[SignalClass]float64{} + for k, v := range ruleCls.Scores { + combined[k] += v * 0.4 + } + for k, v := range mathCls.Scores { + combined[k] += v * 0.6 + } + best, _, second, _ := top2(combined) + if best == "" { + best = ClassUnknown + } + if second == "" { + second = ClassUnknown + } + conf := softmaxConfidence(combined, best) + if snrDb < 20 { + snrFactor := clamp01((snrDb - 3) / 17.0) + conf *= 0.3 + 0.7*snrFactor + } + if conf <= 0 { + conf = 0.1 + } + return Classification{ + ModType: best, + Confidence: conf, + BW3dB: feat.BW3dB, + Features: feat, + MathFeatures: &mf, + SecondBest: second, + Scores: combined, + } +} diff --git a/internal/classifier/math_classify.go b/internal/classifier/math_classify.go new file mode 100644 index 0000000..5874af6 --- /dev/null +++ b/internal/classifier/math_classify.go @@ -0,0 +1,229 @@ +package classifier + +import "math" + +type MathFeatures struct { + EnvCoV float64 `json:"env_cov"` + EnvKurtosis float64 `json:"env_kurtosis"` + InstFreqStd float64 `json:"inst_freq_std"` + InstFreqRange float64 `json:"inst_freq_range"` + AMIndex float64 `json:"am_index"` + FMIndex float64 `json:"fm_index"` + InstFreqModes int `json:"inst_freq_modes"` +} + +func ExtractMathFeatures(iq []complex64) MathFeatures { + if len(iq) < 10 { + return MathFeatures{} + } + n := len(iq) + env := make([]float64, n) + var envMean float64 + for i, v := range iq { + a := math.Hypot(float64(real(v)), float64(imag(v))) + env[i] = a + envMean += a + } + envMean /= float64(n) + var envVar, envM4 float64 + for _, a := range env { + d := a - envMean + envVar += d * d + envM4 += d * d * d * d + } + envVar /= float64(n) + envM4 /= float64(n) + envStd := math.Sqrt(envVar) + envCoV := 0.0 + if envMean > 1e-12 { + envCoV = envStd / envMean + } + envKurtosis := 0.0 + if envVar > 1e-20 { + envKurtosis = envM4 / (envVar * envVar) + } + instFreq := make([]float64, n-1) + var ifMean float64 + ifMin := math.Inf(1) + ifMax := math.Inf(-1) + for i := 1; i < n; i++ { + p := iq[i-1] + c := iq[i] + num := float64(real(p))*float64(imag(c)) - float64(imag(p))*float64(real(c)) + den := float64(real(p))*float64(real(c)) + float64(imag(p))*float64(imag(c)) + f := math.Atan2(num, den) + instFreq[i-1] = f + ifMean += f + if f < ifMin { + ifMin = f + } + if f > ifMax { + ifMax = f + } + } + ifMean /= float64(n - 1) + var ifVar float64 + for _, f := range instFreq { + d := f - ifMean + ifVar += d * d + } + ifVar /= float64(n - 1) + ifStd := math.Sqrt(ifVar) + ifRange := ifMax - ifMin + modes := countHistogramPeaks(instFreq, 32) + amIndex := envCoV / math.Max(ifStd, 0.001) + fmIndex := ifStd / math.Max(envCoV, 0.001) + return MathFeatures{ + EnvCoV: envCoV, + EnvKurtosis: envKurtosis, + InstFreqStd: ifStd, + InstFreqRange: ifRange, + AMIndex: amIndex, + FMIndex: fmIndex, + InstFreqModes: modes, + } +} + +func countHistogramPeaks(vals []float64, bins int) int { + if len(vals) == 0 || bins < 3 { + return 0 + } + minV, maxV := vals[0], vals[0] + for _, v := range vals { + if v < minV { + minV = v + } + if v > maxV { + maxV = v + } + } + span := maxV - minV + if span < 1e-10 { + return 1 + } + hist := make([]int, bins) + for _, v := range vals { + idx := int(float64(bins-1) * (v - minV) / span) + if idx >= bins { + idx = bins - 1 + } + if idx < 0 { + idx = 0 + } + hist[idx]++ + } + smooth := make([]int, bins) + maxSmooth := 0 + for i := range hist { + s := hist[i] + if i > 0 { + s += hist[i-1] + } + if i < bins-1 { + s += hist[i+1] + } + smooth[i] = s + if s > maxSmooth { + maxSmooth = s + } + } + peaks := 0 + for i := 1; i < bins-1; i++ { + if smooth[i] > smooth[i-1] && smooth[i] > smooth[i+1] { + if float64(smooth[i]) > 0.1*float64(maxSmooth) { + peaks++ + } + } + } + if peaks == 0 { + peaks = 1 + } + return peaks +} + +func MathClassify(mf MathFeatures, bw float64, centerHz float64, snrDb float64) Classification { + scores := map[SignalClass]float64{} + if bw < 500 && mf.InstFreqStd < 0.15 { + scores[ClassCW] += 3.0 + } + if mf.AMIndex > 3.0 { + scores[ClassAM] += 2.0 + } else if mf.AMIndex > 1.5 { + scores[ClassAM] += 1.0 + } + if mf.FMIndex > 5.0 && mf.EnvCoV < 0.1 { + if bw >= 80e3 { + scores[ClassWFM] += 2.5 + } else if bw >= 6e3 { + scores[ClassNFM] += 2.5 + } else { + scores[ClassNFM] += 1.5 + } + } else if mf.FMIndex > 2.0 && mf.EnvCoV < 0.15 { + if bw >= 50e3 { + scores[ClassWFM] += 1.5 + } else { + scores[ClassNFM] += 1.5 + } + } + if mf.AMIndex > 0.5 && mf.AMIndex < 3.0 && mf.FMIndex > 0.5 && mf.FMIndex < 3.0 { + if bw >= 2000 && bw <= 4000 { + scores[ClassSSBUSB] += 1.5 + scores[ClassSSBLSB] += 1.5 + } + } + if bw < 500 && mf.EnvKurtosis > 5.0 && mf.InstFreqStd < 0.1 { + scores[ClassCW] += 2.5 + } else if bw < 200 && mf.InstFreqStd < 0.15 { + scores[ClassCW] += 1.5 + } + if bw < 500 { + scores[ClassAM] *= 0.4 + } + if mf.EnvCoV < 0.05 && mf.InstFreqModes >= 2 { + if bw >= 10000 && bw <= 14000 { + scores[ClassDMR] += 2.0 + } else if bw >= 5000 && bw <= 8000 { + scores[ClassDStar] += 1.8 + } else { + scores[ClassFSK] += 1.5 + } + } + if mf.EnvCoV < 0.08 && mf.InstFreqModes <= 1 && mf.InstFreqStd < 0.3 { + if bw >= 100 && bw < 500 { + scores[ClassWSPR] += 1.3 + } + if bw >= 100 && bw < 3000 { + scores[ClassPSK] += 1.0 + } + } + if mf.EnvCoV < 0.15 && mf.InstFreqModes >= 3 && bw >= 2000 && bw < 3500 { + scores[ClassFT8] += 1.8 + } + if mf.AMIndex < 0.5 && mf.FMIndex < 0.5 && bw > 2000 { + scores[ClassNoise] += 1.0 + } + best, _, second, _ := top2(scores) + if best == "" { + best = ClassUnknown + } + if second == "" { + second = ClassUnknown + } + conf := softmaxConfidence(scores, best) + if snrDb < 20 { + snrFactor := clamp01((snrDb - 3) / 17.0) + conf *= 0.3 + 0.7*snrFactor + } + if math.IsNaN(conf) || conf <= 0 { + conf = 0.1 + } + return Classification{ + ModType: best, + Confidence: conf, + BW3dB: bw, + SecondBest: second, + Scores: scores, + MathFeatures: &mf, + } +} diff --git a/internal/classifier/math_classify_test.go b/internal/classifier/math_classify_test.go new file mode 100644 index 0000000..aa91047 --- /dev/null +++ b/internal/classifier/math_classify_test.go @@ -0,0 +1,81 @@ +package classifier + +import ( + "math" + "testing" +) + +func makeToneIQ(n int, freqNorm float64, am float64) []complex64 { + iq := make([]complex64, n) + for i := range iq { + phase := 2 * math.Pi * freqNorm * float64(i) + env := 1.0 + am*math.Sin(2*math.Pi*0.01*float64(i)) + iq[i] = complex(float32(env*math.Cos(phase)), float32(env*math.Sin(phase))) + } + return iq +} + +func TestMathClassifyAM(t *testing.T) { + iq := makeToneIQ(4096, 0.1, 0.8) + mf := ExtractMathFeatures(iq) + if mf.AMIndex < 1.5 { + t.Errorf("AM signal should have high AMIndex: got %.2f", mf.AMIndex) + } + cls := MathClassify(mf, 8000, 121.5e6, 25) + if cls.ModType != ClassAM { + t.Errorf("expected AM, got %s (scores: %v)", cls.ModType, cls.Scores) + } +} + +func TestMathClassifyFM(t *testing.T) { + n := 4096 + iq := make([]complex64, n) + phase := 0.0 + for i := range iq { + freqDev := 0.3 * math.Sin(2*math.Pi*0.005*float64(i)) + phase += 2 * math.Pi * (0.1 + freqDev) + iq[i] = complex(float32(math.Cos(phase)), float32(math.Sin(phase))) + } + mf := ExtractMathFeatures(iq) + if mf.FMIndex < 2.0 { + t.Errorf("FM signal should have high FMIndex: got %.2f", mf.FMIndex) + } + if mf.EnvCoV > 0.1 { + t.Errorf("FM signal should have low EnvCoV: got %.3f", mf.EnvCoV) + } + cls := MathClassify(mf, 12000, 145.5e6, 25) + if cls.ModType != ClassNFM { + t.Errorf("expected NFM, got %s (scores: %v)", cls.ModType, cls.Scores) + } +} + +func TestMathClassifyCW(t *testing.T) { + n := 4096 + iq := make([]complex64, n) + for i := range iq { + phase := 2 * math.Pi * 0.05 * float64(i) + iq[i] = complex(float32(math.Cos(phase)), float32(math.Sin(phase))) + } + mf := ExtractMathFeatures(iq) + cls := MathClassify(mf, 100, 7.02e6, 20) + if cls.ModType != ClassCW { + t.Errorf("expected CW, got %s (scores: %v, kurtosis: %.1f)", cls.ModType, cls.Scores, mf.EnvKurtosis) + } +} + +func TestCombinedClassify(t *testing.T) { + n := 4096 + iq := make([]complex64, n) + phase := 0.0 + for i := range iq { + freqDev := 0.2 * math.Sin(2*math.Pi*0.003*float64(i)) + phase += 2 * math.Pi * (0.1 + freqDev) + iq[i] = complex(float32(math.Cos(phase)), float32(math.Sin(phase))) + } + feat := Features{BW3dB: 12000, SpectralFlat: 0.3, PeakToAvg: 1.5, EnvVariance: 0.01, InstFreqStd: 0.8} + mf := ExtractMathFeatures(iq) + cls := CombinedClassify(feat, mf, 145.5e6, 25) + if cls.ModType != ClassNFM { + t.Errorf("expected NFM, got %s (scores: %v)", cls.ModType, cls.Scores) + } +} diff --git a/internal/classifier/types.go b/internal/classifier/types.go index 5cb0b72..404b1f7 100644 --- a/internal/classifier/types.go +++ b/internal/classifier/types.go @@ -39,12 +39,13 @@ type Features struct { // Classification is the classifier output attached to signals/events. type Classification struct { - ModType SignalClass `json:"mod_type"` - Confidence float64 `json:"confidence"` - BW3dB float64 `json:"bw_3db_hz"` - Features Features `json:"features,omitempty"` - SecondBest SignalClass `json:"second_best,omitempty"` - Scores map[SignalClass]float64 `json:"scores,omitempty"` + ModType SignalClass `json:"mod_type"` + Confidence float64 `json:"confidence"` + BW3dB float64 `json:"bw_3db_hz"` + Features Features `json:"features,omitempty"` + MathFeatures *MathFeatures `json:"math_features,omitempty"` + SecondBest SignalClass `json:"second_best,omitempty"` + Scores map[SignalClass]float64 `json:"scores,omitempty"` } // SignalInput is the minimal input needed for classification. diff --git a/internal/config/config.go b/internal/config/config.go index 5701395..3fde881 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -73,6 +73,7 @@ type Config struct { GainDb float64 `yaml:"gain_db" json:"gain_db"` TunerBwKHz int `yaml:"tuner_bw_khz" json:"tuner_bw_khz"` UseGPUFFT bool `yaml:"use_gpu_fft" json:"use_gpu_fft"` + ClassifierMode string `yaml:"classifier_mode" json:"classifier_mode"` AGC bool `yaml:"agc" json:"agc"` DCBlock bool `yaml:"dc_block" json:"dc_block"` IQBalance bool `yaml:"iq_balance" json:"iq_balance"` @@ -97,6 +98,7 @@ func Default() Config { GainDb: 30, TunerBwKHz: 1536, UseGPUFFT: false, + ClassifierMode: "combined", AGC: false, DCBlock: false, IQBalance: false, @@ -246,6 +248,14 @@ func applyDefaults(cfg Config) Config { if cfg.SampleRate <= 0 { cfg.SampleRate = 2_048_000 } + if cfg.ClassifierMode == "" { + cfg.ClassifierMode = "combined" + } + switch cfg.ClassifierMode { + case "rule", "math", "combined": + default: + cfg.ClassifierMode = "combined" + } if cfg.FFTSize <= 0 { cfg.FFTSize = 2048 } diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index fe8129c..026dbd5 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -10,14 +10,15 @@ import ( ) type ConfigUpdate struct { - CenterHz *float64 `json:"center_hz"` - SampleRate *int `json:"sample_rate"` - FFTSize *int `json:"fft_size"` - GainDb *float64 `json:"gain_db"` - TunerBwKHz *int `json:"tuner_bw_khz"` - UseGPUFFT *bool `json:"use_gpu_fft"` - Detector *DetectorUpdate `json:"detector"` - Recorder *RecorderUpdate `json:"recorder"` + CenterHz *float64 `json:"center_hz"` + SampleRate *int `json:"sample_rate"` + FFTSize *int `json:"fft_size"` + GainDb *float64 `json:"gain_db"` + TunerBwKHz *int `json:"tuner_bw_khz"` + UseGPUFFT *bool `json:"use_gpu_fft"` + ClassifierMode *string `json:"classifier_mode"` + Detector *DetectorUpdate `json:"detector"` + Recorder *RecorderUpdate `json:"recorder"` } type DetectorUpdate struct { @@ -123,6 +124,15 @@ func (m *Manager) ApplyConfig(update ConfigUpdate) (config.Config, error) { if update.UseGPUFFT != nil { next.UseGPUFFT = *update.UseGPUFFT } + if update.ClassifierMode != nil { + mode := *update.ClassifierMode + switch mode { + case "rule", "math", "combined": + next.ClassifierMode = mode + default: + return m.cfg, errors.New("classifier_mode must be rule, math, or combined") + } + } if update.Detector != nil { if update.Detector.ThresholdDb != nil { next.Detector.ThresholdDb = *update.Detector.ThresholdDb diff --git a/web/app.js b/web/app.js index bb57003..67145ae 100644 --- a/web/app.js +++ b/web/app.js @@ -30,6 +30,7 @@ const gainRange = qs('gainRange'); const gainInput = qs('gainInput'); const thresholdRange = qs('thresholdRange'); const thresholdInput = qs('thresholdInput'); +const classifierModeSelect = qs('classifierModeSelect'); const cfarModeSelect = qs('cfarModeSelect'); const cfarWrapToggle = qs('cfarWrapToggle'); const cfarGuardHzInput = qs('cfarGuardHzInput'); @@ -408,6 +409,7 @@ function applyConfigToUI(cfg) { gainInput.value = uiGain; thresholdRange.value = cfg.detector.threshold_db; thresholdInput.value = cfg.detector.threshold_db; + if (classifierModeSelect) classifierModeSelect.value = cfg.classifier_mode || 'combined'; if (cfarModeSelect) cfarModeSelect.value = cfg.detector.cfar_mode || 'OFF'; if (cfarWrapToggle) cfarWrapToggle.checked = cfg.detector.cfar_wrap_around !== false; if (cfarGuardHzInput) cfarGuardHzInput.value = cfg.detector.cfar_guard_hz ?? 500; @@ -1377,6 +1379,10 @@ thresholdInput.addEventListener('change', () => { } }); +if (classifierModeSelect) classifierModeSelect.addEventListener('change', () => { + queueConfigUpdate({ classifier_mode: classifierModeSelect.value }); +}); + if (cfarModeSelect) cfarModeSelect.addEventListener('change', () => { queueConfigUpdate({ detector: { cfar_mode: cfarModeSelect.value } }); const rankRow = cfarRankInput?.closest('.field'); diff --git a/web/index.html b/web/index.html index d4a59e7..531d336 100644 --- a/web/index.html +++ b/web/index.html @@ -181,6 +181,17 @@ +
+
Classifier
+ +
+
Detector