shader_recompiler: Maintain loss of precision when folding half-float unpack.

This commit is contained in:
squidbus 2025-01-21 06:23:13 -08:00
parent adbff4056f
commit 283d2dfb13
6 changed files with 25 additions and 1 deletions

View file

@ -58,4 +58,8 @@ Id EmitUnpackHalf2x16(EmitContext& ctx, Id value) {
return ctx.OpUnpackHalf2x16(ctx.F32[2], value);
}
Id EmitQuantizeHalf2x16(EmitContext& ctx, Id value) {
return ctx.OpQuantizeToF16(ctx.F32[2], value);
}
} // namespace Shader::Backend::SPIRV

View file

@ -197,6 +197,7 @@ Id EmitPackFloat2x16(EmitContext& ctx, Id value);
Id EmitUnpackFloat2x16(EmitContext& ctx, Id value);
Id EmitPackHalf2x16(EmitContext& ctx, Id value);
Id EmitUnpackHalf2x16(EmitContext& ctx, Id value);
Id EmitQuantizeHalf2x16(EmitContext& ctx, Id value);
Id EmitFPAbs16(EmitContext& ctx, Id value);
Id EmitFPAbs32(EmitContext& ctx, Id value);
Id EmitFPAbs64(EmitContext& ctx, Id value);

View file

@ -795,6 +795,10 @@ Value IREmitter::UnpackHalf2x16(const U32& value) {
return Inst(Opcode::UnpackHalf2x16, value);
}
Value IREmitter::QuantizeHalf2x16(const Value& value) {
return Inst(Opcode::QuantizeHalf2x16, value);
}
F32F64 IREmitter::FPMul(const F32F64& a, const F32F64& b) {
if (a.Type() != b.Type()) {
UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type());

View file

@ -175,6 +175,7 @@ public:
[[nodiscard]] U32 PackHalf2x16(const Value& vector);
[[nodiscard]] Value UnpackHalf2x16(const U32& value);
[[nodiscard]] Value QuantizeHalf2x16(const Value& value);
[[nodiscard]] F32F64 FPAdd(const F32F64& a, const F32F64& b);
[[nodiscard]] F32F64 FPSub(const F32F64& a, const F32F64& b);

View file

@ -187,6 +187,7 @@ OPCODE(PackFloat2x16, U32, F16x
OPCODE(UnpackFloat2x16, F16x2, U32, )
OPCODE(PackHalf2x16, U32, F32x2, )
OPCODE(UnpackHalf2x16, F32x2, U32, )
OPCODE(QuantizeHalf2x16, F32x2, F32x2, )
// Floating-point operations
OPCODE(FPAbs32, F32, F32, )

View file

@ -204,6 +204,19 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
}
}
void FoldUnpackHalf2x16(IR::Block& block, IR::Inst& inst) {
const IR::Value value{inst.Arg(0)};
if (value.IsImmediate()) {
return;
}
IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == IR::Opcode::PackHalf2x16) {
// When reversing pack half instruction, keep the loss of precision using quantization.
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
inst.ReplaceUsesWithAndRemove(ir.QuantizeHalf2x16(arg_inst->Arg(0)));
}
}
template <typename T>
void FoldAdd(IR::Block& block, IR::Inst& inst) {
if (!FoldCommutative<T>(inst, [](T a, T b) { return a + b; })) {
@ -343,7 +356,7 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
case IR::Opcode::PackHalf2x16:
return FoldInverseFunc(inst, IR::Opcode::UnpackHalf2x16);
case IR::Opcode::UnpackHalf2x16:
return FoldInverseFunc(inst, IR::Opcode::PackHalf2x16);
return FoldUnpackHalf2x16(block, inst);
case IR::Opcode::PackFloat2x16:
return FoldInverseFunc(inst, IR::Opcode::UnpackFloat2x16);
case IR::Opcode::UnpackFloat2x16: