ViennaCL - The Vienna Computing Library  1.7.0
Free open-source GPU-accelerated linear algebra and solver library.
execute.hpp
Go to the documentation of this file.
1 #ifndef VIENNACL_SCHEDULER_EXECUTE_HPP
2 #define VIENNACL_SCHEDULER_EXECUTE_HPP
3 
4 /* =========================================================================
5  Copyright (c) 2010-2015, Institute for Microelectronics,
6  Institute for Analysis and Scientific Computing,
7  TU Wien.
8  Portions of this software are copyright by UChicago Argonne, LLC.
9 
10  -----------------
11  ViennaCL - The Vienna Computing Library
12  -----------------
13 
14  Project Head: Karl Rupp rupp@iue.tuwien.ac.at
15 
16  (A list of authors and contributors can be found in the manual)
17 
18  License: MIT (X11), see file LICENSE in the base directory
19 ============================================================================= */
20 
21 
26 #include "viennacl/forwards.h"
28 
34 
35 namespace viennacl
36 {
37 namespace scheduler
38 {
39 namespace detail
40 {
42  void execute_composite(statement const & s, statement_node const & root_node)
43  {
44  statement::container_type const & expr = s.array();
45  viennacl::context ctx = extract_context(root_node);
46 
47  statement_node const & leaf = expr[root_node.rhs.node_index];
48 
49  if (leaf.op.type == OPERATION_BINARY_ADD_TYPE || leaf.op.type == OPERATION_BINARY_SUB_TYPE) // x = (y) +- (z) where y and z are either data objects or expressions
50  execute_axbx(s, root_node);
51  else if (leaf.op.type == OPERATION_BINARY_MULT_TYPE || leaf.op.type == OPERATION_BINARY_DIV_TYPE) // x = (y) * / alpha;
52  {
53  bool scalar_is_temporary = (leaf.rhs.type_family != SCALAR_TYPE_FAMILY);
54 
55  statement_node scalar_temp_node;
56  if (scalar_is_temporary)
57  {
58  lhs_rhs_element temp;
61  temp.numeric_type = root_node.lhs.numeric_type;
62  detail::new_element(scalar_temp_node.lhs, temp, ctx);
63 
64  scalar_temp_node.op.type_family = OPERATION_BINARY_TYPE_FAMILY;
65  scalar_temp_node.op.type = OPERATION_BINARY_ASSIGN_TYPE;
66 
67  scalar_temp_node.rhs.type_family = COMPOSITE_OPERATION_FAMILY;
68  scalar_temp_node.rhs.subtype = INVALID_SUBTYPE;
69  scalar_temp_node.rhs.numeric_type = INVALID_NUMERIC_TYPE;
70  scalar_temp_node.rhs.node_index = leaf.rhs.node_index;
71 
72  // work on subexpression:
73  // TODO: Catch exception, free temporary, then rethrow
74  execute_composite(s, scalar_temp_node);
75  }
76 
77  if (leaf.lhs.type_family == COMPOSITE_OPERATION_FAMILY) //(y) is an expression, so introduce a temporary z = (y):
78  {
79  statement_node new_root_y;
80 
81  new_root_y.lhs.type_family = root_node.lhs.type_family;
82  new_root_y.lhs.subtype = root_node.lhs.subtype;
83  new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
84  detail::new_element(new_root_y.lhs, root_node.lhs, ctx);
85 
88 
90  new_root_y.rhs.subtype = INVALID_SUBTYPE;
92  new_root_y.rhs.node_index = leaf.lhs.node_index;
93 
94  // work on subexpression:
95  // TODO: Catch exception, free temporary, then rethrow
96  execute_composite(s, new_root_y);
97 
98  // now compute x = z * / alpha:
99  lhs_rhs_element u = root_node.lhs;
100  lhs_rhs_element v = new_root_y.lhs;
101  lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
102 
103  bool is_division = (leaf.op.type == OPERATION_BINARY_DIV_TYPE);
104  switch (root_node.op.type)
105  {
107  detail::ax(u,
108  v, alpha, 1, is_division, false);
109  break;
111  detail::axbx(u,
112  u, 1.0, 1, false, false,
113  v, alpha, 1, is_division, false);
114  break;
116  detail::axbx(u,
117  u, 1.0, 1, false, false,
118  v, alpha, 1, is_division, true);
119  break;
120  default:
121  throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
122  }
123 
124  detail::delete_element(new_root_y.lhs);
125  }
126  else if (leaf.lhs.type_family != COMPOSITE_OPERATION_FAMILY)
127  {
128  lhs_rhs_element u = root_node.lhs;
129  lhs_rhs_element v = leaf.lhs;
130  lhs_rhs_element alpha = scalar_is_temporary ? scalar_temp_node.lhs : leaf.rhs;
131 
132  bool is_division = (leaf.op.type == OPERATION_BINARY_DIV_TYPE);
133  switch (root_node.op.type)
134  {
136  detail::ax(u,
137  v, alpha, 1, is_division, false);
138  break;
140  detail::axbx(u,
141  u, 1.0, 1, false, false,
142  v, alpha, 1, is_division, false);
143  break;
145  detail::axbx(u,
146  u, 1.0, 1, false, false,
147  v, alpha, 1, is_division, true);
148  break;
149  default:
150  throw statement_not_supported_exception("Unsupported binary operator for vector operation in root note (should be =, +=, or -=)");
151  }
152  }
153  else
154  throw statement_not_supported_exception("Unsupported binary operator for OPERATION_BINARY_MULT_TYPE || OPERATION_BINARY_DIV_TYPE on leaf node.");
155 
156  // clean up
157  if (scalar_is_temporary)
158  detail::delete_element(scalar_temp_node.lhs);
159  }
160  else if ( leaf.op.type == OPERATION_BINARY_INNER_PROD_TYPE
164  || leaf.op.type == OPERATION_UNARY_MAX_TYPE
165  || leaf.op.type == OPERATION_UNARY_MIN_TYPE)
166  execute_scalar_assign_composite(s, root_node);
170  || leaf.op.type == OPERATION_BINARY_ELEMENT_POW_TYPE) // element-wise operations
171  execute_element_composite(s, root_node);
172  else if ( leaf.op.type == OPERATION_BINARY_MAT_VEC_PROD_TYPE
174  execute_matrix_prod(s, root_node);
175  else if ( leaf.op.type == OPERATION_UNARY_TRANS_TYPE)
176  {
177  if (root_node.op.type == OPERATION_BINARY_ASSIGN_TYPE)
178  assign_trans(root_node.lhs, leaf.lhs);
179  else // use temporary object:
180  {
181  statement_node new_root_y;
182 
183  new_root_y.lhs.type_family = root_node.lhs.type_family;
184  new_root_y.lhs.subtype = root_node.lhs.subtype;
185  new_root_y.lhs.numeric_type = root_node.lhs.numeric_type;
186  detail::new_element(new_root_y.lhs, root_node.lhs, ctx);
187 
189  new_root_y.op.type = OPERATION_BINARY_ASSIGN_TYPE;
190 
192  new_root_y.rhs.subtype = INVALID_SUBTYPE;
194  new_root_y.rhs.node_index = root_node.rhs.node_index;
195 
196  // work on subexpression:
197  // TODO: Catch exception, free temporary, then rethrow
198  execute_composite(s, new_root_y);
199 
200  // now compute x += temp or x -= temp:
201  lhs_rhs_element u = root_node.lhs;
202  lhs_rhs_element v = new_root_y.lhs;
203 
204  if (root_node.op.type == OPERATION_BINARY_INPLACE_ADD_TYPE)
205  {
206  detail::axbx(u,
207  u, 1.0, 1, false, false,
208  v, 1.0, 1, false, false);
209  }
210  else if (root_node.op.type == OPERATION_BINARY_INPLACE_SUB_TYPE)
211  {
212  detail::axbx(u,
213  u, 1.0, 1, false, false,
214  v, 1.0, 1, false, true);
215  }
216  else
217  throw statement_not_supported_exception("Unsupported binary operator for operation in root node (should be =, +=, or -=)");
218 
219  detail::delete_element(new_root_y.lhs);
220  }
221  }
222  else
223  throw statement_not_supported_exception("Unsupported binary operator");
224  }
225 
226 
228  inline void execute_single(statement const &, statement_node const & root_node)
229  {
230  lhs_rhs_element u = root_node.lhs;
231  lhs_rhs_element v = root_node.rhs;
232  switch (root_node.op.type)
233  {
235  detail::ax(u,
236  v, 1.0, 1, false, false);
237  break;
239  detail::axbx(u,
240  u, 1.0, 1, false, false,
241  v, 1.0, 1, false, false);
242  break;
244  detail::axbx(u,
245  u, 1.0, 1, false, false,
246  v, 1.0, 1, false, true);
247  break;
248  default:
249  throw statement_not_supported_exception("Unsupported binary operator for operation in root note (should be =, +=, or -=)");
250  }
251 
252  }
253 
254 
255  inline void execute_impl(statement const & s, statement_node const & root_node)
256  {
257  if ( root_node.lhs.type_family != SCALAR_TYPE_FAMILY
258  && root_node.lhs.type_family != VECTOR_TYPE_FAMILY
259  && root_node.lhs.type_family != MATRIX_TYPE_FAMILY)
260  throw statement_not_supported_exception("Unsupported lvalue encountered in head node.");
261 
262  switch (root_node.rhs.type_family)
263  {
265  execute_composite(s, root_node);
266  break;
267  case SCALAR_TYPE_FAMILY:
268  case VECTOR_TYPE_FAMILY:
269  case MATRIX_TYPE_FAMILY:
270  execute_single(s, root_node);
271  break;
272  default:
273  throw statement_not_supported_exception("Invalid rvalue encountered in vector assignment");
274  }
275 
276  }
277 }
278 
279 inline void execute(statement const & s)
280 {
281  // simply start execution from the root node:
282  detail::execute_impl(s, s.array()[s.root()]);
283 }
284 
285 
286 }
287 } //namespace viennacl
288 
289 #endif
290 
Deals with the execution of unary and binary element-wise operations.
void execute_matrix_prod(statement const &s, statement_node const &root_node)
viennacl::context extract_context(statement_node const &root_node)
Helper routine for extracting the context in which a statement is executed.
void assign_trans(lhs_rhs_element const &A, lhs_rhs_element const &B)
Scheduler unwrapper for A =/+=/-= trans(B)
Deals with matrix-vector and matrix-matrix products.
statement_node_subtype subtype
Definition: forwards.h:340
void execute_scalar_assign_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is a vector expression.
This file provides the forward declarations for the main types used within ViennaCL.
statement_node_type_family type_family
Definition: forwards.h:339
void execute(statement const &s)
Definition: execute.hpp:279
A class representing the 'data' for the LHS or RHS operand of the respective node.
Definition: forwards.h:337
container_type const & array() const
Definition: forwards.h:528
void execute_element_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is a vector expression.
operation_node_type_family type_family
Definition: forwards.h:473
Deals with the execution of x = RHS; for a vector x and any compatible right hand side expression RHS...
void delete_element(lhs_rhs_element &elem)
Represents a generic 'context' similar to an OpenCL context, but is backend-agnostic and thus also su...
Definition: context.hpp:39
Main namespace in ViennaCL. Holds all the basic types such as vector, matrix, etc. and defines operations upon them.
Definition: cpu_ram.hpp:34
void execute_single(statement const &, statement_node const &root_node)
Deals with x = y for a scalar/vector/matrix x, y.
Definition: execute.hpp:228
std::vector< value_type > container_type
Definition: forwards.h:507
void execute_impl(statement const &s, statement_node const &root_node)
Definition: execute.hpp:255
void axbx(lhs_rhs_element &x1, lhs_rhs_element const &x2, ScalarType1 const &alpha, vcl_size_t len_alpha, bool reciprocal_alpha, bool flip_sign_alpha, lhs_rhs_element const &x3, ScalarType2 const &beta, vcl_size_t len_beta, bool reciprocal_beta, bool flip_sign_beta)
Wrapper for viennacl::linalg::avbv(), taking care of the argument unwrapping.
Definition: blas3.hpp:36
statement_node_numeric_type numeric_type
Definition: forwards.h:341
void execute_axbx(statement const &s, statement_node const &root_node)
Deals with x = (y) +- (z) where y and z are either data objects or expressions.
Provides the datastructures for dealing with a single statement such as 'x = y + z;'.
Helper metafunction for checking whether the provided type is viennacl::op_div (for division) ...
Definition: predicate.hpp:466
void execute_composite(statement const &s, statement_node const &root_node)
Deals with x = RHS where RHS is an expression and x is either a scalar, a vector, or a matrix...
Definition: execute.hpp:42
void ax(lhs_rhs_element &x1, lhs_rhs_element const &x2, ScalarType1 const &alpha, vcl_size_t len_alpha, bool reciprocal_alpha, bool flip_sign_alpha)
Wrapper for viennacl::linalg::av(), taking care of the argument unwrapping.
void new_element(lhs_rhs_element &new_elem, lhs_rhs_element const &old_element, viennacl::context const &ctx)
operation_node_type type
Definition: forwards.h:474
size_type root() const
Definition: forwards.h:530
The main class for representing a statement such as x = inner_prod(y,z); at runtime.
Definition: forwards.h:502
Provides various utilities for implementing the execution of statements.
Main datastructure for an node in the statement tree.
Definition: forwards.h:478
Exception for the case the scheduler is unable to deal with the operation.
Definition: forwards.h:38
Provides the datastructures for dealing with statements of the type x = (y) +- (z) ...