Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] resolve the size of override-sized arrays in backends #6787

Open
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ pub enum Error {
/// [`crate::Sampling::First`] is unsupported.
#[error("`{:?}` sampling is unsupported", crate::Sampling::First)]
FirstSamplingNotSupported,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

/// Binary operation with a different logic on the GLSL side.
Expand Down Expand Up @@ -976,13 +978,12 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, "[")?;

// Write the array size
// Writes nothing if `ArraySize::Dynamic`
match size {
crate::ArraySize::Constant(size) => {
// Writes nothing if `ResolvedSize::Runtime`
match proc::resolve_array_size(size, self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => (),
proc::ResolvedSize::Runtime => (),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -4519,13 +4520,9 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, ")")?;
}
TypeInner::Array { base, size, .. } => {
let count = match size
.to_indexable_length(self.module)
.expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => return Ok(()),
let count = match proc::resolve_array_size(size, self.module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => return Ok(()),
};
self.write_type(base)?;
self.write_array_size(base, size)?;
Expand Down
19 changes: 9 additions & 10 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ impl crate::TypeInner {
}
}

pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 {
pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> Result<u32, Error> {
match *self {
Self::Matrix {
columns,
Expand All @@ -62,19 +62,18 @@ impl crate::TypeInner {
} => {
let stride = Alignment::from(rows) * scalar.width as u32;
let last_row_size = rows as u32 * scalar.width as u32;
((columns as u32 - 1) * stride) + last_row_size
Ok(((columns as u32 - 1) * stride) + last_row_size)
}
Self::Array { base, size, stride } => {
let count = match size {
crate::ArraySize::Constant(size) => size.get(),
// A dynamically-sized array has to have at least one element
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => 1,
let count = match crate::proc::resolve_array_size(size, gctx)? {
crate::proc::ResolvedSize::Constant(size) => size,
// A runtime-sized array has to have at least one element
crate::proc::ResolvedSize::Runtime => 1,
};
let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
((count - 1) * stride) + last_el_size
let last_el_size = gctx.types[base].inner.size_hlsl(gctx)?;
Ok(((count - 1) * stride) + last_el_size)
}
_ => self.size(gctx),
_ => Ok(self.size(gctx)),
}
}

Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ pub enum Error {
Custom(String),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
10 changes: 4 additions & 6 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,12 +1074,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
) -> BackendResult {
write!(self.out, "[")?;

match size {
crate::ArraySize::Constant(size) => {
match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => unreachable!(),
proc::ResolvedSize::Runtime => unreachable!(),
}

write!(self.out, "]")?;
Expand Down Expand Up @@ -1124,7 +1123,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
let ty_inner = &module.types[member.ty].inner;
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx());
last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx())?;

// The indentation is only for readability
write!(self.out, "{}", back::INDENT)?;
Expand Down Expand Up @@ -2822,7 +2821,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => unreachable!(),
}
write!(self.out, ")")?;
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub enum Error {
Override,
#[error("bitcasting to {0:?} is not supported")]
UnsupportedBitCast(crate::TypeInner),
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Clone, Debug, PartialEq, thiserror::Error)]
Expand Down
24 changes: 8 additions & 16 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use super::{sampler as sm, Error, LocationMode, Options, PipelineOptions, Transl
use crate::{
arena::{Handle, HandleSet},
back::{self, Baked},
proc::index,
proc::{self, NameKey, TypeResolution},
proc::{self, index, NameKey, TypeResolution},
valid, FastHashMap, FastHashSet,
};
#[cfg(test)]
Expand Down Expand Up @@ -2555,7 +2554,6 @@ impl<W: Write> Writer<W> {
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Expand Down Expand Up @@ -2692,7 +2690,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
let accessing_wrapped_array = match *base_ty {
crate::TypeInner::Array {
size: crate::ArraySize::Constant(_),
size: crate::ArraySize::Constant(_) | crate::ArraySize::Pending(_),
..
} => true,
_ => false,
Expand Down Expand Up @@ -2720,7 +2718,6 @@ impl<W: Write> Writer<W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
Expand Down Expand Up @@ -3911,8 +3908,8 @@ impl<W: Write> Writer<W> {
first_time: false,
};

match size {
crate::ArraySize::Constant(size) => {
match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => {
writeln!(self.out, "struct {name} {{")?;
writeln!(
self.out,
Expand All @@ -3924,10 +3921,7 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Pending(_) => {
unreachable!()
}
crate::ArraySize::Dynamic => {
proc::ResolvedSize::Runtime => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
}
Expand Down Expand Up @@ -6321,11 +6315,9 @@ mod workgroup_mem_init {
writeln!(self.out, ", 0, {NAMESPACE}::memory_order_relaxed);")?;
}
crate::TypeInner::Array { base, size, .. } => {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => unreachable!(),
let count = match proc::resolve_array_size(size, module.to_ctx())? {
proc::ResolvedSize::Constant(size) => size,
proc::ResolvedSize::Runtime => unreachable!(),
};

access_stack.enter_array(|access_stack, array_depth| {
Expand Down
113 changes: 62 additions & 51 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Span, Statement, TypeInner, WithSpan,
Span, Statement, Type, TypeInner, UniqueArena, WithSpan,
};
use std::{borrow::Cow, collections::HashSet, mem};
use thiserror::Error;
Expand All @@ -29,19 +29,29 @@ pub enum PipelineConstantError {
NegativeWorkgroupSize,
}

/// Replace all overrides in `module` with constants.
/// Replace all overrides in `module` with fully-evaluated constant expressions.
///
/// If no changes are needed, this just returns `Cow::Borrowed`
/// references to `module` and `module_info`. Otherwise, it clones
/// `module`, edits its [`global_expressions`] arena to contain only
/// fully-evaluated expressions, and returns `Cow::Owned` values
/// holding the simplified module and its validation results.
/// Given `pipeline_constants`, providing values for all overrides in
/// `module`:
///
/// In either case, the module returned has an empty `overrides`
/// arena, and the `global_expressions` arena contains only
/// fully-evaluated expressions.
/// - Replace all [`Override`] expressions with fully-evaluated
/// constant expressions.
///
/// [`global_expressions`]: Module::global_expressions
/// - Replace all [`Override`][paso] array sizes with [`Expression`]
/// array sizes, referring to fully-evaluated constant expressions.
///
/// - Empty out the `module.overrides` arena.
///
/// Although the above is described in terms of changes to `module`'s
/// contents, this function only actually has shared access to
/// `module`. When changes are needed, this function clones `module`
/// and returns a [`Cow::Owned`] value. If no changes are needed, this
/// function returns a [`Cow::Borrowed`] value that just passes along
/// the original reference.
///
/// [`Override`]: Expression::Override
/// [paso]: crate::PendingArraySize::Override
/// [`Expression`]: crate::PendingArraySize::Expression
pub fn process_overrides<'a>(
module: &'a Module,
module_info: &'a ModuleInfo,
Expand All @@ -51,6 +61,7 @@ pub fn process_overrides<'a>(
return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
}

let original_module_types = &module.types;
let mut module = module.clone();

// A map from override handles to the handles of the constants
Expand Down Expand Up @@ -196,7 +207,12 @@ pub fn process_overrides<'a>(
}
module.entry_points = entry_points;

process_pending(&mut module, &override_map, &adjusted_global_expressions)?;
process_pending(
&mut module,
original_module_types,
&override_map,
&adjusted_global_expressions,
);

// Now that we've rewritten all the expressions, we need to
// recompute their types and other metadata. For the time being,
Expand All @@ -209,60 +225,55 @@ pub fn process_overrides<'a>(

fn process_pending(
module: &mut Module,
original_module_types: &UniqueArena<Type>,
override_map: &HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
) -> Result<(), PipelineConstantError> {
for (handle, ty) in module.types.clone().iter() {
) {
for (handle, ty) in original_module_types.iter() {
if let TypeInner::Array {
base,
size: crate::ArraySize::Pending(size),
stride,
} = ty.inner
{
let expr = match size {
match size {
crate::PendingArraySize::Expression(size_expr) => {
adjusted_global_expressions[size_expr]
let expr = adjusted_global_expressions[size_expr];
if expr != size_expr {
module.types.replace(
handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
}
crate::PendingArraySize::Override(size_override) => {
module.constants[override_map[size_override]].init
let expr = module.constants[override_map[size_override]].init;
module.types.replace(
handle,
Type {
name: ty.name.clone(),
inner: TypeInner::Array {
base,
size: crate::ArraySize::Pending(
crate::PendingArraySize::Expression(expr),
),
stride,
},
},
);
}
};
let value = module
.to_ctx()
.eval_expr_to_u32(expr)
.map(|n| {
if n == 0 {
Err(PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(
module.global_expressions.get_span(expr),
"evaluated to zero",
),
))
} else {
Ok(std::num::NonZeroU32::new(n).unwrap())
}
})
.map_err(|_| {
PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(module.global_expressions.get_span(expr), "negative"),
)
})??;
module.types.replace(
handle,
crate::Type {
name: None,
inner: TypeInner::Array {
base,
size: crate::ArraySize::Constant(value),
stride,
},
},
);
}
}
Ok(())
}

fn process_workgroup_size_override(
Expand Down
7 changes: 3 additions & 4 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,12 @@ impl BlockContext<'_> {
block: &mut Block,
) -> Result<MaybeKnown<u32>, Error> {
let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
match sequence_ty.indexable_length(self.ir_module) {
match sequence_ty
.indexable_length(self.ir_module, crate::ArraySize::indexable_length_resolved)
{
Ok(crate::proc::IndexableLength::Known(known_length)) => {
Ok(MaybeKnown::Known(known_length))
}
Ok(crate::proc::IndexableLength::Pending) => {
unreachable!()
}
Ok(crate::proc::IndexableLength::Dynamic) => {
let length_id = self.write_runtime_array_length(sequence, block)?;
Ok(MaybeKnown::Computed(length_id))
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub enum Error {
Validation(&'static str),
#[error("overrides should not be present at this stage")]
Override,
#[error(transparent)]
ResolveArraySizeError(#[from] crate::proc::ResolveArraySizeError),
}

#[derive(Default)]
Expand Down
Loading
Loading