Merge pull request #6722 from ReinUsesLisp/xmad-opts
shader: Fold integer FMA from Nvidia's pattern
This commit is contained in:
commit
a98f14e9b0
|
@ -57,6 +57,7 @@ public:
|
||||||
|
|
||||||
[[nodiscard]] IR::Inst* Inst() const;
|
[[nodiscard]] IR::Inst* Inst() const;
|
||||||
[[nodiscard]] IR::Inst* InstRecursive() const;
|
[[nodiscard]] IR::Inst* InstRecursive() const;
|
||||||
|
[[nodiscard]] IR::Inst* TryInstRecursive() const;
|
||||||
[[nodiscard]] IR::Value Resolve() const;
|
[[nodiscard]] IR::Value Resolve() const;
|
||||||
[[nodiscard]] IR::Reg Reg() const;
|
[[nodiscard]] IR::Reg Reg() const;
|
||||||
[[nodiscard]] IR::Pred Pred() const;
|
[[nodiscard]] IR::Pred Pred() const;
|
||||||
|
@ -308,6 +309,13 @@ inline IR::Inst* Value::InstRecursive() const {
|
||||||
return inst;
|
return inst;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline IR::Inst* Value::TryInstRecursive() const {
|
||||||
|
if (IsIdentity()) {
|
||||||
|
return inst->Arg(0).TryInstRecursive();
|
||||||
|
}
|
||||||
|
return type == Type::Opaque ? inst : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
inline IR::Value Value::Resolve() const {
|
inline IR::Value Value::Resolve() const {
|
||||||
if (IsIdentity()) {
|
if (IsIdentity()) {
|
||||||
return inst->Arg(0).Resolve();
|
return inst->Arg(0).Resolve();
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
// Refer to the license.txt file included.
|
// Refer to the license.txt file included.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return true when all values in a range are equal
|
||||||
|
template <typename Range>
|
||||||
|
bool AreEqual(const Range& range) {
|
||||||
|
auto resolver{[](const auto& value) { return value.Resolve(); }};
|
||||||
|
auto equal{[](const IR::Value& lhs, const IR::Value& rhs) {
|
||||||
|
if (lhs == rhs) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// Not equal, but try to match if they read the same constant buffer
|
||||||
|
if (!lhs.IsImmediate() && !rhs.IsImmediate() &&
|
||||||
|
lhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 &&
|
||||||
|
rhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 &&
|
||||||
|
lhs.Inst()->Arg(0) == rhs.Inst()->Arg(0) && lhs.Inst()->Arg(1) == rhs.Inst()->Arg(1)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}};
|
||||||
|
return std::ranges::adjacent_find(range, std::not_fn(equal), resolver) == std::end(range);
|
||||||
|
}
|
||||||
|
|
||||||
void FoldGetRegister(IR::Inst& inst) {
|
void FoldGetRegister(IR::Inst& inst) {
|
||||||
if (inst.Arg(0).Reg() == IR::Reg::RZ) {
|
if (inst.Arg(0).Reg() == IR::Reg::RZ) {
|
||||||
inst.ReplaceUsesWith(IR::Value{u32{0}});
|
inst.ReplaceUsesWith(IR::Value{u32{0}});
|
||||||
|
@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Replaces the XMAD pattern generated by an integer FMA
|
||||||
|
bool FoldXmadMultiplyAdd(IR::Block& block, IR::Inst& inst) {
|
||||||
|
/*
|
||||||
|
* We are looking for this specific pattern:
|
||||||
|
* %6 = BitFieldUExtract %op_b, #0, #16
|
||||||
|
* %7 = BitFieldUExtract %op_a', #16, #16
|
||||||
|
* %8 = IMul32 %6, %7
|
||||||
|
* %10 = BitFieldUExtract %op_a', #0, #16
|
||||||
|
* %11 = BitFieldInsert %8, %10, #16, #16
|
||||||
|
* %15 = BitFieldUExtract %op_b, #0, #16
|
||||||
|
* %16 = BitFieldUExtract %op_a, #0, #16
|
||||||
|
* %17 = IMul32 %15, %16
|
||||||
|
* %18 = IAdd32 %17, %op_c
|
||||||
|
* %22 = BitFieldUExtract %op_b, #16, #16
|
||||||
|
* %23 = BitFieldUExtract %11, #16, #16
|
||||||
|
* %24 = IMul32 %22, %23
|
||||||
|
* %25 = ShiftLeftLogical32 %24, #16
|
||||||
|
* %26 = ShiftLeftLogical32 %11, #16
|
||||||
|
* %27 = IAdd32 %26, %18
|
||||||
|
* %result = IAdd32 %25, %27
|
||||||
|
*
|
||||||
|
* And replace it with:
|
||||||
|
* %temp = IMul32 %op_a, %op_b
|
||||||
|
* %result = IAdd32 %temp, %op_c
|
||||||
|
*
|
||||||
|
* This optimization has been proven safe by Nvidia's compiler logic being reversed.
|
||||||
|
* (If Nvidia generates this code from 'fma(a, b, c)', we can do the same in the reverse order.)
|
||||||
|
*/
|
||||||
|
const IR::Value zero{0u};
|
||||||
|
const IR::Value sixteen{16u};
|
||||||
|
IR::Inst* const _25{inst.Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _27{inst.Arg(1).TryInstRecursive()};
|
||||||
|
if (!_25 || !_27) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_27->GetOpcode() != IR::Opcode::IAdd32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_25->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _25->Arg(1) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _24{_25->Arg(0).TryInstRecursive()};
|
||||||
|
if (!_24 || _24->GetOpcode() != IR::Opcode::IMul32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _22{_24->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _23{_24->Arg(1).TryInstRecursive()};
|
||||||
|
if (!_22 || !_23) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_22->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_23->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_22->Arg(1) != sixteen || _22->Arg(2) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_23->Arg(1) != sixteen || _23->Arg(2) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _11{_23->Arg(0).TryInstRecursive()};
|
||||||
|
if (!_11 || _11->GetOpcode() != IR::Opcode::BitFieldInsert) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_11->Arg(2) != sixteen || _11->Arg(3) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _8{_11->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _10{_11->Arg(1).TryInstRecursive()};
|
||||||
|
if (!_8 || !_10) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_8->GetOpcode() != IR::Opcode::IMul32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_10->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _6{_8->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _7{_8->Arg(1).TryInstRecursive()};
|
||||||
|
if (!_6 || !_7) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_6->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_7->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_6->Arg(1) != zero || _6->Arg(2) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_7->Arg(1) != sixteen || _7->Arg(2) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _26{_27->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _18{_27->Arg(1).TryInstRecursive()};
|
||||||
|
if (!_26 || !_18) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_26->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _26->Arg(1) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_26->Arg(0).InstRecursive() != _11) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_18->GetOpcode() != IR::Opcode::IAdd32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _17{_18->Arg(0).TryInstRecursive()};
|
||||||
|
if (!_17 || _17->GetOpcode() != IR::Opcode::IMul32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const _15{_17->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const _16{_17->Arg(1).TryInstRecursive()};
|
||||||
|
if (!_15 || !_16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_15->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_16->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_15->Arg(1) != zero || _16->Arg(1) != zero || _10->Arg(1) != zero) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (_15->Arg(2) != sixteen || _16->Arg(2) != sixteen || _10->Arg(2) != sixteen) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const std::array<IR::Value, 3> op_as{
|
||||||
|
_7->Arg(0).Resolve(),
|
||||||
|
_16->Arg(0).Resolve(),
|
||||||
|
_10->Arg(0).Resolve(),
|
||||||
|
};
|
||||||
|
const std::array<IR::Value, 3> op_bs{
|
||||||
|
_22->Arg(0).Resolve(),
|
||||||
|
_6->Arg(0).Resolve(),
|
||||||
|
_15->Arg(0).Resolve(),
|
||||||
|
};
|
||||||
|
const IR::U32 op_c{_18->Arg(1)};
|
||||||
|
if (!AreEqual(op_as) || !AreEqual(op_bs)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
|
||||||
|
inst.ReplaceUsesWith(ir.IAdd(ir.IMul(IR::U32{op_as[0]}, IR::U32{op_bs[1]}), op_c));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Replaces the pattern generated by two XMAD multiplications
|
/// Replaces the pattern generated by two XMAD multiplications
|
||||||
bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
|
bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
|
||||||
/*
|
/*
|
||||||
|
@ -116,33 +288,31 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
|
||||||
*
|
*
|
||||||
* This optimization has been proven safe by LLVM and MSVC.
|
* This optimization has been proven safe by LLVM and MSVC.
|
||||||
*/
|
*/
|
||||||
const IR::Value lhs_arg{inst.Arg(0)};
|
IR::Inst* const lhs_shl{inst.Arg(0).TryInstRecursive()};
|
||||||
const IR::Value rhs_arg{inst.Arg(1)};
|
IR::Inst* const rhs_mul{inst.Arg(1).TryInstRecursive()};
|
||||||
if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
|
if (!lhs_shl || !rhs_mul) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
|
|
||||||
if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 ||
|
if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 ||
|
||||||
lhs_shl->Arg(1) != IR::Value{16U}) {
|
lhs_shl->Arg(1) != IR::Value{16U}) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (lhs_shl->Arg(0).IsImmediate()) {
|
IR::Inst* const lhs_mul{lhs_shl->Arg(0).TryInstRecursive()};
|
||||||
|
if (!lhs_mul) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
|
|
||||||
IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
|
|
||||||
if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) {
|
if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const IR::U32 factor_b{lhs_mul->Arg(1)};
|
const IR::U32 factor_b{lhs_mul->Arg(1)};
|
||||||
if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
|
if (factor_b.Resolve() != rhs_mul->Arg(1).Resolve()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
IR::Inst* const lhs_bfe{lhs_mul->Arg(0).TryInstRecursive()};
|
||||||
|
IR::Inst* const rhs_bfe{rhs_mul->Arg(0).TryInstRecursive()};
|
||||||
|
if (!lhs_bfe || !rhs_bfe) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
|
|
||||||
IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
|
|
||||||
if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -155,10 +325,10 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
|
||||||
if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
|
if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
|
const IR::U32 factor_a{lhs_bfe->Arg(0)};
|
||||||
|
if (factor_a.Resolve() != rhs_bfe->Arg(0).Resolve()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const IR::U32 factor_a{lhs_bfe->Arg(0)};
|
|
||||||
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
|
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
|
||||||
inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
|
inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
|
||||||
return true;
|
return true;
|
||||||
|
@ -181,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
|
||||||
if (FoldXmadMultiply(block, inst)) {
|
if (FoldXmadMultiply(block, inst)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (FoldXmadMultiplyAdd(block, inst)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue