forked from eden-emu/eden
		
	shader: Misc fixes
This commit is contained in:
		
					parent
					
						
							
								c4d75e4b78
							
						
					
				
			
			
				commit
				
					
						76a3a2510f
					
				
			
		
					 10 changed files with 104 additions and 89 deletions
				
			
		|  | @ -25,6 +25,9 @@ EmitContext::EmitContext(IR::Program& program) { | ||||||
|     f16.Define(*this, TypeFloat(16), "f16"); |     f16.Define(*this, TypeFloat(16), "f16"); | ||||||
|     f64.Define(*this, TypeFloat(64), "f64"); |     f64.Define(*this, TypeFloat(64), "f64"); | ||||||
| 
 | 
 | ||||||
|  |     true_value = ConstantTrue(u1); | ||||||
|  |     false_value = ConstantFalse(u1); | ||||||
|  | 
 | ||||||
|     for (const IR::Function& function : program.functions) { |     for (const IR::Function& function : program.functions) { | ||||||
|         for (IR::Block* const block : function.blocks) { |         for (IR::Block* const block : function.blocks) { | ||||||
|             block_label_map.emplace_back(block, OpLabel()); |             block_label_map.emplace_back(block, OpLabel()); | ||||||
|  | @ -58,6 +61,7 @@ EmitSPIRV::EmitSPIRV(IR::Program& program) { | ||||||
|     std::fclose(file); |     std::fclose(file); | ||||||
|     std::system("spirv-dis shader.spv"); |     std::system("spirv-dis shader.spv"); | ||||||
|     std::system("spirv-val shader.spv"); |     std::system("spirv-val shader.spv"); | ||||||
|  |     std::system("spirv-cross shader.spv"); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <auto method> | template <auto method> | ||||||
|  | @ -109,6 +113,8 @@ static Id TypeId(const EmitContext& ctx, IR::Type type) { | ||||||
|     switch (type) { |     switch (type) { | ||||||
|     case IR::Type::U1: |     case IR::Type::U1: | ||||||
|         return ctx.u1; |         return ctx.u1; | ||||||
|  |     case IR::Type::U32: | ||||||
|  |         return ctx.u32[1]; | ||||||
|     default: |     default: | ||||||
|         throw NotImplementedException("Phi node type {}", type); |         throw NotImplementedException("Phi node type {}", type); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | @ -79,6 +79,8 @@ public: | ||||||
|             return def_map.Consume(value.Inst()); |             return def_map.Consume(value.Inst()); | ||||||
|         } |         } | ||||||
|         switch (value.Type()) { |         switch (value.Type()) { | ||||||
|  |         case IR::Type::U1: | ||||||
|  |             return value.U1() ? true_value : false_value; | ||||||
|         case IR::Type::U32: |         case IR::Type::U32: | ||||||
|             return Constant(u32[1], value.U32()); |             return Constant(u32[1], value.U32()); | ||||||
|         case IR::Type::F32: |         case IR::Type::F32: | ||||||
|  | @ -108,6 +110,9 @@ public: | ||||||
|     VectorTypes f16; |     VectorTypes f16; | ||||||
|     VectorTypes f64; |     VectorTypes f64; | ||||||
| 
 | 
 | ||||||
|  |     Id true_value{}; | ||||||
|  |     Id false_value{}; | ||||||
|  | 
 | ||||||
|     Id workgroup_id{}; |     Id workgroup_id{}; | ||||||
|     Id local_invocation_id{}; |     Id local_invocation_id{}; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -113,7 +113,7 @@ static std::string ArgToIndex(const std::map<const Block*, size_t>& block_to_ind | ||||||
|     if (arg.IsLabel()) { |     if (arg.IsLabel()) { | ||||||
|         return BlockToIndex(block_to_index, arg.Label()); |         return BlockToIndex(block_to_index, arg.Label()); | ||||||
|     } |     } | ||||||
|     if (!arg.IsImmediate()) { |     if (!arg.IsImmediate() || arg.IsIdentity()) { | ||||||
|         return fmt::format("%{}", InstIndex(inst_to_index, inst_index, arg.Inst())); |         return fmt::format("%{}", InstIndex(inst_to_index, inst_index, arg.Inst())); | ||||||
|     } |     } | ||||||
|     switch (arg.Type()) { |     switch (arg.Type()) { | ||||||
|  | @ -166,7 +166,7 @@ std::string DumpBlock(const Block& block, const std::map<const Block*, size_t>& | ||||||
|             const std::string arg_str{ArgToIndex(block_to_index, inst_to_index, inst_index, arg)}; |             const std::string arg_str{ArgToIndex(block_to_index, inst_to_index, inst_index, arg)}; | ||||||
|             ret += arg_index != 0 ? ", " : " "; |             ret += arg_index != 0 ? ", " : " "; | ||||||
|             if (op == Opcode::Phi) { |             if (op == Opcode::Phi) { | ||||||
|                 ret += fmt::format("[ {}, {} ]", arg_index, |                 ret += fmt::format("[ {}, {} ]", arg_str, | ||||||
|                                    BlockToIndex(block_to_index, inst.PhiBlock(arg_index))); |                                    BlockToIndex(block_to_index, inst.PhiBlock(arg_index))); | ||||||
|             } else { |             } else { | ||||||
|                 ret += arg_str; |                 ret += arg_str; | ||||||
|  |  | ||||||
|  | @ -46,10 +46,12 @@ F64 IREmitter::Imm64(f64 value) const { | ||||||
| 
 | 
 | ||||||
| void IREmitter::Branch(Block* label) { | void IREmitter::Branch(Block* label) { | ||||||
|     label->AddImmediatePredecessor(block); |     label->AddImmediatePredecessor(block); | ||||||
|  |     block->SetBranch(label); | ||||||
|     Inst(Opcode::Branch, label); |     Inst(Opcode::Branch, label); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void IREmitter::BranchConditional(const U1& condition, Block* true_label, Block* false_label) { | void IREmitter::BranchConditional(const U1& condition, Block* true_label, Block* false_label) { | ||||||
|  |     block->SetBranches(IR::Condition{true}, true_label, false_label); | ||||||
|     true_label->AddImmediatePredecessor(block); |     true_label->AddImmediatePredecessor(block); | ||||||
|     false_label->AddImmediatePredecessor(block); |     false_label->AddImmediatePredecessor(block); | ||||||
|     Inst(Opcode::BranchConditional, condition, true_label, false_label); |     Inst(Opcode::BranchConditional, condition, true_label, false_label); | ||||||
|  |  | ||||||
|  | @ -143,19 +143,21 @@ Value Inst::Arg(size_t index) const { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void Inst::SetArg(size_t index, Value value) { | void Inst::SetArg(size_t index, Value value) { | ||||||
|     if (op == Opcode::Phi) { |     if (index >= NumArgs()) { | ||||||
|         throw LogicError("Setting argument on a phi instruction"); |  | ||||||
|     } |  | ||||||
|     if (index >= NumArgsOf(op)) { |  | ||||||
|         throw InvalidArgument("Out of bounds argument index {} in opcode {}", index, op); |         throw InvalidArgument("Out of bounds argument index {} in opcode {}", index, op); | ||||||
|     } |     } | ||||||
|     if (!args[index].IsImmediate()) { |     const IR::Value arg{Arg(index)}; | ||||||
|         UndoUse(args[index]); |     if (!arg.IsImmediate()) { | ||||||
|  |         UndoUse(arg); | ||||||
|     } |     } | ||||||
|     if (!value.IsImmediate()) { |     if (!value.IsImmediate()) { | ||||||
|         Use(value); |         Use(value); | ||||||
|     } |     } | ||||||
|  |     if (op == Opcode::Phi) { | ||||||
|  |         phi_args[index].second = value; | ||||||
|  |     } else { | ||||||
|         args[index] = value; |         args[index] = value; | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Block* Inst::PhiBlock(size_t index) const { | Block* Inst::PhiBlock(size_t index) const { | ||||||
|  |  | ||||||
|  | @ -76,8 +76,8 @@ void IADD(TranslatorVisitor& v, u64 insn, IR::U32 op_b) { | ||||||
| } | } | ||||||
| } // Anonymous namespace
 | } // Anonymous namespace
 | ||||||
| 
 | 
 | ||||||
| void TranslatorVisitor::IADD_reg(u64) { | void TranslatorVisitor::IADD_reg(u64 insn) { | ||||||
|     throw NotImplementedException("IADD (reg)"); |     IADD(*this, insn, GetReg20(insn)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void TranslatorVisitor::IADD_cbuf(u64 insn) { | void TranslatorVisitor::IADD_cbuf(u64 insn) { | ||||||
|  |  | ||||||
|  | @ -92,8 +92,8 @@ void TranslatorVisitor::ISETP_cbuf(u64 insn) { | ||||||
|     ISETP(*this, insn, GetCbuf(insn)); |     ISETP(*this, insn, GetCbuf(insn)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void TranslatorVisitor::ISETP_imm(u64) { | void TranslatorVisitor::ISETP_imm(u64 insn) { | ||||||
|     throw NotImplementedException("ISETP_imm"); |     ISETP(*this, insn, GetImm20(insn)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| } // namespace Shader::Maxwell
 | } // namespace Shader::Maxwell
 | ||||||
|  |  | ||||||
|  | @ -32,6 +32,8 @@ template <typename T> | ||||||
|         return value.U1(); |         return value.U1(); | ||||||
|     } else if constexpr (std::is_same_v<T, u32>) { |     } else if constexpr (std::is_same_v<T, u32>) { | ||||||
|         return value.U32(); |         return value.U32(); | ||||||
|  |     } else if constexpr (std::is_same_v<T, s32>) { | ||||||
|  |         return static_cast<s32>(value.U32()); | ||||||
|     } else if constexpr (std::is_same_v<T, f32>) { |     } else if constexpr (std::is_same_v<T, f32>) { | ||||||
|         return value.F32(); |         return value.F32(); | ||||||
|     } else if constexpr (std::is_same_v<T, u64>) { |     } else if constexpr (std::is_same_v<T, u64>) { | ||||||
|  | @ -39,17 +41,8 @@ template <typename T> | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename ImmFn> | template <typename T, typename ImmFn> | ||||||
| bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||||
|     const auto arg = [](const IR::Value& value) { |  | ||||||
|         if constexpr (std::is_invocable_r_v<bool, ImmFn, bool, bool>) { |  | ||||||
|             return value.U1(); |  | ||||||
|         } else if constexpr (std::is_invocable_r_v<u32, ImmFn, u32, u32>) { |  | ||||||
|             return value.U32(); |  | ||||||
|         } else if constexpr (std::is_invocable_r_v<u64, ImmFn, u64, u64>) { |  | ||||||
|             return value.U64(); |  | ||||||
|         } |  | ||||||
|     }; |  | ||||||
|     const IR::Value lhs{inst.Arg(0)}; |     const IR::Value lhs{inst.Arg(0)}; | ||||||
|     const IR::Value rhs{inst.Arg(1)}; |     const IR::Value rhs{inst.Arg(1)}; | ||||||
| 
 | 
 | ||||||
|  | @ -57,14 +50,14 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||||
|     const bool is_rhs_immediate{rhs.IsImmediate()}; |     const bool is_rhs_immediate{rhs.IsImmediate()}; | ||||||
| 
 | 
 | ||||||
|     if (is_lhs_immediate && is_rhs_immediate) { |     if (is_lhs_immediate && is_rhs_immediate) { | ||||||
|         const auto result{imm_fn(arg(lhs), arg(rhs))}; |         const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))}; | ||||||
|         inst.ReplaceUsesWith(IR::Value{result}); |         inst.ReplaceUsesWith(IR::Value{result}); | ||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|     if (is_lhs_immediate && !is_rhs_immediate) { |     if (is_lhs_immediate && !is_rhs_immediate) { | ||||||
|         IR::Inst* const rhs_inst{rhs.InstRecursive()}; |         IR::Inst* const rhs_inst{rhs.InstRecursive()}; | ||||||
|         if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { |         if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { | ||||||
|             const auto combined{imm_fn(arg(lhs), arg(rhs_inst->Arg(1)))}; |             const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; | ||||||
|             inst.SetArg(0, rhs_inst->Arg(0)); |             inst.SetArg(0, rhs_inst->Arg(0)); | ||||||
|             inst.SetArg(1, IR::Value{combined}); |             inst.SetArg(1, IR::Value{combined}); | ||||||
|         } else { |         } else { | ||||||
|  | @ -76,7 +69,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | ||||||
|     if (!is_lhs_immediate && is_rhs_immediate) { |     if (!is_lhs_immediate && is_rhs_immediate) { | ||||||
|         const IR::Inst* const lhs_inst{lhs.InstRecursive()}; |         const IR::Inst* const lhs_inst{lhs.InstRecursive()}; | ||||||
|         if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { |         if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { | ||||||
|             const auto combined{imm_fn(arg(rhs), arg(lhs_inst->Arg(1)))}; |             const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; | ||||||
|             inst.SetArg(0, lhs_inst->Arg(0)); |             inst.SetArg(0, lhs_inst->Arg(0)); | ||||||
|             inst.SetArg(1, IR::Value{combined}); |             inst.SetArg(1, IR::Value{combined}); | ||||||
|         } |         } | ||||||
|  | @ -101,7 +94,7 @@ void FoldAdd(IR::Inst& inst) { | ||||||
|     if (inst.HasAssociatedPseudoOperation()) { |     if (inst.HasAssociatedPseudoOperation()) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     if (!FoldCommutative(inst, [](T a, T b) { return a + b; })) { |     if (!FoldCommutative<T>(inst, [](T a, T b) { return a + b; })) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     const IR::Value rhs{inst.Arg(1)}; |     const IR::Value rhs{inst.Arg(1)}; | ||||||
|  | @ -119,7 +112,7 @@ void FoldSelect(IR::Inst& inst) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void FoldLogicalAnd(IR::Inst& inst) { | void FoldLogicalAnd(IR::Inst& inst) { | ||||||
|     if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) { |     if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a && b; })) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     const IR::Value rhs{inst.Arg(1)}; |     const IR::Value rhs{inst.Arg(1)}; | ||||||
|  | @ -133,7 +126,7 @@ void FoldLogicalAnd(IR::Inst& inst) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void FoldLogicalOr(IR::Inst& inst) { | void FoldLogicalOr(IR::Inst& inst) { | ||||||
|     if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) { |     if (!FoldCommutative<bool>(inst, [](bool a, bool b) { return a || b; })) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     const IR::Value rhs{inst.Arg(1)}; |     const IR::Value rhs{inst.Arg(1)}; | ||||||
|  | @ -226,6 +219,8 @@ void ConstantPropagation(IR::Inst& inst) { | ||||||
|         return FoldLogicalOr(inst); |         return FoldLogicalOr(inst); | ||||||
|     case IR::Opcode::LogicalNot: |     case IR::Opcode::LogicalNot: | ||||||
|         return FoldLogicalNot(inst); |         return FoldLogicalNot(inst); | ||||||
|  |     case IR::Opcode::SLessThan: | ||||||
|  |         return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); | ||||||
|     case IR::Opcode::ULessThan: |     case IR::Opcode::ULessThan: | ||||||
|         return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); |         return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); | ||||||
|     case IR::Opcode::BitFieldUExtract: |     case IR::Opcode::BitFieldUExtract: | ||||||
|  |  | ||||||
|  | @ -113,6 +113,7 @@ private: | ||||||
|     IR::Value ReadVariableRecursive(auto variable, IR::Block* block) { |     IR::Value ReadVariableRecursive(auto variable, IR::Block* block) { | ||||||
|         IR::Value val; |         IR::Value val; | ||||||
|         if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) { |         if (const std::span preds{block->ImmediatePredecessors()}; preds.size() == 1) { | ||||||
|  |             // Optimize the common case of one predecessor: no phi needed
 | ||||||
|             val = ReadVariable(variable, preds.front()); |             val = ReadVariable(variable, preds.front()); | ||||||
|         } else { |         } else { | ||||||
|             // Break potential cycles with operandless phi
 |             // Break potential cycles with operandless phi
 | ||||||
|  | @ -160,12 +161,8 @@ private: | ||||||
| 
 | 
 | ||||||
|     DefTable current_def; |     DefTable current_def; | ||||||
| }; | }; | ||||||
| } // Anonymous namespace
 |  | ||||||
| 
 | 
 | ||||||
| void SsaRewritePass(IR::Function& function) { | void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { | ||||||
|     Pass pass; |  | ||||||
|     for (IR::Block* const block : function.blocks) { |  | ||||||
|         for (IR::Inst& inst : block->Instructions()) { |  | ||||||
|     switch (inst.Opcode()) { |     switch (inst.Opcode()) { | ||||||
|     case IR::Opcode::SetRegister: |     case IR::Opcode::SetRegister: | ||||||
|         if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { |         if (const IR::Reg reg{inst.Arg(0).Reg()}; reg != IR::Reg::RZ) { | ||||||
|  | @ -220,6 +217,14 @@ void SsaRewritePass(IR::Function& function) { | ||||||
|     default: |     default: | ||||||
|         break; |         break; | ||||||
|     } |     } | ||||||
|  | } | ||||||
|  | } // Anonymous namespace
 | ||||||
|  | 
 | ||||||
|  | void SsaRewritePass(IR::Function& function) { | ||||||
|  |     Pass pass; | ||||||
|  |     for (IR::Block* const block : function.blocks) { | ||||||
|  |         for (IR::Inst& inst : block->Instructions()) { | ||||||
|  |             VisitInst(pass, block, inst); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -38,7 +38,8 @@ void RunDatabase() { | ||||||
|         map.emplace_back(std::make_unique<FileEnvironment>(path.string().c_str())); |         map.emplace_back(std::make_unique<FileEnvironment>(path.string().c_str())); | ||||||
|     }); |     }); | ||||||
|     auto block_pool{std::make_unique<ObjectPool<Flow::Block>>()}; |     auto block_pool{std::make_unique<ObjectPool<Flow::Block>>()}; | ||||||
|     auto t0 = std::chrono::high_resolution_clock::now(); |     using namespace std::chrono; | ||||||
|  |     auto t0 = high_resolution_clock::now(); | ||||||
|     int N = 1; |     int N = 1; | ||||||
|     int n = 0; |     int n = 0; | ||||||
|     for (int i = 0; i < N; ++i) { |     for (int i = 0; i < N; ++i) { | ||||||
|  | @ -55,9 +56,8 @@ void RunDatabase() { | ||||||
|             // const std::string code{EmitGLASM(program)};
 |             // const std::string code{EmitGLASM(program)};
 | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     auto t = std::chrono::high_resolution_clock::now(); |     auto t = high_resolution_clock::now(); | ||||||
|     fmt::print(stdout, "{} ms", |     fmt::print(stdout, "{} ms", duration_cast<milliseconds>(t - t0).count() / double(N)); | ||||||
|                std::chrono::duration_cast<std::chrono::milliseconds>(t - t0).count() / double(N)); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int main() { | int main() { | ||||||
|  | @ -67,8 +67,8 @@ int main() { | ||||||
|     auto inst_pool{std::make_unique<ObjectPool<IR::Inst>>()}; |     auto inst_pool{std::make_unique<ObjectPool<IR::Inst>>()}; | ||||||
|     auto block_pool{std::make_unique<ObjectPool<IR::Block>>()}; |     auto block_pool{std::make_unique<ObjectPool<IR::Block>>()}; | ||||||
| 
 | 
 | ||||||
|     FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; |     // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"};
 | ||||||
|     // FileEnvironment env{"D:\\Shaders\\shader.bin"};
 |     FileEnvironment env{"D:\\Shaders\\shader.bin"}; | ||||||
|     for (int i = 0; i < 1; ++i) { |     for (int i = 0; i < 1; ++i) { | ||||||
|         block_pool->ReleaseContents(); |         block_pool->ReleaseContents(); | ||||||
|         inst_pool->ReleaseContents(); |         inst_pool->ReleaseContents(); | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 ReinUsesLisp
				ReinUsesLisp