more changes might have overlooked something
Some checks failed
eden-license / license-header (pull_request) Failing after 23s

This commit is contained in:
Ribbit 2025-10-07 21:08:24 -07:00 committed by crueter
parent cb6da0409b
commit 7bbeafc0ca
2 changed files with 203 additions and 60 deletions

View file

@ -98,57 +98,46 @@ Id ImageType(EmitContext& ctx, const ImageDescriptor& desc, Id sampled_type) {
throw InvalidArgument("Invalid texture type {}", desc.type); throw InvalidArgument("Invalid texture type {}", desc.type);
} }
bool MatchesVectorType(const VectorTypes& vectors, Id type) {
for (std::size_t components = 1; components <= 4; ++components) { bool IsFragmentStage(const EmitContext& ctx) {
const Id candidate{vectors[components]}; return ctx.stage == Stage::Fragment;
if (candidate.value != 0 && candidate.value == type.value) {
return true;
}
}
return false;
} }
bool HasIntegerOrDoubleComponent(const EmitContext& ctx, Id type) { bool IsUserVaryingInput(bool is_builtin, bool has_location) {
if (MatchesVectorType(ctx.U32, type) || MatchesVectorType(ctx.S32, type)) { return !is_builtin && has_location;
return true;
}
if (MatchesVectorType(ctx.F64, type)) {
return true;
}
if (ctx.profile.support_int8) {
if ((ctx.U8.value != 0 && ctx.U8.value == type.value) ||
(ctx.S8.value != 0 && ctx.S8.value == type.value)) {
return true;
}
}
if (ctx.profile.support_int16) {
if ((ctx.U16.value != 0 && ctx.U16.value == type.value) ||
(ctx.S16.value != 0 && ctx.S16.value == type.value)) {
return true;
}
}
if (ctx.profile.support_int64) {
if (ctx.U64.value != 0 && ctx.U64.value == type.value) {
return true;
}
}
return false;
} }
bool RequiresFlatDecoration(const EmitContext& ctx, Id type, spv::StorageClass storage_class) { bool IsIntegerOrBoolType(EmitContext& ctx, Id type) {
return ctx.stage == Stage::Fragment && storage_class == spv::StorageClass::Input && return ctx.IsIntegerOrBoolType(type);
HasIntegerOrDoubleComponent(ctx, type); }
bool RequiresFlatDecoration(EmitContext& ctx, Id type, spv::StorageClass storage_class,
bool is_builtin, bool has_location) {
if (!IsFragmentStage(ctx)) {
return false;
}
if (storage_class != spv::StorageClass::Input) {
return false;
}
if (!IsUserVaryingInput(is_builtin, has_location)) {
return false;
}
return IsIntegerOrBoolType(ctx, type);
} }
Id DefineVariable(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin, Id DefineVariable(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin,
spv::StorageClass storage_class, std::optional<Id> initializer = std::nullopt) { spv::StorageClass storage_class, std::optional<Id> initializer = std::nullopt,
bool has_location = false) {
const Id pointer_type{ctx.TypePointer(storage_class, type)}; const Id pointer_type{ctx.TypePointer(storage_class, type)};
const Id id{ctx.AddGlobalVariable(pointer_type, storage_class, initializer)}; const Id id{ctx.AddGlobalVariable(pointer_type, storage_class, initializer)};
if (builtin) { const bool is_builtin{builtin.has_value()};
ctx.Decorate(id, spv::Decoration::BuiltIn, *builtin); if (is_builtin) {
ctx.DecorateUnique(id, spv::Decoration::BuiltIn, static_cast<u32>(*builtin));
} }
if (RequiresFlatDecoration(ctx, type, storage_class)) { // Flat only for integer/bool user varyings in fragment input; never for BuiltIns; dedupe avoids
ctx.Decorate(id, spv::Decoration::Flat); // multiple identical decorations.
if (RequiresFlatDecoration(ctx, type, storage_class, is_builtin, has_location)) {
ctx.DecorateUnique(id, spv::Decoration::Flat);
} }
ctx.interfaces.push_back(id); ctx.interfaces.push_back(id);
return id; return id;
@ -171,7 +160,8 @@ u32 NumVertices(InputTopology input_topology) {
} }
Id DefineInput(EmitContext& ctx, Id type, bool per_invocation, Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
std::optional<spv::BuiltIn> builtin = std::nullopt) { std::optional<spv::BuiltIn> builtin = std::nullopt,
bool has_location = false) {
switch (ctx.stage) { switch (ctx.stage) {
case Stage::TessellationControl: case Stage::TessellationControl:
case Stage::TessellationEval: case Stage::TessellationEval:
@ -188,7 +178,7 @@ Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
default: default:
break; break;
} }
return DefineVariable(ctx, type, builtin, spv::StorageClass::Input); return DefineVariable(ctx, type, builtin, spv::StorageClass::Input, std::nullopt, has_location);
} }
Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations, Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
@ -215,7 +205,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invo
const u32 num_components{xfb_varying ? xfb_varying->components : remainder}; const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)}; const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); ctx.DecorateUnique(id, spv::Decoration::Location, static_cast<u32>(index));
if (element > 0) { if (element > 0) {
ctx.Decorate(id, spv::Decoration::Component, element); ctx.Decorate(id, spv::Decoration::Component, element);
} }
@ -490,6 +480,94 @@ Id DescType(EmitContext& ctx, Id sampled_type, Id pointer_type, u32 count) {
} }
} // Anonymous namespace } // Anonymous namespace
Id EmitContext::TypeArray(Id element_type, Id length) {
const Id array_type{Sirit::Module::TypeArray(element_type, length)};
array_element_types[array_type] = element_type;
type_integer_or_bool_cache[array_type] = IsIntegerOrBoolType(element_type);
return array_type;
}
Id EmitContext::TypeStruct(Id member) {
const std::array<Id, 1> members{member};
return TypeStruct(std::span<const Id>(members));
}
Id EmitContext::TypeStruct(std::span<const Id> members) {
const Id struct_type{Sirit::Module::TypeStruct(members)};
struct_member_types[struct_type] = std::vector<Id>(members.begin(), members.end());
const bool has_integer_member{
std::any_of(members.begin(), members.end(),
[this](Id member_type) { return IsIntegerOrBoolType(member_type); })};
type_integer_or_bool_cache[struct_type] = has_integer_member;
return struct_type;
}
Id EmitContext::TypeVector(Id element_type, u32 components) {
const Id vector_type{
Sirit::Module::TypeVector(element_type, static_cast<int>(components))};
type_integer_or_bool_cache[vector_type] = IsIntegerOrBoolType(element_type);
return vector_type;
}
bool EmitContext::HasDecoration(Id id, spv::Decoration decoration,
std::optional<u32> literal) const {
const auto list_it{decorations.find(id)};
if (list_it == decorations.end()) {
return false;
}
const auto& records{list_it->second};
return std::any_of(records.begin(), records.end(), [&](const DecorationRecord& record) {
if (record.decoration != decoration) {
return false;
}
if (!literal.has_value()) {
return true;
}
return record.literal.has_value() && record.literal.value() == literal.value();
});
}
void EmitContext::DecorateUnique(Id id, spv::Decoration decoration,
std::optional<u32> literal) {
if (decoration == spv::Decoration::Flat || decoration == spv::Decoration::NoPerspective) {
// SPIR-V only allows non-default interpolation decorations on user-defined inputs.
ASSERT_MSG(!HasDecoration(id, spv::Decoration::BuiltIn),
"Interpolation decoration applied to a BuiltIn");
}
if (HasDecoration(id, decoration, literal)) {
return;
}
decorations[id].emplace_back(DecorationRecord{decoration, literal});
if (literal.has_value()) {
Sirit::Module::Decorate(id, decoration, literal.value());
} else {
Sirit::Module::Decorate(id, decoration);
}
}
bool EmitContext::IsIntegerOrBoolType(Id type) {
if (const auto it = type_integer_or_bool_cache.find(type);
it != type_integer_or_bool_cache.end()) {
return it->second;
}
if (const auto array_it = array_element_types.find(type); array_it != array_element_types.end()) {
const bool result{IsIntegerOrBoolType(array_it->second)};
type_integer_or_bool_cache[type] = result;
return result;
}
if (const auto struct_it = struct_member_types.find(type);
struct_it != struct_member_types.end()) {
const bool result{std::any_of(struct_it->second.begin(), struct_it->second.end(),
[this](Id member_type) {
return IsIntegerOrBoolType(member_type);
})};
type_integer_or_bool_cache[type] = result;
return result;
}
type_integer_or_bool_cache[type] = false;
return false;
}
void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) { void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
defs[0] = sirit_ctx.Name(base_type, name); defs[0] = sirit_ctx.Name(base_type, name);
@ -577,11 +655,24 @@ Id EmitContext::BitOffset16(const IR::Value& offset) {
void EmitContext::DefineCommonTypes(const Info& info) { void EmitContext::DefineCommonTypes(const Info& info) {
void_id = TypeVoid(); void_id = TypeVoid();
const auto mark_vector_type = [this](VectorTypes& vectors, bool is_integer_or_bool) {
for (size_t components = 1; components <= 4; ++components) {
const Id type{vectors[components]};
if (type.value != 0) {
type_integer_or_bool_cache[type] = is_integer_or_bool;
}
}
};
U1 = Name(TypeBool(), "u1"); U1 = Name(TypeBool(), "u1");
type_integer_or_bool_cache[U1] = true;
F32.Define(*this, TypeFloat(32), "f32"); F32.Define(*this, TypeFloat(32), "f32");
mark_vector_type(F32, false);
U32.Define(*this, TypeInt(32, false), "u32"); U32.Define(*this, TypeInt(32, false), "u32");
mark_vector_type(U32, true);
S32.Define(*this, TypeInt(32, true), "s32"); S32.Define(*this, TypeInt(32, true), "s32");
mark_vector_type(S32, true);
private_u32 = Name(TypePointer(spv::StorageClass::Private, U32[1]), "private_u32"); private_u32 = Name(TypePointer(spv::StorageClass::Private, U32[1]), "private_u32");
@ -596,23 +687,30 @@ void EmitContext::DefineCommonTypes(const Info& info) {
AddCapability(spv::Capability::Int8); AddCapability(spv::Capability::Int8);
U8 = Name(TypeInt(8, false), "u8"); U8 = Name(TypeInt(8, false), "u8");
S8 = Name(TypeInt(8, true), "s8"); S8 = Name(TypeInt(8, true), "s8");
type_integer_or_bool_cache[U8] = true;
type_integer_or_bool_cache[S8] = true;
} }
if (info.uses_int16 && profile.support_int16) { if (info.uses_int16 && profile.support_int16) {
AddCapability(spv::Capability::Int16); AddCapability(spv::Capability::Int16);
U16 = Name(TypeInt(16, false), "u16"); U16 = Name(TypeInt(16, false), "u16");
S16 = Name(TypeInt(16, true), "s16"); S16 = Name(TypeInt(16, true), "s16");
type_integer_or_bool_cache[U16] = true;
type_integer_or_bool_cache[S16] = true;
} }
if (info.uses_int64 && profile.support_int64) { if (info.uses_int64 && profile.support_int64) {
AddCapability(spv::Capability::Int64); AddCapability(spv::Capability::Int64);
U64 = Name(TypeInt(64, false), "u64"); U64 = Name(TypeInt(64, false), "u64");
type_integer_or_bool_cache[U64] = true;
} }
if (info.uses_fp16) { if (info.uses_fp16) {
AddCapability(spv::Capability::Float16); AddCapability(spv::Capability::Float16);
F16.Define(*this, TypeFloat(16), "f16"); F16.Define(*this, TypeFloat(16), "f16");
mark_vector_type(F16, false);
} }
if (info.uses_fp64) { if (info.uses_fp64) {
AddCapability(spv::Capability::Float64); AddCapability(spv::Capability::Float64);
F64.Define(*this, TypeFloat(64), "f64"); F64.Define(*this, TypeFloat(64), "f64");
mark_vector_type(F64, false);
} }
} }
@ -1117,7 +1215,7 @@ void EmitContext::DefineRescalingInputUniformConstant() {
const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, F32[4])}; const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, F32[4])};
rescaling_uniform_constant = rescaling_uniform_constant =
AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant); AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant);
Decorate(rescaling_uniform_constant, spv::Decoration::Location, 0u); DecorateUnique(rescaling_uniform_constant, spv::Decoration::Location, 0u);
if (profile.supported_spirv >= 0x00010400) { if (profile.supported_spirv >= 0x00010400) {
interfaces.push_back(rescaling_uniform_constant); interfaces.push_back(rescaling_uniform_constant);
@ -1494,7 +1592,6 @@ void EmitContext::DefineInputs(const IR::Program& program) {
AddCapability(spv::Capability::GroupNonUniform); AddCapability(spv::Capability::GroupNonUniform);
subgroup_local_invocation_id = subgroup_local_invocation_id =
DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId); DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
Decorate(subgroup_local_invocation_id, spv::Decoration::Flat);
} }
if (info.uses_fswzadd) { if (info.uses_fswzadd) {
const Id f32_one{Const(1.0f)}; const Id f32_one{Const(1.0f)};
@ -1510,7 +1607,6 @@ void EmitContext::DefineInputs(const IR::Program& program) {
if (loads[IR::Attribute::Layer]) { if (loads[IR::Attribute::Layer]) {
AddCapability(spv::Capability::Geometry); AddCapability(spv::Capability::Geometry);
layer = DefineInput(*this, U32[1], false, spv::BuiltIn::Layer); layer = DefineInput(*this, U32[1], false, spv::BuiltIn::Layer);
Decorate(layer, spv::Decoration::Flat);
} }
if (loads.AnyComponent(IR::Attribute::PositionX)) { if (loads.AnyComponent(IR::Attribute::PositionX)) {
const bool is_fragment{stage == Stage::Fragment}; const bool is_fragment{stage == Stage::Fragment};
@ -1586,8 +1682,8 @@ void EmitContext::DefineInputs(const IR::Program& program) {
continue; continue;
} }
const Id type{GetAttributeType(*this, input_type)}; const Id type{GetAttributeType(*this, input_type)};
const Id id{DefineInput(*this, type, true)}; const Id id{DefineInput(*this, type, true, std::nullopt, true)};
Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); DecorateUnique(id, spv::Decoration::Location, static_cast<u32>(index));
Name(id, fmt::format("in_attr{}", index)); Name(id, fmt::format("in_attr{}", index));
input_generics[index] = GetAttributeInfo(*this, input_type, id); input_generics[index] = GetAttributeInfo(*this, input_type, id);
@ -1597,19 +1693,20 @@ void EmitContext::DefineInputs(const IR::Program& program) {
if (stage != Stage::Fragment) { if (stage != Stage::Fragment) {
continue; continue;
} }
if (RequiresFlatDecoration(*this, type, spv::StorageClass::Input)) { if (RequiresFlatDecoration(*this, type, spv::StorageClass::Input, false, true)) {
ASSERT_MSG(HasDecoration(id, spv::Decoration::Flat),
"Flat decoration missing on integer/bool user varying input");
continue; continue;
} }
switch (info.interpolation[index]) { switch (info.interpolation[index]) {
case Interpolation::Smooth: case Interpolation::Smooth:
// Default // Default interpolation per SPIR-V spec; no decoration emitted.
// Decorate(id, spv::Decoration::Smooth);
break; break;
case Interpolation::NoPerspective: case Interpolation::NoPerspective:
Decorate(id, spv::Decoration::NoPerspective); DecorateUnique(id, spv::Decoration::NoPerspective);
break; break;
case Interpolation::Flat: case Interpolation::Flat:
Decorate(id, spv::Decoration::Flat); DecorateUnique(id, spv::Decoration::Flat);
break; break;
} }
} }
@ -1618,9 +1715,9 @@ void EmitContext::DefineInputs(const IR::Program& program) {
if (!info.uses_patches[index]) { if (!info.uses_patches[index]) {
continue; continue;
} }
const Id id{DefineInput(*this, F32[4], false)}; const Id id{DefineInput(*this, F32[4], false, std::nullopt, true)};
Decorate(id, spv::Decoration::Patch); Decorate(id, spv::Decoration::Patch);
Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); DecorateUnique(id, spv::Decoration::Location, static_cast<u32>(index));
patches[index] = id; patches[index] = id;
} }
} }
@ -1697,7 +1794,7 @@ void EmitContext::DefineOutputs(const IR::Program& program) {
} }
const Id id{DefineOutput(*this, F32[4], std::nullopt)}; const Id id{DefineOutput(*this, F32[4], std::nullopt)};
Decorate(id, spv::Decoration::Patch); Decorate(id, spv::Decoration::Patch);
Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); DecorateUnique(id, spv::Decoration::Location, static_cast<u32>(index));
patches[index] = id; patches[index] = id;
} }
break; break;
@ -1707,17 +1804,19 @@ void EmitContext::DefineOutputs(const IR::Program& program) {
continue; continue;
} }
frag_color[index] = DefineOutput(*this, F32[4], std::nullopt); frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
Decorate(frag_color[index], spv::Decoration::Location, index); DecorateUnique(frag_color[index], spv::Decoration::Location, index);
Name(frag_color[index], fmt::format("frag_color{}", index)); Name(frag_color[index], fmt::format("frag_color{}", index));
} }
if (info.stores_frag_depth) { if (info.stores_frag_depth) {
frag_depth = DefineOutput(*this, F32[1], std::nullopt); frag_depth = DefineOutput(*this, F32[1], std::nullopt);
Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth); DecorateUnique(frag_depth, spv::Decoration::BuiltIn,
static_cast<u32>(spv::BuiltIn::FragDepth));
} }
if (info.stores_sample_mask) { if (info.stores_sample_mask) {
const Id array_type{TypeArray(U32[1], Const(1U))}; const Id array_type{TypeArray(U32[1], Const(1U))};
sample_mask = DefineOutput(*this, array_type, std::nullopt); sample_mask = DefineOutput(*this, array_type, std::nullopt);
Decorate(sample_mask, spv::Decoration::BuiltIn, spv::BuiltIn::SampleMask); DecorateUnique(sample_mask, spv::Decoration::BuiltIn,
static_cast<u32>(spv::BuiltIn::SampleMask));
} }
break; break;
default: default:

