@@ -186,9 +186,6 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
186186#define CHECK_ARG (x ) ok = ok && (x)
187187#define CHECK_MASK (mask_name, mask_field ) \
188188 CHECK_ARG (IMPLICATION ((bool )(~mask & (mask_name)), (mask_field).defined ()))
189- CHECK_MASK (smask_t ::scales, scales_);
190- CHECK_MASK (smask_t ::zero_points, zero_points_);
191- CHECK_MASK (smask_t ::post_ops, post_ops_);
192189 CHECK_MASK (smask_t ::rnn_data_qparams, rnn_data_qparams_);
193190 CHECK_MASK (smask_t ::rnn_weights_qparams, rnn_weights_qparams_);
194191 CHECK_MASK (smask_t ::rnn_weights_projection_qparams,
@@ -200,6 +197,8 @@ bool primitive_attr_t::defined(dnnl_primitive_attr::skip_mask_t mask) const {
200197
201198status_t post_ops_t::append_sum (
202199 float scale, int32_t zero_point, data_type_t dt) {
200+ if (is_runtime_value (scale)) return invalid_arguments;
201+
203202 entry_.emplace_back ();
204203 auto &e = entry_.back ();
205204 e.kind = primitive_kind::sum;
@@ -213,6 +212,9 @@ status_t post_ops_t::append_eltwise(
213212 float scale, alg_kind_t alg, float alpha, float beta) {
214213 if (!math::is_eltwise_ok (data_type::f32 , alg, alpha, beta))
215214 return invalid_arguments;
215+ if (is_runtime_value (scale)) return invalid_arguments;
216+ if (is_runtime_value (alpha)) return invalid_arguments;
217+ if (is_runtime_value (beta)) return invalid_arguments;
216218
217219 entry_.emplace_back ();
218220 auto &e = entry_.back ();
@@ -310,27 +312,6 @@ status_t post_ops_t::append_prelu(int mask) {
310312 return success;
311313}
312314
313- bool post_ops_t::defined () const {
314- for (int idx = 0 ; idx < len (); ++idx) {
315- auto kind = entry_[idx].kind ;
316- if (kind == primitive_kind::sum) {
317- if (is_runtime_value (entry_[idx].sum .scale )) return false ;
318- } else if (kind == primitive_kind::eltwise) {
319- const auto &e = entry_[idx].eltwise ;
320- if (is_runtime_value (e.scale ) || is_runtime_value (e.alpha )
321- || is_runtime_value (e.beta ))
322- return false ;
323- } else if (utils::one_of (kind, primitive_kind::binary,
324- primitive_kind::prelu,
325- primitive_kind::convolution)) {
326- // binary is always defined
327- } else {
328- assert (!" unreachable" );
329- }
330- }
331- return true ;
332- }
333-
334315status_t post_ops_t::set_default_formats (const memory_desc_t *dst_md) {
335316 for (int idx = 0 ; idx < len (); ++idx) {
336317 if (!contain (primitive_kind::binary, idx)) continue ;
0 commit comments