1616#include < migraphx/instruction.hpp>
1717#include < migraphx/config.hpp>
1818#include < migraphx/onnx.hpp>
19+ #include < migraphx/pad_calc.hpp>
1920
2021namespace migraphx {
2122inline namespace MIGRAPHX_INLINE_NS {
@@ -302,6 +303,24 @@ struct onnx_parser
302303 return curr_ins;
303304 }
304305
306+ template <class Op >
307+ void check_asym_padding (instruction_ref& ins,
308+ std::vector<int64_t >& padding,
309+ Op& op,
310+ float pad_val = 0 )
311+ {
312+ if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
313+ {
314+ padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
315+ ins = prog.add_instruction (op::pad{padding, pad_val}, ins);
316+ }
317+ else
318+ {
319+ op.padding [0 ] = padding[0 ];
320+ op.padding [1 ] = padding[1 ];
321+ }
322+ }
323+
305324 instruction_ref parse_clip (const std::string&,
306325 const attribute_map& attributes,
307326 std::vector<instruction_ref> args)
@@ -424,7 +443,8 @@ struct onnx_parser
424443 parse_conv (const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
425444 {
426445 Op op;
427- auto l0 = args[0 ];
446+ auto l0 = args[0 ];
447+ auto weights = args[1 ];
428448 if (contains (attributes, " pads" ))
429449 {
430450 if (contains (attributes, " auto_pad" ))
@@ -441,17 +461,7 @@ struct onnx_parser
441461 {
442462 MIGRAPHX_THROW (" padding should have 4 values" );
443463 }
444- if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
445- {
446- // insert zeros for pad op (args[0] has 4 dims)
447- padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
448- l0 = prog.add_instruction (op::pad{padding}, l0);
449- }
450- else
451- {
452- op.padding [0 ] = padding[0 ];
453- op.padding [1 ] = padding[1 ];
454- }
464+ check_asym_padding (l0, padding, op);
455465 }
456466 if (contains (attributes, " strides" ))
457467 {
@@ -471,7 +481,19 @@ struct onnx_parser
471481
472482 if (s.find (" SAME" ) != std::string::npos)
473483 {
474- op.padding_mode = op::padding_mode_t ::same;
484+ op.padding_mode = op::padding_mode_t ::same;
485+ std::vector<size_t > weight_dims = weights->get_shape ().lens ();
486+ size_t weight_h = weight_dims[2 ];
487+ size_t weight_w = weight_dims[3 ];
488+
489+ auto input_dims = l0->get_shape ().lens ();
490+ std::vector<int64_t > padding (input_dims.size ());
491+ calculate_padding (
492+ 0 , padding, input_dims[2 ], op.stride [0 ], op.dilation [0 ], weight_h);
493+ calculate_padding (
494+ 1 , padding, input_dims[3 ], op.stride [1 ], op.dilation [1 ], weight_w);
495+
496+ check_asym_padding (l0, padding, op);
475497 }
476498 }
477499 if (contains (attributes, " group" ))
@@ -618,27 +640,10 @@ struct onnx_parser
618640 {
619641 MIGRAPHX_THROW (" PARSE_POOLING: padding should have 4 values" );
620642 }
621- if (padding[0 ] != padding[2 ] || padding[1 ] != padding[3 ])
622- {
623- // insert zeros for pad op (args[0] has 4 dims)
624- padding = {0 , 0 , padding[0 ], padding[1 ], 0 , 0 , padding[2 ], padding[3 ]};
625- // MaxPool
626- if (op.mode == " max" )
627- {
628- l0 = prog.add_instruction (
629- op::pad{padding, std::numeric_limits<float >::lowest ()}, l0);
630- }
631- // AveragePool
632- else
633- {
634- l0 = prog.add_instruction (op::pad{padding}, l0);
635- }
636- }
637- else
638- {
639- op.padding [0 ] = padding[0 ];
640- op.padding [1 ] = padding[1 ];
641- }
643+ float pad_val = 0 ;
644+ if (op.mode == " max" )
645+ pad_val = std::numeric_limits<float >::lowest ();
646+ check_asym_padding (l0, padding, op, pad_val);
642647 }
643648
644649 if (contains (attributes, " strides" ))
0 commit comments