From c13b29662e23220ff88811173da13463deb1392e Mon Sep 17 00:00:00 2001 From: baggins183 Date: Fri, 17 Jan 2025 00:14:54 -0800 Subject: [PATCH] handle control point strides that arent a multiple of 16 (#2172) --- .../backend/spirv/spirv_emit_context.cpp | 12 +++++------ .../ir/passes/hull_shader_transform.cpp | 21 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp index 7e86dfb4b..7d492a384 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp @@ -395,7 +395,7 @@ void EmitContext::DefineInputs() { DefineVariable(U32[1], spv::BuiltIn::PatchVertices, spv::StorageClass::Input); primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input); - const u32 num_attrs = runtime_info.hs_info.ls_stride >> 4; + const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4; if (num_attrs > 0) { const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; // The input vertex count isn't statically known, so make length 32 (what glslang does) @@ -409,7 +409,7 @@ void EmitContext::DefineInputs() { tess_coord = DefineInput(F32[3], std::nullopt, spv::BuiltIn::TessCoord); primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input); - const u32 num_attrs = runtime_info.vs_info.hs_output_cp_stride >> 4; + const u32 num_attrs = Common::AlignUp(runtime_info.vs_info.hs_output_cp_stride, 16) >> 4; if (num_attrs > 0) { const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; // The input vertex count isn't statically known, so make length 32 (what glslang does) @@ -418,7 +418,7 @@ void EmitContext::DefineInputs() { Name(input_attr_array, "in_attrs"); } - u32 patch_base_location = runtime_info.vs_info.hs_output_cp_stride >> 4; + const u32 patch_base_location = num_attrs; for (size_t index = 0; index < 30; ++index) { if (!(info.uses_patches & (1U << index))) { continue; @@ -453,7 +453,7 @@ void EmitContext::DefineOutputs() { DefineVariable(type, spv::BuiltIn::CullDistance, spv::StorageClass::Output); } if (stage == Shader::Stage::Local && runtime_info.ls_info.links_with_tcs) { - const u32 num_attrs = runtime_info.ls_info.ls_stride >> 4; + const u32 num_attrs = Common::AlignUp(runtime_info.ls_info.ls_stride, 16) >> 4; if (num_attrs > 0) { const Id type{TypeArray(F32[4], ConstU32(num_attrs))}; output_attr_array = DefineOutput(type, 0); @@ -488,7 +488,7 @@ void EmitContext::DefineOutputs() { Decorate(output_tess_level_inner, spv::Decoration::Patch); } - const u32 num_attrs = runtime_info.hs_info.hs_output_cp_stride >> 4; + const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.hs_output_cp_stride, 16) >> 4; if (num_attrs > 0) { const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; // The input vertex count isn't statically known, so make length 32 (what glslang does) @@ -498,7 +498,7 @@ void EmitContext::DefineOutputs() { Name(output_attr_array, "out_attrs"); } - u32 patch_base_location = runtime_info.hs_info.hs_output_cp_stride >> 4; + const u32 patch_base_location = num_attrs; for (size_t index = 0; index < 30; ++index) { if (!(info.uses_patches & (1U << index))) { continue; diff --git a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp index 6164fec12..b41e38339 100644 --- a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp +++ b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp @@ -349,11 +349,11 @@ static IR::F32 ReadTessControlPointAttribute(IR::U32 addr, const u32 stride, IR: addr = ir.IAdd(addr, ir.Imm32(off_dw)); } const IR::U32 control_point_index = ir.IDiv(addr, ir.Imm32(stride)); - const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); - const IR::U32 attr_index = - ir.ShiftRightLogical(ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); + const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir); + const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride)); + const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u)); const IR::U32 comp_index = - ir.ShiftRightLogical(ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); + ir.ShiftRightLogical(ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u)); if (is_output_read_in_tcs) { return ir.ReadTcsGenericOuputAttribute(control_point_index, attr_index, comp_index); } else { @@ -452,13 +452,13 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { if (off_dw > 0) { addr = ir.IAdd(addr, ir.Imm32(off_dw)); } - u32 stride = runtime_info.hs_info.hs_output_cp_stride; + const u32 stride = runtime_info.hs_info.hs_output_cp_stride; // Invocation ID array index is implicit, handled by SPIRV backend - const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); - const IR::U32 attr_index = ir.ShiftRightLogical( - ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); + const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir); + const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride)); + const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u)); const IR::U32 comp_index = ir.ShiftRightLogical( - ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); + ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u)); ir.SetTcsGenericAttribute(data_component, attr_index, comp_index); } else { ASSERT(output_kind == AttributeRegion::PatchConst); @@ -535,8 +535,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { // ... IR::IREmitter ir{*entry_block, it}; - ASSERT(runtime_info.hs_info.ls_stride % 16 == 0); - u32 num_attributes = runtime_info.hs_info.ls_stride / 16; + u32 num_attributes = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4; const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId); for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) { for (u32 comp = 0; comp < 4; comp++) {