handle control point strides that arent a multiple of 16 (#2172)

This commit is contained in:
baggins183 2025-01-17 00:14:54 -08:00 committed by GitHub
parent 3b474a12f9
commit c13b29662e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 17 deletions

View file

@ -395,7 +395,7 @@ void EmitContext::DefineInputs() {
DefineVariable(U32[1], spv::BuiltIn::PatchVertices, spv::StorageClass::Input); DefineVariable(U32[1], spv::BuiltIn::PatchVertices, spv::StorageClass::Input);
primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input); primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input);
const u32 num_attrs = runtime_info.hs_info.ls_stride >> 4; const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4;
if (num_attrs > 0) { if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does) // The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -409,7 +409,7 @@ void EmitContext::DefineInputs() {
tess_coord = DefineInput(F32[3], std::nullopt, spv::BuiltIn::TessCoord); tess_coord = DefineInput(F32[3], std::nullopt, spv::BuiltIn::TessCoord);
primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input); primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input);
const u32 num_attrs = runtime_info.vs_info.hs_output_cp_stride >> 4; const u32 num_attrs = Common::AlignUp(runtime_info.vs_info.hs_output_cp_stride, 16) >> 4;
if (num_attrs > 0) { if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does) // The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -418,7 +418,7 @@ void EmitContext::DefineInputs() {
Name(input_attr_array, "in_attrs"); Name(input_attr_array, "in_attrs");
} }
u32 patch_base_location = runtime_info.vs_info.hs_output_cp_stride >> 4; const u32 patch_base_location = num_attrs;
for (size_t index = 0; index < 30; ++index) { for (size_t index = 0; index < 30; ++index) {
if (!(info.uses_patches & (1U << index))) { if (!(info.uses_patches & (1U << index))) {
continue; continue;
@ -453,7 +453,7 @@ void EmitContext::DefineOutputs() {
DefineVariable(type, spv::BuiltIn::CullDistance, spv::StorageClass::Output); DefineVariable(type, spv::BuiltIn::CullDistance, spv::StorageClass::Output);
} }
if (stage == Shader::Stage::Local && runtime_info.ls_info.links_with_tcs) { if (stage == Shader::Stage::Local && runtime_info.ls_info.links_with_tcs) {
const u32 num_attrs = runtime_info.ls_info.ls_stride >> 4; const u32 num_attrs = Common::AlignUp(runtime_info.ls_info.ls_stride, 16) >> 4;
if (num_attrs > 0) { if (num_attrs > 0) {
const Id type{TypeArray(F32[4], ConstU32(num_attrs))}; const Id type{TypeArray(F32[4], ConstU32(num_attrs))};
output_attr_array = DefineOutput(type, 0); output_attr_array = DefineOutput(type, 0);
@ -488,7 +488,7 @@ void EmitContext::DefineOutputs() {
Decorate(output_tess_level_inner, spv::Decoration::Patch); Decorate(output_tess_level_inner, spv::Decoration::Patch);
} }
const u32 num_attrs = runtime_info.hs_info.hs_output_cp_stride >> 4; const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.hs_output_cp_stride, 16) >> 4;
if (num_attrs > 0) { if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does) // The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -498,7 +498,7 @@ void EmitContext::DefineOutputs() {
Name(output_attr_array, "out_attrs"); Name(output_attr_array, "out_attrs");
} }
u32 patch_base_location = runtime_info.hs_info.hs_output_cp_stride >> 4; const u32 patch_base_location = num_attrs;
for (size_t index = 0; index < 30; ++index) { for (size_t index = 0; index < 30; ++index) {
if (!(info.uses_patches & (1U << index))) { if (!(info.uses_patches & (1U << index))) {
continue; continue;

View file

@ -349,11 +349,11 @@ static IR::F32 ReadTessControlPointAttribute(IR::U32 addr, const u32 stride, IR:
addr = ir.IAdd(addr, ir.Imm32(off_dw)); addr = ir.IAdd(addr, ir.Imm32(off_dw));
} }
const IR::U32 control_point_index = ir.IDiv(addr, ir.Imm32(stride)); const IR::U32 control_point_index = ir.IDiv(addr, ir.Imm32(stride));
const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 attr_index = const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride));
ir.ShiftRightLogical(ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u));
const IR::U32 comp_index = const IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); ir.ShiftRightLogical(ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u));
if (is_output_read_in_tcs) { if (is_output_read_in_tcs) {
return ir.ReadTcsGenericOuputAttribute(control_point_index, attr_index, comp_index); return ir.ReadTcsGenericOuputAttribute(control_point_index, attr_index, comp_index);
} else { } else {
@ -452,13 +452,13 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
if (off_dw > 0) { if (off_dw > 0) {
addr = ir.IAdd(addr, ir.Imm32(off_dw)); addr = ir.IAdd(addr, ir.Imm32(off_dw));
} }
u32 stride = runtime_info.hs_info.hs_output_cp_stride; const u32 stride = runtime_info.hs_info.hs_output_cp_stride;
// Invocation ID array index is implicit, handled by SPIRV backend // Invocation ID array index is implicit, handled by SPIRV backend
const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 attr_index = ir.ShiftRightLogical( const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride));
ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u));
const IR::U32 comp_index = ir.ShiftRightLogical( const IR::U32 comp_index = ir.ShiftRightLogical(
ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.SetTcsGenericAttribute(data_component, attr_index, comp_index); ir.SetTcsGenericAttribute(data_component, attr_index, comp_index);
} else { } else {
ASSERT(output_kind == AttributeRegion::PatchConst); ASSERT(output_kind == AttributeRegion::PatchConst);
@ -535,8 +535,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
// ... // ...
IR::IREmitter ir{*entry_block, it}; IR::IREmitter ir{*entry_block, it};
ASSERT(runtime_info.hs_info.ls_stride % 16 == 0); u32 num_attributes = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4;
u32 num_attributes = runtime_info.hs_info.ls_stride / 16;
const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId); const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId);
for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) { for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) {
for (u32 comp = 0; comp < 4; comp++) { for (u32 comp = 0; comp < 4; comp++) {