34 #include <cudnn_backend.h> 58 ss <<
"CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR :";
75 int64_t global_count = -1;
76 auto status = cudnnBackendGetAttribute(
pointer->get_backend_descriptor(),
77 CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT,
82 if (
status != CUDNN_STATUS_SUCCESS) {
85 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: GetAttribute " 86 "CUDNN_ATTR_OPERATIONGRAPH_ENGINE_GLOBAL_COUNT Failed");
119 std::array<ManagedOpaqueDescriptor, 10>
ops{};
137 m_operationGraph.handle = handle_;
143 m_operationGraph.numOps = numOps_;
144 m_operationGraph.feature_vectors.resize(numOps_);
145 for (
auto i = 0u; i < numOps_; i++) {
146 m_operationGraph.ops[i] = ops_[i]->get_desc();
147 m_operationGraph.opGraphTag += ops_[i]->getTag() +
'_';
148 m_operationGraph.feature_vectors[i] = ops_[i]->getFeatureVector();
158 if (m_operationGraph.numOps <= 0) {
161 CUDNN_STATUS_BAD_PARAM,
162 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set the CUDNN_ATTR_OPERATIONGRAPH_OPS Count field");
163 return std::move(m_operationGraph);
165 if (m_operationGraph.ops[0] ==
nullptr) {
168 CUDNN_STATUS_BAD_PARAM,
169 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and set CUDNN_ATTR_OPERATIONGRAPH_OPS field");
170 return std::move(m_operationGraph);
172 if (m_operationGraph.handle ==
nullptr) {
175 CUDNN_STATUS_BAD_PARAM,
176 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: Check and Set CUDNN_ATTR_OPERATIONGRAPH_HANDLE");
177 return std::move(m_operationGraph);
181 auto status = m_operationGraph.initialize_managed_backend_pointer(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR);
182 if (
status != CUDNN_STATUS_SUCCESS) {
184 &m_operationGraph,
status,
"CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnCreate Failed");
185 return std::move(m_operationGraph);
188 std::array<cudnnBackendDescriptor_t, 10> ops_raw{
nullptr};
189 for (
auto i = 0u; i < m_operationGraph.numOps; i++) {
190 ops_raw[i] = m_operationGraph.ops[i]->get_backend_descriptor();
193 status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
194 CUDNN_ATTR_OPERATIONGRAPH_OPS,
195 CUDNN_TYPE_BACKEND_DESCRIPTOR,
196 m_operationGraph.numOps,
198 if (
status != CUDNN_STATUS_SUCCESS) {
202 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_OPS Failed");
203 return std::move(m_operationGraph);
205 status = cudnnBackendSetAttribute(m_operationGraph.pointer->get_backend_descriptor(),
206 CUDNN_ATTR_OPERATIONGRAPH_HANDLE,
209 &m_operationGraph.handle);
210 if (
status != CUDNN_STATUS_SUCCESS) {
214 "CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: SetAttribute CUDNN_ATTR_OPERATIONGRAPH_HANDLE Failed");
215 return std::move(m_operationGraph);
219 status = cudnnBackendFinalize(m_operationGraph.pointer->get_backend_descriptor());
220 if (
status != CUDNN_STATUS_SUCCESS) {
222 &m_operationGraph,
status,
"CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR: cudnnFinalize Failed");
223 return std::move(m_operationGraph);
226 getLogger() <<
"[cudnn_frontend] " << m_operationGraph << std::endl;
227 return std::move(m_operationGraph);
std::vector< int64_t > feature_vector_t
Detailed feature_vector. Generally the Tensor and Operation properties.
uint64_t getOpCount() const
OperationGraph_v8 & operator=(OperationGraph_v8 &&from)=default
ConditionalStreamer & getLogger()
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setHandle(cudnnHandle_t handle_) -> OperationGraphBuilder_v8 &
Set cudnnHandle for the operations.
auto setOperationGraph(int64_t numOps_, Operation_v8 const **ops_) -> OperationGraphBuilder_v8 &
Set numoperations and the operations.
auto getEngineCount(void) const -> int64_t
Query the total count of the engines for the Operation Set.
friend class OperationGraphBuilder_v8
OperationGraph_v8 m_operationGraph
std::string const & getTag() const
feature_vector_t getFeatureVector() const
std::string describe() const override
Return a string describing the backend Descriptor.
OperationGraph_v8()=default
~OperationGraph_v8()=default
OperationGraph_v8 && build()
std::array< ManagedOpaqueDescriptor, 10 > ops
std::vector< feature_vector_t > feature_vectors
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
ManagedOpaqueDescriptor pointer