42 void *dev_filter_ptr,
void *reordered_filter_ptr,
43 void *dev_bias_ptr,
void *reordered_bias_ptr) {
45 auto cudnn_status = CUDNN_STATUS_SUCCESS;
47 if (dev_filter_ptr && reordered_filter_ptr ==
nullptr) {
48 return CUDNN_STATUS_BAD_PARAM;
50 if (dev_bias_ptr && reordered_bias_ptr ==
nullptr) {
51 return CUDNN_STATUS_BAD_PARAM;
54 cudnnFilterDescriptor_t filterDesc =
nullptr;
56 cudnn_status = cudnnCreateFilterDescriptor(&filterDesc);
57 if (cudnn_status != CUDNN_STATUS_SUCCESS) {
return cudnn_status;}
61 auto non_shape_dims = tensor_dims - conv_dims;
63 if (non_shape_dims != 2 && non_shape_dims != 3) {
64 return CUDNN_STATUS_BAD_PARAM;
67 if (conv_dims != 2 && conv_dims != 3) {
68 return CUDNN_STATUS_BAD_PARAM;
71 int filter_dims_[5] = {1,1,1,1,1};
73 filter_dims_[0] =
static_cast<int> (filter_dims[0]);
74 filter_dims_[1] =
static_cast<int> ((non_shape_dims == 2) ? filter_dims[1] : filter_dims[2]) * 32;
75 filter_dims_[2] =
static_cast<int> ((non_shape_dims == 2) ? filter_dims[2] : filter_dims[3]);
76 filter_dims_[3] =
static_cast<int> ((non_shape_dims == 2) ? filter_dims[3] : filter_dims[4]);
78 filter_dims_[4] =
static_cast<int> ((non_shape_dims == 2) ? filter_dims[4] : filter_dims[5]);
81 cudnn_status = cudnnSetFilterNdDescriptor(filterDesc, CUDNN_DATA_INT8x32, CUDNN_TENSOR_NCHW_VECT_C, conv_dims + 2, filter_dims_);
83 if (cudnn_status != CUDNN_STATUS_SUCCESS) {
return cudnn_status;}
85 int reorderBias = (dev_bias_ptr !=
nullptr);
87 cudnn_status = cudnnReorderFilterAndBias(handle,
88 filterDesc, CUDNN_DEFAULT_REORDER, dev_filter_ptr, reordered_filter_ptr, reorderBias, dev_bias_ptr, reordered_bias_ptr);
90 cudnnDestroyFilterDescriptor(filterDesc);
int64_t getDimensionCount() const
int64_t const * getDimArray() const
int64_t getDimensionCount() const
static cudnnStatus_t cudnnReorderFilterAndBiasInt8x32(cudnnHandle_t handle, const Tensor_v8 &tensor, const ConvDesc_v8 &conv_desc, void *dev_filter_ptr, void *reordered_filter_ptr, void *dev_bias_ptr, void *reordered_bias_ptr)