Wideband autonomous SDR analysis engine forked from sdr-visual-suite
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 line
2.1KB

  1. //go:build cufft
  2. package gpufft
  3. /*
  4. #cgo windows LDFLAGS: -lcufft64_12 -lcudart64_13
  5. #include <cuda_runtime.h>
  6. #include <cufft.h>
  7. */
  8. import "C"
  9. import (
  10. "errors"
  11. "fmt"
  12. "unsafe"
  13. )
  14. type Engine struct {
  15. plan C.cufftHandle
  16. n int
  17. data *C.cufftComplex
  18. bytes C.size_t
  19. }
  20. func Available() bool {
  21. var count C.int
  22. if C.cudaGetDeviceCount(&count) != C.cudaSuccess {
  23. return false
  24. }
  25. return count > 0
  26. }
  27. func New(n int) (*Engine, error) {
  28. if n <= 0 {
  29. return nil, errors.New("invalid fft size")
  30. }
  31. if !Available() {
  32. return nil, errors.New("cuda device not available")
  33. }
  34. var plan C.cufftHandle
  35. if C.cufftPlan1d(&plan, C.int(n), C.CUFFT_C2C, 1) != C.CUFFT_SUCCESS {
  36. return nil, errors.New("cufftPlan1d failed")
  37. }
  38. var ptr unsafe.Pointer
  39. bytes := C.size_t(n) * C.size_t(unsafe.Sizeof(C.cufftComplex{}))
  40. if C.cudaMalloc(&ptr, bytes) != C.cudaSuccess {
  41. C.cufftDestroy(plan)
  42. return nil, errors.New("cudaMalloc failed")
  43. }
  44. return &Engine{plan: plan, n: n, data: (*C.cufftComplex)(ptr), bytes: bytes}, nil
  45. }
  46. func (e *Engine) Close() {
  47. if e == nil {
  48. return
  49. }
  50. if e.plan != 0 {
  51. _ = C.cufftDestroy(e.plan)
  52. e.plan = 0
  53. }
  54. if e.data != nil {
  55. _ = C.cudaFree(unsafe.Pointer(e.data))
  56. e.data = nil
  57. }
  58. }
  59. func (e *Engine) Exec(in []complex64) ([]complex64, error) {
  60. if e == nil {
  61. return nil, errors.New("gpu fft not initialized")
  62. }
  63. if len(in) != e.n {
  64. return nil, fmt.Errorf("expected %d samples, got %d", e.n, len(in))
  65. }
  66. if len(in) == 0 {
  67. return nil, nil
  68. }
  69. if C.cudaMemcpy(unsafe.Pointer(e.data), unsafe.Pointer(&in[0]), e.bytes, C.cudaMemcpyHostToDevice) != C.cudaSuccess {
  70. return nil, errors.New("cudaMemcpy H2D failed")
  71. }
  72. if C.cufftExecC2C(e.plan, e.data, e.data, C.CUFFT_FORWARD) != C.CUFFT_SUCCESS {
  73. return nil, errors.New("cufftExecC2C failed")
  74. }
  75. if C.cudaDeviceSynchronize() != C.cudaSuccess {
  76. return nil, errors.New("cudaDeviceSynchronize failed")
  77. }
  78. out := make([]complex64, e.n)
  79. if C.cudaMemcpy(unsafe.Pointer(&out[0]), unsafe.Pointer(e.data), e.bytes, C.cudaMemcpyDeviceToHost) != C.cudaSuccess {
  80. return nil, errors.New("cudaMemcpy D2H failed")
  81. }
  82. return out, nil
  83. }