diff options
| author | Stefan Weigl-Bosker <stefan@s00.xyz> | 2026-01-20 11:10:38 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2026-01-20 11:10:38 -0500 |
| commit | c5b2905c5a64433f8519531a77d3acc42d881f17 (patch) | |
| tree | 8d4d555c057b2ca00adab68797a9b814ad5c8891 | |
| parent | 8d40f659fabdba2d6a17228f76168e7bdbf5c955 (diff) | |
| download | compiler-c5b2905c5a64433f8519531a77d3acc42d881f17.tar.gz | |
[willow]: finish verifier (#7)
| -rw-r--r-- | willow/include/willow/IR/BasicBlock.h | 20 | ||||
| -rw-r--r-- | willow/include/willow/IR/Constant.h | 6 | ||||
| -rw-r--r-- | willow/include/willow/IR/Context.h | 8 | ||||
| -rw-r--r-- | willow/include/willow/IR/Function.h | 19 | ||||
| -rw-r--r-- | willow/include/willow/IR/Instruction.h | 2 | ||||
| -rw-r--r-- | willow/include/willow/IR/Instructions.h | 132 | ||||
| -rw-r--r-- | willow/include/willow/IR/TypeContext.h | 25 | ||||
| -rw-r--r-- | willow/include/willow/IR/Types.h | 16 | ||||
| -rw-r--r-- | willow/include/willow/IR/Value.h | 1 | ||||
| -rw-r--r-- | willow/include/willow/IR/Verifier.h | 14 | ||||
| -rw-r--r-- | willow/lib/IR/Verifier.cpp | 323 | ||||
| -rw-r--r-- | willow/tools/willowc/BUILD.bazel | 0 |
12 files changed, 479 insertions, 87 deletions
diff --git a/willow/include/willow/IR/BasicBlock.h b/willow/include/willow/IR/BasicBlock.h index 753c04a..c18dfc2 100644 --- a/willow/include/willow/IR/BasicBlock.h +++ b/willow/include/willow/IR/BasicBlock.h @@ -24,6 +24,7 @@ class BasicBlock : public Value { std::optional<Location> loc; std::list<std::unique_ptr<Instruction>> body; + std::unordered_map<BasicBlock *, size_t> predecessors; public: // ~BasicBlock() = TODO @@ -65,6 +66,25 @@ public: } std::optional<Location> getLoc() const { return loc; } + + std::unordered_map<BasicBlock*, size_t>& preds() { return predecessors; } + const std::unordered_map<BasicBlock*, size_t>& preds() const { return predecessors; } + + inline void addPred(BasicBlock *bb) { + auto [it, inserted] = predecessors.try_emplace(bb, 1); + + if (!inserted) + it->second += 1; + } + + inline void delPred(BasicBlock *bb) { + auto it = preds().find(bb); + + it->second -= 1; + + if (it->second <= 0) + preds().erase(it); + } }; Instruction *BasicBlock::trailer() { diff --git a/willow/include/willow/IR/Constant.h b/willow/include/willow/IR/Constant.h index 4476ac6..d93d65d 100644 --- a/willow/include/willow/IR/Constant.h +++ b/willow/include/willow/IR/Constant.h @@ -9,7 +9,6 @@ enum class ConstantKind { Int, //< Integer value with known bits Undef, //< Known undef Poison, //< Known poison - Label, //< Known reference to a BasicBlock }; class Constant : public Value { @@ -88,11 +87,6 @@ public: explicit PoisonVal(Type ty) : Constant(ty, ConstantKind::Poison) {} }; -class BlockRef final : public Constant { -public: - explicit BlockRef(Type ty, Block *b) : Constant(ty, ConstantKind::Label) {} -}; - } // namespace willow #endif // WILLOW_INCLUDE_IR_CONSTANT_H diff --git a/willow/include/willow/IR/Context.h b/willow/include/willow/IR/Context.h index 4f787ef..3f4f2ba 100644 --- a/willow/include/willow/IR/Context.h +++ b/willow/include/willow/IR/Context.h @@ -1,6 +1,7 @@ #ifndef WILLOW_INCLUDE_IR_CONTEXT_H #define WILLOW_INCLUDE_IR_CONTEXT_H +#include <willow/IR/ConstantPool.h> #include <willow/IR/Module.h> #include <willow/IR/TypeContext.h> @@ -8,9 +9,14 @@ namespace willow { /// The global context. Contains all shared state. class WillowContext { - TypeContext ctx; + TypeContext typectx; + ConstantPool constantpool; std::vector<std::unique_ptr<Module>> modules; + +public: + TypeContext &types() { return typectx; } + ConstantPool &constants() { return constantpool; } }; } // namespace willow diff --git a/willow/include/willow/IR/Function.h b/willow/include/willow/IR/Function.h index dd11553..4be4de7 100644 --- a/willow/include/willow/IR/Function.h +++ b/willow/include/willow/IR/Function.h @@ -30,9 +30,9 @@ public: /// This should usually not be called directly. /// \param name Identifier. /// \param parent Owning module. - /// \param ty Signature. Should be + /// \param fty Signature. Should be a function type /// \param params List of named parameters to the function. Should match \p - /// ty. + /// fty. Function(std::string name, Module *parent, Type fty, std::vector<std::unique_ptr<Parameter>> params) : Value(ValueKind::Function, std::move(name), fty), parent(parent), @@ -61,6 +61,21 @@ public: [](auto &p) -> const BasicBlock & { return *p; }); } + auto getParams() { + return params | + std::views::transform([](auto &p) -> Parameter & { return *p; }); + } + + auto getParams() const { + return params | std::views::transform( + [](auto &p) -> const Parameter & { return *p; }); + } + + Parameter *getParam(std::size_t idx) { return params[idx].get(); } + const Parameter *getParam(std::size_t idx) const { return params[idx].get(); } + + std::size_t getNumParams() const { return params.size(); } + /// \return The SSA values that exist in this block. auto getValues() { return blocks | diff --git a/willow/include/willow/IR/Instruction.h b/willow/include/willow/IR/Instruction.h index 17981ea..a74c409 100644 --- a/willow/include/willow/IR/Instruction.h +++ b/willow/include/willow/IR/Instruction.h @@ -63,7 +63,7 @@ public: Jmp, ///< goto %0 Br, ///< goto (%0 ? %1 : %2) - Call, ///< call %0 <args> + Call, ///< call %0 ... Ret, ///< ret val? Phi, ///< phi ^label1 %val1 ^label2 %val2 ... diff --git a/willow/include/willow/IR/Instructions.h b/willow/include/willow/IR/Instructions.h index d1d41fa..569c372 100644 --- a/willow/include/willow/IR/Instructions.h +++ b/willow/include/willow/IR/Instructions.h @@ -42,127 +42,215 @@ public: const Value *getRHS() const { return getOperand(1); } }; +/// Add two integers. class AddInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The first integer + /// \p lhs The second integer AddInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Add, type, lhs, rhs, loc) {} }; +/// Multiply two integers. class MulInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The first integer + /// \p lhs The second integer MulInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Mul, type, lhs, rhs, loc) {} }; +/// Subtract two integers. class SubInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The first integer + /// \p lhs The second integer SubInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Sub, type, lhs, rhs, loc) {} }; +/// Divide two integers. class DivInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The dividend + /// \p lhs The divisor DivInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Div, type, lhs, rhs, loc) {} }; +/// Compute the modulus of two integers. class ModInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The dividend + /// \p lhs The divisor ModInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Mod, type, lhs, rhs, loc) {} }; +/// Compute the bitwise left shift of an integer class ShlInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The integer to be shifted + /// \p lhs The shift amount ShlInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Shl, type, lhs, rhs, loc) {} }; +/// Compute the bitwise right shift of an integer class ShrInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The integer to be shifted + /// \p lhs The shift amount ShrInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Shr, type, lhs, rhs, loc) {} }; +/// Compute the arithmetic left shift of an integer (preserves sign) class AshlInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The integer to be shifted + /// \p lhs The shift amount AshlInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Ashl, type, lhs, rhs, loc) {} }; +/// Compute the arithmetic right shift of an integer (preserves sign) class AshrInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an integer type. + /// \p lhs The integer to be shifted + /// \p lhs The shift amount AshrInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Ashr, type, lhs, rhs, loc) {} }; +/// Test two integers for equality. class EqInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first integer + /// \p lhs The second integer EqInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Eq, type, lhs, rhs, loc) {} }; +/// Test if one integer is less than another. class LtInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first integer + /// \p lhs The second integer LtInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Lt, type, lhs, rhs, loc) {} }; +// Test if one integer is greater than another. class GtInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first integer + /// \p lhs The second integer GtInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Gt, type, lhs, rhs, loc) {} }; +// Test if one integer is less than or equal to another. class LeInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first integer + /// \p lhs The second integer LeInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Le, type, lhs, rhs, loc) {} }; +// Test if one integer is greater than or equal to another. class GeInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first integer + /// \p lhs The second integer GeInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Ge, type, lhs, rhs, loc) {} }; +/// Compute the logical and of two boolean values. class AndInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first value + /// \p lhs The second value AndInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::And, type, lhs, rhs, loc) {} }; +/// Compute the logical and of two boolean values. class OrInst : public BinaryInst { public: + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first value + /// \p lhs The second value OrInst(std::string name, Type type, Value *lhs, Value *rhs, std::optional<Location> loc = std::nullopt) : BinaryInst(std::move(name), Opcode::Or, type, lhs, rhs, loc) {} }; +/// Compute the logical negation of a boolean value. class NotInst : public UnaryInst { + /// \p name The name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. Must be an i1. + /// \p lhs The first value + /// \p lhs The second value public: explicit NotInst(std::string name, Type type, Value *value, std::optional<Location> loc = std::nullopt) : UnaryInst(std::move(name), Opcode::Not, type, value, loc) {} }; +/// Jump to another basic block. class JmpInst : public UnaryInst { public: + /// \p type Must be a TypeID::BasicBlock + /// \p target Target of the jump JmpInst(Type type, BasicBlock *target, std::optional<Location> loc = std::nullopt) : UnaryInst(Opcode::Jmp, type, static_cast<Value *>(target), loc) {} @@ -173,8 +261,14 @@ public: } }; +/// Conditionally branch to another basic block. class BrInst : public Instruction { public: + /// \p type Type of the ssa result. Must be a TypeID::Void. + /// \p condition The boolean condition used to select the destination. Mut + /// have type i1. + /// \p true_target. The basic block to jump to if the condition is true + /// \p false_target. The basic block to jump to if the condition is false BrInst(Type type, Value *condition, BasicBlock *true_target, BasicBlock *false_target, std::optional<Location> loc = std::nullopt) : Instruction(Opcode::Br, type, loc) { @@ -201,8 +295,15 @@ public: } }; +/// Call another function class CallInst : public Instruction { public: + /// \p optional name name of the ssa result produced by the function call + /// \p rty Type of the ssa result produced by the function. Must match the + /// signature of the function + /// \p func The function to be called + /// \p args List of arguments passed to the function. Must match the signature + /// of the function CallInst(std::string name, Type rty, Function *func, std::initializer_list<Value *> args, std::optional<Location> loc = std::nullopt) @@ -220,6 +321,7 @@ public: } }; +/// Return instruction class RetInst : public Instruction { public: RetInst(Type voidty, std::optional<Location> loc) @@ -230,8 +332,12 @@ public: } }; +/// Select a control-flow variant value class PhiInst : public Instruction { public: + /// \p name name of the ssa result produced by the instruction. + /// \p type Type of the result and operands. + /// \p args List of (predecessor block, value to take) pairs PhiInst(std::string name, Type type, std::initializer_list<std::pair<BasicBlock *, Value *>> args, std::optional<Location> loc) @@ -241,10 +347,36 @@ public: addOperand(v); } } + + auto incomings() { + auto ops = std::span<Value *>{getOperands()}; + assert(ops.size() % 2 == 0); + + return ops | std::views::chunk(2) | + std::views::transform( + [](auto c) -> std::pair<BasicBlock *, Value *> { + return {static_cast<BasicBlock *>(c[0]), + static_cast<Value *>(c[1])}; + }); + } + + auto incomings() const { + auto ops = std::span<Value *const>{getOperands()}; + assert(ops.size() % 2 == 0); + + return ops | std::views::chunk(2) | + std::views::transform( + [](auto c) -> std::pair<const BasicBlock *, const Value *> { + return {static_cast<const BasicBlock *>(c[0]), + static_cast<const Value *>(c[1])}; + }); + } }; +/// Allocate stack space of a value. class AllocaInst : public Instruction { public: + /// \p name Name of the ssa result. /// \p type The pointer type to allocate. AllocaInst(std::string name, Type type, std::optional<Location> loc) : Instruction(std::move(name), Opcode::Alloca, type, loc) { diff --git a/willow/include/willow/IR/TypeContext.h b/willow/include/willow/IR/TypeContext.h index 665369d..5bdcf51 100644 --- a/willow/include/willow/IR/TypeContext.h +++ b/willow/include/willow/IR/TypeContext.h @@ -13,28 +13,28 @@ class TypeContext { public: TypeContext() : voidty(std::make_unique<VoidTypeImpl>()), - labelty(std::make_unique<LabelTypeImpl>()) {} + labelty(std::make_unique<BasicBlockTypeImpl>()) {} TypeContext(const TypeContext &) = delete; TypeContext &operator=(const TypeContext &) = delete; /// \param width Bit width of the integer type. /// \return Uniqued integer type for the requested width. - Type getIntType(std::size_t width); + Type IntType(std::size_t width); /// \param pointee Type of the pointee. /// \param size Size in bits of the pointer representation. /// \return Uniqued pointer type. - Type getPtrType(Type pointee, std::size_t size); + Type PtrType(Type pointee, std::size_t size); /// \return Uniqued void type. - Type getVoidType(); + Type VoidType(); /// \return Uniqued label type. - Type getLabelType(); + Type BasicBlockType(); /// \param ret Return type of the function /// \param params Parameter types of the function - Type getFunctionType(Type ret, std::initializer_list<Type> params); + Type FunctionType(Type ret, std::initializer_list<Type> params); private: std::unordered_map<IntTypeImpl::Key, std::unique_ptr<IntTypeImpl>, @@ -47,29 +47,28 @@ private: FunctionTypeImpl::Hash> fncache; std::unique_ptr<VoidTypeImpl> voidty; - std::unique_ptr<LabelTypeImpl> labelty; + std::unique_ptr<BasicBlockTypeImpl> labelty; }; -Type TypeContext::getIntType(std::size_t width) { +Type TypeContext::IntType(std::size_t width) { auto [it, _] = icache.try_emplace(IntTypeImpl::Key{width}, std::make_unique<IntTypeImpl>(width)); return Type(it->second.get()); } -Type TypeContext::getPtrType(Type pointee, std::size_t size) { +Type TypeContext::PtrType(Type pointee, std::size_t size) { auto [it, _] = pcache.try_emplace( PtrTypeImpl::Key{pointee}, std::make_unique<PtrTypeImpl>(pointee, size)); return Type(it->second.get()); } -Type TypeContext::getVoidType() { return Type(voidty.get()); } +Type TypeContext::VoidType() { return Type(voidty.get()); } -Type TypeContext::getLabelType() { return Type{labelty.get()}; } +Type TypeContext::BasicBlockType() { return Type{labelty.get()}; } -Type TypeContext::getFunctionType(Type ret, - std::initializer_list<Type> params) { +Type TypeContext::FunctionType(Type ret, std::initializer_list<Type> params) { auto [it, _] = fncache.try_emplace(FunctionTypeImpl::Key{ret, params}, std::make_unique<FunctionTypeImpl>(ret, params)); diff --git a/willow/include/willow/IR/Types.h b/willow/include/willow/IR/Types.h index 886c3d5..9a4f868 100644 --- a/willow/include/willow/IR/Types.h +++ b/willow/include/willow/IR/Types.h @@ -7,13 +7,16 @@ #include <willow/Util/Hashing.h> +/// \file Types.h +/// The type system. + namespace willow { /// IDs of all base types. enum class TypeID { // Primitive types - Void, ///< Void type - Label, ///< Label type + Void, ///< Void type + BasicBlock, ///< BasicBlock type // Parameterized types Int, ///< Integer type of any bit width @@ -64,6 +67,7 @@ public: bool isPtr() const { return kind() == TypeID::Ptr; } bool isVoid() const { return kind() == TypeID::Void; } bool isFunction() const { return kind() == TypeID::Function; } + bool isBasicBlock() const { return kind() == TypeID::BasicBlock; } }; class IntTypeImpl : public TypeImpl { @@ -118,7 +122,7 @@ constexpr std::string_view primitiveTypeName(TypeID tid) { switch (tid) { case TypeID::Void: return "void"; - case TypeID::Label: + case TypeID::BasicBlock: return "label"; case TypeID::Int: return "int"; @@ -133,9 +137,9 @@ constexpr std::string_view primitiveTypeName(TypeID tid) { } } -class LabelTypeImpl : public TypeImpl { +class BasicBlockTypeImpl : public TypeImpl { public: - LabelTypeImpl() : TypeImpl{TypeID::Label} {} + BasicBlockTypeImpl() : TypeImpl{TypeID::BasicBlock} {} }; class FunctionTypeImpl : public TypeImpl { @@ -178,7 +182,7 @@ struct std::formatter<willow::Type> { using enum willow::TypeID; switch (ty.kind()) { case Void: - case Label: + case BasicBlock: case Function: return std::format_to(ctx.out(), "{}", willow::primitiveTypeName(ty.kind())); diff --git a/willow/include/willow/IR/Value.h b/willow/include/willow/IR/Value.h index 636fcf4..c219788 100644 --- a/willow/include/willow/IR/Value.h +++ b/willow/include/willow/IR/Value.h @@ -46,6 +46,7 @@ public: ValueKind getValueKind() const { return valuekind; } bool isVoid() const { return type.isVoid(); } + bool isBasicBlock() const { return type.isBasicBlock(); } /// Get the instructions that use this value, and the number of times they use /// it. diff --git a/willow/include/willow/IR/Verifier.h b/willow/include/willow/IR/Verifier.h index cdd1747..0a2c855 100644 --- a/willow/include/willow/IR/Verifier.h +++ b/willow/include/willow/IR/Verifier.h @@ -3,7 +3,10 @@ /// \file Verifier.h These are generic helpers for verification of IR, that can /// be used to check the validity of a transformation. +/// +/// This doesn't use the pass interface because it is useful to have on its own. +#include <willow/IR/Context.h> #include <willow/Util/LogicalResult.h> namespace willow { @@ -15,10 +18,13 @@ class BasicBlock; class Instruction; class BinaryInst; -LogicalResult verifyModule(const Module &, DiagnosticEngine &); -LogicalResult verifyFunction(const Function &, DiagnosticEngine &); -LogicalResult verifyBasicBlock(const BasicBlock &, DiagnosticEngine &); -LogicalResult verifyInst(const Instruction &, DiagnosticEngine &); +LogicalResult verifyModule(WillowContext &, const Module &, DiagnosticEngine &); +LogicalResult verifyFunction(WillowContext &, const Function &, + DiagnosticEngine &); +LogicalResult verifyBasicBlock(WillowContext &, const BasicBlock &, + DiagnosticEngine &); +LogicalResult verifyInst(WillowContext &, const Instruction &, + DiagnosticEngine &); } // namespace willow diff --git a/willow/lib/IR/Verifier.cpp b/willow/lib/IR/Verifier.cpp index f016bbb..d19bc83 100644 --- a/willow/lib/IR/Verifier.cpp +++ b/willow/lib/IR/Verifier.cpp @@ -1,6 +1,7 @@ #include <willow/IR/BasicBlock.h> #include <willow/IR/Diagnostic.h> #include <willow/IR/DiagnosticEngine.h> +#include <willow/IR/Instructions.h> #include <willow/IR/Module.h> #include <willow/IR/Verifier.h> @@ -8,14 +9,30 @@ namespace willow { /// Verify that an instruction defines an SSA result LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags); +/// Verify that an instruction does not define an ssa result +LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags); /// Verify that an instruction has the expected number of operands LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected); -LogicalResult verifyBinaryInst(const Instruction &, DiagnosticEngine &); +/// Verify operand type +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, std::size_t opidx, + Type expected); +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, const Value *operand, + Type expected); -LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { +LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, + Type expected); + +LogicalResult verifyBinaryIntegerInst(const Instruction &, DiagnosticEngine &); +LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &, + DiagnosticEngine &); + +LogicalResult verifyModule(WillowContext &ctx, const Module &module, + DiagnosticEngine &diags) { bool any_failure = false; for (auto &func : module.getFunctions()) { @@ -23,7 +40,7 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { DiagnosticEngine eng( [&](Diagnostic d) { collected.push_back(std::move(d)); }); - auto r = verifyFunction(func, eng); + auto r = verifyFunction(ctx, func, eng); if (succeeded(r)) continue; @@ -40,12 +57,22 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { return any_failure ? failure() : success(); } -// LogicalResult verifyFunction(const Function&function, DiagnosticEngine -// &diags) { -// bool any_failure = false; -// } +LogicalResult verifyFunction(WillowContext &ctx, const Function &function, + DiagnosticEngine &diags) { + if (function.empty()) + return success(); + + bool has_failed = false; + for (auto &block : function.getBlocks()) { + if (failed(verifyBasicBlock(ctx, block, diags))) + has_failed = true; + } + + return has_failed ? failure() : success(); +} -LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) { +LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB, + DiagnosticEngine &diags) { if (BB.empty()) return emit(diags, Severity::Error, BB.getLoc()) << "Basic block '" << BB.getName() << "' has an empty body"; @@ -55,19 +82,27 @@ LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) { << "Basic block '" << BB.getName() << "' does not end with a terminator"; + bool has_failed = false; for (auto &inst : BB.getBody()) { // verify inst - if (failed(verifyInst(inst, diags))) - return failure(); + if (failed(verifyInst(ctx, inst, diags))) + has_failed = true; if (&inst != BB.trailer() && inst.isTerminator()) return emit(diags, Severity::Error, BB.getLoc()) << "Illegal terminator in the middle of a block"; } + + return has_failed ? failure() : success(); } -// TODO: better instruction printing -LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) { +/// Verify an instruction. This will stop on the first invariant that fails to +/// hold. +LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst, + DiagnosticEngine &diags) { + const BasicBlock *BB = inst.getParent(); + const Function *fn = BB ? BB->getParent() : nullptr; + using enum Instruction::Opcode; switch (inst.opcode()) { case Add: @@ -81,70 +116,177 @@ LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) { case Ashr: case And: case Or: - return verifyBinaryInst(inst, diags); + return verifyBinaryIntegerInst(inst, diags); case Eq: case Lt: case Gt: case Le: - case Ge: { + case Ge: + return verifyBinaryIntegerCmp(ctx, inst, diags); + case Not: { Type ty = inst.getType(); - if (!ty.isInt() || ty.getNumBits() != 1) + + if (failed(verifyResult(inst, diags))) + return failure(); + + const Value *operand = inst.getOperand(0); + if (!operand) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("unexpected result type '{}': compare " - "instructions return i1", - ty); + << "instruction 'not' requires 1 operand"; - if (failed(verifyNumOperands(inst, diags, 2))) + Type oty = operand->getType(); + if (ty != oty) + return emit(diags, Severity::Error, inst.getLoc()) + << std::format("expected argument type '{}', got '{}'", ty, oty); + return success(); + } + case Jmp: { + if (failed(verifyNoResult(inst, diags))) return failure(); - const Value *lhs = inst.getOperand(0); - const Value *rhs = inst.getOperand(1); + if (failed(verifyNumOperands(inst, diags, 1))) + return failure(); - Type lty = lhs->getType(), rty = rhs->getType(); + const BasicBlock *dst = static_cast<const BasicBlock *>(inst.getOperand(0)); - if (!lty.isInt()) - return emit(diags, Severity::Error, inst.getLoc()) << std::format( - "invalid operand type '{}': expected integral type", lty); + if (failed( + expectOperandType(inst, diags, dst, ctx.types().BasicBlockType()))) + return failure(); - if (!rty.isInt()) - return emit(diags, Severity::Error, inst.getLoc()) << std::format( - "invalid operand type '{}': expected integral type", rty); + if (BB && fn) { + if (dst->getParent() != fn) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "trying to jump to a block outside of the current function"); + } + return success(); + } + case Br: { + if (failed(verifyNoResult(inst, diags))) + return failure(); + + if (failed(verifyNumOperands(inst, diags, 3))) + return failure(); + + auto *cond = inst.getOperand(0); + auto *truedst = static_cast<const BasicBlock *>(inst.getOperand(1)); + auto *falsedst = static_cast<const BasicBlock *>(inst.getOperand(2)); + + if (failed(expectOperandType(inst, diags, cond, ctx.types().IntType(1)))) + return failure(); + + if (failed(expectOperandType(inst, diags, truedst, + ctx.types().BasicBlockType()))) + return failure(); + + if (failed(expectOperandType(inst, diags, falsedst, + ctx.types().BasicBlockType()))) + return failure(); - if (lty != rty) + if (BB && fn && (fn != truedst->getParent() || fn != falsedst->getParent())) return emit(diags, Severity::Error, inst.getLoc()) - << "mismatched operand types"; + << "branching to a basic block that does not belong to the " + "current function"; - break; + return success(); } - case Not: { - Type ty = inst.getType(); + case Call: { + auto &operands = inst.getOperands(); + const Function *callee = static_cast<const Function *>(inst.getOperand(0)); - if (failed(verifyResult(inst, diags))) + Type rty = callee->getReturnType(); + auto has_result = (rty != ctx.types().VoidType()); + + if (failed((has_result ? verifyResult : verifyNoResult)(inst, diags))) return failure(); - const Value *operand = inst.getOperand(0); - if (!operand) - return emit(diags, Severity::Error, inst.getLoc()) - << "instruction 'not' requires 1 operand"; + if (failed(expectResultType(inst, diags, callee->getReturnType()))) + return failure(); - Type oty = operand->getType(); - if (ty != oty) + auto args = std::ranges::subrange(operands.begin() + 1, operands.end()); + auto params = callee->getParams(); + + if (args.size() != params.size()) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("expected argument type '{}', got '{}'", ty, oty); + << "expected " << params.size() + << " operands to match the signature of function '" + << callee->getName() << "', got " << operands.size(); + + for (const auto &&[arg, param] : std::views::zip(args, params)) { + // TODO normalize interface + auto aty = arg->getType(); + auto pty = param.getType(); + + if (aty == pty) + continue; + + DiagnosticBuilder d(diags, Severity::Error, inst.getLoc()); + d << "invalid argument: expected '" << pty << "', got '" << aty << "'"; + if (param.hasName()) + d.note(Diagnostic{Severity::Remark, + std::format("param name: {}", param.getName())}); + + return failure(); + } + + return success(); + } + case Ret: { + if (!BB || !fn) + return success(); // not much we can say + + bool has_arg = (fn->getReturnType() != ctx.types().VoidType()); + + if (!has_arg) { + if (failed(verifyNumOperands(inst, diags, 0))) + return failure(); + } else { + if (failed(verifyNumOperands(inst, diags, 1))) + return failure(); + + if (failed(expectOperandType(inst, diags, inst.getOperand(0), + fn->getReturnType()))) + return failure(); + } break; } - case Jmp: - case Br: - case Call: - case Ret: - case Phi: - case Alloca: + case Phi: { + auto phi = static_cast<const PhiInst *>(&inst); + if (phi->getNumOperands() % 2) + return emit(diags, Severity::Error, inst.getLoc()) + << "Expected even number of arguments"; + + for (auto [pred, val] : phi->incomings()) { + if (!pred->isBasicBlock()) + return emit(diags, Severity::Error, inst.getLoc()) + << "Expected basic block"; + + if (!BB->preds().contains(const_cast<BasicBlock *>(pred))) + return emit(diags, Severity::Error, inst.getLoc()) + << "Incoming phi edge is not a predecessor"; + + if (BB && fn && (pred->getParent() != fn)) + return emit(diags, Severity::Error, inst.getLoc()) + << "basic block: '" << pred->getName() + << "' is not a child of function '" << fn->getName() << "'"; + } + + return success(); + } + case Alloca: { + Type vty = inst.getType(); + if (!vty.isPtr()) + return emit(diags, Severity::Error, inst.getLoc()) + << "expected alloca to produce a pointer"; + + return success(); } + } + + return success(); } -// TODO: better naming? -LogicalResult verifyBinaryInst(const Instruction &inst, - DiagnosticEngine &diags) { +LogicalResult verifyBinaryIntegerInst(const Instruction &inst, + DiagnosticEngine &diags) { Type ty = inst.getType(); // TODO non scalars @@ -159,9 +301,6 @@ LogicalResult verifyBinaryInst(const Instruction &inst, auto *lhs = inst.getOperand(0); auto *rhs = inst.getOperand(1); - assert(lhs && "Binary op lacks LHS"); - assert(rhs && "Binary op needs RHS"); - if (lhs->getType() != ty) { return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, lhs->getType()); @@ -171,6 +310,37 @@ LogicalResult verifyBinaryInst(const Instruction &inst, return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, rhs->getType()); } + + return success(); +} + +LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, + const Instruction &inst, + DiagnosticEngine &diags) { + if (failed(expectResultType(inst, diags, ctx.types().IntType(1)))) + return failure(); + + if (failed(verifyNumOperands(inst, diags, 2))) + return failure(); + + const Value *lhs = inst.getOperand(0); + const Value *rhs = inst.getOperand(1); + + Type lty = lhs->getType(), rty = rhs->getType(); + + if (!lty.isInt()) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "invalid operand type '{}': expected integral type", lty); + + if (!rty.isInt()) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "invalid operand type '{}': expected integral type", rty); + + if (lty != rty) + return emit(diags, Severity::Error, inst.getLoc()) + << "mismatched operand types"; + + return success(); } LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) { @@ -180,15 +350,60 @@ LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) { return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result"; } +LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) { + if (!inst.hasName()) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result"; +} + LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected) { std::size_t num_operands = inst.getNumOperands(); if (num_operands != expected) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("'{}':, expected {} operands, got {}", inst.opcode(), + << std::format("wrong number of operands: expected {}, found {}", expected, num_operands); return success(); } +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, std::size_t opidx, + Type expected) { + auto *operand = inst.getOperand(opidx); + + assert(operand && "expected operand"); + + if (operand->getType() == expected) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) + << std::format("expected operand #{} to be of type '{}', but got '{}'", + opidx, expected, operand->getType()); +} + +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, const Value *operand, + Type expected) { + assert(operand && "expected operand"); + + auto ty = operand->getType(); + if (ty == expected) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "unexpected operand type '{}': expected '{}'", ty, expected); +} + +LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, + Type expected) { + auto ty = inst.getType(); + if (ty == expected) + return success(); + + return emit(diags, Severity::Error) << std::format( + "unexpected result type: expected '{}', found '{}'", expected, ty); +} + } // namespace willow diff --git a/willow/tools/willowc/BUILD.bazel b/willow/tools/willowc/BUILD.bazel new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/willow/tools/willowc/BUILD.bazel |