You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 line
3.5KB

  1. //go:build cufft && windows
  2. package gpudemod
  3. /*
  4. #include <stdlib.h>
  5. #include <cuda_runtime.h>
  6. typedef struct { float x; float y; } gpud_float2;
  7. typedef void* gpud_stream_handle;
  8. extern int gpud_stream_create(gpud_stream_handle* out);
  9. extern int gpud_stream_destroy(gpud_stream_handle stream);
  10. extern int gpud_stream_sync(gpud_stream_handle stream);
  11. extern int gpud_launch_freq_shift_stream(gpud_float2 *in, gpud_float2 *out, int n, double phase_inc, double phase_start, gpud_stream_handle stream);
  12. extern int gpud_launch_fir_stream(gpud_float2 *in, gpud_float2 *out, int n, int num_taps, gpud_stream_handle stream);
  13. extern int gpud_launch_decimate_stream(gpud_float2 *in, gpud_float2 *out, int n_out, int factor, gpud_stream_handle stream);
  14. extern int gpud_memcpy_h2d(void *dst, const void *src, size_t bytes);
  15. extern int gpud_memcpy_d2h(void *dst, const void *src, size_t bytes);
  16. */
  17. import "C"
  18. import (
  19. "errors"
  20. "math"
  21. "unsafe"
  22. "sdr-visual-suite/internal/dsp"
  23. )
  24. func (r *BatchRunner) shiftFilterDecimateBatchImpl(iq []complex64) ([][]complex64, []int, error) {
  25. outs := make([][]complex64, len(r.slots))
  26. rates := make([]int, len(r.slots))
  27. streams := make([]C.gpud_stream_handle, len(r.slots))
  28. for i := range streams {
  29. _ = C.gpud_stream_create(&streams[i])
  30. }
  31. defer func() {
  32. for _, s := range streams {
  33. if s != nil {
  34. _ = C.gpud_stream_destroy(s)
  35. }
  36. }
  37. }()
  38. for i := range r.slots {
  39. if !r.slots[i].active {
  40. continue
  41. }
  42. out, rate, err := r.shiftFilterDecimateSlot(iq, r.slots[i].job, streams[i])
  43. if err != nil {
  44. return nil, nil, err
  45. }
  46. r.slots[i].out = out
  47. r.slots[i].rate = rate
  48. outs[i] = out
  49. rates[i] = rate
  50. }
  51. return outs, rates, nil
  52. }
  53. func (r *BatchRunner) shiftFilterDecimateSlot(iq []complex64, job ExtractJob, stream C.gpud_stream_handle) ([]complex64, int, error) {
  54. e := r.eng
  55. if e == nil || !e.cudaReady {
  56. return nil, 0, ErrUnavailable
  57. }
  58. if len(iq) == 0 {
  59. return nil, 0, nil
  60. }
  61. cutoff := job.BW / 2
  62. if cutoff < 200 {
  63. cutoff = 200
  64. }
  65. taps := e.firTaps
  66. if len(taps) == 0 {
  67. base64 := dsp.LowpassFIR(cutoff, e.sampleRate, 101)
  68. taps = make([]float32, len(base64))
  69. for i, v := range base64 {
  70. taps[i] = float32(v)
  71. }
  72. e.SetFIR(taps)
  73. }
  74. decim := int(math.Round(float64(e.sampleRate) / float64(job.OutRate)))
  75. if decim < 1 {
  76. decim = 1
  77. }
  78. n := len(iq)
  79. nOut := n / decim
  80. if nOut <= 0 {
  81. return nil, 0, errors.New("not enough output samples after decimation")
  82. }
  83. bytesIn := C.size_t(n) * C.size_t(unsafe.Sizeof(complex64(0)))
  84. if C.gpud_memcpy_h2d(unsafe.Pointer(e.dIQIn), unsafe.Pointer(&iq[0]), bytesIn) != C.cudaSuccess {
  85. return nil, 0, errors.New("cudaMemcpy H2D failed")
  86. }
  87. phaseInc := -2.0 * math.Pi * job.OffsetHz / float64(e.sampleRate)
  88. if C.gpud_launch_freq_shift_stream(e.dIQIn, e.dShifted, C.int(n), C.double(phaseInc), C.double(e.phase), stream) != 0 {
  89. return nil, 0, errors.New("gpu freq shift failed")
  90. }
  91. if C.gpud_launch_fir_stream(e.dShifted, e.dFiltered, C.int(n), C.int(len(taps)), stream) != 0 {
  92. return nil, 0, errors.New("gpu FIR failed")
  93. }
  94. if C.gpud_launch_decimate_stream(e.dFiltered, e.dDecimated, C.int(nOut), C.int(decim), stream) != 0 {
  95. return nil, 0, errors.New("gpu decimate failed")
  96. }
  97. if C.gpud_stream_sync(stream) != 0 {
  98. return nil, 0, errors.New("cuda stream sync failed")
  99. }
  100. out := make([]complex64, nOut)
  101. outBytes := C.size_t(nOut) * C.size_t(unsafe.Sizeof(complex64(0)))
  102. if C.gpud_memcpy_d2h(unsafe.Pointer(&out[0]), unsafe.Pointer(e.dDecimated), outBytes) != C.cudaSuccess {
  103. return nil, 0, errors.New("cudaMemcpy D2H failed")
  104. }
  105. return out, e.sampleRate / decim, nil
  106. }