ViennaCL - The Vienna Computing Library  1.7.0
Free open-source GPU-accelerated linear algebra and solver library.
matrix_product_template.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
2 #define VIENNACL_DEVICE_SPECIFIC_TEMPLATES_MATRIX_PRODUCT_HPP
3 
4 /* =========================================================================
5 Copyright (c) 2010-2015, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8 Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14 Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16 (A list of authors and contributors can be found in the manual)
17 
18 License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
27 #include <vector>
28 
30 
33 
38 #include "viennacl/forwards.h"
39 
40 #include "viennacl/tools/tools.hpp"
41 
42 namespace viennacl
43 {
44 namespace device_specific
45 {
46 
48 {
50  , unsigned int local_size_0, unsigned int KL, unsigned int local_size_1
51  , unsigned int ms, unsigned int ks, unsigned int ns
52  , fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param
53  , unsigned int local_fetch_0_param, unsigned int local_fetch_1_param): template_base::parameters_type(simd_width, local_size_0, local_size_1, 1),
54  kL(KL), mS(ms), kS(ks), nS(ns), A_fetching_policy(A_fetching_policy_param), B_fetching_policy(B_fetching_policy_param),
55  local_fetch_0(local_fetch_0_param), local_fetch_1(local_fetch_1_param),
56  mL(ms*local_size_0), nL(ns*local_size_1){}
57 
58  unsigned int kL;
59 
60  unsigned int mS;
61  unsigned int kS;
62  unsigned int nS;
63 
66 
67  unsigned int local_fetch_0;
68  unsigned int local_fetch_1;
69 
70  unsigned int mL;
71  unsigned int nL;
72 };
73 
74 class matrix_product_template : public template_base_impl<matrix_product_template, matrix_product_parameters>
75 {
76 
77 private:
78  unsigned int n_lmem_elements() const
79  {
80  unsigned int N = 0;
81  if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
82  N += p_.kL * (p_.mL+1);
83  if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
84  N += p_.nL * (p_.kL+1);
85  return N;
86  }
87 
88  int check_invalid_impl(viennacl::ocl::device const & /*device*/) const
89  {
90  if (p_.A_fetching_policy!=FETCH_FROM_LOCAL && p_.B_fetching_policy!=FETCH_FROM_LOCAL&& (p_.local_fetch_0!=0 || p_.local_fetch_1!=0))
91  return TEMPLATE_GLOBAL_MEMORY_REQUIRES_ZERO_LOCAL_FETCH;
92 
93  if ((p_.mS % p_.simd_width) > 0 || (p_.nS % p_.simd_width) > 0)
94  return TEMPLATE_MS_NS_MUST_BE_SIMD_WIDTH_MULTIPLE;
95 
96  if (p_.kS > p_.kL)
97  return TEMPLATE_KS_MUST_BE_SMALLER_THAN_KL;
98 
99  if (!(A_trans_=='N' && B_trans_=='T') && p_.simd_width>1)
100  return TEMPLATE_SIMD_WIDTH_MUST_BE_ONE;
101 
102  if (p_.A_fetching_policy==FETCH_FROM_LOCAL || p_.B_fetching_policy==FETCH_FROM_LOCAL)
103  {
104  if ((p_.local_fetch_0*p_.local_fetch_1) !=(p_.local_size_0*p_.local_size_1))
105  return TEMPLATE_LOCAL_FETCH_PRODUCT_MUST_MATCH_LOCAL_SIZE_PRODUCT;
106  }
107 
108  if (p_.A_fetching_policy==FETCH_FROM_LOCAL)
109  {
110  unsigned int bound1 = (A_trans_=='N')?p_.kL:p_.mL;
111  unsigned int bound0 = (A_trans_=='N')?p_.mL:p_.kL;
112 
113  if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
114  return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
115 
116  if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
117  return A_trans_=='N'?TEMPLATE_LOCAL_FETCH_0_MUST_BE_NL_MULTIPLE:TEMPLATE_LOCAL_FETCH_0_MUST_BE_KL_MULTIPLE;
118 
119  }
120  if (p_.B_fetching_policy==FETCH_FROM_LOCAL)
121  {
122  unsigned int bound1 = (B_trans_=='T')?p_.kL:p_.nL;
123  unsigned int bound0 = (B_trans_=='T')?p_.nL:p_.kL;
124 
125  if (p_.local_fetch_1>0 && (bound1 % p_.local_fetch_1)> 0)
126  return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
127 
128  if (p_.local_fetch_0>0 && (bound0 % (p_.local_fetch_0*p_.simd_width)) > 0)
129  return B_trans_=='T'?TEMPLATE_LOCAL_FETCH_1_MUST_BE_KL_MULTIPLE:TEMPLATE_LOCAL_FETCH_1_MUST_BE_ML_MULTIPLE;
130 
131  }
132 
133  return TEMPLATE_VALID;
134  }
135 
136  static void parse(scheduler::statement const & s,
137  vcl_size_t & C_idx, leaf_t & C_leaf, vcl_size_t & alpha_idx, leaf_t & alpha_leaf,
138  vcl_size_t & A_idx, leaf_t & A_leaf, bool& A_trans, vcl_size_t & B_idx, leaf_t & B_leaf, bool& B_trans,
139  vcl_size_t & beta_idx, leaf_t & beta_leaf)
140  {
141  using namespace tree_parsing;
142  using namespace scheduler;
143 
144  scheduler::statement::container_type const & array = s.array();
145  vcl_size_t root_idx = s.root();
146 
147  C_idx = root_idx;
148  C_leaf = LHS_NODE_TYPE;
149 
150  vcl_size_t node_add_idx = array[root_idx].rhs.node_index;
151 
152  vcl_size_t node_1_idx = array[node_add_idx].lhs.node_index;
153  alpha_idx = node_1_idx;
154  alpha_leaf = RHS_NODE_TYPE;
155 
156  vcl_size_t mat_prod_idx = array[node_1_idx].lhs.node_index;
157  if (array[mat_prod_idx].lhs.type_family==MATRIX_TYPE_FAMILY)
158  {
159  A_trans = false;
160  A_idx = mat_prod_idx;
161  }
162  else
163  {
164  A_trans = true;
165  A_idx = array[mat_prod_idx].lhs.node_index;
166  }
167  A_leaf = LHS_NODE_TYPE;
168 
169  if (array[mat_prod_idx].rhs.type_family==MATRIX_TYPE_FAMILY)
170  {
171  B_trans = false;
172  B_idx = mat_prod_idx;
173  B_leaf = RHS_NODE_TYPE;
174  }
175  else
176  {
177  B_trans = true;
178  B_idx = array[mat_prod_idx].rhs.node_index;
179  B_leaf = LHS_NODE_TYPE;
180  }
181 
182  vcl_size_t node_2_idx = array[node_add_idx].rhs.node_index;
183  beta_idx = node_2_idx;
184  beta_leaf = RHS_NODE_TYPE;
185  }
186 
187  void VIENNACL_HANDLE_BOUNDS(bool fallback, utils::kernel_generation_stream & stream, std::string const & inbounds, std::string const & do_if, std::string do_else) const
188  {
189  if (fallback)
190  {
191  stream << "if (" << inbounds << ")" << std::endl;
192  stream.inc_tab();
193  stream << do_if << ";" << std::endl;
194  stream.dec_tab();
195  stream << "else" << std::endl;
196  stream.inc_tab();
197  stream << do_else << ";" << std::endl;
198  stream.dec_tab();
199  }
200  else
201  stream << do_if << ";" << std::endl;
202  }
203 
204 
205  std::string generate_impl(const std::string &kernel_prefix, const statements_container &statements, const std::vector<mapping_type> &mappings, bool fallback) const
206  {
207  using std::string;
208  using tools::to_string;
209 
211  parameters_type const & p = fallback?pfallback:p_;
212 
213 #define VIENNACL_MUL_STRIDE1 string(fallback?"*#stride1":"")
214 #define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load) (!fallback?string(to_load):string( string(in_bounds) + "?" + string(to_load) + ":0"))
215 #define VIENNACL_VSTORE(value, offset, ptr) vstore(p.simd_width, value, offset, ptr)
216 
217  string widthstr = tools::to_string(p.simd_width);
218 
223  scheduler::statement const & st = statements.data().front();
224  mapping_type const & mapping = mappings.front();
225 
226  bool A_trans = false, B_trans = false;
227  vcl_size_t C_idx=0, alpha_idx=0, A_idx=0, B_idx=0, beta_idx=0;
228  leaf_t C_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE;
229  parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
230 
231  mapped_matrix * C = (mapped_matrix* )at(mapping, mapping_key( C_idx, C_leaf)).get();
232  mapped_host_scalar * alpha = (mapped_host_scalar*)at(mapping, mapping_key(alpha_idx, alpha_leaf)).get();
233  mapped_matrix * A = (mapped_matrix* )at(mapping, mapping_key( A_idx, A_leaf)).get();
234  mapped_matrix * B = (mapped_matrix* )at(mapping, mapping_key( B_idx, B_leaf)).get();
235  mapped_host_scalar * beta = (mapped_host_scalar*)at(mapping, mapping_key( beta_idx, beta_leaf)).get();
236 
240 
241  stream << " __attribute__((reqd_work_group_size(" << p.local_size_0 << "," << p.local_size_1 << ",1)))" << std::endl;
242  std::map<std::string, unsigned int> widths;
243  widths[A->name()] = p.simd_width;
244  widths[B->name()] = p.simd_width;
245  generate_prototype(stream, kernel_prefix, "unsigned int M, unsigned int N, unsigned int K, ", mappings, statements, widths);
246  stream << "{" << std::endl;
247  stream.inc_tab();
248  if(!fallback)
249  {
250  stream << A->process("#start1 /= " + to_string(p.simd_width) + ";") << std::endl;
251  stream << A->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl;
252  stream << B->process("#start1/= " + to_string(p.simd_width) + ";") << std::endl;
253  stream << B->process("#ld /= " + to_string(p.simd_width) + ";") << std::endl;
254  }
255  tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#pointer += $OFFSET{#start1, #start2};", statements, mappings);
256  tree_parsing::process(stream, PARENT_NODE_TYPE, "matrix", "#ld *= #nldstride;", statements, mappings);
257 
259  stream << C->process("#scalartype rC[" + to_string(p.mS) + "][" + to_string(p.nS) + "] = {{(#scalartype)0}};") << std::endl;
261  stream << A->process("#scalartype rA[" + to_string(p.kS) + "][" + to_string(p.mS) + "];") << std::endl;
262  else
263  stream << A->process(utils::append_width("#scalartype",p.simd_width) + " rA[" + to_string(p.kS) + "][" + to_string(p.mS/p.simd_width) + "];") << std::endl;
265  stream << B->process("#scalartype rB[" + to_string(p.kS) + "][" + to_string(p.nS) + "];");
266  else
267  stream << B->process(utils::append_width("#scalartype",p.simd_width) + " rB[" + to_string(p.kS) + "][" + to_string(p.nS/p.simd_width) + "];") << std::endl;
268 
269 
271  stream << A->process("__local #scalartype lA[" + to_string(p.kL*(p.mL+1)) + "];");
273  stream << B->process("__local #scalartype lB[" + to_string(p.kL*(p.nL+1)) + "];");
274  stream << std::endl;
275 
276  stream << "uint gidx = get_group_id(0);" << std::endl;
277  stream << "uint gidy = get_group_id(1);" << std::endl;
278  stream << "uint idx = get_local_id(0);" << std::endl;
279  stream << "uint idy = get_local_id(1);" << std::endl;
280 
282  {
283  stream << std::endl;
284  stream << "uint idt = " << p.local_size_0 << "*idy + idx;" << std::endl;
285  stream << "uint idxT = idt % " << p.local_fetch_0 << ";" << std::endl;
286  stream << "uint idyT = idt / " << p.local_fetch_0 << ";" << std::endl;
287  }
288  stream << std::endl;
289 
290  if (fallback)
291  {
292  //Bounds checking for M (in A, C)
293  stream << "bool in_bounds_m[" << p.mS << "];" << std::endl;
294  stream << "for(unsigned int m = 0; m < " << p.mS << "; m++)" << std::endl;
295  stream.inc_tab();
296  switch (p.A_fetching_policy)
297  {
299  stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx*" << p.mS << " + m < M;" << std::endl;
300  break;
301  default:
302  stream << "in_bounds_m[m] = gidx*" << p.mL << " + idx + m*" << p.local_size_0 << " < M;" << std::endl;
303  break;
304  }
305  stream.dec_tab();
306 
307  //Bounds checking for A if Local
309  {
310  unsigned int fetch_size = (A_trans_=='N'?p.local_fetch_0*p.simd_width:p.local_fetch_1);
311  stream << "bool in_bounds_m_local[" << p.mL/fetch_size << "];" << std::endl;
312  stream << "for(unsigned int m = 0; m < " << p.mL/fetch_size << "; m++)" << std::endl;
313  stream.inc_tab();
314  stream << "in_bounds_m_local[m] = gidx*" << p.mL << " + " << (A_trans_=='N'?"idxT":"idyT") << " + m*" << fetch_size << " < M;" << std::endl;
315  stream.dec_tab();
316  }
317 
318  //Bounds checking for N (in B, C)
319  stream << "bool in_bounds_n[" << p.nS << "];" << std::endl;
320  stream << "for(unsigned int n = 0; n < " << p.nS << "; n++)" << std::endl;
321  stream.inc_tab();
322  switch (p.B_fetching_policy)
323  {
325  stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy*" << p.nS << " + n < N;" << std::endl;
326  break;
327  default:
328  stream << "in_bounds_n[n] = gidy*" << p.nL << " + idy + n*" << p.local_size_1 << " < N;" << std::endl;
329  break;
330  }
331  stream.dec_tab();
332 
333  //Bounds checking for B if Local
335  {
336  unsigned int fetch_size = (B_trans_=='T'?p.local_fetch_0*p.simd_width:p.local_fetch_1);
337  stream << "bool in_bounds_n_local[" << p.nL/fetch_size << "];" << std::endl;
338  stream << "for(unsigned int n = 0; n < " << p.nL/fetch_size << "; n++)" << std::endl;
339  stream.inc_tab();
340  stream << "in_bounds_n_local[n] = gidy*" << p.nL << " + " << (B_trans_=='T'?"idxT":"idyT") << " + n*" << fetch_size << " < N;" << std::endl;
341  stream.dec_tab();
342  }
343  }
344 
345  switch (p.A_fetching_policy)
346  {
347  case FETCH_FROM_LOCAL:
348  if (A_trans_=='N')
349  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + " + idxT)" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl;
350  else
351  stream << A->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidx*" + to_string(p.mL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl;
352  break;
353 
355  if (A_trans_=='N')
356  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
357  else
358  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx*" + to_string(p.mS/p.simd_width) + ")*#ld;") << std::endl;
359  break;
360 
362  if (A_trans_=='N')
363  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
364  else
365  stream << A->process("#pointer += (gidx*" + to_string(p.mL/p.simd_width) + "+ idx)*#ld;") << std::endl;
366  break;
367 
368  //default: break;
369  }
370 
371  switch (p.B_fetching_policy)
372  {
373  case FETCH_FROM_LOCAL:
374  if (B_trans_=='T')
375  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + " + idxT" + ")" + VIENNACL_MUL_STRIDE1 + " + idyT*#ld;") << std::endl;
376  else
377  stream << B->process("#pointer += idxT" + VIENNACL_MUL_STRIDE1 + " + gidy*" + to_string(p.nL/p.simd_width) + "*#ld + idyT*#ld;") << std::endl;
378  break;
379 
381  if (B_trans_=='T')
382  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
383  else
384  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy*" + to_string(p.nS/p.simd_width) + ")*#ld;") << std::endl;
385  break;
386 
388  if (B_trans_=='T')
389  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy" + ")" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
390  else
391  stream << B->process("#pointer += (gidy*" + to_string(p.nL/p.simd_width) + "+ idy)*#ld;") << std::endl;
392  break;
393 
394  //default: break;
395  }
396 
397  stream << std::endl;
398  stream << "for(unsigned int block_k=0; block_k < K; block_k+=" << p.kL << "){" << std::endl;
399  stream.inc_tab();
400 
402  {
403  if (A_trans_=='N')
404  stream << A->process("__local #scalartype* plA = lA + idyT*" + to_string(p.mL + 1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl;
405  else
406  stream << A->process("__local #scalartype* plA = lA + idxT*" + to_string(p.mL + 1) + " + idyT;") << std::endl;
407  }
408 
409 
411  {
412  if (B_trans_=='T')
413  stream << B->process("__local #scalartype* plB = lB + idyT*" + to_string(p.nL+1) + " + " + to_string(p.simd_width) + "*idxT;") << std::endl;
414  else
415  stream << B->process("__local #scalartype* plB = lB + idxT*" + to_string(p.nL+1) + "+ idyT;") <<std::endl;
416  }
417 
418 
420  stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
421 
423  if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='N')
424  for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1)
425  for (unsigned int m = 0; m < p.mL; m += p.local_fetch_0*p.simd_width)
426  {
427  string in_bounds = "in_bounds_m_local[" + to_string(m/(p.local_fetch_0*p.simd_width)) + "]";
428  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
429  stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(k*(p.mL+1)+m))) << ";" << std::endl;
430  }
431  else if (p.A_fetching_policy==FETCH_FROM_LOCAL && A_trans_=='T')
432  for (unsigned int k = 0; k < p.mL; k += p.local_fetch_1)
433  for (unsigned int m = 0; m < p.kL; m += p.local_fetch_0*p.simd_width)
434  {
435  string in_bounds = "in_bounds_m_local[" + to_string(k/p.local_fetch_1) + "]";
436  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(m/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
437  stream << A->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plA + " + to_string(m*(p.mL+1)+k))) << ";" << std::endl;
438  }
439 
440  if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='T')
441  for (unsigned int k = 0; k < p.kL; k += p.local_fetch_1)
442  for (unsigned int n = 0; n < p.nL; n += p.local_fetch_0*p.simd_width)
443  {
444  string in_bounds = "in_bounds_n_local[" + to_string(n/(p.local_fetch_0*p.simd_width)) + "]";
445  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
446  stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(k*(p.nL+1)+n))) << ";" << std::endl;
447  }
448  else if (p.B_fetching_policy==FETCH_FROM_LOCAL && B_trans_=='N')
449  for (unsigned int k = 0; k < p.nL; k += p.local_fetch_1)
450  for (unsigned int n = 0; n < p.kL; n += p.local_fetch_0*p.simd_width)
451  {
452  string in_bounds = "in_bounds_n_local[" + to_string(k/p.local_fetch_1) + "]";
453  string to_load = "#pointer[" + to_string(k) + "*#ld + " + to_string(n/p.simd_width) + VIENNACL_MUL_STRIDE1 + "]";
454  stream << B->process(VIENNACL_VSTORE(VIENNACL_HANDLE_BOUNDS(in_bounds, to_load), "0", "plB + " + to_string(n*(p.nL+1)+k))) << ";" << std::endl;
455  }
456 
458  {
459  stream << "barrier(CLK_LOCAL_MEM_FENCE);" << std::endl;
460  stream << "uint offA = " << p.simd_width << "*idx;" << std::endl;
461  stream << "uint offB = " << p.simd_width << "*idy;" << std::endl;
462  }
463 
464  if (fallback)
465  stream << "for(unsigned int k = 0; k < " << p.kL << " && (block_k + k < K); k+=" << p.kS << "){" << std::endl;
466  else
467  stream << "for(unsigned int k = 0; k < " << p.kL << "; k+=" << p.kS << "){" << std::endl;
468  stream.inc_tab();
469 
471  stream << "#pragma unroll " << p.kS << std::endl;
472  stream << "for(unsigned int kk = 0; kk < " << p.kS << "; kk++)" << std::endl;
473  stream << "#pragma unroll " << p.mS/p.simd_width << std::endl;
474  stream << "for(unsigned int mm = 0; mm < " << p.mS/p.simd_width << "; mm++)" << std::endl;
475  stream << "{" << std::endl;
476  stream.inc_tab();
477  switch (p.A_fetching_policy)
478  {
479  case FETCH_FROM_LOCAL:
480  for (unsigned int ss = 0; ss < p.simd_width; ++ss)
481  stream << "rA[kk][mm*" << p.simd_width << "+" << ss << "] = lA[offA + mm*" << p.local_size_0*p.simd_width << "+" << ss << "+ kk*" << (p.mL+1) << "];" << std::endl;
482  break;
483 
485  {
486  if (A_trans_=='N')
487  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
488  else
489  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
490  break;
491  }
492 
494  {
495  if (A_trans_=='N')
496  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[kk*#ld + mm*" + to_string(p.local_size_0) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
497  else
498  stream << "rA[kk][mm] = " << A->process(VIENNACL_HANDLE_BOUNDS("in_bounds_m[mm]", "#pointer[mm*#ld*" + to_string(p.local_size_0) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
499  break;
500  }
501 
502  //default: break;
503  }
504  stream.dec_tab();
505  stream << "}" << std::endl;
506 
507  stream << "#pragma unroll " << p.kS << std::endl;
508  stream << "for(unsigned int kk = 0; kk < " << p.kS << "; kk++)" << std::endl;
509  stream << "#pragma unroll " << p.nS/p.simd_width << std::endl;
510  stream << "for(unsigned int nn = 0; nn < " << p.nS/p.simd_width << "; nn++)" << std::endl;
511  stream << "{" << std::endl;
512  stream.inc_tab();
513  switch (p.B_fetching_policy)
514  {
515  case FETCH_FROM_LOCAL:
516  for (unsigned int ss = 0; ss < p.simd_width; ++ss)
517  stream << "rB[kk][nn*" << p.simd_width << "+" << ss << "] = lB[offB + nn*" << p.local_size_1*p.simd_width << "+" << ss << "+ kk*" << (p.nL+1) << "];" << std::endl;
518  break;
519 
521  {
522  if (B_trans_=='T')
523  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
524  else
525  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
526  break;
527  }
528 
530  {
531  if (B_trans_=='T')
532  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[kk*#ld + nn*" + to_string(p.local_size_1) + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
533  else
534  stream << "rB[kk][nn] = " << B->process(VIENNACL_HANDLE_BOUNDS("in_bounds_n[nn]", "#pointer[nn*#ld*" + to_string(p.local_size_1) + " + kk" + VIENNACL_MUL_STRIDE1 + "]")) << ";" << std::endl;
535  break;
536  }
537 
538  //default: break;
539  }
540  stream.dec_tab();
541  stream << "}" << std::endl;
542 
543 
545  switch (p.A_fetching_policy)
546  {
547  case FETCH_FROM_LOCAL:
548  stream << "offA += " << p.kS*(p.mL+1) << ";" << std::endl;
549  break;
550 
551  default:
552  if (A_trans_=='N')
553  stream << A->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl;
554  else
555  stream << A->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
556  break;
557  }
558 
559 
560  switch (p.B_fetching_policy)
561  {
562  case FETCH_FROM_LOCAL:
563  stream << "offB += " << p.kS*(p.nL+1) << ";" << std::endl;
564  break;
565 
566  default:
567  if (B_trans_=='T')
568  stream << B->process("#pointer += " + to_string(p.kS) + "*#ld;") << std::endl;
569  else
570  stream << B->process("#pointer += " + to_string(p.kS) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
571  break;
572  }
573 
574 
575  stream << "#pragma unroll " << p.kS << std::endl;
576  stream << "for(unsigned int kk = 0; kk <" << p.kS << "; ++kk)" << std::endl;
577  stream << "{" << std::endl;
578  stream.inc_tab();
579  for (unsigned int nn=0; nn < p.nS; ++nn)
580  for (unsigned int mm=0; mm < p.mS; ++mm)
581  {
582  string res_str, lhs_str, rhs_str;
583  res_str = "rC[" + tools::to_string(mm) + "][" + tools::to_string(nn) + "]";
585  lhs_str = "rA[kk][" + tools::to_string(mm) + "]";
586  else
587  lhs_str = "rA[kk][" + tools::to_string(mm/p.simd_width) + "].s" + tools::to_string(mm%p.simd_width);
589  rhs_str = "rB[kk]["+tools::to_string(nn)+"]";
590  else
591  rhs_str = "rB[kk]["+tools::to_string(nn/p.simd_width)+"].s"+tools::to_string(nn%p.simd_width);
592  stream << res_str << "=" << "fma(" << lhs_str << "," << rhs_str << "," << res_str << ");" << std::endl;
593  }
594  stream.dec_tab();
595  stream << "}" << std::endl;
596 
597 
598 
599 
600  stream.dec_tab();
601  stream << "}" << std::endl;
602 
603  //Increment global pointer if local memory is used
604  //Else, it's incremented directly when fetching
606  {
607  if (A_trans_=='N')
608  stream << A->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl;
609  else
610  stream << A->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
611  }
612 
614  {
615  if (B_trans_=='T')
616  stream << B->process("#pointer += " + to_string(p.kL) + "*#ld;") << std::endl;
617  else
618  stream << B->process("#pointer += " + to_string(p.kL) + "" + VIENNACL_MUL_STRIDE1 + ";") << std::endl;
619  }
620 
621  stream.dec_tab();
622  stream << "}" << std::endl;
623 
624 
625  if (C->row_major())
626  {
627  unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width;
628  unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width;
629 
630  stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#ld;") << std::endl;
631  stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#ld;") << std::endl;
632  stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#stride2;") << std::endl;
633  stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#stride2;") << std::endl;
634 
635  for (unsigned int n=0; n < p.nS; ++n)
636  {
637  for (unsigned int m=0; m < p.mS; ++m)
638  {
639  unsigned int ministride1 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_0;
640  string Cj = to_string((m/p.simd_width)*(ministride1*p.simd_width) + m%p.simd_width);
641  if (fallback)
642  {
643  stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl;
644  stream.inc_tab();
645  }
646  stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + "+ #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl;
647  if (fallback)
648  stream.dec_tab();
649  }
651  stream << C->process("#pointer += #stride2;") << std::endl;
652  else
653  stream << C->process("#pointer += " + to_string((p.local_size_1*p.simd_width) - (p.simd_width-1)) + "*#stride2;") << std::endl;
654  }
655 
656  }
657  else
658  {
659  unsigned int ministartstride0 = p.A_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.mS:p.simd_width;
660  unsigned int ministartstride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?p.nS:p.simd_width;
661 
662  stream << C->process("#pointer += gidx*" + to_string(p.mL) + "*#stride1;") << std::endl;
663  stream << C->process("#pointer += idx*" + to_string(ministartstride0) + "*#stride1;") << std::endl;
664  stream << C->process("#pointer += gidy*" + to_string(p.nL) + "*#ld;") << std::endl;
665  stream << C->process("#pointer += idy*" + to_string(ministartstride1) + "*#ld;") << std::endl;
666 
667  for (unsigned int m=0; m < p.mS; ++m)
668  {
669  for (unsigned int n=0; n < p.nS; ++n)
670  {
671  unsigned int ministride1 = p.B_fetching_policy==FETCH_FROM_GLOBAL_CONTIGUOUS?1:p.local_size_1;
672  string Cj = to_string((n/p.simd_width)*(ministride1*p.simd_width) + n%p.simd_width);
673  if (fallback)
674  {
675  stream << "if (in_bounds_m[" + to_string(m) + "] && in_bounds_n[" + to_string(n) + "])" << std::endl;
676  stream.inc_tab();
677  }
678  stream << C->process("#pointer[" + Cj + "*#ld] = rC[" + to_string(m) + "][" + to_string(n) + "]*" + alpha->name() + " + #pointer[" + Cj + "*#ld]*" + beta->name() + ";") << std::endl;
679  if (fallback)
680  stream.dec_tab();
681  }
682 
684  stream << C->process("#pointer += #stride1;") << std::endl;
685  else
686  stream << C->process("#pointer += " + to_string((p.local_size_0*p.simd_width) - (p.simd_width-1)) + "*#stride1;") << std::endl;
687  }
688  }
689 
690  stream.dec_tab();
691  stream << "}" << std::endl;
692 
693  return stream.str();
694 
695 #undef VIENNACL_MUL_STRIDE1
696 #undef VIENNACL_HANDLE_BOUNDS
697 #undef VIENNACL_VSTORE
698  }
699 
700  std::vector<std::string> generate_impl(std::string const & kernel_prefix, statements_container const & statements, std::vector<mapping_type> const & mappings) const
701  {
702  std::vector<std::string> res;
703  res.push_back(generate_impl(kernel_prefix, statements, mappings, false));
704  res.push_back(generate_impl(kernel_prefix, statements, mappings, true));
705  return res;
706  }
707 
708  template<class NumericT>
709  void enqueue_block(scheduler::statement & statement,
711  matrix_base<NumericT> const & A, matrix_base<NumericT> const & B, matrix_base<NumericT> const & C, NumericT beta,
712  std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix, vcl_size_t id)
713  {
714  if (A.size1()==0 || A.size2()==0 || B.size1()==0 || B.size2()==0 || C.size1()==0 || C.size2()==0)
715  return;
716 
717  viennacl::ocl::kernel& kernel = programs[id].program().get_kernel(kernel_prefix);
718 
719  kernel.local_work_size(0, p_.local_size_0);
720  kernel.local_work_size(1, p_.local_size_1);
721 
726 
727  if (id==1)
728  {
731  }
732  else
733  {
734  kernel.global_work_size(0, C.size1()/p_.mS);
735  kernel.global_work_size(1, C.size2()/p_.nS);
736  }
737  unsigned int current_arg = 0;
738  kernel.arg(current_arg++, cl_uint(C.size1()));
739  kernel.arg(current_arg++, cl_uint(C.size2()));
740  if (A.row_major())
741  kernel.arg(current_arg++, cl_uint(A_trans_=='T'?A.size2():A.size1()));
742  else
743  kernel.arg(current_arg++, cl_uint(A_trans_=='N'?A.size2():A.size1()));
744  set_arguments(statement, kernel, current_arg);
745  viennacl::ocl::enqueue(kernel);
746 
747  }
748 
749  template<class NumericT>
751  vcl_size_t s0_0, vcl_size_t s0_1, vcl_size_t s1_0, vcl_size_t s1_1, bool swap)
752  {
753  matrix_base<NumericT> & M = *(element.*ptr);
754  slice s0(s0_0, 1, s0_1 - s0_0);
755  slice s1(s1_0, 1, s1_1 - s1_0);
756  if (swap)
757  std::swap(s0, s1);
759  }
760 
761  template<class NumericT>
762  void enqueue_impl(viennacl::matrix_base<NumericT>* scheduler::lhs_rhs_element::*ptr_matrix,
764  NumericT beta_value, std::vector<lazy_program_compiler> & programs, std::string const & kernel_prefix)
765  {
766  using namespace device_specific::utils;
767  vcl_size_t ldstrideA = call_on_matrix(A, leading_stride());
768  vcl_size_t ldstrideB = call_on_matrix(B, leading_stride());
769  vcl_size_t ldstrideC = call_on_matrix(C, leading_stride());
770  vcl_size_t ldstartA = call_on_matrix(A, leading_start());
771  vcl_size_t ldstartB = call_on_matrix(B, leading_start());
772  bool swap_A = ((A_trans_=='T') ^ utils::call_on_matrix(A, row_major_fun()));
773  bool swap_B = ((B_trans_=='T') ^ utils::call_on_matrix(B, row_major_fun()));
774 
775  vcl_size_t M = call_on_matrix(C, size1_fun());
776  vcl_size_t N = call_on_matrix(C, size2_fun());
777  vcl_size_t K;
778  if (utils::call_on_matrix(A, row_major_fun()))
779  K = A_trans_=='T'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
780  else
781  K = A_trans_=='N'?call_on_matrix(A, size2_fun()):call_on_matrix(A, size1_fun());
782 
783  if (M < p_.mL || N < p_.nL || K < p_.kL || ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 ||
784  (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0)))
785  {
786  enqueue_block(statement, A, B, C, beta, create_slice(ptr_matrix, A, 0, M, 0, K, swap_A),
787  create_slice(ptr_matrix, B, 0, K, 0, N, swap_B),
788  create_slice(ptr_matrix, C, 0, M, 0, N, false), beta_value, programs, kernel_prefix, 1);
789  return;
790  }
791 
792 
793  scheduler::lhs_rhs_element Acopy = A;
794  scheduler::lhs_rhs_element Bcopy = B;
795  scheduler::lhs_rhs_element Ccopy = C;
796 
797  vcl_size_t lM = M / p_.mL * p_.mL;
798  vcl_size_t lN = N / p_.nL * p_.nL;
799  vcl_size_t lK = K / p_.kL * p_.kL;
800 
801 
802  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), beta_value, programs, kernel_prefix, 0);
803  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1);
804 
805  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), beta_value, programs, kernel_prefix, 1);
806  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, 0, lM, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, 0, lM, lN, N, false), (NumericT)1, programs, kernel_prefix, 1);
807 
808  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), beta_value, programs, kernel_prefix, 1);
809  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, 0, lN, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, 0, lN, false), (NumericT)1, programs, kernel_prefix, 1);
810 
811  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, 0, lK, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, 0, lK, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), beta_value, programs, kernel_prefix, 1);
812  enqueue_block(statement, A, B, C, beta, create_slice<NumericT>(ptr_matrix, Acopy, lM, M, lK, K, swap_A), create_slice<NumericT>(ptr_matrix, Bcopy, lK, K, lN, N, swap_B), create_slice<NumericT>(ptr_matrix, Ccopy, lM, M, lN, N, false), (NumericT)1, programs, kernel_prefix, 1);
813  }
814 
815 public:
817 
818  virtual void enqueue(std::string const & kernel_prefix, std::vector<lazy_program_compiler> & programs, statements_container const & statements)
819  {
820  using namespace device_specific::utils;
821  using namespace tree_parsing;
822 
823  scheduler::statement const & st = statements.data().front();
824  bool A_trans, B_trans;
825  vcl_size_t C_idx=0, A_idx=0, B_idx=0, alpha_idx=0, beta_idx = 0;
826  leaf_t C_leaf=LHS_NODE_TYPE, A_leaf=LHS_NODE_TYPE, B_leaf=LHS_NODE_TYPE, alpha_leaf=LHS_NODE_TYPE, beta_leaf=LHS_NODE_TYPE;
827  parse(st, C_idx, C_leaf, alpha_idx, alpha_leaf, A_idx, A_leaf, A_trans, B_idx, B_leaf, B_trans, beta_idx, beta_leaf);
828 
829  scheduler::statement stcopy = st;
830  scheduler::lhs_rhs_element& A = utils::lhs_rhs_element(stcopy, A_idx, A_leaf);
831  scheduler::lhs_rhs_element& B = utils::lhs_rhs_element(stcopy, B_idx, B_leaf);
832  scheduler::lhs_rhs_element& C = utils::lhs_rhs_element(stcopy, C_idx, C_leaf);
833  scheduler::lhs_rhs_element& beta = utils::lhs_rhs_element(stcopy, beta_idx, beta_leaf);
834 
835 
836 
837 
838 
839 
841  enqueue_impl<float>(&scheduler::lhs_rhs_element::matrix_float, stcopy, A, B, C, beta, beta.host_float, programs, kernel_prefix);
842  else if (C.numeric_type==scheduler::DOUBLE_TYPE)
843  enqueue_impl<double>(&scheduler::lhs_rhs_element::matrix_double, stcopy, A, B, C, beta, beta.host_double, programs, kernel_prefix);
844  else
845  throw generator_not_supported_exception("GEMM only supported for float/double");
846 
847  }
848 
849 private:
850  const char A_trans_;
851  const char B_trans_;
852 };
853 
854 }
855 
856 }
857 
858 #endif
virtual void enqueue(std::string const &kernel_prefix, std::vector< lazy_program_compiler > &programs, statements_container const &statements)
#define VIENNACL_MUL_STRIDE1
Exception for the case the generator is unable to deal with the operation.
Definition: forwards.h:163
void set_arguments(statements_container const &statements, viennacl::ocl::kernel &kernel, unsigned int &current_arg)
Class for representing strided submatrices of a bigger matrix A.
Definition: forwards.h:443
Represents an OpenCL kernel within ViennaCL.
Definition: kernel.hpp:58
Various little tools used here and there in ViennaCL.
size_type local_work_size(int index=0) const
Returns the local work size at the respective dimension.
Definition: kernel.hpp:742
static void assign_element(lhs_rhs_element &elem, char const &t)
Definition: forwards.h:535
A class representing a compute device (e.g. a GPU)
Definition: device.hpp:49
This file provides the forward declarations for the main types used within ViennaCL.
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:337
container_type const & array() const
Definition: forwards.h:528
viennacl::scalar< float > s1
std::list< scheduler::statement > const & data() const
Definition: forwards.h:282
Forward declaration of dense matrix classes.
viennacl::matrix_base< double > * matrix_double
Definition: forwards.h:410
void swap(vector_base< T > &vec1, vector_base< T > &vec2)
Swaps the contents of two vectors, data is copied.
Definition: vector.hpp:1648
float NumericT
Definition: bisect.cpp:40
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
Definition: cpu_ram.hpp:34
std::vector< value_type > container_type
Definition: forwards.h:507
#define VIENNACL_VSTORE(value, offset, ptr)
std::string to_string(viennacl::scheduler::op_element op_elem)
Helper routine for converting the operation enums to string.
Definition: io.hpp:42
Map ViennaCL objects to generator wrappers.
statement_node_numeric_type numeric_type
Definition: forwards.h:341
viennacl::matrix_base< float > * matrix_float
Definition: forwards.h:409
std::size_t vcl_size_t
Definition: forwards.h:75
std::string process(std::string const &in) const
size_type size2() const
Returns the number of columns.
Definition: matrix_def.hpp:226
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
#define VIENNACL_HANDLE_BOUNDS(in_bounds, to_load)
static void generate_prototype(utils::kernel_generation_stream &stream, std::string const &name, std::string const &first_arguments, std::vector< mapping_type > const &mappings, statements_container const &statements, std::map< std::string, unsigned int > const &widths)
size_type size1() const
Returns the number of rows.
Definition: matrix_def.hpp:224
viennacl::enable_if< viennacl::is_scalar< ScalarT1 >::value &&viennacl::is_scalar< ScalarT2 >::value >::type swap(ScalarT1 &s1, ScalarT2 &s2)
Swaps the contents of two scalars, data is copied.
std::map< mapping_key, tools::shared_ptr< mapped_object > > mapping_type
Definition: forwards.h:191
matrix_product_parameters(unsigned int simd_width, unsigned int local_size_0, unsigned int KL, unsigned int local_size_1, unsigned int ms, unsigned int ks, unsigned int ns, fetching_policy_type A_fetching_policy_param, fetching_policy_type B_fetching_policy_param, unsigned int local_fetch_0_param, unsigned int local_fetch_1_param)
Proxy classes for matrices.
Code for parsing the expression trees.
INT_TYPE align_to_multiple(INT_TYPE to_reach, INT_TYPE base)
Rounds an integer to the next multiple of another integer.
Definition: tools.hpp:133
void enqueue(KernelType &k, viennacl::ocl::command_queue const &queue)
Enqueues a kernel in the provided queue.
Definition: enqueue.hpp:50
Internal utils.
bool row_major() const
Definition: matrix_def.hpp:248
scheduler::lhs_rhs_element & lhs_rhs_element(scheduler::statement const &st, vcl_size_t idx, leaf_t leaf)
Definition: utils.hpp:525
size_type global_work_size(int index=0) const
Returns the global work size at the respective dimension.
Definition: kernel.hpp:751
size_type root() const
Definition: forwards.h:530
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:502
void arg(unsigned int pos, cl_char val)
Sets a char argument at the provided position.
Definition: kernel.hpp:116
matrix_product_template(matrix_product_template::parameters_type const &parameters, char A_trans, char B_trans)
ValueT const & at(std::map< KeyT, ValueT > const &map, KeyT const &key)
Emulation of C++11's .at() member for std::map<>, const-version.
Definition: forwards.h:142
std::string to_string(T const t)
Definition: tools.hpp:304
A slice class that refers to an interval [start, stop), where 'start' is included, and 'stop' is excluded.
Definition: forwards.h:429
std::pair< vcl_size_t, leaf_t > mapping_key
Definition: forwards.h:188
parameters_type(unsigned int _simd_width, unsigned int _local_size_1, unsigned int _local_size_2, unsigned int _num_kernels)
void process(utils::kernel_generation_stream &stream, leaf_t leaf, std::string const &type_key, std::string const &to_process, scheduler::statement const &statement, vcl_size_t root_idx, mapping_type const &mapping, std::set< std::string > &already_processed)
std::string append_width(std::string const &str, unsigned int width)
Definition: utils.hpp:558