From a49c7e9dcb451c16c4372a01f8ce1b28f0bf329f Mon Sep 17 00:00:00 2001
From: squidbus <175574877+squidbus@users.noreply.github.com>
Date: Thu, 12 Sep 2024 12:59:52 -0700
Subject: [PATCH] shader_recompiler: Add buffer offset calculation when swizzle
 is enabled. (#834)

---
 .../backend/spirv/emit_spirv_instructions.h   |  2 ++
 .../backend/spirv/emit_spirv_integer.cpp      |  8 ++++++
 src/shader_recompiler/ir/ir_emitter.cpp       |  4 +++
 src/shader_recompiler/ir/ir_emitter.h         |  1 +
 src/shader_recompiler/ir/opcodes.inc          |  2 ++
 .../ir/passes/resource_tracking_pass.cpp      | 27 ++++++++++++++++---
 6 files changed, 41 insertions(+), 3 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
index e506ced3..8b76938b 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h
@@ -269,6 +269,8 @@ Id EmitIMul32(EmitContext& ctx, Id a, Id b);
 Id EmitIMul64(EmitContext& ctx, Id a, Id b);
 Id EmitSDiv32(EmitContext& ctx, Id a, Id b);
 Id EmitUDiv32(EmitContext& ctx, Id a, Id b);
+Id EmitSMod32(EmitContext& ctx, Id a, Id b);
+Id EmitUMod32(EmitContext& ctx, Id a, Id b);
 Id EmitINeg32(EmitContext& ctx, Id value);
 Id EmitINeg64(EmitContext& ctx, Id value);
 Id EmitIAbs32(EmitContext& ctx, Id value);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
index a9becb1e..02af9238 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
@@ -96,6 +96,14 @@ Id EmitUDiv32(EmitContext& ctx, Id a, Id b) {
     return ctx.OpUDiv(ctx.U32[1], a, b);
 }
 
+Id EmitSMod32(EmitContext& ctx, Id a, Id b) {
+    return ctx.OpSMod(ctx.U32[1], a, b);
+}
+
+Id EmitUMod32(EmitContext& ctx, Id a, Id b) {
+    return ctx.OpUMod(ctx.U32[1], a, b);
+}
+
 Id EmitINeg32(EmitContext& ctx, Id value) {
     return ctx.OpSNegate(ctx.U32[1], value);
 }
diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp
index 7e52cfb5..7e0264bc 100644
--- a/src/shader_recompiler/ir/ir_emitter.cpp
+++ b/src/shader_recompiler/ir/ir_emitter.cpp
@@ -1055,6 +1055,10 @@ U32 IREmitter::IDiv(const U32& a, const U32& b, bool is_signed) {
     return Inst<U32>(is_signed ? Opcode::SDiv32 : Opcode::UDiv32, a, b);
 }
 
+U32 IREmitter::IMod(const U32& a, const U32& b, bool is_signed) {
+    return Inst<U32>(is_signed ? Opcode::SMod32 : Opcode::UMod32, a, b);
+}
+
 U32U64 IREmitter::INeg(const U32U64& value) {
     switch (value.Type()) {
     case Type::U32:
diff --git a/src/shader_recompiler/ir/ir_emitter.h b/src/shader_recompiler/ir/ir_emitter.h
index 01e71893..46f6157a 100644
--- a/src/shader_recompiler/ir/ir_emitter.h
+++ b/src/shader_recompiler/ir/ir_emitter.h
@@ -194,6 +194,7 @@ public:
     [[nodiscard]] Value IMulExt(const U32& a, const U32& b, bool is_signed = false);
     [[nodiscard]] U32U64 IMul(const U32U64& a, const U32U64& b);
     [[nodiscard]] U32 IDiv(const U32& a, const U32& b, bool is_signed = false);
+    [[nodiscard]] U32 IMod(const U32& a, const U32& b, bool is_signed = false);
     [[nodiscard]] U32U64 INeg(const U32U64& value);
     [[nodiscard]] U32 IAbs(const U32& value);
     [[nodiscard]] U32U64 ShiftLeftLogical(const U32U64& base, const U32& shift);
diff --git a/src/shader_recompiler/ir/opcodes.inc b/src/shader_recompiler/ir/opcodes.inc
index 4b922d55..263096c6 100644
--- a/src/shader_recompiler/ir/opcodes.inc
+++ b/src/shader_recompiler/ir/opcodes.inc
@@ -243,6 +243,8 @@ OPCODE(SMulExt,                                             U32x2,          U32,
 OPCODE(UMulExt,                                             U32x2,          U32,            U32,                                                            )
 OPCODE(SDiv32,                                              U32,            U32,            U32,                                                            )
 OPCODE(UDiv32,                                              U32,            U32,            U32,                                                            )
+OPCODE(SMod32,                                              U32,            U32,            U32,                                                            )
+OPCODE(UMod32,                                              U32,            U32,            U32,                                                            )
 OPCODE(INeg32,                                              U32,            U32,                                                                            )
 OPCODE(INeg64,                                              U64,            U64,                                                                            )
 OPCODE(IAbs32,                                              U32,            U32,                                                                            )
diff --git a/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp b/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp
index 6b2aa8bb..f1fc14d0 100644
--- a/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp
+++ b/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp
@@ -378,24 +378,45 @@ void PatchBufferInstruction(IR::Block& block, IR::Inst& inst, Info& info,
     // Replace handle with binding index in buffer resource list.
     IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
     inst.SetArg(0, ir.Imm32(binding));
-    ASSERT(!buffer.swizzle_enable && !buffer.add_tid_enable);
+    ASSERT(!buffer.add_tid_enable);
 
     // Address of constant buffer reads can be calculated at IR emittion time.
     if (inst.GetOpcode() == IR::Opcode::ReadConstBuffer) {
         return;
     }
 
+    const IR::U32 index_stride = ir.Imm32(buffer.index_stride);
+    const IR::U32 element_size = ir.Imm32(buffer.element_size);
+
     // Compute address of the buffer using the stride.
     IR::U32 address = ir.Imm32(inst_info.inst_offset.Value());
     if (inst_info.index_enable) {
         const IR::U32 index = inst_info.offset_enable ? IR::U32{ir.CompositeExtract(inst.Arg(1), 0)}
                                                       : IR::U32{inst.Arg(1)};
-        address = ir.IAdd(address, ir.IMul(index, ir.Imm32(buffer.GetStride())));
+        if (buffer.swizzle_enable) {
+            const IR::U32 stride_index_stride =
+                ir.Imm32(static_cast<u32>(buffer.stride * buffer.index_stride));
+            const IR::U32 index_msb = ir.IDiv(index, index_stride);
+            const IR::U32 index_lsb = ir.IMod(index, index_stride);
+            address = ir.IAdd(address, ir.IAdd(ir.IMul(index_msb, stride_index_stride),
+                                               ir.IMul(index_lsb, element_size)));
+        } else {
+            address = ir.IAdd(address, ir.IMul(index, ir.Imm32(buffer.GetStride())));
+        }
     }
     if (inst_info.offset_enable) {
         const IR::U32 offset = inst_info.index_enable ? IR::U32{ir.CompositeExtract(inst.Arg(1), 1)}
                                                       : IR::U32{inst.Arg(1)};
-        address = ir.IAdd(address, offset);
+        if (buffer.swizzle_enable) {
+            const IR::U32 element_size_index_stride =
+                ir.Imm32(buffer.element_size * buffer.index_stride);
+            const IR::U32 offset_msb = ir.IDiv(offset, element_size);
+            const IR::U32 offset_lsb = ir.IMod(offset, element_size);
+            address = ir.IAdd(address,
+                              ir.IAdd(ir.IMul(offset_msb, element_size_index_stride), offset_lsb));
+        } else {
+            address = ir.IAdd(address, offset);
+        }
     }
     inst.SetArg(1, address);
 }