Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DXIL] Add support for root signature flag element in DXContainer #123147

Open
wants to merge 13 commits into
base: users/joaosaffran/122396
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ add_llvm_target(DirectXCodeGen
DXILResourceAccess.cpp
DXILShaderFlags.cpp
DXILTranslateMetadata.cpp

DXILRootSignature.cpp
LINK_COMPONENTS
Analysis
AsmPrinter
Expand Down
26 changes: 26 additions & 0 deletions llvm/lib/Target/DirectX/DXContainerGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "DXILRootSignature.h"
#include "DXILShaderFlags.h"
#include "DirectX.h"
#include "llvm/ADT/SmallVector.h"
Expand All @@ -23,6 +24,7 @@
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/MC/DXContainerPSVInfo.h"
#include "llvm/MC/DXContainerRootSignature.h"
#include "llvm/Pass.h"
#include "llvm/Support/MD5.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
Expand All @@ -41,6 +43,7 @@ class DXContainerGlobals : public llvm::ModulePass {
GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
StringRef SectionName);
void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
void addPipelineStateValidationInfo(Module &M,
SmallVector<GlobalValue *> &Globals);
Expand All @@ -60,6 +63,7 @@ class DXContainerGlobals : public llvm::ModulePass {
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesAll();
AU.addRequired<ShaderFlagsAnalysisWrapper>();
AU.addRequired<RootSignatureAnalysisWrapper>();
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
AU.addRequired<DXILResourceTypeWrapperPass>();
AU.addRequired<DXILResourceBindingWrapperPass>();
Expand All @@ -73,6 +77,7 @@ bool DXContainerGlobals::runOnModule(Module &M) {
Globals.push_back(getFeatureFlags(M));
Globals.push_back(computeShaderHash(M));
addSignature(M, Globals);
addRootSignature(M, Globals);
addPipelineStateValidationInfo(M, Globals);
appendToCompilerUsed(M, Globals);
return true;
Expand Down Expand Up @@ -144,6 +149,27 @@ void DXContainerGlobals::addSignature(Module &M,
Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
}

void DXContainerGlobals::addRootSignature(Module &M,
SmallVector<GlobalValue *> &Globals) {

std::optional<ModuleRootSignature> MRS =
getAnalysis<RootSignatureAnalysisWrapper>().getRootSignature();
if (!MRS.has_value())
return;

SmallString<256> Data;
raw_svector_ostream OS(Data);

RootSignatureHeader RSH;
RSH.Flags = MRS->Flags;

RSH.write(OS);

Constant *Constant =
ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
}

void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
const DXILBindingMap &DBM =
getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();
Expand Down
158 changes: 158 additions & 0 deletions llvm/lib/Target/DirectX/DXILRootSignature.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//===- DXILRootSignature.cpp - DXIL Root Signature helper objects ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helper objects and APIs for working with DXIL
/// Root Signatures.
///
//===----------------------------------------------------------------------===//
#include "DXILRootSignature.h"
#include "DirectX.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Module.h"
#include <cstdint>

using namespace llvm;
using namespace llvm::dxil;

static bool reportError(Twine Message) {
report_fatal_error(Message, false);
return true;
}

static bool parseRootFlags(ModuleRootSignature *MRS, MDNode *RootFlagNode) {

if (RootFlagNode->getNumOperands() != 2)
return reportError("Invalid format for RootFlag Element");

auto *Flag = mdconst::extract<ConstantInt>(RootFlagNode->getOperand(1));
uint32_t Value = Flag->getZExtValue();

// Root Element validation, as specified:
// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#validations-during-dxil-generation
if ((Value & ~0x80000fff) != 0)
return reportError("Invalid flag value for RootFlag");

MRS->Flags = Value;
return false;
}

static bool parseRootSignatureElement(ModuleRootSignature *MRS,
MDNode *Element) {
MDString *ElementText = cast<MDString>(Element->getOperand(0));
if (ElementText == nullptr)
return reportError("Invalid format for Root Element");

RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(ElementText->getString())
.Case("RootFlags", RootSignatureElementKind::RootFlags)
.Case("RootConstants", RootSignatureElementKind::RootConstants)
.Case("RootCBV", RootSignatureElementKind::RootDescriptor)
.Case("RootSRV", RootSignatureElementKind::RootDescriptor)
.Case("RootUAV", RootSignatureElementKind::RootDescriptor)
.Case("Sampler", RootSignatureElementKind::RootDescriptor)
.Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
.Case("StaticSampler", RootSignatureElementKind::StaticSampler)
.Default(RootSignatureElementKind::None);

switch (ElementKind) {

case RootSignatureElementKind::RootFlags: {
return parseRootFlags(MRS, Element);
break;
}

case RootSignatureElementKind::RootConstants:
case RootSignatureElementKind::RootDescriptor:
case RootSignatureElementKind::DescriptorTable:
case RootSignatureElementKind::StaticSampler:
case RootSignatureElementKind::None:
return reportError("Invalid Root Element: " + ElementText->getString());
break;
}

return true;
}

bool ModuleRootSignature::parse(NamedMDNode *Root) {
bool HasError = false;

/** Root Signature are specified as following in the metadata:
!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3 } ; function, root signature
!3 = !{ !4, !5, !6, !7 } ; list of root signature elements
So for each MDNode inside dx.rootsignatures NamedMDNode
(the Root parameter of this function), the parsing process needs
to loop through each of it's operand and process the pairs function
signature pair.
*/

for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) {
MDNode *Node = dyn_cast<MDNode>(Root->getOperand(Sid));
Comment on lines +98 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (unsigned int Sid = 0; Sid < Root->getNumOperands(); Sid++) {
MDNode *Node = dyn_cast<MDNode>(Root->getOperand(Sid));
for (const MDNode *Node: Root->operands()) {


if (Node == nullptr || Node->getNumOperands() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    if (Node == nullptr || Node->getNumOperands() != 2)

Can Node really ever be nullptr? I'm not sure it can.

return reportError("Invalid format for Root Signature Definition. Pairs "
"of function, root signature expected.");

// Get the Root Signature Description from the function signature pair.
MDNode *RS = dyn_cast<MDNode>(Node->getOperand(1).get());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    MDNode *RS = dyn_cast<MDNode>(Node->getOperand(1).get());

If I change this dyn_cast to cast then the tests still pass. Test hole? Or maybe cast is sufficient here?

I'm not sure how to construct a metadata entry here that isn't an MDNode so maybe it just isn't possible to generate something that isn't an MDNode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have MDString, instead, like here: https://github.com/llvm/llvm-project/pull/123147/files#diff-a24a6166cc67ad1172e3dccc299cdb3f4533033c5e02f8affd7f54ec3cd82608. Not sure if this case might happen, I know that some other tools might modify this metadata, but I am not aware of the extent of such changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but that's still an MDString inside an MDNode. I couldn't find a way to have an operand of an NamedMDNode be anything other than an MDNode.

This, for example, is not valid:

!2 = !"RootFlags"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs https://llvm.org/docs/LangRef.html#named-metadata, this is correct. I will remove the test and the check


if (RS == nullptr)
return reportError("Missing Root Signature Metadata node.");

// Loop through the Root Elements of the root signature.
for (unsigned int Eid = 0; Eid < RS->getNumOperands(); Eid++) {

MDNode *Element = dyn_cast<MDNode>(RS->getOperand(Eid));
if (Element == nullptr)
return reportError("Missing Root Element Metadata Node.");

HasError = HasError || parseRootSignatureElement(this, Element);
}
}
return HasError;
}

ModuleRootSignature ModuleRootSignature::analyzeModule(Module &M) {
ModuleRootSignature MRS;

NamedMDNode *RootSignatureNode = M.getNamedMetadata("dx.rootsignatures");
if (RootSignatureNode) {
if (MRS.parse(RootSignatureNode))
llvm_unreachable("Invalid Root Signature Metadata.");
}

return MRS;
}

AnalysisKey RootSignatureAnalysis::Key;

ModuleRootSignature RootSignatureAnalysis::run(Module &M,
ModuleAnalysisManager &AM) {
return ModuleRootSignature::analyzeModule(M);
}

//===----------------------------------------------------------------------===//
bool RootSignatureAnalysisWrapper::runOnModule(Module &M) {

this->MRS = MRS = ModuleRootSignature::analyzeModule(M);

return false;
}

void RootSignatureAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
}

char RootSignatureAnalysisWrapper::ID = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's ID for, and is 0 really the correct value for it? It looks like this is just passed by value to the base class? Does it need to be a non-const static variable at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used by LLVM legacy pass manager to uniquely identity the pass. The reason why the ID is often 0 is that this does not require a unique ID for it's functionality.


INITIALIZE_PASS(RootSignatureAnalysisWrapper, "dx-root-signature-analysis",
"DXIL Root Signature Analysis", true, true)
74 changes: 74 additions & 0 deletions llvm/lib/Target/DirectX/DXILRootSignature.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- DXILRootSignature.h - DXIL Root Signature helper objects
//---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file contains helper objects and APIs for working with DXIL
/// Root Signatures.
///
//===----------------------------------------------------------------------===//

#include "llvm/IR/Metadata.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include <optional>

namespace llvm {
namespace dxil {

enum class RootSignatureElementKind {
None = 0,
RootFlags = 1,
RootConstants = 2,
RootDescriptor = 3,
DescriptorTable = 4,
StaticSampler = 5
Comment on lines +26 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we shouldn't add these until they're actually implemented:

  RootConstants = 2,
  RootDescriptor = 3,
  DescriptorTable = 4,
  StaticSampler = 5

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other places in the code where we specify all possible options and add llvm_unreachable to handle any unimplemented options. I find it helpful to identify places where I should edit the code in future PRs. That is why I did this way

};

struct ModuleRootSignature {
uint32_t Flags;

ModuleRootSignature() = default;

bool parse(NamedMDNode *Root);

static ModuleRootSignature analyzeModule(Module &M);
};

class RootSignatureAnalysis : public AnalysisInfoMixin<RootSignatureAnalysis> {
friend AnalysisInfoMixin<RootSignatureAnalysis>;
static AnalysisKey Key;

public:
RootSignatureAnalysis() = default;

using Result = ModuleRootSignature;

ModuleRootSignature run(Module &M, ModuleAnalysisManager &AM);
};

/// Wrapper pass for the legacy pass manager.
///
/// This is required because the passes that will depend on this are codegen
/// passes which run through the legacy pass manager.
class RootSignatureAnalysisWrapper : public ModulePass {
std::optional<ModuleRootSignature> MRS;

public:
static char ID;

RootSignatureAnalysisWrapper() : ModulePass(ID) {}

const std::optional<ModuleRootSignature> &getRootSignature() { return MRS; }

bool runOnModule(Module &M) override;

void getAnalysisUsage(AnalysisUsage &AU) const override;
};

} // namespace dxil
} // namespace llvm
3 changes: 3 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void initializeDXILPrettyPrinterLegacyPass(PassRegistry &);
/// Initializer for dxil::ShaderFlagsAnalysisWrapper pass.
void initializeShaderFlagsAnalysisWrapperPass(PassRegistry &);

/// Initializer for dxil::RootSignatureAnalysisWrapper pass.
void initializeRootSignatureAnalysisWrapperPass(PassRegistry &);

/// Initializer for DXContainerGlobals pass.
void initializeDXContainerGlobalsPass(PassRegistry &);

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeDXILTranslateMetadataLegacyPass(*PR);
initializeDXILResourceMDWrapperPass(*PR);
initializeShaderFlagsAnalysisWrapperPass(*PR);
initializeRootSignatureAnalysisWrapperPass(*PR);
initializeDXILFinalizeLinkageLegacyPass(*PR);
}

Expand Down
17 changes: 17 additions & 0 deletions llvm/test/CodeGen/DirectX/ContainerData/RootSignature-Error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
; RUN: not llc %s --filetype=obj -o - 2>&1 | FileCheck %s

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: LLVM ERROR: Invalid format for Root Signature Definition. Pairs of function, root signature expected.


define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


!dx.rootsignatures = !{!1} ; list of function/root signature pairs
!1= !{ !"RootFlags" } ; function, root signature
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: not llc %s --filetype=obj -o - 2>&1 | FileCheck %s

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: LLVM ERROR: Invalid Root Element: NOTRootFlags


define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3 } ; function, root signature
!3 = !{ !4 } ; list of root signature elements
!4 = !{ !"NOTRootFlags", i32 1 } ; 1 = allow_input_assembler_input_layout
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: not llc %s --filetype=obj -o - 2>&1 | FileCheck %s

target triple = "dxil-unknown-shadermodel6.0-compute"

; CHECK: LLVM ERROR: Invalid flag value for RootFlag


define void @main() #0 {
entry:
ret void
}

attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


!dx.rootsignatures = !{!2} ; list of function/root signature pairs
!2 = !{ ptr @main, !3 } ; function, root signature
!3 = !{ !4 } ; list of root signature elements
!4 = !{ !"RootFlags", i32 2147487744 } ; 1 = allow_input_assembler_input_layout
Loading