CUDNN Frontend API  8.3.0
cudnn_frontend_Operation.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <sstream>
30 #include <utility>
31 
32 #include <cudnn.h>
33 #include <cudnn_backend.h>
34 
39 #include "cudnn_frontend_Tensor.h"
40 #include "cudnn_frontend_utils.h"
41 
42 namespace cudnn_frontend {
68  public:
69  friend class OperationBuilder_v8;
70  std::string
71  describe() const override {
72  std::stringstream ss;
73  ss << "CUDNN_BACKEND_OPERATION :"
74  << " OpMode: " << std::to_string(op_mode);
75  ss << std::hex << " X " << xdesc;
76  ss << std::hex << " Y " << ydesc;
77  ss << std::hex << " W " << wdesc;
78  ss << std::hex << " B " << bdesc;
79  ss << std::hex << " DW " << dwdesc;
80  ss << std::hex << " DY " << dydesc;
81  ss << std::hex << " DX " << dxdesc;
82  ss << std::hex << " C " << cdesc;
83  ss << std::hex << " A Mtrix " << amatdesc;
84  ss << std::hex << " B Mtrix " << bmatdesc;
85  ss << std::hex << " C Mtrix " << cmatdesc;
86  ss << std::hex << " P " << pwdesc;
87  ss << std::hex << " MatMul " << matmuldesc;
88  ss << std::hex << " Reduction " << reductiondesc;
89  ss << std::dec << " alphabetaType " << alphabetaType;
90  ss << " Alpha: " << alpha_s << " " << alpha_d;
91  ss << " Alpha2: " << alpha2_s << " " << alpha2_d;
92  ss << " Beta: " << beta_s << " " << beta_d;
93  return ss.str();
94  }
95 
96  Operation_v8(Operation_v8 &&from) = default;
97  Operation_v8 &
98  operator= (Operation_v8 &&from) = default;
99 
102  return (op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) ? cmatdesc : ydesc;
103  }
104 
105  std::string const &
106  getTag() const {
107  return operationTag;
108  }
109 
112  return feature_vector;
113  }
114 
115  ~Operation_v8() = default;
116 
117  private:
118  Operation_v8() = default;
119  Operation_v8(Operation_v8 const &) = delete;
120  Operation_v8 &
121  operator=(Operation_v8 const &) = delete;
122 
123  cudnnBackendDescriptorType_t op_mode = CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR;
124 
139 
140  cudnnBackendAttributeType_t alphabetaType = CUDNN_TYPE_FLOAT;
141  float alpha_s = 1.0f, beta_s = .0f, alpha2_s = 1.0f;
142  double alpha_d = 1.0, beta_d = 0.0, alpha2_d = 1.0;
143  int64_t pointwise_port_count = -1;
144  cudnnPointwiseMode_t pointwise_mode;
147  bool is_pointwise_math_op = false;
148  std::string operationTag;
150 };
151 
155 
157  private:
159  bool is_convolution_op = false;
160  bool is_pointwise_op = false;
161  bool is_matmul_op = false;
162  bool is_reduction_op = false;
163 
164  using Message_t = const char *;
165 
166  int64_t xTensor_dimA[CUDNN_DIM_MAX + 1];
167  int64_t xTensor_strA[CUDNN_DIM_MAX + 1];
168  int64_t wTensor_dimA[CUDNN_DIM_MAX + 1];
169  int64_t wTensor_strA[CUDNN_DIM_MAX + 1];
170  int64_t yTensor_dimA[CUDNN_DIM_MAX + 1];
171  int64_t yTensor_strA[CUDNN_DIM_MAX + 1];
172 
173  bool is2D = true;
174 
175  int64_t conv_padding [CUDNN_DIM_MAX];
176  int64_t conv_dilation[CUDNN_DIM_MAX];
177  int64_t conv_stride [CUDNN_DIM_MAX];
178  int64_t mode;
179  int64_t xType, yType, wType, cType /* compute_precision */;
180 
181  int64_t tensor_dims = 0;
182 
183  Operation_v8 &&
185  m_operation.operationTag = "Reduction";
186  auto status = CUDNN_STATUS_SUCCESS;
187  if ((cudnnGetVersion() / 100) == 81) { // workaround for cudnn 8.1
188  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
189  CUDNN_ATTR_REDUCTION_OPERATOR,
190  CUDNN_TYPE_BACKEND_DESCRIPTOR,
191  1,
192  &(m_operation.reductiondesc->get_backend_descriptor()));
193  } else {
194  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
195  CUDNN_ATTR_OPERATION_REDUCTION_DESC,
196  CUDNN_TYPE_BACKEND_DESCRIPTOR,
197  1,
198  &(m_operation.reductiondesc->get_backend_descriptor()));
199  }
200  if (status != CUDNN_STATUS_SUCCESS) {
202  &m_operation,
203  status,
204  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_DESC Failed");
205  return std::move(m_operation);
206  }
207  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
208  CUDNN_ATTR_OPERATION_REDUCTION_XDESC,
209  CUDNN_TYPE_BACKEND_DESCRIPTOR,
210  1,
211  &(m_operation.xdesc->get_backend_descriptor()));
212  if (status != CUDNN_STATUS_SUCCESS) {
214  &m_operation,
215  status,
216  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_XDESC Failed");
217  return std::move(m_operation);
218  }
219  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
220  CUDNN_ATTR_OPERATION_REDUCTION_YDESC,
221  CUDNN_TYPE_BACKEND_DESCRIPTOR,
222  1,
223  &(m_operation.ydesc->get_backend_descriptor()));
224  if (status != CUDNN_STATUS_SUCCESS) {
226  &m_operation,
227  status,
228  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_REDUCTION_YDESC Failed");
229  return std::move(m_operation);
230  }
231  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
232  if (status != CUDNN_STATUS_SUCCESS) {
233  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
234  return std::move(m_operation);
235  }
236  return std::move(m_operation);
237  }
238 
239  Operation_v8 &&
241  m_operation.operationTag = "Matmul";
242  auto status = CUDNN_STATUS_SUCCESS;
243  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
244  CUDNN_ATTR_OPERATION_MATMUL_ADESC,
245  CUDNN_TYPE_BACKEND_DESCRIPTOR,
246  1,
247  &(m_operation.amatdesc->get_backend_descriptor()));
248  if (status != CUDNN_STATUS_SUCCESS) {
250  &m_operation,
251  status,
252  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_ADESC Failed");
253  return std::move(m_operation);
254  }
255  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
256  CUDNN_ATTR_OPERATION_MATMUL_BDESC,
257  CUDNN_TYPE_BACKEND_DESCRIPTOR,
258  1,
259  &(m_operation.bmatdesc->get_backend_descriptor()));
260  if (status != CUDNN_STATUS_SUCCESS) {
262  &m_operation,
263  status,
264  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_BDESC Failed");
265  return std::move(m_operation);
266  }
267  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
268  CUDNN_ATTR_OPERATION_MATMUL_CDESC,
269  CUDNN_TYPE_BACKEND_DESCRIPTOR,
270  1,
271  &(m_operation.cmatdesc->get_backend_descriptor()));
272  if (status != CUDNN_STATUS_SUCCESS) {
274  &m_operation,
275  status,
276  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_CDESC Failed");
277  return std::move(m_operation);
278  }
279  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
280  CUDNN_ATTR_OPERATION_MATMUL_DESC,
281  CUDNN_TYPE_BACKEND_DESCRIPTOR,
282  1,
283  &(m_operation.matmuldesc->get_backend_descriptor()));
284  if (status != CUDNN_STATUS_SUCCESS) {
286  &m_operation,
287  status,
288  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_MATMUL_DESC Failed");
289  return std::move(m_operation);
290  }
291  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
292  if (status != CUDNN_STATUS_SUCCESS) {
293  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
294  return std::move(m_operation);
295  }
296  return std::move(m_operation);
297  }
298 
299  Operation_v8 &&
301  auto status = CUDNN_STATUS_SUCCESS;
302 
303  switch (m_operation.pointwise_mode) {
304  case CUDNN_POINTWISE_ADD:
305  m_operation.operationTag = "Add";
306  break;
307  case CUDNN_POINTWISE_MUL:
308  m_operation.operationTag = "Mul";
309  break;
310 #if (CUDNN_VERSION >= 8300)
311  case CUDNN_POINTWISE_DIV:
312  m_operation.operationTag = "Div";
313  break;
314  case CUDNN_POINTWISE_ADD_SQUARE:
315  m_operation.operationTag = "AddSquare";
316  break;
317  case CUDNN_POINTWISE_EXP:
318  m_operation.operationTag = "Exp";
319  break;
320  case CUDNN_POINTWISE_SUB:
321  m_operation.operationTag = "Sub";
322  break;
323  case CUDNN_POINTWISE_CMP_EQ:
324  m_operation.operationTag = "CmpEq";
325  break;
326  case CUDNN_POINTWISE_CMP_NEQ:
327  m_operation.operationTag = "CmpNeq";
328  break;
329  case CUDNN_POINTWISE_CMP_GT:
330  m_operation.operationTag = "CmpGT";
331  break;
332  case CUDNN_POINTWISE_CMP_GE:
333  m_operation.operationTag = "CmpGE";
334  break;
335  case CUDNN_POINTWISE_CMP_LT:
336  m_operation.operationTag = "CmpLT";
337  break;
338  case CUDNN_POINTWISE_CMP_LE:
339  m_operation.operationTag = "CmpLE";
340  break;
341  case CUDNN_POINTWISE_LOGICAL_OR:
342  m_operation.operationTag = "LogicalOr";
343  break;
344  case CUDNN_POINTWISE_LOGICAL_AND:
345  m_operation.operationTag = "LogicalAnd";
346  break;
347  case CUDNN_POINTWISE_LOGICAL_NOT:
348  m_operation.operationTag = "LogicalNot";
349  break;
350  case CUDNN_POINTWISE_LOG:
351  m_operation.operationTag = "Log";
352  break;
353  case CUDNN_POINTWISE_NEG:
354  m_operation.operationTag = "Neg";
355  break;
356  case CUDNN_POINTWISE_MOD:
357  m_operation.operationTag = "Mod";
358  break;
359  case CUDNN_POINTWISE_POW:
360  m_operation.operationTag = "Pow";
361  break;
362  case CUDNN_POINTWISE_ABS:
363  m_operation.operationTag = "Abs";
364  break;
365  case CUDNN_POINTWISE_CEIL:
366  m_operation.operationTag = "Ceil";
367  break;
368  case CUDNN_POINTWISE_FLOOR:
369  m_operation.operationTag = "Floor";
370  break;
371  case CUDNN_POINTWISE_SIN:
372  m_operation.operationTag = "Sine";
373  break;
374  case CUDNN_POINTWISE_COS:
375  m_operation.operationTag = "Cosine";
376  break;
377  case CUDNN_POINTWISE_TAN:
378  m_operation.operationTag = "Tan";
379  break;
380  case CUDNN_POINTWISE_RSQRT:
381  m_operation.operationTag = "RSqrt";
382  break;
383 #endif
384  case CUDNN_POINTWISE_MIN:
385  m_operation.operationTag = "Min";
386  break;
387  case CUDNN_POINTWISE_MAX:
388  m_operation.operationTag = "Max";
389  break;
390  case CUDNN_POINTWISE_SQRT:
391  m_operation.operationTag = "Sqrt";
392  break;
393  case CUDNN_POINTWISE_RELU_FWD:
394  m_operation.operationTag = "ReluFwd";
395  break;
396  case CUDNN_POINTWISE_TANH_FWD:
397  m_operation.operationTag = "TanhFwd";
398  break;
399  case CUDNN_POINTWISE_SIGMOID_FWD:
400  m_operation.operationTag = "SigmoidFwd";
401  break;
402  case CUDNN_POINTWISE_ELU_FWD:
403  m_operation.operationTag = "EluFwd";
404  break;
405  case CUDNN_POINTWISE_GELU_FWD:
406  m_operation.operationTag = "GeluFwd";
407  break;
408  case CUDNN_POINTWISE_SOFTPLUS_FWD:
409  m_operation.operationTag = "SoftplusFwd";
410  break;
411  case CUDNN_POINTWISE_SWISH_FWD:
412  m_operation.operationTag = "SwishFwd";
413  break;
414  case CUDNN_POINTWISE_RELU_BWD:
415  m_operation.operationTag = "ReluBwd";
416  break;
417  case CUDNN_POINTWISE_TANH_BWD:
418  m_operation.operationTag = "TanhBwd";
419  break;
420  case CUDNN_POINTWISE_SIGMOID_BWD:
421  m_operation.operationTag = "SigmoidBwd";
422  break;
423  case CUDNN_POINTWISE_ELU_BWD:
424  m_operation.operationTag = "EluBwd";
425  break;
426  case CUDNN_POINTWISE_GELU_BWD:
427  m_operation.operationTag = "GeluBwd";
428  break;
429  case CUDNN_POINTWISE_SOFTPLUS_BWD:
430  m_operation.operationTag = "SoftplusBwd";
431  break;
432  case CUDNN_POINTWISE_SWISH_BWD:
433  m_operation.operationTag = "SwishBwd";
434  break;
435  }
436 
437  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
438  CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR,
439  CUDNN_TYPE_BACKEND_DESCRIPTOR,
440  1,
441  &(m_operation.pwdesc->get_backend_descriptor()));
442  if (status != CUDNN_STATUS_SUCCESS) {
444  &m_operation,
445  status,
446  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_PW_DESCRIPTOR Failed");
447  return std::move(m_operation);
448  }
449 
450  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
451  CUDNN_ATTR_OPERATION_POINTWISE_XDESC,
452  CUDNN_TYPE_BACKEND_DESCRIPTOR,
453  1,
454  &(m_operation.xdesc->get_backend_descriptor()));
455  if (status != CUDNN_STATUS_SUCCESS) {
457  &m_operation,
458  status,
459  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_XDESC Failed");
460  return std::move(m_operation);
461  }
462 
463  if (!m_operation.is_pointwise_activation_bwd_op) {
464  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
465  CUDNN_ATTR_OPERATION_POINTWISE_YDESC,
466  CUDNN_TYPE_BACKEND_DESCRIPTOR,
467  1,
468  &(m_operation.ydesc->get_backend_descriptor()));
469  if (status != CUDNN_STATUS_SUCCESS) {
471  &m_operation,
472  status,
473  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_YDESC Failed");
474  return std::move(m_operation);
475  }
476  } else {
477  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
478  CUDNN_ATTR_OPERATION_POINTWISE_DYDESC,
479  CUDNN_TYPE_BACKEND_DESCRIPTOR,
480  1,
481  &(m_operation.dydesc->get_backend_descriptor()));
482  if (status != CUDNN_STATUS_SUCCESS) {
484  &m_operation,
485  status,
486  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DYDESC Failed");
487  return std::move(m_operation);
488  }
489 
490  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
491  CUDNN_ATTR_OPERATION_POINTWISE_DXDESC,
492  CUDNN_TYPE_BACKEND_DESCRIPTOR,
493  1,
494  &(m_operation.dxdesc->get_backend_descriptor()));
495  if (status != CUDNN_STATUS_SUCCESS) {
497  &m_operation,
498  status,
499  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_DXDESC Failed");
500  return std::move(m_operation);
501  }
502  }
503 
504  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
505  : static_cast<void *>(&m_operation.alpha_d));
506  void *alpha2 = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha2_s)
507  : static_cast<void *>(&m_operation.alpha2_d));
508  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
509  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1,
510  m_operation.alphabetaType,
511  1,
512  alpha);
513  if (status != CUDNN_STATUS_SUCCESS) {
515  &m_operation,
516  status,
517  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA1 Failed");
518  return std::move(m_operation);
519  }
520  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
521  CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2,
522  m_operation.alphabetaType,
523  1,
524  alpha2);
525  if (status != CUDNN_STATUS_SUCCESS) {
527  &m_operation,
528  status,
529  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_ALPHA2 Failed");
530  return std::move(m_operation);
531  }
532 
533  if (m_operation.pointwise_port_count == 3 && !m_operation.is_pointwise_activation_bwd_op) {
534  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
535  CUDNN_ATTR_OPERATION_POINTWISE_BDESC,
536  CUDNN_TYPE_BACKEND_DESCRIPTOR,
537  1,
538  &(m_operation.bdesc->get_backend_descriptor()));
539  if (status != CUDNN_STATUS_SUCCESS) {
541  &m_operation,
542  status,
543  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_POINTWISE_BDESC Failed");
544  return std::move(m_operation);
545  }
546  }
547  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
548  if (status != CUDNN_STATUS_SUCCESS) {
549  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
550  return std::move(m_operation);
551  }
552  return std::move(m_operation);
553  }
554 
555  Operation_v8 &&
557  m_operation.operationTag = "ConvBwdData";
558 
559  auto status = CUDNN_STATUS_SUCCESS;
560 
561  auto dxdesc_ = m_operation.dxdesc != nullptr ? m_operation.dxdesc : m_operation.xdesc;
562  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
563  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX,
564  CUDNN_TYPE_BACKEND_DESCRIPTOR,
565  1,
566  &(dxdesc_->get_backend_descriptor()));
567  if (status != CUDNN_STATUS_SUCCESS) {
569  &m_operation,
570  status,
571  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX Failed");
572  return std::move(m_operation);
573  }
574 
575  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
576  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W,
577  CUDNN_TYPE_BACKEND_DESCRIPTOR,
578  1,
579  &(m_operation.wdesc->get_backend_descriptor()));
580  if (status != CUDNN_STATUS_SUCCESS) {
582  &m_operation,
583  status,
584  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W Failed");
585  return std::move(m_operation);
586  }
587 
588  auto dydesc_ = m_operation.dydesc != nullptr ? m_operation.dydesc : m_operation.ydesc;
589  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
590  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY,
591  CUDNN_TYPE_BACKEND_DESCRIPTOR,
592  1,
593  &(dydesc_->get_backend_descriptor()));
594  if (status != CUDNN_STATUS_SUCCESS) {
596  &m_operation,
597  status,
598  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY Failed");
599  return std::move(m_operation);
600  }
601 
602  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
603  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC,
604  CUDNN_TYPE_BACKEND_DESCRIPTOR,
605  1,
606  &(m_operation.cdesc->get_backend_descriptor()));
607  if (status != CUDNN_STATUS_SUCCESS) {
609  &m_operation,
610  status,
611  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC Failed");
612  return std::move(m_operation);
613  }
614 
615  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
616  : static_cast<void *>(&m_operation.alpha_d));
617  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
618  : static_cast<void *>(&m_operation.beta_d));
619  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
620  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA,
621  m_operation.alphabetaType,
622  1,
623  alpha);
624  if (status != CUDNN_STATUS_SUCCESS) {
626  &m_operation,
627  status,
628  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA Failed");
629  return std::move(m_operation);
630  }
631  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
632  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA,
633  m_operation.alphabetaType,
634  1,
635  beta);
636  if (status != CUDNN_STATUS_SUCCESS) {
638  &m_operation,
639  status,
640  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA Failed");
641  return std::move(m_operation);
642  }
643 
644  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
645  if (status != CUDNN_STATUS_SUCCESS) {
646  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
647  return std::move(m_operation);
648  }
649  getLogger() << "Extracting the feature vector" << std::endl;
650  extract_feature_vector(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
651  return std::move(m_operation);
652  }
653 
654  Operation_v8 &&
656  m_operation.operationTag = "ConvBwdFilter";
657 
658  auto status = CUDNN_STATUS_SUCCESS;
659 
660  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
661  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X,
662  CUDNN_TYPE_BACKEND_DESCRIPTOR,
663  1,
664  &(m_operation.xdesc->get_backend_descriptor()));
665  if (status != CUDNN_STATUS_SUCCESS) {
667  &m_operation,
668  status,
669  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X Failed");
670  return std::move(m_operation);
671  }
672 
673  auto dwdesc_ = m_operation.dwdesc != nullptr ? m_operation.dwdesc : m_operation.wdesc;
674  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
675  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW,
676  CUDNN_TYPE_BACKEND_DESCRIPTOR,
677  1,
678  &(dwdesc_->get_backend_descriptor()));
679  if (status != CUDNN_STATUS_SUCCESS) {
681  &m_operation,
682  status,
683  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW Failed");
684  return std::move(m_operation);
685  }
686 
687  auto dydesc_ = m_operation.dydesc != nullptr ? m_operation.dydesc : m_operation.ydesc;
688  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
689  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY,
690  CUDNN_TYPE_BACKEND_DESCRIPTOR,
691  1,
692  &(dydesc_->get_backend_descriptor()));
693  if (status != CUDNN_STATUS_SUCCESS) {
695  &m_operation,
696  status,
697  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY Failed");
698  return std::move(m_operation);
699  }
700 
701  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
702  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC,
703  CUDNN_TYPE_BACKEND_DESCRIPTOR,
704  1,
705  &(m_operation.cdesc->get_backend_descriptor()));
706  if (status != CUDNN_STATUS_SUCCESS) {
707  set_error_and_throw_exception(&m_operation,
708  status,
709  "CUDNN_BACKEND_OPERATION: SetAttribute "
710  "CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC Failed");
711  return std::move(m_operation);
712  }
713  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
714  : static_cast<void *>(&m_operation.alpha_d));
715  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
716  : static_cast<void *>(&m_operation.beta_d));
717  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
718  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA,
719  m_operation.alphabetaType,
720  1,
721  alpha);
722  if (status != CUDNN_STATUS_SUCCESS) {
724  &m_operation,
725  status,
726  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA Failed");
727  return std::move(m_operation);
728  }
729  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
730  CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA,
731  m_operation.alphabetaType,
732  1,
733  beta);
734  if (status != CUDNN_STATUS_SUCCESS) {
736  &m_operation,
737  status,
738  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA Failed");
739  return std::move(m_operation);
740  }
741 
742  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
743  if (status != CUDNN_STATUS_SUCCESS) {
744  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
745  return std::move(m_operation);
746  }
747  getLogger() << "Extracting the feature vector" << std::endl;
748  extract_feature_vector(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
749  return std::move(m_operation);
750  }
751 
752  Operation_v8 &&
754  m_operation.operationTag = "ConvFwd";
755 
756  auto status = CUDNN_STATUS_SUCCESS;
757 
758  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
759  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X,
760  CUDNN_TYPE_BACKEND_DESCRIPTOR,
761  1,
762  &(m_operation.xdesc->get_backend_descriptor()));
763  if (status != CUDNN_STATUS_SUCCESS) {
765  &m_operation,
766  status,
767  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X Failed");
768  return std::move(m_operation);
769  }
770  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
771  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W,
772  CUDNN_TYPE_BACKEND_DESCRIPTOR,
773  1,
774  &(m_operation.wdesc->get_backend_descriptor()));
775  if (status != CUDNN_STATUS_SUCCESS) {
777  &m_operation,
778  status,
779  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W Failed");
780  return std::move(m_operation);
781  }
782  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
783  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y,
784  CUDNN_TYPE_BACKEND_DESCRIPTOR,
785  1,
786  &(m_operation.ydesc->get_backend_descriptor()));
787  if (status != CUDNN_STATUS_SUCCESS) {
789  &m_operation,
790  status,
791  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y Failed");
792  return std::move(m_operation);
793  }
794  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
795  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC,
796  CUDNN_TYPE_BACKEND_DESCRIPTOR,
797  1,
798  &(m_operation.cdesc->get_backend_descriptor()));
799  if (status != CUDNN_STATUS_SUCCESS) {
801  &m_operation,
802  status,
803  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC Failed");
804  return std::move(m_operation);
805  }
806  void *alpha = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.alpha_s)
807  : static_cast<void *>(&m_operation.alpha_d));
808  void *beta = (m_operation.alphabetaType == CUDNN_TYPE_FLOAT ? static_cast<void *>(&m_operation.beta_s)
809  : static_cast<void *>(&m_operation.beta_d));
810  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
811  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA,
812  m_operation.alphabetaType,
813  1,
814  alpha);
815  if (status != CUDNN_STATUS_SUCCESS) {
817  &m_operation,
818  status,
819  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA Failed");
820  return std::move(m_operation);
821  }
822  status = cudnnBackendSetAttribute(m_operation.pointer->get_backend_descriptor(),
823  CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA,
824  m_operation.alphabetaType,
825  1,
826  beta);
827  if (status != CUDNN_STATUS_SUCCESS) {
829  &m_operation,
830  status,
831  "CUDNN_BACKEND_OPERATION: SetAttribute CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA Failed");
832  return std::move(m_operation);
833  }
834  status = cudnnBackendFinalize(m_operation.pointer->get_backend_descriptor());
835  if (status != CUDNN_STATUS_SUCCESS) {
836  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnFinalize Failed");
837  return std::move(m_operation);
838  }
839 
840  getLogger() << "Extracting the feature vector" << std::endl;
841  extract_feature_vector(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR);
842  return std::move(m_operation);
843  }
844 
845  void extract_feature_vector(cudnnBackendDescriptorType_t op_type) {
847  m_operation.feature_vector.reserve(50);
848 
849  m_operation.feature_vector.push_back(op_type);
850  for (auto i = 0; i < tensor_dims; i++) {
851  m_operation.feature_vector.push_back(xTensor_dimA[i]); // n, c, (g), d, h , w
852  }
853  for (auto i = 0; i < tensor_dims; i++) {
854  m_operation.feature_vector.push_back(wTensor_dimA[i]); // n, c, (g), d, h , w
855  }
856  for (auto i = 0; i < tensor_dims; i++) {
857  m_operation.feature_vector.push_back(yTensor_dimA[i]); // n, c, (g), d, h , w
858  }
859  const int max_spatial_dim = 3;
860 
862  for (auto i = 0; i < max_spatial_dim; i++) {
863  if (i == 0 && is2D) {
864  m_operation.feature_vector.push_back(0);
865  } else {
866  m_operation.feature_vector.push_back(conv_padding[i]);
867  }
868  }
870  for (auto i = 0; i < max_spatial_dim; i++) {
871  if (i == 0 && is2D) {
872  m_operation.feature_vector.push_back(0);
873  } else {
874  m_operation.feature_vector.push_back(conv_dilation[i]);
875  }
876  }
878  for (auto i = 0; i < max_spatial_dim; i++) {
879  if (i == 0 && is2D) {
880  m_operation.feature_vector.push_back(0);
881  } else {
882  m_operation.feature_vector.push_back(conv_stride[i]);
883  }
884  }
885 
886  m_operation.feature_vector.push_back(xType);
887  m_operation.feature_vector.push_back(wType);
888  m_operation.feature_vector.push_back(yType);
889  m_operation.feature_vector.push_back(cType);
890  m_operation.feature_vector.push_back(mode);
891 
892  for (auto i = 0; i < tensor_dims; i++) {
893  m_operation.feature_vector.push_back(xTensor_strA[i]); // n, c, (g), d, h , w
894  }
895  for (auto i = 0; i < tensor_dims; i++) {
896  m_operation.feature_vector.push_back(wTensor_strA[i]); // n, c, (g), d, h , w
897  }
898  for (auto i = 0; i < tensor_dims; i++) {
899  m_operation.feature_vector.push_back(yTensor_strA[i]); // n, c, (g), d, h , w
900  }
901  }
902 
903  cudnnStatus_t
905  if (m_operation.matmuldesc == nullptr) {
906  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_DESC";
907  return CUDNN_STATUS_BAD_PARAM;
908  }
909  if (m_operation.amatdesc == nullptr) {
910  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_ADESC";
911  return CUDNN_STATUS_BAD_PARAM;
912  }
913  if (m_operation.bmatdesc == nullptr) {
914  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_BDESC";
915  return CUDNN_STATUS_BAD_PARAM;
916  }
917  if (m_operation.cmatdesc == nullptr) {
918  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_MATMUL_CDESC";
919  return CUDNN_STATUS_BAD_PARAM;
920  }
921  return CUDNN_STATUS_SUCCESS;
922  }
923 
924  cudnnStatus_t
926  if (m_operation.reductiondesc == nullptr) {
927  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_DESC";
928  return CUDNN_STATUS_BAD_PARAM;
929  }
930  if (m_operation.xdesc == nullptr) {
931  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_XDESC";
932  return CUDNN_STATUS_BAD_PARAM;
933  }
934  if (m_operation.ydesc == nullptr) {
935  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_REDUCTION_YDESC";
936  return CUDNN_STATUS_BAD_PARAM;
937  }
938  return CUDNN_STATUS_SUCCESS;
939  }
940 
941  cudnnStatus_t
943  if (m_operation.xdesc == nullptr) {
944  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_XDESC";
945  return CUDNN_STATUS_BAD_PARAM;
946  }
947  if (m_operation.is_pointwise_math_op) {
948  if (m_operation.pointwise_port_count == 3 && m_operation.bdesc == nullptr) {
949  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_BDESC";
950  return CUDNN_STATUS_BAD_PARAM;
951  }
952  if (m_operation.ydesc == nullptr) {
953  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC";
954  return CUDNN_STATUS_BAD_PARAM;
955  }
956  } else if (m_operation.is_pointwise_activation_fwd_op) {
957  if (m_operation.ydesc == nullptr) {
958  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_YDESC";
959  return CUDNN_STATUS_BAD_PARAM;
960  }
961  } else if (m_operation.is_pointwise_activation_bwd_op) {
962  if (m_operation.dydesc == nullptr) {
963  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DYDESC";
964  return CUDNN_STATUS_BAD_PARAM;
965  }
966  if (m_operation.dxdesc == nullptr) {
967  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_POINTWISE_DXDESC";
968  return CUDNN_STATUS_BAD_PARAM;
969  }
970  } else {
971  msg = "CUDNN_BACKEND_OPERATION: Unsupported cudnn pointwise mode. Check and set CUDNN_POINTWISE_*";
972  return CUDNN_STATUS_BAD_PARAM;
973  }
974  return CUDNN_STATUS_SUCCESS;
975  }
976 
977  cudnnStatus_t
979  if (m_operation.cdesc == nullptr) {
980  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_CONV_DESC";
981  return CUDNN_STATUS_BAD_PARAM;
982  }
983  if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
984  if (m_operation.xdesc == nullptr) {
985  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X";
986  return CUDNN_STATUS_BAD_PARAM;
987  }
988  if (m_operation.wdesc == nullptr) {
989  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W";
990  return CUDNN_STATUS_BAD_PARAM;
991  }
992  if (m_operation.ydesc == nullptr) {
993  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_Y";
994  return CUDNN_STATUS_BAD_PARAM;
995  }
996 
997  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
998  if (m_operation.ydesc != nullptr && m_operation.dydesc != nullptr) {
999  msg = "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set only one of setyDesc() or setdyDesc()";
1000  return CUDNN_STATUS_BAD_PARAM;
1001  }
1002  if (m_operation.ydesc == nullptr && m_operation.dydesc == nullptr) {
1003  msg = "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()";
1004  return CUDNN_STATUS_BAD_PARAM;
1005  }
1006  if (m_operation.xdesc == nullptr) {
1007  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_X";
1008  return CUDNN_STATUS_BAD_PARAM;
1009  }
1010  if (m_operation.wdesc != nullptr && m_operation.dwdesc != nullptr) {
1011  msg = "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set only one of setwDesc() or setdwDesc()";
1012  return CUDNN_STATUS_BAD_PARAM;
1013  }
1014  if (m_operation.wdesc == nullptr && m_operation.dwdesc == nullptr) {
1015  msg = "CUDNN_BACKEND_OPERATION: Choose and Set one of setwDesc() or setdwDesc()";
1016  return CUDNN_STATUS_BAD_PARAM;
1017  }
1018  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
1019  if (m_operation.ydesc != nullptr && m_operation.dydesc != nullptr) {
1020  msg = "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set only one of setyDesc() or setdyDesc()";
1021  return CUDNN_STATUS_BAD_PARAM;
1022  }
1023  if (m_operation.ydesc == nullptr && m_operation.dydesc == nullptr) {
1024  msg = "CUDNN_BACKEND_OPERATION: Choose and Set one of setyDesc() or setdyDesc()";
1025  return CUDNN_STATUS_BAD_PARAM;
1026  }
1027  if (m_operation.wdesc == nullptr) {
1028  msg = "CUDNN_BACKEND_OPERATION: Check and Set the CUDNN_ATTR_OPERATION_CONVOLUTION_*_W";
1029  return CUDNN_STATUS_BAD_PARAM;
1030  }
1031  if (m_operation.xdesc != nullptr && m_operation.dxdesc != nullptr) {
1032  msg = "CUDNN_BACKEND_OPERATION: Ambiguous specification. Choose and Set only one of setxDesc() or setdxDesc()";
1033  return CUDNN_STATUS_BAD_PARAM;
1034  }
1035  if (m_operation.xdesc == nullptr && m_operation.dxdesc == nullptr) {
1036  msg = "CUDNN_BACKEND_OPERATION: Choose and Set one of setxDesc() or setdxDesc()";
1037  return CUDNN_STATUS_BAD_PARAM;
1038  }
1039  } else {
1040  msg = "CUDNN_BACKEND_OPERATION: Unsupported convolution operation. Check and set CUDNN_BACKEND_OPERATION_CONVOLUTION_*_DESCRIPTOR";
1041  return CUDNN_STATUS_BAD_PARAM;
1042  }
1043  return CUDNN_STATUS_SUCCESS;
1044  }
1045 
1046  void
1047  copy_dims_and_strides(const int64_t *from, int64_t *to) const {
1048  for (auto i = 0; i < CUDNN_DIM_MAX + 1; i++) {
1049  to[i] = from[i];
1050  }
1051  }
1052 
1053  public:
1058  auto
1061  m_operation.xdesc = raw_tensor;
1062  return *this;
1063  }
1064 
1065  auto
1067  m_operation.xdesc = tensor.get_desc();
1068  copy_dims_and_strides(tensor.getDimArray(), xTensor_dimA);
1069  copy_dims_and_strides(tensor.getStrideArray(), xTensor_strA);
1070  tensor_dims = tensor.getDimensionCount();
1071  xType = tensor.getDataType();
1072  return *this;
1073  }
1074  auto
1076  if (is_pointwise_op == false) {
1078  &m_operation,
1079  CUDNN_STATUS_BAD_PARAM,
1080  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need bTensor");
1081  }
1082  m_operation.bdesc = tensor.get_desc();
1083  return *this;
1084  }
1085  auto
1087  m_operation.ydesc = tensor.get_desc();
1088  copy_dims_and_strides(tensor.getDimArray(), yTensor_dimA);
1089  copy_dims_and_strides(tensor.getStrideArray(), yTensor_strA);
1090  yType = tensor.getDataType();
1091  return *this;
1092  }
1093  auto
1095  if (is_convolution_op == false) {
1097  &m_operation,
1098  CUDNN_STATUS_BAD_PARAM,
1099  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need wTensor");
1100  }
1101  m_operation.wdesc = tensor.get_desc();
1102  copy_dims_and_strides(tensor.getDimArray(), wTensor_dimA);
1103  copy_dims_and_strides(tensor.getStrideArray(), wTensor_strA);
1104  wType = tensor.getDataType();
1105  return *this;
1106  }
1107 
1109  auto
1111  m_operation.dydesc = raw_tensor;
1112  return *this;
1113  }
1114  auto
1116  m_operation.dydesc = tensor.get_desc();
1117  copy_dims_and_strides(tensor.getDimArray(), yTensor_dimA);
1118  copy_dims_and_strides(tensor.getStrideArray(), yTensor_strA);
1119  yType = tensor.getDataType();
1120  return *this;
1121  }
1122  auto
1124  m_operation.dxdesc = tensor.get_desc();
1125  copy_dims_and_strides(tensor.getDimArray(), xTensor_dimA);
1126  copy_dims_and_strides(tensor.getStrideArray(), xTensor_strA);
1127  tensor_dims = tensor.getDimensionCount();
1128  xType = tensor.getDataType();
1129  return *this;
1130  }
1131  auto
1133  m_operation.dwdesc = tensor.get_desc();
1134  copy_dims_and_strides(tensor.getDimArray(), wTensor_dimA);
1135  copy_dims_and_strides(tensor.getStrideArray(), wTensor_strA);
1136  wType = tensor.getDataType();
1137  return *this;
1138  }
1139 
1140  auto
1142  if (is_convolution_op == false) {
1144  &m_operation,
1145  CUDNN_STATUS_BAD_PARAM,
1146  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Convolution operation does not need Convolution DESCRIPTOR");
1147  }
1148  m_operation.cdesc = conv.get_desc();
1149  if (conv.getComputePrecision() == CUDNN_DATA_DOUBLE) {
1150  m_operation.alphabetaType = CUDNN_TYPE_DOUBLE;
1151  }
1152  is2D = conv.getDimensionCount() == 2;
1153  copy_dims_and_strides(conv.getPadding(), conv_padding);
1154  copy_dims_and_strides(conv.getDilation(), conv_dilation);
1155  copy_dims_and_strides(conv.getStride(), conv_stride);
1156  cType = conv.getComputePrecision();
1157  mode = conv.getMathMode();
1158  return *this;
1159  }
1160  auto
1162  if (is_matmul_op == false) {
1164  &m_operation,
1165  CUDNN_STATUS_BAD_PARAM,
1166  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need a Matrix Tensor");
1167  }
1168  m_operation.amatdesc = tensor.get_desc();
1169  return *this;
1170  }
1171  auto
1173  if (is_matmul_op == false) {
1175  &m_operation,
1176  CUDNN_STATUS_BAD_PARAM,
1177  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need b Matrix Tensor");
1178  }
1179  m_operation.bmatdesc = tensor.get_desc();
1180  return *this;
1181  }
1182  auto
1184  if (is_matmul_op == false) {
1186  &m_operation,
1187  CUDNN_STATUS_BAD_PARAM,
1188  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need c Matrix Tensor");
1189  }
1190  m_operation.cmatdesc = tensor.get_desc();
1191  return *this;
1192  }
1193  auto
1195  if (is_matmul_op == false) {
1197  &m_operation,
1198  CUDNN_STATUS_BAD_PARAM,
1199  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Matmul operation does not need MATMUL DESCRIPTOR");
1200  }
1201  m_operation.matmuldesc = matmulDesc.get_desc();
1202  return *this;
1203  }
1204  auto
1206  if (is_reduction_op == false) {
1208  &m_operation,
1209  CUDNN_STATUS_BAD_PARAM,
1210  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Reduction operation does not need REDUCTION DESCRIPTOR");
1211  }
1212  m_operation.reductiondesc = reductionDesc.get_desc();
1213  return *this;
1214  }
1215  auto
1216  setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 & {
1217  if (is_pointwise_op == false) {
1219  &m_operation,
1220  CUDNN_STATUS_BAD_PARAM,
1221  "CUDNN_BACKEND_OPERATION_*_DESCRIPTOR: Non Pointwise operation does not need POINTWISE DESCRIPTOR");
1222  }
1223  m_operation.pwdesc = pointWiseDesc.get_desc();
1224  m_operation.pointwise_port_count = pointWiseDesc.getPortCount();
1225  m_operation.pointwise_mode = pointWiseDesc.getPointWiseMode();
1226 
1227  m_operation.is_pointwise_math_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_ADD) ||
1228  (m_operation.pointwise_mode == CUDNN_POINTWISE_MUL) ||
1229 #if (CUDNN_VERSION >= 8300)
1230  (m_operation.pointwise_mode == CUDNN_POINTWISE_DIV) ||
1231  (m_operation.pointwise_mode == CUDNN_POINTWISE_SUB) ||
1232  (m_operation.pointwise_mode == CUDNN_POINTWISE_ADD_SQUARE) ||
1233  (m_operation.pointwise_mode == CUDNN_POINTWISE_RSQRT) ||
1234  (m_operation.pointwise_mode == CUDNN_POINTWISE_SIN) ||
1235  (m_operation.pointwise_mode == CUDNN_POINTWISE_COS) ||
1236  (m_operation.pointwise_mode == CUDNN_POINTWISE_TAN) ||
1237  (m_operation.pointwise_mode == CUDNN_POINTWISE_LOGICAL_OR) ||
1238  (m_operation.pointwise_mode == CUDNN_POINTWISE_LOGICAL_AND) ||
1239  (m_operation.pointwise_mode == CUDNN_POINTWISE_LOGICAL_NOT) ||
1240  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_EQ) ||
1241  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_NEQ) ||
1242  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_GT) ||
1243  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_GE) ||
1244  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_LT) ||
1245  (m_operation.pointwise_mode == CUDNN_POINTWISE_CMP_LE) ||
1246  (m_operation.pointwise_mode == CUDNN_POINTWISE_LOG) ||
1247  (m_operation.pointwise_mode == CUDNN_POINTWISE_NEG) ||
1248  (m_operation.pointwise_mode == CUDNN_POINTWISE_MOD) ||
1249  (m_operation.pointwise_mode == CUDNN_POINTWISE_POW) ||
1250  (m_operation.pointwise_mode == CUDNN_POINTWISE_ABS) ||
1251  (m_operation.pointwise_mode == CUDNN_POINTWISE_CEIL) ||
1252  (m_operation.pointwise_mode == CUDNN_POINTWISE_FLOOR) ||
1253 #endif
1254  (m_operation.pointwise_mode == CUDNN_POINTWISE_MIN) ||
1255  (m_operation.pointwise_mode == CUDNN_POINTWISE_MAX) ||
1256  (m_operation.pointwise_mode == CUDNN_POINTWISE_SQRT));
1257 
1258  m_operation.is_pointwise_activation_fwd_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_RELU_FWD) ||
1259  (m_operation.pointwise_mode == CUDNN_POINTWISE_TANH_FWD) ||
1260  (m_operation.pointwise_mode == CUDNN_POINTWISE_SIGMOID_FWD) ||
1261  (m_operation.pointwise_mode == CUDNN_POINTWISE_ELU_FWD) ||
1262  (m_operation.pointwise_mode == CUDNN_POINTWISE_GELU_FWD) ||
1263  (m_operation.pointwise_mode == CUDNN_POINTWISE_SOFTPLUS_FWD) ||
1264 #if (CUDNN_VERSION >= 8300)
1265  (m_operation.pointwise_mode == CUDNN_POINTWISE_EXP) ||
1266 #endif
1267  (m_operation.pointwise_mode == CUDNN_POINTWISE_SWISH_FWD));
1268 
1269  m_operation.is_pointwise_activation_bwd_op = ((m_operation.pointwise_mode == CUDNN_POINTWISE_RELU_BWD) ||
1270  (m_operation.pointwise_mode == CUDNN_POINTWISE_TANH_BWD) ||
1271  (m_operation.pointwise_mode == CUDNN_POINTWISE_SIGMOID_BWD) ||
1272  (m_operation.pointwise_mode == CUDNN_POINTWISE_ELU_BWD) ||
1273  (m_operation.pointwise_mode == CUDNN_POINTWISE_GELU_BWD) ||
1274  (m_operation.pointwise_mode == CUDNN_POINTWISE_SOFTPLUS_BWD) ||
1275  (m_operation.pointwise_mode == CUDNN_POINTWISE_SWISH_BWD));
1276 
1277  return *this;
1278  }
1279 
1280  auto
1281  setAlpha(float alpha) -> OperationBuilder_v8 & {
1282  m_operation.alpha_d = static_cast<double>(alpha);
1283  m_operation.alpha_s = alpha;
1284  return *this;
1285  }
1286  auto
1287  setAlpha(double alpha) -> OperationBuilder_v8 & {
1288  m_operation.alpha_s = static_cast<float>(alpha);
1289  m_operation.alpha_d = alpha;
1290  return *this;
1291  }
1292  auto
1293  setAlpha2(float alpha) -> OperationBuilder_v8 & {
1294  m_operation.alpha2_d = static_cast<double>(alpha);
1295  m_operation.alpha2_s = alpha;
1296  return *this;
1297  }
1298  auto
1299  setAlpha2(double alpha) -> OperationBuilder_v8 & {
1300  m_operation.alpha2_s = static_cast<float>(alpha);
1301  m_operation.alpha2_d = alpha;
1302  return *this;
1303  }
1304  auto
1305  setBeta(float beta) -> OperationBuilder_v8 & {
1306  m_operation.beta_d = static_cast<double>(beta);
1307  m_operation.beta_s = beta;
1308  return *this;
1309  }
1310  auto
1311  setBeta(double beta) -> OperationBuilder_v8 & {
1312  m_operation.beta_s = static_cast<float>(beta);
1313  m_operation.beta_d = beta;
1314  return *this;
1315  }
1316 
1317  OperationBuilder_v8(cudnnBackendDescriptorType_t mode) {
1318  m_operation.op_mode = mode;
1319  is_convolution_op = ((m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) ||
1320  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) ||
1321  (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR));
1322 
1323  is_pointwise_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR);
1324  is_matmul_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR);
1325  is_reduction_op = (m_operation.op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR);
1326  }
1329  Operation_v8 &&
1332  build() {
1333  if (m_operation.status != CUDNN_STATUS_SUCCESS) {
1335  &m_operation, m_operation.status, "CUDNN_BACKEND_OPERATION: Operation not initialized properly");
1336  return std::move(m_operation);
1337  }
1338 
1339  Message_t msg = nullptr;
1340  cudnnStatus_t status_ = CUDNN_STATUS_SUCCESS;
1341  if (is_convolution_op) {
1342  status_ = validate_convolution_op(msg);
1343  } else if (is_pointwise_op) {
1344  status_ = validate_pointwise_op(msg);
1345  } else if (is_matmul_op) {
1346  status_ = validate_matmul_op(msg);
1347  } else if (is_reduction_op) {
1348  status_ = validate_reduction_op(msg);
1349  } else {
1350  status_ = CUDNN_STATUS_BAD_PARAM;
1351  msg = "CUDNN_BACKEND_OPERATION_DESCRIPTOR: Unsupported cudnn backend descriptor type. Check and set CUDNN_BACKEND_OPERATION_*_DESCRIPTOR";
1352  }
1353  if (status_ != CUDNN_STATUS_SUCCESS) {
1354  set_error_and_throw_exception(&m_operation, status_, msg);
1355  return std::move(m_operation);
1356  }
1357 
1358  // Create the descriptor.
1359  auto status = m_operation.initialize_managed_backend_pointer(m_operation.op_mode);
1360  if (status != CUDNN_STATUS_SUCCESS) {
1361  set_error_and_throw_exception(&m_operation, status, "CUDNN_BACKEND_OPERATION: cudnnCreate Failed");
1362  return std::move(m_operation);
1363  }
1364 
1365  if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
1366  return build_conv_forward();
1367  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
1368  return build_conv_backward_filter();
1369  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
1370  return build_conv_backward_data();
1371  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) {
1372  return build_pointwise_op();
1373  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) {
1374  return build_matmul_op();
1375  } else if (m_operation.op_mode == CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) {
1376  return build_reduction_op();
1377  }
1378  getLogger() << "[cudnn_frontend] " << m_operation << std::endl;
1379  return std::move(m_operation);
1380  }
1381 };
1382 }
std::vector< int64_t > feature_vector_t
Detailed feature_vector. Generally the Tensor and Operation properties.
auto setcDesc(ConvDesc_v8 const &conv) -> OperationBuilder_v8 &
ConditionalStreamer & getLogger()
cudnnStatus_t initialize_managed_backend_pointer(cudnnBackendDescriptorType_t type)
Initializes the underlying managed descriptor.
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
NLOHMANN_BASIC_JSON_TPL_DECLARATION std::string to_string(const NLOHMANN_BASIC_JSON_TPL &j)
user-defined to_string function for JSON values
Definition: json.hpp:25855
auto setAlpha(float alpha) -> OperationBuilder_v8 &
cudnnStatus_t validate_matmul_op(Message_t &msg)
cudnnStatus_t validate_convolution_op(Message_t &msg)
auto setdxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
cudnnStatus_t validate_reduction_op(Message_t &msg)
auto setbDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
void copy_dims_and_strides(const int64_t *from, int64_t *to) const
auto setaMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
void extract_feature_vector(cudnnBackendDescriptorType_t op_type)
auto setmatmulDesc(MatMulDesc_v8 const &matmulDesc) -> OperationBuilder_v8 &
cudnnBackendDescriptorType_t op_mode
auto setBeta(float beta) -> OperationBuilder_v8 &
auto setpwDesc(PointWiseDesc_v8 const &pointWiseDesc) -> OperationBuilder_v8 &
auto setAlpha2(float alpha) -> OperationBuilder_v8 &
auto setdwDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setreductionDesc(ReductionDesc_v8 const &reductionDesc) -> OperationBuilder_v8 &
auto setBeta(double beta) -> OperationBuilder_v8 &
auto setbMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor reductiondesc
std::shared_ptr< OpaqueBackendPointer > ManagedOpaqueDescriptor
auto setdyDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
Will be Deprecated Do not use.
std::string describe() const override
Return a string describing the backend Descriptor.
Operation_v8 & operator=(Operation_v8 &&from)=default
auto setdyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
cudnnStatus_t validate_pointwise_op(Message_t &msg)
cudnnBackendAttributeType_t alphabetaType
auto setxDesc(ManagedOpaqueDescriptor const &raw_tensor) -> OperationBuilder_v8 &
Will be Deprecated Do not use.
auto setcMatDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setyDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
auto setxDesc(Tensor_v8 const &tensor) -> OperationBuilder_v8 &
ManagedOpaqueDescriptor getOutputTensor()
auto setAlpha2(double alpha) -> OperationBuilder_v8 &
OperationBuilder_v8(cudnnBackendDescriptorType_t mode)
std::string const & getTag() const
auto setAlpha(double alpha) -> OperationBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
feature_vector_t getFeatureVector() const