CUDNN Frontend API  8.3.0
cudnn_frontend_VariantPack.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice shall be included in
12  * all copies or substantial portions of the Software.
13  *
14  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
17  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20  * DEALINGS IN THE SOFTWARE.
21  */
22 
23 #pragma once
24 
25 #include <algorithm>
26 #include <array>
27 #include <functional>
28 #include <memory>
29 #include <set>
30 #include <sstream>
31 #include <utility>
32 
33 #include <cudnn.h>
34 #include <cudnn_backend.h>
35 
36 #include "cudnn_frontend_utils.h"
37 
38 namespace cudnn_frontend {
39 
53  public:
54  friend class VariantPackBuilder_v8;
55  std::string
56  describe() const override {
57  std::stringstream ss;
58  ss << "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR :"
59  << " has " << num_ptrs << " data pointers";
60  return ss.str();
61  }
62 
63  VariantPack_v8(VariantPack_v8 &&from) = default;
65  operator=(VariantPack_v8 &&from) = default;
66 
67  ~VariantPack_v8() = default;
68 
69  private:
70  VariantPack_v8() = default;
71  VariantPack_v8(VariantPack_v8 const &) = delete;
73  operator=(VariantPack_v8 const &) = delete;
74 
75  void *workspace = nullptr;
76  void *data_pointers[10] = {nullptr};
77  int64_t uid[10] = {-1};
78  int64_t num_ptrs = -1;
79 };
80 
85  public:
90  auto
92  setDataPointers(int64_t num_ptr, void **ptrs) -> VariantPackBuilder_v8 & {
93  std::copy(ptrs, ptrs + num_ptr, m_variant_pack.data_pointers);
94  m_variant_pack.num_ptrs = num_ptr;
95  return *this;
96  }
98  auto
99  setUids(int64_t num_uids, int64_t *uid) -> VariantPackBuilder_v8 & {
100  std::copy(uid, uid + num_uids, m_variant_pack.uid);
101  return *this;
102  }
104  auto
105  setDataPointers(std::set<std::pair<uint64_t, void *>> const &data_pointers) -> VariantPackBuilder_v8 & {
106  auto i = 0;
107  for (auto &data_pointer : data_pointers) {
108  m_variant_pack.uid[i] = data_pointer.first;
109  m_variant_pack.data_pointers[i] = data_pointer.second;
110  i++;
111  }
112  m_variant_pack.num_ptrs = data_pointers.size();
113  return *this;
114  }
116  auto
118  m_variant_pack.workspace = ws;
119  return *this;
120  }
123  VariantPack_v8 &&
126  build() {
127  // Create a descriptor. Memory allocation happens here.
128  auto status = m_variant_pack.initialize_managed_backend_pointer(CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR);
129  if (status != CUDNN_STATUS_SUCCESS) {
131  &m_variant_pack, status, "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: cudnnCreate Failed");
132  return std::move(m_variant_pack);
133  }
134 
135  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
136  CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS,
137  CUDNN_TYPE_VOID_PTR,
138  m_variant_pack.num_ptrs,
139  m_variant_pack.data_pointers);
140  if (status != CUDNN_STATUS_SUCCESS) {
142  &m_variant_pack,
143  status,
144  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_DATA_POINTERS Failed");
145  return std::move(m_variant_pack);
146  }
147 
148  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
149  CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS,
150  CUDNN_TYPE_INT64,
151  m_variant_pack.num_ptrs,
152  m_variant_pack.uid);
153  if (status != CUDNN_STATUS_SUCCESS) {
155  &m_variant_pack,
156  status,
157  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_UNIQUE_IDS Failed");
158  return std::move(m_variant_pack);
159  }
160 
161  status = cudnnBackendSetAttribute(m_variant_pack.pointer->get_backend_descriptor(),
162  CUDNN_ATTR_VARIANT_PACK_WORKSPACE,
163  CUDNN_TYPE_VOID_PTR,
164  1,
165  &m_variant_pack.workspace);
166  if (status != CUDNN_STATUS_SUCCESS) {
168  &m_variant_pack,
169  status,
170  "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: SetAttribute CUDNN_ATTR_VARIANT_PACK_WORKSPACE Failed");
171  return std::move(m_variant_pack);
172  }
173 
174  // Finalizing the descriptor
175  status = cudnnBackendFinalize(m_variant_pack.pointer->get_backend_descriptor());
176  if (status != CUDNN_STATUS_SUCCESS) {
178  &m_variant_pack, status, "CUDNN_BACKEND_VARIANT_PACK_DESCRIPTOR: cudnnFinalize Failed");
179  return std::move(m_variant_pack);
180  }
181  getLogger() << "[cudnn_frontend] "<< m_variant_pack << std::endl;
182  return std::move(m_variant_pack);
183  }
184 
185  explicit VariantPackBuilder_v8() = default;
186  ~VariantPackBuilder_v8() = default;
190  operator=(VariantPackBuilder_v8 const &) = delete;
191 
192  private:
194 };
195 }
ConditionalStreamer & getLogger()
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setWorkspacePointer(void *ws) -> VariantPackBuilder_v8 &
Set Workspace.
auto setDataPointers(int64_t num_ptr, void **ptrs) -> VariantPackBuilder_v8 &
Set dataPointers for the VariantPack_v8.
auto setUids(int64_t num_uids, int64_t *uid) -> VariantPackBuilder_v8 &
Set Uids for the VariantPack_v8.
auto setDataPointers(std::set< std::pair< uint64_t, void *>> const &data_pointers) -> VariantPackBuilder_v8 &
Initialize a set of pairs containing uid and data pointer.
std::string describe() const override
Return a string describing the backend Descriptor.
VariantPack_v8 & operator=(VariantPack_v8 &&from)=default
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.