shader_recompiler: Implement basic 64-bit floating point support (#915)

* shader_recompiler: Implement basic 64-bit floating point support

* Fix formatting
This commit is contained in:
Daniel R. 2024-09-15 22:53:08 +02:00 committed by GitHub
parent a441244365
commit 4006fcb7d9
13 changed files with 65 additions and 30 deletions

View file

@ -184,6 +184,9 @@ void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
ctx.AddCapability(spv::Capability::Float16);
ctx.AddCapability(spv::Capability::Int16);
}
if (info.uses_fp64) {
ctx.AddCapability(spv::Capability::Float64);
}
ctx.AddCapability(spv::Capability::Int64);
if (info.has_storage_images || info.has_image_buffers) {
ctx.AddCapability(spv::Capability::StorageImageExtendedFormats);

View file

@ -14,8 +14,8 @@ Id EmitBitCastU32F32(EmitContext& ctx, Id value) {
return ctx.OpBitcast(ctx.U32[1], value);
}
void EmitBitCastU64F64(EmitContext&) {
UNREACHABLE_MSG("SPIR-V Instruction");
Id EmitBitCastU64F64(EmitContext& ctx, Id value) {
return ctx.OpBitcast(ctx.U64, value);
}
Id EmitBitCastF16U16(EmitContext& ctx, Id value) {
@ -38,6 +38,10 @@ Id EmitUnpackUint2x32(EmitContext& ctx, Id value) {
return ctx.OpBitcast(ctx.U32[2], value);
}
Id EmitPackFloat2x32(EmitContext& ctx, Id value) {
return ctx.OpBitcast(ctx.F64[1], value);
}
Id EmitPackFloat2x16(EmitContext& ctx, Id value) {
return ctx.OpBitcast(ctx.U32[1], value);
}

View file

@ -158,12 +158,13 @@ Id EmitSelectF32(EmitContext& ctx, Id cond, Id true_value, Id false_value);
Id EmitSelectF64(EmitContext& ctx, Id cond, Id true_value, Id false_value);
Id EmitBitCastU16F16(EmitContext& ctx, Id value);
Id EmitBitCastU32F32(EmitContext& ctx, Id value);
void EmitBitCastU64F64(EmitContext& ctx);
Id EmitBitCastU64F64(EmitContext& ctx, Id value);
Id EmitBitCastF16U16(EmitContext& ctx, Id value);
Id EmitBitCastF32U32(EmitContext& ctx, Id value);
void EmitBitCastF64U64(EmitContext& ctx);
Id EmitPackUint2x32(EmitContext& ctx, Id value);
Id EmitUnpackUint2x32(EmitContext& ctx, Id value);
Id EmitPackFloat2x32(EmitContext& ctx, Id value);
Id EmitPackFloat2x16(EmitContext& ctx, Id value);
Id EmitUnpackFloat2x16(EmitContext& ctx, Id value);
Id EmitPackHalf2x16(EmitContext& ctx, Id value);

View file

@ -85,6 +85,9 @@ void EmitContext::DefineArithmeticTypes() {
F16[1] = Name(TypeFloat(16), "f16_id");
U16 = Name(TypeUInt(16), "u16_id");
}
if (info.uses_fp64) {
F64[1] = Name(TypeFloat(64), "f64_id");
}
F32[1] = Name(TypeFloat(32), "f32_id");
S32[1] = Name(TypeSInt(32), "i32_id");
U32[1] = Name(TypeUInt(32), "u32_id");
@ -94,6 +97,9 @@ void EmitContext::DefineArithmeticTypes() {
if (info.uses_fp16) {
F16[i] = Name(TypeVector(F16[1], i), fmt::format("f16vec{}_id", i));
}
if (info.uses_fp64) {
F64[i] = Name(TypeVector(F64[1], i), fmt::format("f64vec{}_id", i));
}
F32[i] = Name(TypeVector(F32[1], i), fmt::format("f32vec{}_id", i));
S32[i] = Name(TypeVector(S32[1], i), fmt::format("i32vec{}_id", i));
U32[i] = Name(TypeVector(U32[1], i), fmt::format("u32vec{}_id", i));

View file

@ -211,7 +211,7 @@ T Translator::GetSrc64(const InstOperand& operand) {
const auto value_lo = ir.GetVectorReg(IR::VectorReg(operand.code));
const auto value_hi = ir.GetVectorReg(IR::VectorReg(operand.code + 1));
if constexpr (is_float) {
UNREACHABLE();
value = ir.PackFloat2x32(ir.CompositeConstruct(value_lo, value_hi));
} else {
value = ir.PackUint2x32(ir.CompositeConstruct(value_lo, value_hi));
}

View file

@ -141,6 +141,7 @@ public:
void V_FMA_F32(const GcnInst& inst);
void V_CMP_F32(ConditionOp op, bool set_exec, const GcnInst& inst);
void V_MAX_F32(const GcnInst& inst, bool is_legacy = false);
void V_MAX_F64(const GcnInst& inst);
void V_MAX_U32(bool is_signed, const GcnInst& inst);
void V_RSQ_F32(const GcnInst& inst);
void V_SIN_F32(const GcnInst& inst);

View file

@ -198,6 +198,8 @@ void Translator::EmitVectorAlu(const GcnInst& inst) {
return V_FMA_F32(inst);
case Opcode::V_MAX_F32:
return V_MAX_F32(inst);
case Opcode::V_MAX_F64:
return V_MAX_F64(inst);
case Opcode::V_RSQ_F32:
return V_RSQ_F32(inst);
case Opcode::V_SIN_F32:
@ -582,6 +584,12 @@ void Translator::V_MAX_F32(const GcnInst& inst, bool is_legacy) {
SetDst(inst.dst[0], ir.FPMax(src0, src1, is_legacy));
}
void Translator::V_MAX_F64(const GcnInst& inst) {
const IR::F64 src0{GetSrc64<IR::F64>(inst.src[0])};
const IR::F64 src1{GetSrc64<IR::F64>(inst.src[1])};
SetDst64(inst.dst[0], ir.FPMax(src0, src1));
}
void Translator::V_MAX_U32(bool is_signed, const GcnInst& inst) {
const IR::U32 src0{GetSrc(inst.src[0])};
const IR::U32 src1{GetSrc(inst.src[1])};

View file

@ -168,6 +168,7 @@ struct Info {
bool uses_group_ballot{};
bool uses_shared{};
bool uses_fp16{};
bool uses_fp64{};
bool uses_step_rates{};
bool translation_failed{}; // indicates that shader has unsupported instructions

View file

@ -629,6 +629,10 @@ Value IREmitter::UnpackUint2x32(const U64& value) {
return Inst<Value>(Opcode::UnpackUint2x32, value);
}
F64 IREmitter::PackFloat2x32(const Value& vector) {
return Inst<F64>(Opcode::PackFloat2x32, vector);
}
U32 IREmitter::PackFloat2x16(const Value& vector) {
return Inst<U32>(Opcode::PackFloat2x16, vector);
}

View file

@ -142,6 +142,8 @@ public:
[[nodiscard]] U64 PackUint2x32(const Value& vector);
[[nodiscard]] Value UnpackUint2x32(const U64& value);
[[nodiscard]] F64 PackFloat2x32(const Value& vector);
[[nodiscard]] U32 PackFloat2x16(const Value& vector);
[[nodiscard]] Value UnpackFloat2x16(const U32& value);

View file

@ -34,9 +34,9 @@ OPCODE(WriteSharedU128, Void, U32,
// Shared atomic operations
OPCODE(SharedAtomicIAdd32, U32, U32, U32, )
OPCODE(SharedAtomicSMin32, U32, U32, U32, )
OPCODE(SharedAtomicSMin32, U32, U32, U32, )
OPCODE(SharedAtomicUMin32, U32, U32, U32, )
OPCODE(SharedAtomicSMax32, U32, U32, U32, )
OPCODE(SharedAtomicSMax32, U32, U32, U32, )
OPCODE(SharedAtomicUMax32, U32, U32, U32, )
// Context getters/setters
@ -54,19 +54,19 @@ OPCODE(GetAttributeU32, U32, Attr
OPCODE(SetAttribute, Void, Attribute, F32, U32, )
// Flags
OPCODE(GetScc, U1, Void, )
OPCODE(GetExec, U1, Void, )
OPCODE(GetVcc, U1, Void, )
OPCODE(GetVccLo, U32, Void, )
OPCODE(GetVccHi, U32, Void, )
OPCODE(GetM0, U32, Void, )
OPCODE(SetScc, Void, U1, )
OPCODE(SetExec, Void, U1, )
OPCODE(SetVcc, Void, U1, )
OPCODE(SetSccLo, Void, U32, )
OPCODE(SetVccLo, Void, U32, )
OPCODE(SetVccHi, Void, U32, )
OPCODE(SetM0, Void, U32, )
OPCODE(GetScc, U1, Void, )
OPCODE(GetExec, U1, Void, )
OPCODE(GetVcc, U1, Void, )
OPCODE(GetVccLo, U32, Void, )
OPCODE(GetVccHi, U32, Void, )
OPCODE(GetM0, U32, Void, )
OPCODE(SetScc, Void, U1, )
OPCODE(SetExec, Void, U1, )
OPCODE(SetVcc, Void, U1, )
OPCODE(SetSccLo, Void, U32, )
OPCODE(SetVccLo, Void, U32, )
OPCODE(SetVccHi, Void, U32, )
OPCODE(SetM0, Void, U32, )
// Undefined
OPCODE(UndefU1, U1, )
@ -88,17 +88,17 @@ OPCODE(StoreBufferU32x4, Void, Opaq
OPCODE(StoreBufferFormatF32, Void, Opaque, Opaque, U32x4, )
// Buffer atomic operations
OPCODE(BufferAtomicIAdd32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicInc32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicDec32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicAnd32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicOr32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicXor32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicSwap32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicIAdd32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMin32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicSMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicUMax32, U32, Opaque, Opaque, U32 )
OPCODE(BufferAtomicInc32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicDec32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicAnd32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicOr32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicXor32, U32, Opaque, Opaque, U32, )
OPCODE(BufferAtomicSwap32, U32, Opaque, Opaque, U32, )
// Vector utility
OPCODE(CompositeConstructU32x2, U32x2, U32, U32, )
@ -156,6 +156,7 @@ OPCODE(BitCastF32U32, F32, U32,
OPCODE(BitCastF64U64, F64, U64, )
OPCODE(PackUint2x32, U64, U32x2, )
OPCODE(UnpackUint2x32, U32x2, U64, )
OPCODE(PackFloat2x32, F64, F32x2, )
OPCODE(PackFloat2x16, U32, F16x2, )
OPCODE(UnpackFloat2x16, F16x2, U32, )
OPCODE(PackHalf2x16, U32, F32x2, )

View file

@ -27,6 +27,9 @@ void Visit(Info& info, IR::Inst& inst) {
case IR::Opcode::BitCastF16U16:
info.uses_fp16 = true;
break;
case IR::Opcode::BitCastU64F64:
info.uses_fp64 = true;
break;
case IR::Opcode::ImageWrite:
info.has_storage_images = true;
break;

View file

@ -289,6 +289,7 @@ bool Instance::CreateDevice() {
.shaderStorageImageExtendedFormats = features.shaderStorageImageExtendedFormats,
.shaderStorageImageMultisample = features.shaderStorageImageMultisample,
.shaderClipDistance = features.shaderClipDistance,
.shaderFloat64 = features.shaderFloat64,
.shaderInt64 = features.shaderInt64,
.shaderInt16 = features.shaderInt16,
},