Open3D (C++ API)  0.18.0
Loading...
Searching...
No Matches
BlasWrapper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9
13
14namespace open3d {
15namespace core {
16
17template <typename scalar_t>
18inline void gemm_cpu(CBLAS_LAYOUT layout,
19 CBLAS_TRANSPOSE trans_A,
20 CBLAS_TRANSPOSE trans_B,
24 scalar_t alpha,
25 const scalar_t *A_data,
27 const scalar_t *B_data,
29 scalar_t beta,
30 scalar_t *C_data,
32 utility::LogError("Unsupported data type.");
33}
34
35template <>
36inline void gemm_cpu<float>(CBLAS_LAYOUT layout,
37 CBLAS_TRANSPOSE trans_A,
38 CBLAS_TRANSPOSE trans_B,
42 float alpha,
43 const float *A_data,
45 const float *B_data,
47 float beta,
48 float *C_data,
50 cblas_sgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
51 ldb, beta, C_data, ldc);
52}
53
54template <>
55inline void gemm_cpu<double>(CBLAS_LAYOUT layout,
56 CBLAS_TRANSPOSE trans_A,
57 CBLAS_TRANSPOSE trans_B,
61 double alpha,
62 const double *A_data,
64 const double *B_data,
66 double beta,
67 double *C_data,
69 cblas_dgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
70 ldb, beta, C_data, ldc);
71}
72
73#ifdef BUILD_CUDA_MODULE
74template <typename scalar_t>
75inline cublasStatus_t gemm_cuda(cublasHandle_t handle,
76 cublasOperation_t transa,
77 cublasOperation_t transb,
78 int m,
79 int n,
80 int k,
81 const scalar_t *alpha,
82 const scalar_t *A_data,
83 int lda,
84 const scalar_t *B_data,
85 int ldb,
86 const scalar_t *beta,
87 scalar_t *C_data,
88 int ldc) {
89 utility::LogError("Unsupported data type.");
90 return CUBLAS_STATUS_NOT_SUPPORTED;
91}
92
93template <typename scalar_t>
94inline cublasStatus_t trsm_cuda(cublasHandle_t handle,
95 cublasSideMode_t side,
96 cublasFillMode_t uplo,
97 cublasOperation_t trans,
98 cublasDiagType_t diag,
99 int m,
100 int n,
101 const scalar_t *alpha,
102 const scalar_t *A,
103 int lda,
104 scalar_t *B,
105 int ldb) {
106 utility::LogError("Unsupported data type.");
107 return CUBLAS_STATUS_NOT_SUPPORTED;
108}
109
110template <>
111inline cublasStatus_t gemm_cuda<float>(cublasHandle_t handle,
112 cublasOperation_t transa,
113 cublasOperation_t transb,
114 int m,
115 int n,
116 int k,
117 const float *alpha,
118 const float *A_data,
119 int lda,
120 const float *B_data,
121 int ldb,
122 const float *beta,
123 float *C_data,
124 int ldc) {
125 return cublasSgemm(handle, transa,
126 transb, // A, B transpose flag
127 m, n, k, // dimensions
128 alpha, static_cast<const float *>(A_data), lda,
129 static_cast<const float *>(B_data),
130 ldb, // input and their leading dims
131 beta, static_cast<float *>(C_data), ldc);
132}
133
134template <>
135inline cublasStatus_t gemm_cuda<double>(cublasHandle_t handle,
136 cublasOperation_t transa,
137 cublasOperation_t transb,
138 int m,
139 int n,
140 int k,
141 const double *alpha,
142 const double *A_data,
143 int lda,
144 const double *B_data,
145 int ldb,
146 const double *beta,
147 double *C_data,
148 int ldc) {
149 return cublasDgemm(handle, transa,
150 transb, // A, B transpose flag
151 m, n, k, // dimensions
152 alpha, static_cast<const double *>(A_data), lda,
153 static_cast<const double *>(B_data),
154 ldb, // input and their leading dims
155 beta, static_cast<double *>(C_data), ldc);
156}
157
158template <>
159inline cublasStatus_t trsm_cuda<float>(cublasHandle_t handle,
160 cublasSideMode_t side,
161 cublasFillMode_t uplo,
162 cublasOperation_t trans,
163 cublasDiagType_t diag,
164 int m,
165 int n,
166 const float *alpha,
167 const float *A,
168 int lda,
169 float *B,
170 int ldb) {
171 return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
172 ldb);
173}
174
175template <>
176inline cublasStatus_t trsm_cuda<double>(cublasHandle_t handle,
177 cublasSideMode_t side,
178 cublasFillMode_t uplo,
179 cublasOperation_t trans,
180 cublasDiagType_t diag,
181 int m,
182 int n,
183 const double *alpha,
184 const double *A,
185 int lda,
186 double *B,
187 int ldb) {
188 return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
189 ldb);
190}
191#endif
192
193} // namespace core
194} // namespace open3d
#define OPEN3D_CPU_LINALG_INT
Definition LinalgHeadersCPU.h:23
Eigen::Matrix3d B
Definition PointCloudPlanarPatchDetection.cpp:506
void gemm_cpu< double >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, double alpha, const double *A_data, OPEN3D_CPU_LINALG_INT lda, const double *B_data, OPEN3D_CPU_LINALG_INT ldb, double beta, double *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition BlasWrapper.h:55
void gemm_cpu(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, scalar_t alpha, const scalar_t *A_data, OPEN3D_CPU_LINALG_INT lda, const scalar_t *B_data, OPEN3D_CPU_LINALG_INT ldb, scalar_t beta, scalar_t *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition BlasWrapper.h:18
void gemm_cpu< float >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, float alpha, const float *A_data, OPEN3D_CPU_LINALG_INT lda, const float *B_data, OPEN3D_CPU_LINALG_INT ldb, float beta, float *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition BlasWrapper.h:36
Definition PinholeCameraIntrinsic.cpp:16