View file

@ -4,6 +4,11 @@
#pragma once #pragma once
#include <array> #include <array>
#include <bitset>
#include <optional>
#include <span>
#include <unordered_map>
#include <vector>
#include <sirit/sirit.h> #include <sirit/sirit.h>
@ -19,6 +24,23 @@ static std::bitset<8> clip_distance_written;
using Sirit::Id; using Sirit::Id;
struct DecorationRecord {
spv::Decoration decoration;
std::optional<u32> literal;
};
struct IdHash {
std::size_t operator()(const Id& id) const noexcept {
return std::hash<u32>{}(id.value);
}
};
struct IdEqual {
bool operator()(const Id& lhs, const Id& rhs) const noexcept {
return lhs.value == rhs.value;
}
};
class VectorTypes { class VectorTypes {
public: public:
void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name); void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name);
@ -204,6 +226,23 @@ public:
return Constant(F32[1], value); return Constant(F32[1], value);
} }
Id TypeArray(Id element_type, Id length);
Id TypeStruct(Id member);
Id TypeStruct(std::span<const Id> members);
Id TypeVector(Id element_type, u32 components);
template <typename... Members>
Id TypeStruct(Id first, Members... rest) {
const std::array<Id, sizeof...(rest) + 1> members{first, rest...};
return TypeStruct(std::span<const Id>(members));
}
[[nodiscard]] bool HasDecoration(Id id, spv::Decoration decoration,
std::optional<u32> literal = std::nullopt) const;
void DecorateUnique(Id id, spv::Decoration decoration,
std::optional<u32> literal = std::nullopt);
bool IsIntegerOrBoolType(Id type);
const Profile& profile; const Profile& profile;
const RuntimeInfo& runtime_info; const RuntimeInfo& runtime_info;
Stage stage{}; Stage stage{};
@ -361,6 +400,11 @@ public:
Id load_const_func_u32x2{}; Id load_const_func_u32x2{};
Id load_const_func_u32x4{}; Id load_const_func_u32x4{};
std::unordered_map<Id, std::vector<DecorationRecord>, IdHash, IdEqual> decorations;
std::unordered_map<Id, bool, IdHash, IdEqual> type_integer_or_bool_cache;
std::unordered_map<Id, Id, IdHash, IdEqual> array_element_types;
std::unordered_map<Id, std::vector<Id>, IdHash, IdEqual> struct_member_types;
private: private:
void DefineCommonTypes(const Info& info); void DefineCommonTypes(const Info& info);
void DefineCommonConstants(); void DefineCommonConstants();