00001 #ifndef VIENNACL_DIRECT_SOLVE_HPP_
00002 #define VIENNACL_DIRECT_SOLVE_HPP_
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00024 #include "viennacl/vector.hpp"
00025 #include "viennacl/matrix.hpp"
00026 #include "viennacl/tools/matrix_kernel_class_deducer.hpp"
00027 #include "viennacl/tools/matrix_solve_kernel_class_deducer.hpp"
00028 #include "viennacl/ocl/kernel.hpp"
00029 #include "viennacl/ocl/device.hpp"
00030 #include "viennacl/ocl/handle.hpp"
00031
00032
00033 namespace viennacl
00034 {
00035 namespace linalg
00036 {
00038
00043 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00044 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00045 matrix<SCALARTYPE, F2, A2> & B,
00046 SOLVERTAG)
00047 {
00048 assert(mat.size1() == mat.size2());
00049 assert(mat.size2() == B.size1());
00050
00051 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00052 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00053 KernelClass::init();
00054
00055 std::stringstream ss;
00056 ss << SOLVERTAG::name() << "_solve";
00057 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00058
00059 k.global_work_size(0, B.size2() * k.local_work_size());
00060 viennacl::ocl::enqueue(k(mat, cl_uint(mat.size1()), cl_uint(mat.size2()),
00061 cl_uint(mat.internal_size1()), cl_uint(mat.internal_size2()),
00062 B, cl_uint(B.size1()), cl_uint(B.size2()),
00063 cl_uint(B.internal_size1()), cl_uint(B.internal_size2()))
00064 );
00065 }
00066
00072 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00073 void inplace_solve(const matrix<SCALARTYPE, F1, A1> & mat,
00074 const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00075 const matrix<SCALARTYPE, F2, A2>,
00076 op_trans> & B,
00077 SOLVERTAG)
00078 {
00079 assert(mat.size1() == mat.size2());
00080 assert(mat.size2() == B.lhs().size2());
00081
00082 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00083 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00084 KernelClass::init();
00085
00086 std::stringstream ss;
00087 ss << SOLVERTAG::name() << "_trans_solve";
00088 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00089
00090 k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00091 viennacl::ocl::enqueue(k(mat, cl_uint(mat.size1()), cl_uint(mat.size2()),
00092 cl_uint(mat.internal_size1()), cl_uint(mat.internal_size2()),
00093 B.lhs(), cl_uint(B.lhs().size1()), cl_uint(B.lhs().size2()),
00094 cl_uint(B.lhs().internal_size1()), cl_uint(B.lhs().internal_size2()))
00095 );
00096 }
00097
00098
00104 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00105 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00106 const matrix<SCALARTYPE, F1, A1>,
00107 op_trans> & proxy,
00108 matrix<SCALARTYPE, F2, A2> & B,
00109 SOLVERTAG)
00110 {
00111 assert(proxy.lhs().size1() == proxy.lhs().size2());
00112 assert(proxy.lhs().size2() == B.size1());
00113
00114 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00115 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00116 KernelClass::init();
00117
00118 std::stringstream ss;
00119 ss << "trans_" << SOLVERTAG::name() << "_solve";
00120 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00121
00122 k.global_work_size(0, B.size2() * k.local_work_size());
00123 viennacl::ocl::enqueue(k(proxy.lhs(), cl_uint(proxy.lhs().size1()), cl_uint(proxy.lhs().size2()),
00124 cl_uint(proxy.lhs().internal_size1()), cl_uint(proxy.lhs().internal_size2()),
00125 B, cl_uint(B.size1()), cl_uint(B.size2()),
00126 cl_uint(B.internal_size1()), cl_uint(B.internal_size2()))
00127 );
00128 }
00129
00135 template<typename SCALARTYPE, typename F1, typename F2, unsigned int A1, unsigned int A2, typename SOLVERTAG>
00136 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F1, A1>,
00137 const matrix<SCALARTYPE, F1, A1>,
00138 op_trans> & proxy,
00139 const matrix_expression< const matrix<SCALARTYPE, F2, A2>,
00140 const matrix<SCALARTYPE, F2, A2>,
00141 op_trans> & B,
00142 SOLVERTAG)
00143 {
00144 assert(proxy.lhs().size1() == proxy.lhs().size2());
00145 assert(proxy.lhs().size2() == B.lhs().size2());
00146
00147 typedef typename viennacl::tools::MATRIX_SOLVE_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F1, A1>,
00148 matrix<SCALARTYPE, F2, A2> >::ResultType KernelClass;
00149 KernelClass::init();
00150
00151 std::stringstream ss;
00152 ss << "trans_" << SOLVERTAG::name() << "_trans_solve";
00153 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00154
00155 k.global_work_size(0, B.lhs().size1() * k.local_work_size());
00156 viennacl::ocl::enqueue(k(proxy.lhs(), cl_uint(proxy.lhs().size1()), cl_uint(proxy.lhs().size2()),
00157 cl_uint(proxy.lhs().internal_size1()), cl_uint(proxy.lhs().internal_size2()),
00158 B.lhs(), cl_uint(B.lhs().size1()), cl_uint(B.lhs().size2()),
00159 cl_uint(B.lhs().internal_size1()), cl_uint(B.lhs().internal_size2()))
00160 );
00161 }
00162
00163 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00164 void inplace_solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00165 vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00166 SOLVERTAG)
00167 {
00168 assert(mat.size1() == vec.size());
00169 assert(mat.size2() == vec.size());
00170
00171 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00172
00173 std::stringstream ss;
00174 ss << SOLVERTAG::name() << "_triangular_substitute_inplace";
00175 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00176
00177 k.global_work_size(0, k.local_work_size());
00178 viennacl::ocl::enqueue(k(mat, cl_uint(mat.size1()), cl_uint(mat.size2()),
00179 cl_uint(mat.internal_size1()), cl_uint(mat.internal_size2()), vec));
00180 }
00181
00187 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename SOLVERTAG>
00188 void inplace_solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00189 const matrix<SCALARTYPE, F, ALIGNMENT>,
00190 op_trans> & proxy,
00191 vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00192 SOLVERTAG)
00193 {
00194 assert(proxy.lhs().size1() == vec.size());
00195 assert(proxy.lhs().size2() == vec.size());
00196
00197 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00198
00199 std::stringstream ss;
00200 ss << "trans_" << SOLVERTAG::name() << "_triangular_substitute_inplace";
00201 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), ss.str());
00202
00203 k.global_work_size(0, k.local_work_size());
00204 viennacl::ocl::enqueue(k(proxy.lhs(), cl_uint(proxy.lhs().size1()), cl_uint(proxy.lhs().size2()),
00205 cl_uint(proxy.lhs().internal_size1()), cl_uint(proxy.lhs().internal_size2()), vec));
00206 }
00207
00209
00216 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00217 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00218 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00219 TAG const & tag)
00220 {
00221
00222 matrix<SCALARTYPE, F2, ALIGNMENT_A> result(B.size1(), B.size2());
00223 result = B;
00224
00225 inplace_solve(A, result, tag);
00226
00227 return result;
00228 }
00229
00236 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00237 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix<SCALARTYPE, F1, ALIGNMENT_A> & A,
00238 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00239 const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00240 op_trans> & proxy,
00241 TAG const & tag)
00242 {
00243
00244 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy.lhs().size2(), proxy.lhs().size1());
00245 result = proxy;
00246
00247 inplace_solve(A, result, tag);
00248
00249 return result;
00250 }
00251
00258 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00259 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix<SCALARTYPE, F, ALIGNMENT> & mat,
00260 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00261 TAG const & tag)
00262 {
00263
00264 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00265 result = vec;
00266
00267 inplace_solve(mat, result, tag);
00268
00269 return result;
00270 }
00271
00272
00274
00280 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00281 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00282 const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00283 op_trans> & proxy,
00284 const matrix<SCALARTYPE, F2, ALIGNMENT_B> & B,
00285 TAG const & tag)
00286 {
00287
00288 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(B.size1(), B.size2());
00289 result = B;
00290
00291 inplace_solve(proxy, result, tag);
00292
00293 return result;
00294 }
00295
00296
00303 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B, typename TAG>
00304 matrix<SCALARTYPE, F2, ALIGNMENT_B> solve(const matrix_expression< const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00305 const matrix<SCALARTYPE, F1, ALIGNMENT_A>,
00306 op_trans> & proxy_A,
00307 const matrix_expression< const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00308 const matrix<SCALARTYPE, F2, ALIGNMENT_B>,
00309 op_trans> & proxy_B,
00310 TAG const & tag)
00311 {
00312
00313 matrix<SCALARTYPE, F2, ALIGNMENT_B> result(proxy_B.lhs().size2(), proxy_B.lhs().size1());
00314 result = trans(proxy_B.lhs());
00315
00316 inplace_solve(proxy_A, result, tag);
00317
00318 return result;
00319 }
00320
00327 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT, typename TAG>
00328 vector<SCALARTYPE, VEC_ALIGNMENT> solve(const matrix_expression< const matrix<SCALARTYPE, F, ALIGNMENT>,
00329 const matrix<SCALARTYPE, F, ALIGNMENT>,
00330 op_trans> & proxy,
00331 const vector<SCALARTYPE, VEC_ALIGNMENT> & vec,
00332 TAG const & tag)
00333 {
00334
00335 vector<SCALARTYPE, VEC_ALIGNMENT> result(vec.size());
00336 result = vec;
00337
00338 inplace_solve(proxy, result, tag);
00339
00340 return result;
00341 }
00342
00343
00345
00349 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT>
00350 void lu_factorize(matrix<SCALARTYPE, F, ALIGNMENT> & mat)
00351 {
00352 assert(mat.size1() == mat.size2());
00353
00354 typedef typename viennacl::tools::MATRIX_KERNEL_CLASS_DEDUCER< matrix<SCALARTYPE, F, ALIGNMENT> >::ResultType KernelClass;
00355
00356 viennacl::ocl::kernel & k = viennacl::ocl::get_kernel(KernelClass::program_name(), "lu_factorize");
00357
00358 k.global_work_size(0, k.local_work_size());
00359 viennacl::ocl::enqueue(k(mat, cl_uint(mat.size1()), cl_uint(mat.size2()),
00360 cl_uint(mat.internal_size1()), cl_uint(mat.internal_size2())) );
00361 }
00362
00363
00369 template<typename SCALARTYPE, typename F1, typename F2, unsigned int ALIGNMENT_A, unsigned int ALIGNMENT_B>
00370 void lu_substitute(matrix<SCALARTYPE, F1, ALIGNMENT_A> const & A,
00371 matrix<SCALARTYPE, F2, ALIGNMENT_B> & B)
00372 {
00373 assert(A.size1() == A.size2());
00374 assert(A.size1() == A.size2());
00375 inplace_solve(A, B, unit_lower_tag());
00376 inplace_solve(A, B, upper_tag());
00377 }
00378
00384 template<typename SCALARTYPE, typename F, unsigned int ALIGNMENT, unsigned int VEC_ALIGNMENT>
00385 void lu_substitute(matrix<SCALARTYPE, F, ALIGNMENT> const & mat,
00386 vector<SCALARTYPE, VEC_ALIGNMENT> & vec)
00387 {
00388 assert(mat.size1() == mat.size2());
00389 inplace_solve(mat, vec, unit_lower_tag());
00390 inplace_solve(mat, vec, upper_tag());
00391 }
00392
00393 }
00394 }
00395
00396 #endif