diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp index 5f174bb7f9..e4c3e3d9cd 100644 --- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp +++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp @@ -30,6 +30,7 @@ using namespace VideoCommon::Shader; using Maxwell = Tegra::Engines::Maxwell3D::Regs; using ShaderStage = Tegra::Engines::Maxwell3D::Regs::ShaderStage; +using Operation = const OperationNode&; // TODO(Rodrigo): Use rasterizer's value constexpr u32 MAX_CONSTBUFFER_ELEMENTS = 0x1000; @@ -73,6 +74,19 @@ constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) { return static_cast(attribute) - static_cast(Attribute::Index::Attribute_0); } +/// Returns true if an object has to be treated as precise +bool IsPrecise(Operation operand) { + const auto& meta = operand.GetMeta(); + + if (std::holds_alternative(meta)) { + return std::get(meta).precise; + } + if (std::holds_alternative(meta)) { + return std::get(meta).precise; + } + return false; +} + } // namespace class SPIRVDecompiler : public Sirit::Module { @@ -194,6 +208,10 @@ public: } private: + using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation); + using OperationDecompilersArray = + std::array(OperationCode::Amount)>; + static constexpr auto INTERNAL_FLAGS_COUNT = static_cast(InternalFlag::Amount); static constexpr u32 CBUF_STRIDE = 16; @@ -468,7 +486,12 @@ private: Id Visit(Node node) { if (const auto operation = std::get_if(node)) { - UNIMPLEMENTED(); + const auto operation_index = static_cast(operation->GetCode()); + const auto decompiler = operation_decompilers[operation_index]; + if (decompiler == nullptr) { + UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index); + } + return (this->*decompiler)(*operation); } else if (const auto gpr = std::get_if(node)) { UNIMPLEMENTED(); @@ -500,6 +523,184 @@ private: return {}; } + template + Id Unary(Operation operation) { + const Id type_def = GetTypeDefinition(result_type); + const Id op_a = VisitOperand(operation, 0); + + const Id value = BitcastFrom(Emit((this->*func)(type_def, op_a))); + if (IsPrecise(operation)) { + Decorate(value, spv::Decoration::NoContraction); + } + return value; + } + + template + Id Binary(Operation operation) { + const Id type_def = GetTypeDefinition(result_type); + const Id op_a = VisitOperand(operation, 0); + const Id op_b = VisitOperand(operation, 1); + + const Id value = BitcastFrom(Emit((this->*func)(type_def, op_a, op_b))); + if (IsPrecise(operation)) { + Decorate(value, spv::Decoration::NoContraction); + } + return value; + } + + template + Id Ternary(Operation operation) { + const Id type_def = GetTypeDefinition(result_type); + const Id op_a = VisitOperand(operation, 0); + const Id op_b = VisitOperand(operation, 1); + const Id op_c = VisitOperand(operation, 2); + + const Id value = BitcastFrom(Emit((this->*func)(type_def, op_a, op_b, op_c))); + if (IsPrecise(operation)) { + Decorate(value, spv::Decoration::NoContraction); + } + return value; + } + + template + Id Quaternary(Operation operation) { + const Id type_def = GetTypeDefinition(result_type); + const Id op_a = VisitOperand(operation, 0); + const Id op_b = VisitOperand(operation, 1); + const Id op_c = VisitOperand(operation, 2); + const Id op_d = VisitOperand(operation, 3); + + const Id value = + BitcastFrom(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d))); + if (IsPrecise(operation)) { + Decorate(value, spv::Decoration::NoContraction); + } + return value; + } + + Id Assign(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id HNegate(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id HMergeF32(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id HMergeH0(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id HMergeH1(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id HPack2(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id LogicalAssign(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id LogicalPick2(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id LogicalAll2(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id LogicalAny2(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id Texture(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id TextureLod(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id TextureGather(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id TextureQueryDimensions(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id TextureQueryLod(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id TexelFetch(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id Branch(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id PushFlowStack(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id PopFlowStack(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id Exit(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id Discard(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id EmitVertex(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id EndPrimitive(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + + Id YNegate(Operation operation) { + UNIMPLEMENTED(); + return {}; + } + Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, const std::string& name) { const Id id = OpVariable(type, storage); @@ -518,6 +719,215 @@ private: return false; } + template + Id VisitOperand(Operation operation, std::size_t operand_index) { + const Id value = Visit(operation[operand_index]); + + switch (type) { + case Type::Bool: + case Type::Bool2: + case Type::Float: + return value; + case Type::Int: + return Emit(OpBitcast(t_int, value)); + case Type::Uint: + return Emit(OpBitcast(t_uint, value)); + case Type::HalfFloat: + UNIMPLEMENTED(); + } + UNREACHABLE(); + return value; + } + + template + Id BitcastFrom(Id value) { + switch (type) { + case Type::Bool: + case Type::Bool2: + case Type::Float: + return value; + case Type::Int: + case Type::Uint: + return Emit(OpBitcast(t_float, value)); + case Type::HalfFloat: + UNIMPLEMENTED(); + } + UNREACHABLE(); + return value; + } + + template + Id BitcastTo(Id value) { + switch (type) { + case Type::Bool: + case Type::Bool2: + UNREACHABLE(); + case Type::Float: + return Emit(OpBitcast(t_float, value)); + case Type::Int: + return Emit(OpBitcast(t_int, value)); + case Type::Uint: + return Emit(OpBitcast(t_uint, value)); + case Type::HalfFloat: + UNIMPLEMENTED(); + } + UNREACHABLE(); + return value; + } + + Id GetTypeDefinition(Type type) { + switch (type) { + case Type::Bool: + return t_bool; + case Type::Bool2: + return t_bool2; + case Type::Float: + return t_float; + case Type::Int: + return t_int; + case Type::Uint: + return t_uint; + case Type::HalfFloat: + UNIMPLEMENTED(); + } + UNREACHABLE(); + return {}; + } + + static constexpr OperationDecompilersArray operation_decompilers = { + &SPIRVDecompiler::Assign, + + &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float, + Type::Float>, + + &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFMul, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFDiv, Type::Float>, + &SPIRVDecompiler::Ternary<&Module::OpFma, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>, + &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpSin, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpExp2, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpLog2, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpInverseSqrt, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpSqrt, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpRoundEven, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpFloor, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpCeil, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpTrunc, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpConvertSToF, Type::Float, Type::Int>, + &SPIRVDecompiler::Unary<&Module::OpConvertUToF, Type::Float, Type::Uint>, + + &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSDiv, Type::Int>, + &SPIRVDecompiler::Unary<&Module::OpSNegate, Type::Int>, + &SPIRVDecompiler::Unary<&Module::OpSAbs, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSMin, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSMax, Type::Int>, + + &SPIRVDecompiler::Unary<&Module::OpConvertFToS, Type::Int, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Int, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Int, Type::Int, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Int, Type::Int, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Int, Type::Int, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Int>, + &SPIRVDecompiler::Unary<&Module::OpNot, Type::Int>, + &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Int>, + &SPIRVDecompiler::Ternary<&Module::OpBitFieldSExtract, Type::Int>, + &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Int>, + + &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpUDiv, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpUMin, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpUMax, Type::Uint>, + &SPIRVDecompiler::Unary<&Module::OpConvertFToU, Type::Uint, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>, + &SPIRVDecompiler::Unary<&Module::OpNot, Type::Uint>, + &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Uint>, + &SPIRVDecompiler::Ternary<&Module::OpBitFieldUExtract, Type::Uint>, + &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Uint>, + + &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFMul, Type::HalfFloat>, + &SPIRVDecompiler::Ternary<&Module::OpFma, Type::HalfFloat>, + &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::HalfFloat>, + &SPIRVDecompiler::HNegate, + &SPIRVDecompiler::HMergeF32, + &SPIRVDecompiler::HMergeH0, + &SPIRVDecompiler::HMergeH1, + &SPIRVDecompiler::HPack2, + + &SPIRVDecompiler::LogicalAssign, + &SPIRVDecompiler::Binary<&Module::OpLogicalAnd, Type::Bool>, + &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>, + &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>, + &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>, + &SPIRVDecompiler::LogicalPick2, + &SPIRVDecompiler::LogicalAll2, + &SPIRVDecompiler::LogicalAny2, + + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>, + &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>, + + &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSLessThanEqual, Type::Bool, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSGreaterThan, Type::Bool, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Int>, + &SPIRVDecompiler::Binary<&Module::OpSGreaterThanEqual, Type::Bool, Type::Int>, + + &SPIRVDecompiler::Binary<&Module::OpULessThan, Type::Bool, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpULessThanEqual, Type::Bool, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpUGreaterThan, Type::Bool, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>, + &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>, + + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, + &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, + + &SPIRVDecompiler::Texture, + &SPIRVDecompiler::TextureLod, + &SPIRVDecompiler::TextureGather, + &SPIRVDecompiler::TextureQueryDimensions, + &SPIRVDecompiler::TextureQueryLod, + &SPIRVDecompiler::TexelFetch, + + &SPIRVDecompiler::Branch, + &SPIRVDecompiler::PushFlowStack, + &SPIRVDecompiler::PopFlowStack, + &SPIRVDecompiler::Exit, + &SPIRVDecompiler::Discard, + + &SPIRVDecompiler::EmitVertex, + &SPIRVDecompiler::EndPrimitive, + + &SPIRVDecompiler::YNegate, + }; + const ShaderIR& ir; const ShaderStage stage; const Tegra::Shader::Header header;