handle control point strides that arent a multiple of 16 (#2172)

This commit is contained in:
baggins183 2025-01-17 00:14:54 -08:00 committed by GitHub
parent 3b474a12f9
commit c13b29662e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 16 additions and 17 deletions

View file

@ -395,7 +395,7 @@ void EmitContext::DefineInputs() {
DefineVariable(U32[1], spv::BuiltIn::PatchVertices, spv::StorageClass::Input);
primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input);
const u32 num_attrs = runtime_info.hs_info.ls_stride >> 4;
const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4;
if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -409,7 +409,7 @@ void EmitContext::DefineInputs() {
tess_coord = DefineInput(F32[3], std::nullopt, spv::BuiltIn::TessCoord);
primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input);
const u32 num_attrs = runtime_info.vs_info.hs_output_cp_stride >> 4;
const u32 num_attrs = Common::AlignUp(runtime_info.vs_info.hs_output_cp_stride, 16) >> 4;
if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -418,7 +418,7 @@ void EmitContext::DefineInputs() {
Name(input_attr_array, "in_attrs");
}
u32 patch_base_location = runtime_info.vs_info.hs_output_cp_stride >> 4;
const u32 patch_base_location = num_attrs;
for (size_t index = 0; index < 30; ++index) {
if (!(info.uses_patches & (1U << index))) {
continue;
@ -453,7 +453,7 @@ void EmitContext::DefineOutputs() {
DefineVariable(type, spv::BuiltIn::CullDistance, spv::StorageClass::Output);
}
if (stage == Shader::Stage::Local && runtime_info.ls_info.links_with_tcs) {
const u32 num_attrs = runtime_info.ls_info.ls_stride >> 4;
const u32 num_attrs = Common::AlignUp(runtime_info.ls_info.ls_stride, 16) >> 4;
if (num_attrs > 0) {
const Id type{TypeArray(F32[4], ConstU32(num_attrs))};
output_attr_array = DefineOutput(type, 0);
@ -488,7 +488,7 @@ void EmitContext::DefineOutputs() {
Decorate(output_tess_level_inner, spv::Decoration::Patch);
}
const u32 num_attrs = runtime_info.hs_info.hs_output_cp_stride >> 4;
const u32 num_attrs = Common::AlignUp(runtime_info.hs_info.hs_output_cp_stride, 16) >> 4;
if (num_attrs > 0) {
const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))};
// The input vertex count isn't statically known, so make length 32 (what glslang does)
@ -498,7 +498,7 @@ void EmitContext::DefineOutputs() {
Name(output_attr_array, "out_attrs");
}
u32 patch_base_location = runtime_info.hs_info.hs_output_cp_stride >> 4;
const u32 patch_base_location = num_attrs;
for (size_t index = 0; index < 30; ++index) {
if (!(info.uses_patches & (1U << index))) {
continue;

View file

@ -349,11 +349,11 @@ static IR::F32 ReadTessControlPointAttribute(IR::U32 addr, const u32 stride, IR:
addr = ir.IAdd(addr, ir.Imm32(off_dw));
}
const IR::U32 control_point_index = ir.IDiv(addr, ir.Imm32(stride));
const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 attr_index =
ir.ShiftRightLogical(ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u));
const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride));
const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u));
const IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.ShiftRightLogical(ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u));
if (is_output_read_in_tcs) {
return ir.ReadTcsGenericOuputAttribute(control_point_index, attr_index, comp_index);
} else {
@ -452,13 +452,13 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
if (off_dw > 0) {
addr = ir.IAdd(addr, ir.Imm32(off_dw));
}
u32 stride = runtime_info.hs_info.hs_output_cp_stride;
const u32 stride = runtime_info.hs_info.hs_output_cp_stride;
// Invocation ID array index is implicit, handled by SPIRV backend
const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u));
const IR::U32 opt_addr = TryOptimizeAddressModulo(addr, stride, ir);
const IR::U32 offset = ir.IMod(opt_addr, ir.Imm32(stride));
const IR::U32 attr_index = ir.ShiftRightLogical(offset, ir.Imm32(4u));
const IR::U32 comp_index = ir.ShiftRightLogical(
ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.BitwiseAnd(offset, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.SetTcsGenericAttribute(data_component, attr_index, comp_index);
} else {
ASSERT(output_kind == AttributeRegion::PatchConst);
@ -535,8 +535,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
// ...
IR::IREmitter ir{*entry_block, it};
ASSERT(runtime_info.hs_info.ls_stride % 16 == 0);
u32 num_attributes = runtime_info.hs_info.ls_stride / 16;
u32 num_attributes = Common::AlignUp(runtime_info.hs_info.ls_stride, 16) >> 4;
const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId);
for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) {
for (u32 comp = 0; comp < 4; comp++) {