Skip to content

Commit 52b345d

Browse files
authored
[RISCV][TII] Add and use new hook to simplify/canonicalize instructions after MachineCopyPropagation (#137973)
PR #136875 was posted as a draft PR that handled a subset of these cases, using the CompressPat mechanism. The consensus from that discussion (and a conclusion I agree with) is that it would be beneficial doing this optimisation earlier on, and in a way that isn't limited just to cases that can be handled by instruction compression. The most common source for instructions that can be optimized/canonicalized in this way is through tail duplication in MachineBlockPlacement followed by machine copy propagation. For RISC-V, choosing a more canonical instruction allows it to be compressed when it couldn't be before. There is the potential that it would make other MI-level optimisations easier. This modifies ~910 instructions across an llvm-test-suite build including SPEC2017, targeting rva22u64. Looking at the diff, it seems there's room for eliminating instructions or further propagating after this. Coverage of instructions is based on observations from a script written to find redundant or improperly canonicalized instructions (though I aim to support all instructions in a 'group' at once, e.g. MUL* even if I only saw some variants of MUL in practice).
1 parent 245def9 commit 52b345d

File tree

5 files changed

+310
-69
lines changed

5 files changed

+310
-69
lines changed

‎llvm/include/llvm/CodeGen/TargetInstrInfo.h

+10
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,16 @@ class TargetInstrInfo : public MCInstrInfo {
510510
return false;
511511
}
512512

513+
/// If possible, converts the instruction to a simplified/canonical form.
514+
/// Returns true if the instruction was modified.
515+
///
516+
/// This function is only called after register allocation. The MI will be
517+
/// modified in place. This is called by passes such as
518+
/// MachineCopyPropagation, where their mutation of the MI operands may
519+
/// expose opportunities to convert the instruction to a simpler form (e.g.
520+
/// a load of 0).
521+
virtual bool simplifyInstruction(MachineInstr &MI) const { return false; }
522+
513523
/// A pair composed of a register and a sub-register index.
514524
/// Used to give some type checking when modeling Reg:SubReg.
515525
struct RegSubRegPair {

‎llvm/lib/CodeGen/MachineCopyPropagation.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,12 @@ void MachineCopyPropagation::forwardUses(MachineInstr &MI) {
870870
++NumCopyForwards;
871871
Changed = true;
872872
}
873+
// Attempt to canonicalize/optimize the instruction now its arguments have
874+
// been mutated.
875+
if (TII->simplifyInstruction(MI)) {
876+
Changed = true;
877+
LLVM_DEBUG(dbgs() << "MCP: After optimizeInstruction: " << MI);
878+
}
873879
}
874880

875881
void MachineCopyPropagation::ForwardCopyPropagateBlock(MachineBasicBlock &MBB) {

‎llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

+225
Original file line numberDiff line numberDiff line change
@@ -2347,6 +2347,21 @@ static unsigned getSHXADDShiftAmount(unsigned Opc) {
23472347
}
23482348
}
23492349

2350+
// Returns the shift amount from a SHXADD.UW instruction. Returns 0 if the
2351+
// instruction is not a SHXADD.UW.
2352+
static unsigned getSHXADDUWShiftAmount(unsigned Opc) {
2353+
switch (Opc) {
2354+
default:
2355+
return 0;
2356+
case RISCV::SH1ADD_UW:
2357+
return 1;
2358+
case RISCV::SH2ADD_UW:
2359+
return 2;
2360+
case RISCV::SH3ADD_UW:
2361+
return 3;
2362+
}
2363+
}
2364+
23502365
// Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
23512366
// (sh3add (sh2add Y, Z), X).
23522367
static bool getSHXADDPatterns(const MachineInstr &Root,
@@ -3876,6 +3891,216 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI,
38763891
#undef CASE_VFMA_OPCODE_VV
38773892
#undef CASE_VFMA_SPLATS
38783893

3894+
bool RISCVInstrInfo::simplifyInstruction(MachineInstr &MI) const {
3895+
switch (MI.getOpcode()) {
3896+
default:
3897+
break;
3898+
case RISCV::ADD:
3899+
case RISCV::OR:
3900+
case RISCV::XOR:
3901+
// Normalize (so we hit the next if clause).
3902+
// add/[x]or rd, zero, rs => add/[x]or rd, rs, zero
3903+
if (MI.getOperand(1).getReg() == RISCV::X0)
3904+
commuteInstruction(MI);
3905+
// add/[x]or rd, rs, zero => addi rd, rs, 0
3906+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3907+
MI.getOperand(2).ChangeToImmediate(0);
3908+
MI.setDesc(get(RISCV::ADDI));
3909+
return true;
3910+
}
3911+
// xor rd, rs, rs => addi rd, zero, 0
3912+
if (MI.getOpcode() == RISCV::XOR &&
3913+
MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
3914+
MI.getOperand(1).setReg(RISCV::X0);
3915+
MI.getOperand(2).ChangeToImmediate(0);
3916+
MI.setDesc(get(RISCV::ADDI));
3917+
return true;
3918+
}
3919+
break;
3920+
case RISCV::ORI:
3921+
case RISCV::XORI:
3922+
// [x]ori rd, zero, N => addi rd, zero, N
3923+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3924+
MI.setDesc(get(RISCV::ADDI));
3925+
return true;
3926+
}
3927+
break;
3928+
case RISCV::SUB:
3929+
// sub rd, rs, zero => addi rd, rs, 0
3930+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3931+
MI.getOperand(2).ChangeToImmediate(0);
3932+
MI.setDesc(get(RISCV::ADDI));
3933+
return true;
3934+
}
3935+
break;
3936+
case RISCV::SUBW:
3937+
// subw rd, rs, zero => addiw rd, rs, 0
3938+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3939+
MI.getOperand(2).ChangeToImmediate(0);
3940+
MI.setDesc(get(RISCV::ADDIW));
3941+
return true;
3942+
}
3943+
break;
3944+
case RISCV::ADDW:
3945+
// Normalize (so we hit the next if clause).
3946+
// addw rd, zero, rs => addw rd, rs, zero
3947+
if (MI.getOperand(1).getReg() == RISCV::X0)
3948+
commuteInstruction(MI);
3949+
// addw rd, rs, zero => addiw rd, rs, 0
3950+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3951+
MI.getOperand(2).ChangeToImmediate(0);
3952+
MI.setDesc(get(RISCV::ADDIW));
3953+
return true;
3954+
}
3955+
break;
3956+
case RISCV::SH1ADD:
3957+
case RISCV::SH1ADD_UW:
3958+
case RISCV::SH2ADD:
3959+
case RISCV::SH2ADD_UW:
3960+
case RISCV::SH3ADD:
3961+
case RISCV::SH3ADD_UW:
3962+
// shNadd[.uw] rd, zero, rs => addi rd, rs, 0
3963+
if (MI.getOperand(1).getReg() == RISCV::X0) {
3964+
MI.removeOperand(1);
3965+
MI.addOperand(MachineOperand::CreateImm(0));
3966+
MI.setDesc(get(RISCV::ADDI));
3967+
return true;
3968+
}
3969+
// shNadd[.uw] rd, rs, zero => slli[.uw] rd, rs, N
3970+
if (MI.getOperand(2).getReg() == RISCV::X0) {
3971+
MI.removeOperand(2);
3972+
unsigned Opc = MI.getOpcode();
3973+
if (Opc == RISCV::SH1ADD_UW || Opc == RISCV::SH2ADD_UW ||
3974+
Opc == RISCV::SH3ADD_UW) {
3975+
MI.addOperand(MachineOperand::CreateImm(getSHXADDUWShiftAmount(Opc)));
3976+
MI.setDesc(get(RISCV::SLLI_UW));
3977+
return true;
3978+
}
3979+
MI.addOperand(MachineOperand::CreateImm(getSHXADDShiftAmount(Opc)));
3980+
MI.setDesc(get(RISCV::SLLI));
3981+
return true;
3982+
}
3983+
break;
3984+
case RISCV::AND:
3985+
case RISCV::MUL:
3986+
case RISCV::MULH:
3987+
case RISCV::MULHSU:
3988+
case RISCV::MULHU:
3989+
case RISCV::MULW:
3990+
// and rd, zero, rs => addi rd, zero, 0
3991+
// mul* rd, zero, rs => addi rd, zero, 0
3992+
// and rd, rs, zero => addi rd, zero, 0
3993+
// mul* rd, rs, zero => addi rd, zero, 0
3994+
if (MI.getOperand(1).getReg() == RISCV::X0 ||
3995+
MI.getOperand(2).getReg() == RISCV::X0) {
3996+
MI.getOperand(1).setReg(RISCV::X0);
3997+
MI.getOperand(2).ChangeToImmediate(0);
3998+
MI.setDesc(get(RISCV::ADDI));
3999+
return true;
4000+
}
4001+
break;
4002+
case RISCV::ANDI:
4003+
// andi rd, zero, C => addi rd, zero, 0
4004+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4005+
MI.getOperand(2).setImm(0);
4006+
MI.setDesc(get(RISCV::ADDI));
4007+
return true;
4008+
}
4009+
break;
4010+
case RISCV::SLL:
4011+
case RISCV::SRL:
4012+
case RISCV::SRA:
4013+
// shift rd, zero, rs => addi rd, zero, 0
4014+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4015+
MI.getOperand(2).ChangeToImmediate(0);
4016+
MI.setDesc(get(RISCV::ADDI));
4017+
return true;
4018+
}
4019+
// shift rd, rs, zero => addi rd, rs, 0
4020+
if (MI.getOperand(2).getReg() == RISCV::X0) {
4021+
MI.getOperand(2).ChangeToImmediate(0);
4022+
MI.setDesc(get(RISCV::ADDI));
4023+
return true;
4024+
}
4025+
break;
4026+
case RISCV::SLLW:
4027+
case RISCV::SRLW:
4028+
case RISCV::SRAW:
4029+
// shiftw rd, zero, rs => addi rd, zero, 0
4030+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4031+
MI.getOperand(2).ChangeToImmediate(0);
4032+
MI.setDesc(get(RISCV::ADDI));
4033+
return true;
4034+
}
4035+
break;
4036+
case RISCV::SLLI:
4037+
case RISCV::SRLI:
4038+
case RISCV::SRAI:
4039+
case RISCV::SLLIW:
4040+
case RISCV::SRLIW:
4041+
case RISCV::SRAIW:
4042+
case RISCV::SLLI_UW:
4043+
// shiftimm rd, zero, N => addi rd, zero, 0
4044+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4045+
MI.getOperand(2).setImm(0);
4046+
MI.setDesc(get(RISCV::ADDI));
4047+
return true;
4048+
}
4049+
break;
4050+
case RISCV::SLTU:
4051+
case RISCV::ADD_UW:
4052+
// sltu rd, zero, zero => addi rd, zero, 0
4053+
// add.uw rd, zero, zero => addi rd, zero, 0
4054+
if (MI.getOperand(1).getReg() == RISCV::X0 &&
4055+
MI.getOperand(2).getReg() == RISCV::X0) {
4056+
MI.getOperand(2).ChangeToImmediate(0);
4057+
MI.setDesc(get(RISCV::ADDI));
4058+
return true;
4059+
}
4060+
// add.uw rd, zero, rs => addi rd, rs, 0
4061+
if (MI.getOpcode() == RISCV::ADD_UW &&
4062+
MI.getOperand(1).getReg() == RISCV::X0) {
4063+
MI.removeOperand(1);
4064+
MI.addOperand(MachineOperand::CreateImm(0));
4065+
MI.setDesc(get(RISCV::ADDI));
4066+
}
4067+
break;
4068+
case RISCV::SLTIU:
4069+
// sltiu rd, zero, NZC => addi rd, zero, 1
4070+
// sltiu rd, zero, 0 => addi rd, zero, 0
4071+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4072+
MI.getOperand(2).setImm(MI.getOperand(2).getImm() != 0);
4073+
MI.setDesc(get(RISCV::ADDI));
4074+
return true;
4075+
}
4076+
break;
4077+
case RISCV::SEXT_H:
4078+
case RISCV::SEXT_B:
4079+
case RISCV::ZEXT_H_RV32:
4080+
case RISCV::ZEXT_H_RV64:
4081+
// sext.[hb] rd, zero => addi rd, zero, 0
4082+
// zext.h rd, zero => addi rd, zero, 0
4083+
if (MI.getOperand(1).getReg() == RISCV::X0) {
4084+
MI.addOperand(MachineOperand::CreateImm(0));
4085+
MI.setDesc(get(RISCV::ADDI));
4086+
return true;
4087+
}
4088+
break;
4089+
case RISCV::MIN:
4090+
case RISCV::MINU:
4091+
case RISCV::MAX:
4092+
case RISCV::MAXU:
4093+
// min|max rd, rs, rs => addi rd, rs, 0
4094+
if (MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
4095+
MI.getOperand(2).ChangeToImmediate(0);
4096+
MI.setDesc(get(RISCV::ADDI));
4097+
return true;
4098+
}
4099+
break;
4100+
}
4101+
return false;
4102+
}
4103+
38794104
// clang-format off
38804105
#define CASE_WIDEOP_OPCODE_COMMON(OP, LMUL) \
38814106
RISCV::PseudoV##OP##_##LMUL##_TIED

‎llvm/lib/Target/RISCV/RISCVInstrInfo.h

+2
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
242242
unsigned OpIdx1,
243243
unsigned OpIdx2) const override;
244244

245+
bool simplifyInstruction(MachineInstr &MI) const override;
246+
245247
MachineInstr *convertToThreeAddress(MachineInstr &MI, LiveVariables *LV,
246248
LiveIntervals *LIS) const override;
247249

0 commit comments

Comments
 (0)