summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Weigl-Bosker <stefan@s00.xyz>2026-01-20 11:10:38 -0500
committerGitHub <noreply@github.com>2026-01-20 11:10:38 -0500
commitc5b2905c5a64433f8519531a77d3acc42d881f17 (patch)
tree8d4d555c057b2ca00adab68797a9b814ad5c8891
parent8d40f659fabdba2d6a17228f76168e7bdbf5c955 (diff)
downloadcompiler-c5b2905c5a64433f8519531a77d3acc42d881f17.tar.gz
[willow]: finish verifier (#7)
-rw-r--r--willow/include/willow/IR/BasicBlock.h20
-rw-r--r--willow/include/willow/IR/Constant.h6
-rw-r--r--willow/include/willow/IR/Context.h8
-rw-r--r--willow/include/willow/IR/Function.h19
-rw-r--r--willow/include/willow/IR/Instruction.h2
-rw-r--r--willow/include/willow/IR/Instructions.h132
-rw-r--r--willow/include/willow/IR/TypeContext.h25
-rw-r--r--willow/include/willow/IR/Types.h16
-rw-r--r--willow/include/willow/IR/Value.h1
-rw-r--r--willow/include/willow/IR/Verifier.h14
-rw-r--r--willow/lib/IR/Verifier.cpp323
-rw-r--r--willow/tools/willowc/BUILD.bazel0
12 files changed, 479 insertions, 87 deletions
diff --git a/willow/include/willow/IR/BasicBlock.h b/willow/include/willow/IR/BasicBlock.h
index 753c04a..c18dfc2 100644
--- a/willow/include/willow/IR/BasicBlock.h
+++ b/willow/include/willow/IR/BasicBlock.h
@@ -24,6 +24,7 @@ class BasicBlock : public Value {
std::optional<Location> loc;
std::list<std::unique_ptr<Instruction>> body;
+ std::unordered_map<BasicBlock *, size_t> predecessors;
public:
// ~BasicBlock() = TODO
@@ -65,6 +66,25 @@ public:
}
std::optional<Location> getLoc() const { return loc; }
+
+ std::unordered_map<BasicBlock*, size_t>& preds() { return predecessors; }
+ const std::unordered_map<BasicBlock*, size_t>& preds() const { return predecessors; }
+
+ inline void addPred(BasicBlock *bb) {
+ auto [it, inserted] = predecessors.try_emplace(bb, 1);
+
+ if (!inserted)
+ it->second += 1;
+ }
+
+ inline void delPred(BasicBlock *bb) {
+ auto it = preds().find(bb);
+
+ it->second -= 1;
+
+ if (it->second <= 0)
+ preds().erase(it);
+ }
};
Instruction *BasicBlock::trailer() {
diff --git a/willow/include/willow/IR/Constant.h b/willow/include/willow/IR/Constant.h
index 4476ac6..d93d65d 100644
--- a/willow/include/willow/IR/Constant.h
+++ b/willow/include/willow/IR/Constant.h
@@ -9,7 +9,6 @@ enum class ConstantKind {
Int, //< Integer value with known bits
Undef, //< Known undef
Poison, //< Known poison
- Label, //< Known reference to a BasicBlock
};
class Constant : public Value {
@@ -88,11 +87,6 @@ public:
explicit PoisonVal(Type ty) : Constant(ty, ConstantKind::Poison) {}
};
-class BlockRef final : public Constant {
-public:
- explicit BlockRef(Type ty, Block *b) : Constant(ty, ConstantKind::Label) {}
-};
-
} // namespace willow
#endif // WILLOW_INCLUDE_IR_CONSTANT_H
diff --git a/willow/include/willow/IR/Context.h b/willow/include/willow/IR/Context.h
index 4f787ef..3f4f2ba 100644
--- a/willow/include/willow/IR/Context.h
+++ b/willow/include/willow/IR/Context.h
@@ -1,6 +1,7 @@
#ifndef WILLOW_INCLUDE_IR_CONTEXT_H
#define WILLOW_INCLUDE_IR_CONTEXT_H
+#include <willow/IR/ConstantPool.h>
#include <willow/IR/Module.h>
#include <willow/IR/TypeContext.h>
@@ -8,9 +9,14 @@ namespace willow {
/// The global context. Contains all shared state.
class WillowContext {
- TypeContext ctx;
+ TypeContext typectx;
+ ConstantPool constantpool;
std::vector<std::unique_ptr<Module>> modules;
+
+public:
+ TypeContext &types() { return typectx; }
+ ConstantPool &constants() { return constantpool; }
};
} // namespace willow
diff --git a/willow/include/willow/IR/Function.h b/willow/include/willow/IR/Function.h
index dd11553..4be4de7 100644
--- a/willow/include/willow/IR/Function.h
+++ b/willow/include/willow/IR/Function.h
@@ -30,9 +30,9 @@ public:
/// This should usually not be called directly.
/// \param name Identifier.
/// \param parent Owning module.
- /// \param ty Signature. Should be
+ /// \param fty Signature. Should be a function type
/// \param params List of named parameters to the function. Should match \p
- /// ty.
+ /// fty.
Function(std::string name, Module *parent, Type fty,
std::vector<std::unique_ptr<Parameter>> params)
: Value(ValueKind::Function, std::move(name), fty), parent(parent),
@@ -61,6 +61,21 @@ public:
[](auto &p) -> const BasicBlock & { return *p; });
}
+ auto getParams() {
+ return params |
+ std::views::transform([](auto &p) -> Parameter & { return *p; });
+ }
+
+ auto getParams() const {
+ return params | std::views::transform(
+ [](auto &p) -> const Parameter & { return *p; });
+ }
+
+ Parameter *getParam(std::size_t idx) { return params[idx].get(); }
+ const Parameter *getParam(std::size_t idx) const { return params[idx].get(); }
+
+ std::size_t getNumParams() const { return params.size(); }
+
/// \return The SSA values that exist in this block.
auto getValues() {
return blocks |
diff --git a/willow/include/willow/IR/Instruction.h b/willow/include/willow/IR/Instruction.h
index 17981ea..a74c409 100644
--- a/willow/include/willow/IR/Instruction.h
+++ b/willow/include/willow/IR/Instruction.h
@@ -63,7 +63,7 @@ public:
Jmp, ///< goto %0
Br, ///< goto (%0 ? %1 : %2)
- Call, ///< call %0 <args>
+ Call, ///< call %0 ...
Ret, ///< ret val?
Phi, ///< phi ^label1 %val1 ^label2 %val2 ...
diff --git a/willow/include/willow/IR/Instructions.h b/willow/include/willow/IR/Instructions.h
index d1d41fa..569c372 100644
--- a/willow/include/willow/IR/Instructions.h
+++ b/willow/include/willow/IR/Instructions.h
@@ -42,127 +42,215 @@ public:
const Value *getRHS() const { return getOperand(1); }
};
+/// Add two integers.
class AddInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
AddInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Add, type, lhs, rhs, loc) {}
};
+/// Multiply two integers.
class MulInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
MulInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Mul, type, lhs, rhs, loc) {}
};
+/// Subtract two integers.
class SubInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
SubInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Sub, type, lhs, rhs, loc) {}
};
+/// Divide two integers.
class DivInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The dividend
+ /// \p lhs The divisor
DivInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Div, type, lhs, rhs, loc) {}
};
+/// Compute the modulus of two integers.
class ModInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The dividend
+ /// \p lhs The divisor
ModInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Mod, type, lhs, rhs, loc) {}
};
+/// Compute the bitwise left shift of an integer
class ShlInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The integer to be shifted
+ /// \p lhs The shift amount
ShlInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Shl, type, lhs, rhs, loc) {}
};
+/// Compute the bitwise right shift of an integer
class ShrInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The integer to be shifted
+ /// \p lhs The shift amount
ShrInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Shr, type, lhs, rhs, loc) {}
};
+/// Compute the arithmetic left shift of an integer (preserves sign)
class AshlInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The integer to be shifted
+ /// \p lhs The shift amount
AshlInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Ashl, type, lhs, rhs, loc) {}
};
+/// Compute the arithmetic right shift of an integer (preserves sign)
class AshrInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an integer type.
+ /// \p lhs The integer to be shifted
+ /// \p lhs The shift amount
AshrInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Ashr, type, lhs, rhs, loc) {}
};
+/// Test two integers for equality.
class EqInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
EqInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Eq, type, lhs, rhs, loc) {}
};
+/// Test if one integer is less than another.
class LtInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
LtInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Lt, type, lhs, rhs, loc) {}
};
+// Test if one integer is greater than another.
class GtInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
GtInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Gt, type, lhs, rhs, loc) {}
};
+// Test if one integer is less than or equal to another.
class LeInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
LeInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Le, type, lhs, rhs, loc) {}
};
+// Test if one integer is greater than or equal to another.
class GeInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first integer
+ /// \p lhs The second integer
GeInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Ge, type, lhs, rhs, loc) {}
};
+/// Compute the logical and of two boolean values.
class AndInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first value
+ /// \p lhs The second value
AndInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::And, type, lhs, rhs, loc) {}
};
+/// Compute the logical and of two boolean values.
class OrInst : public BinaryInst {
public:
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first value
+ /// \p lhs The second value
OrInst(std::string name, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
: BinaryInst(std::move(name), Opcode::Or, type, lhs, rhs, loc) {}
};
+/// Compute the logical negation of a boolean value.
class NotInst : public UnaryInst {
+ /// \p name The name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands. Must be an i1.
+ /// \p lhs The first value
+ /// \p lhs The second value
public:
explicit NotInst(std::string name, Type type, Value *value,
std::optional<Location> loc = std::nullopt)
: UnaryInst(std::move(name), Opcode::Not, type, value, loc) {}
};
+/// Jump to another basic block.
class JmpInst : public UnaryInst {
public:
+ /// \p type Must be a TypeID::BasicBlock
+ /// \p target Target of the jump
JmpInst(Type type, BasicBlock *target,
std::optional<Location> loc = std::nullopt)
: UnaryInst(Opcode::Jmp, type, static_cast<Value *>(target), loc) {}
@@ -173,8 +261,14 @@ public:
}
};
+/// Conditionally branch to another basic block.
class BrInst : public Instruction {
public:
+ /// \p type Type of the ssa result. Must be a TypeID::Void.
+ /// \p condition The boolean condition used to select the destination. Mut
+ /// have type i1.
+ /// \p true_target. The basic block to jump to if the condition is true
+ /// \p false_target. The basic block to jump to if the condition is false
BrInst(Type type, Value *condition, BasicBlock *true_target,
BasicBlock *false_target, std::optional<Location> loc = std::nullopt)
: Instruction(Opcode::Br, type, loc) {
@@ -201,8 +295,15 @@ public:
}
};
+/// Call another function
class CallInst : public Instruction {
public:
+ /// \p optional name name of the ssa result produced by the function call
+ /// \p rty Type of the ssa result produced by the function. Must match the
+ /// signature of the function
+ /// \p func The function to be called
+ /// \p args List of arguments passed to the function. Must match the signature
+ /// of the function
CallInst(std::string name, Type rty, Function *func,
std::initializer_list<Value *> args,
std::optional<Location> loc = std::nullopt)
@@ -220,6 +321,7 @@ public:
}
};
+/// Return instruction
class RetInst : public Instruction {
public:
RetInst(Type voidty, std::optional<Location> loc)
@@ -230,8 +332,12 @@ public:
}
};
+/// Select a control-flow variant value
class PhiInst : public Instruction {
public:
+ /// \p name name of the ssa result produced by the instruction.
+ /// \p type Type of the result and operands.
+ /// \p args List of (predecessor block, value to take) pairs
PhiInst(std::string name, Type type,
std::initializer_list<std::pair<BasicBlock *, Value *>> args,
std::optional<Location> loc)
@@ -241,10 +347,36 @@ public:
addOperand(v);
}
}
+
+ auto incomings() {
+ auto ops = std::span<Value *>{getOperands()};
+ assert(ops.size() % 2 == 0);
+
+ return ops | std::views::chunk(2) |
+ std::views::transform(
+ [](auto c) -> std::pair<BasicBlock *, Value *> {
+ return {static_cast<BasicBlock *>(c[0]),
+ static_cast<Value *>(c[1])};
+ });
+ }
+
+ auto incomings() const {
+ auto ops = std::span<Value *const>{getOperands()};
+ assert(ops.size() % 2 == 0);
+
+ return ops | std::views::chunk(2) |
+ std::views::transform(
+ [](auto c) -> std::pair<const BasicBlock *, const Value *> {
+ return {static_cast<const BasicBlock *>(c[0]),
+ static_cast<const Value *>(c[1])};
+ });
+ }
};
+/// Allocate stack space of a value.
class AllocaInst : public Instruction {
public:
+ /// \p name Name of the ssa result.
/// \p type The pointer type to allocate.
AllocaInst(std::string name, Type type, std::optional<Location> loc)
: Instruction(std::move(name), Opcode::Alloca, type, loc) {
diff --git a/willow/include/willow/IR/TypeContext.h b/willow/include/willow/IR/TypeContext.h
index 665369d..5bdcf51 100644
--- a/willow/include/willow/IR/TypeContext.h
+++ b/willow/include/willow/IR/TypeContext.h
@@ -13,28 +13,28 @@ class TypeContext {
public:
TypeContext()
: voidty(std::make_unique<VoidTypeImpl>()),
- labelty(std::make_unique<LabelTypeImpl>()) {}
+ labelty(std::make_unique<BasicBlockTypeImpl>()) {}
TypeContext(const TypeContext &) = delete;
TypeContext &operator=(const TypeContext &) = delete;
/// \param width Bit width of the integer type.
/// \return Uniqued integer type for the requested width.
- Type getIntType(std::size_t width);
+ Type IntType(std::size_t width);
/// \param pointee Type of the pointee.
/// \param size Size in bits of the pointer representation.
/// \return Uniqued pointer type.
- Type getPtrType(Type pointee, std::size_t size);
+ Type PtrType(Type pointee, std::size_t size);
/// \return Uniqued void type.
- Type getVoidType();
+ Type VoidType();
/// \return Uniqued label type.
- Type getLabelType();
+ Type BasicBlockType();
/// \param ret Return type of the function
/// \param params Parameter types of the function
- Type getFunctionType(Type ret, std::initializer_list<Type> params);
+ Type FunctionType(Type ret, std::initializer_list<Type> params);
private:
std::unordered_map<IntTypeImpl::Key, std::unique_ptr<IntTypeImpl>,
@@ -47,29 +47,28 @@ private:
FunctionTypeImpl::Hash>
fncache;
std::unique_ptr<VoidTypeImpl> voidty;
- std::unique_ptr<LabelTypeImpl> labelty;
+ std::unique_ptr<BasicBlockTypeImpl> labelty;
};
-Type TypeContext::getIntType(std::size_t width) {
+Type TypeContext::IntType(std::size_t width) {
auto [it, _] = icache.try_emplace(IntTypeImpl::Key{width},
std::make_unique<IntTypeImpl>(width));
return Type(it->second.get());
}
-Type TypeContext::getPtrType(Type pointee, std::size_t size) {
+Type TypeContext::PtrType(Type pointee, std::size_t size) {
auto [it, _] = pcache.try_emplace(
PtrTypeImpl::Key{pointee}, std::make_unique<PtrTypeImpl>(pointee, size));
return Type(it->second.get());
}
-Type TypeContext::getVoidType() { return Type(voidty.get()); }
+Type TypeContext::VoidType() { return Type(voidty.get()); }
-Type TypeContext::getLabelType() { return Type{labelty.get()}; }
+Type TypeContext::BasicBlockType() { return Type{labelty.get()}; }
-Type TypeContext::getFunctionType(Type ret,
- std::initializer_list<Type> params) {
+Type TypeContext::FunctionType(Type ret, std::initializer_list<Type> params) {
auto [it, _] =
fncache.try_emplace(FunctionTypeImpl::Key{ret, params},
std::make_unique<FunctionTypeImpl>(ret, params));
diff --git a/willow/include/willow/IR/Types.h b/willow/include/willow/IR/Types.h
index 886c3d5..9a4f868 100644
--- a/willow/include/willow/IR/Types.h
+++ b/willow/include/willow/IR/Types.h
@@ -7,13 +7,16 @@
#include <willow/Util/Hashing.h>
+/// \file Types.h
+/// The type system.
+
namespace willow {
/// IDs of all base types.
enum class TypeID {
// Primitive types
- Void, ///< Void type
- Label, ///< Label type
+ Void, ///< Void type
+ BasicBlock, ///< BasicBlock type
// Parameterized types
Int, ///< Integer type of any bit width
@@ -64,6 +67,7 @@ public:
bool isPtr() const { return kind() == TypeID::Ptr; }
bool isVoid() const { return kind() == TypeID::Void; }
bool isFunction() const { return kind() == TypeID::Function; }
+ bool isBasicBlock() const { return kind() == TypeID::BasicBlock; }
};
class IntTypeImpl : public TypeImpl {
@@ -118,7 +122,7 @@ constexpr std::string_view primitiveTypeName(TypeID tid) {
switch (tid) {
case TypeID::Void:
return "void";
- case TypeID::Label:
+ case TypeID::BasicBlock:
return "label";
case TypeID::Int:
return "int";
@@ -133,9 +137,9 @@ constexpr std::string_view primitiveTypeName(TypeID tid) {
}
}
-class LabelTypeImpl : public TypeImpl {
+class BasicBlockTypeImpl : public TypeImpl {
public:
- LabelTypeImpl() : TypeImpl{TypeID::Label} {}
+ BasicBlockTypeImpl() : TypeImpl{TypeID::BasicBlock} {}
};
class FunctionTypeImpl : public TypeImpl {
@@ -178,7 +182,7 @@ struct std::formatter<willow::Type> {
using enum willow::TypeID;
switch (ty.kind()) {
case Void:
- case Label:
+ case BasicBlock:
case Function:
return std::format_to(ctx.out(), "{}",
willow::primitiveTypeName(ty.kind()));
diff --git a/willow/include/willow/IR/Value.h b/willow/include/willow/IR/Value.h
index 636fcf4..c219788 100644
--- a/willow/include/willow/IR/Value.h
+++ b/willow/include/willow/IR/Value.h
@@ -46,6 +46,7 @@ public:
ValueKind getValueKind() const { return valuekind; }
bool isVoid() const { return type.isVoid(); }
+ bool isBasicBlock() const { return type.isBasicBlock(); }
/// Get the instructions that use this value, and the number of times they use
/// it.
diff --git a/willow/include/willow/IR/Verifier.h b/willow/include/willow/IR/Verifier.h
index cdd1747..0a2c855 100644
--- a/willow/include/willow/IR/Verifier.h
+++ b/willow/include/willow/IR/Verifier.h
@@ -3,7 +3,10 @@
/// \file Verifier.h These are generic helpers for verification of IR, that can
/// be used to check the validity of a transformation.
+///
+/// This doesn't use the pass interface because it is useful to have on its own.
+#include <willow/IR/Context.h>
#include <willow/Util/LogicalResult.h>
namespace willow {
@@ -15,10 +18,13 @@ class BasicBlock;
class Instruction;
class BinaryInst;
-LogicalResult verifyModule(const Module &, DiagnosticEngine &);
-LogicalResult verifyFunction(const Function &, DiagnosticEngine &);
-LogicalResult verifyBasicBlock(const BasicBlock &, DiagnosticEngine &);
-LogicalResult verifyInst(const Instruction &, DiagnosticEngine &);
+LogicalResult verifyModule(WillowContext &, const Module &, DiagnosticEngine &);
+LogicalResult verifyFunction(WillowContext &, const Function &,
+ DiagnosticEngine &);
+LogicalResult verifyBasicBlock(WillowContext &, const BasicBlock &,
+ DiagnosticEngine &);
+LogicalResult verifyInst(WillowContext &, const Instruction &,
+ DiagnosticEngine &);
} // namespace willow
diff --git a/willow/lib/IR/Verifier.cpp b/willow/lib/IR/Verifier.cpp
index f016bbb..d19bc83 100644
--- a/willow/lib/IR/Verifier.cpp
+++ b/willow/lib/IR/Verifier.cpp
@@ -1,6 +1,7 @@
#include <willow/IR/BasicBlock.h>
#include <willow/IR/Diagnostic.h>
#include <willow/IR/DiagnosticEngine.h>
+#include <willow/IR/Instructions.h>
#include <willow/IR/Module.h>
#include <willow/IR/Verifier.h>
@@ -8,14 +9,30 @@ namespace willow {
/// Verify that an instruction defines an SSA result
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags);
+/// Verify that an instruction does not define an ssa result
+LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags);
/// Verify that an instruction has the expected number of operands
LogicalResult verifyNumOperands(const Instruction &inst,
DiagnosticEngine &diags, std::size_t expected);
-LogicalResult verifyBinaryInst(const Instruction &, DiagnosticEngine &);
+/// Verify operand type
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, std::size_t opidx,
+ Type expected);
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, const Value *operand,
+ Type expected);
-LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
+LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags,
+ Type expected);
+
+LogicalResult verifyBinaryIntegerInst(const Instruction &, DiagnosticEngine &);
+LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &,
+ DiagnosticEngine &);
+
+LogicalResult verifyModule(WillowContext &ctx, const Module &module,
+ DiagnosticEngine &diags) {
bool any_failure = false;
for (auto &func : module.getFunctions()) {
@@ -23,7 +40,7 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
DiagnosticEngine eng(
[&](Diagnostic d) { collected.push_back(std::move(d)); });
- auto r = verifyFunction(func, eng);
+ auto r = verifyFunction(ctx, func, eng);
if (succeeded(r))
continue;
@@ -40,12 +57,22 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
return any_failure ? failure() : success();
}
-// LogicalResult verifyFunction(const Function&function, DiagnosticEngine
-// &diags) {
-// bool any_failure = false;
-// }
+LogicalResult verifyFunction(WillowContext &ctx, const Function &function,
+ DiagnosticEngine &diags) {
+ if (function.empty())
+ return success();
+
+ bool has_failed = false;
+ for (auto &block : function.getBlocks()) {
+ if (failed(verifyBasicBlock(ctx, block, diags)))
+ has_failed = true;
+ }
+
+ return has_failed ? failure() : success();
+}
-LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) {
+LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB,
+ DiagnosticEngine &diags) {
if (BB.empty())
return emit(diags, Severity::Error, BB.getLoc())
<< "Basic block '" << BB.getName() << "' has an empty body";
@@ -55,19 +82,27 @@ LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) {
<< "Basic block '" << BB.getName()
<< "' does not end with a terminator";
+ bool has_failed = false;
for (auto &inst : BB.getBody()) {
// verify inst
- if (failed(verifyInst(inst, diags)))
- return failure();
+ if (failed(verifyInst(ctx, inst, diags)))
+ has_failed = true;
if (&inst != BB.trailer() && inst.isTerminator())
return emit(diags, Severity::Error, BB.getLoc())
<< "Illegal terminator in the middle of a block";
}
+
+ return has_failed ? failure() : success();
}
-// TODO: better instruction printing
-LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) {
+/// Verify an instruction. This will stop on the first invariant that fails to
+/// hold.
+LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst,
+ DiagnosticEngine &diags) {
+ const BasicBlock *BB = inst.getParent();
+ const Function *fn = BB ? BB->getParent() : nullptr;
+
using enum Instruction::Opcode;
switch (inst.opcode()) {
case Add:
@@ -81,70 +116,177 @@ LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) {
case Ashr:
case And:
case Or:
- return verifyBinaryInst(inst, diags);
+ return verifyBinaryIntegerInst(inst, diags);
case Eq:
case Lt:
case Gt:
case Le:
- case Ge: {
+ case Ge:
+ return verifyBinaryIntegerCmp(ctx, inst, diags);
+ case Not: {
Type ty = inst.getType();
- if (!ty.isInt() || ty.getNumBits() != 1)
+
+ if (failed(verifyResult(inst, diags)))
+ return failure();
+
+ const Value *operand = inst.getOperand(0);
+ if (!operand)
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("unexpected result type '{}': compare "
- "instructions return i1",
- ty);
+ << "instruction 'not' requires 1 operand";
- if (failed(verifyNumOperands(inst, diags, 2)))
+ Type oty = operand->getType();
+ if (ty != oty)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << std::format("expected argument type '{}', got '{}'", ty, oty);
+ return success();
+ }
+ case Jmp: {
+ if (failed(verifyNoResult(inst, diags)))
return failure();
- const Value *lhs = inst.getOperand(0);
- const Value *rhs = inst.getOperand(1);
+ if (failed(verifyNumOperands(inst, diags, 1)))
+ return failure();
- Type lty = lhs->getType(), rty = rhs->getType();
+ const BasicBlock *dst = static_cast<const BasicBlock *>(inst.getOperand(0));
- if (!lty.isInt())
- return emit(diags, Severity::Error, inst.getLoc()) << std::format(
- "invalid operand type '{}': expected integral type", lty);
+ if (failed(
+ expectOperandType(inst, diags, dst, ctx.types().BasicBlockType())))
+ return failure();
- if (!rty.isInt())
- return emit(diags, Severity::Error, inst.getLoc()) << std::format(
- "invalid operand type '{}': expected integral type", rty);
+ if (BB && fn) {
+ if (dst->getParent() != fn)
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "trying to jump to a block outside of the current function");
+ }
+ return success();
+ }
+ case Br: {
+ if (failed(verifyNoResult(inst, diags)))
+ return failure();
+
+ if (failed(verifyNumOperands(inst, diags, 3)))
+ return failure();
+
+ auto *cond = inst.getOperand(0);
+ auto *truedst = static_cast<const BasicBlock *>(inst.getOperand(1));
+ auto *falsedst = static_cast<const BasicBlock *>(inst.getOperand(2));
+
+ if (failed(expectOperandType(inst, diags, cond, ctx.types().IntType(1))))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, truedst,
+ ctx.types().BasicBlockType())))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, falsedst,
+ ctx.types().BasicBlockType())))
+ return failure();
- if (lty != rty)
+ if (BB && fn && (fn != truedst->getParent() || fn != falsedst->getParent()))
return emit(diags, Severity::Error, inst.getLoc())
- << "mismatched operand types";
+ << "branching to a basic block that does not belong to the "
+ "current function";
- break;
+ return success();
}
- case Not: {
- Type ty = inst.getType();
+ case Call: {
+ auto &operands = inst.getOperands();
+ const Function *callee = static_cast<const Function *>(inst.getOperand(0));
- if (failed(verifyResult(inst, diags)))
+ Type rty = callee->getReturnType();
+ auto has_result = (rty != ctx.types().VoidType());
+
+ if (failed((has_result ? verifyResult : verifyNoResult)(inst, diags)))
return failure();
- const Value *operand = inst.getOperand(0);
- if (!operand)
- return emit(diags, Severity::Error, inst.getLoc())
- << "instruction 'not' requires 1 operand";
+ if (failed(expectResultType(inst, diags, callee->getReturnType())))
+ return failure();
- Type oty = operand->getType();
- if (ty != oty)
+ auto args = std::ranges::subrange(operands.begin() + 1, operands.end());
+ auto params = callee->getParams();
+
+ if (args.size() != params.size())
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("expected argument type '{}', got '{}'", ty, oty);
+ << "expected " << params.size()
+ << " operands to match the signature of function '"
+ << callee->getName() << "', got " << operands.size();
+
+ for (const auto &&[arg, param] : std::views::zip(args, params)) {
+ // TODO normalize interface
+ auto aty = arg->getType();
+ auto pty = param.getType();
+
+ if (aty == pty)
+ continue;
+
+ DiagnosticBuilder d(diags, Severity::Error, inst.getLoc());
+ d << "invalid argument: expected '" << pty << "', got '" << aty << "'";
+ if (param.hasName())
+ d.note(Diagnostic{Severity::Remark,
+ std::format("param name: {}", param.getName())});
+
+ return failure();
+ }
+
+ return success();
+ }
+ case Ret: {
+ if (!BB || !fn)
+ return success(); // not much we can say
+
+ bool has_arg = (fn->getReturnType() != ctx.types().VoidType());
+
+ if (!has_arg) {
+ if (failed(verifyNumOperands(inst, diags, 0)))
+ return failure();
+ } else {
+ if (failed(verifyNumOperands(inst, diags, 1)))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, inst.getOperand(0),
+ fn->getReturnType())))
+ return failure();
+ }
break;
}
- case Jmp:
- case Br:
- case Call:
- case Ret:
- case Phi:
- case Alloca:
+ case Phi: {
+ auto phi = static_cast<const PhiInst *>(&inst);
+ if (phi->getNumOperands() % 2)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Expected even number of arguments";
+
+ for (auto [pred, val] : phi->incomings()) {
+ if (!pred->isBasicBlock())
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Expected basic block";
+
+ if (!BB->preds().contains(const_cast<BasicBlock *>(pred)))
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Incoming phi edge is not a predecessor";
+
+ if (BB && fn && (pred->getParent() != fn))
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "basic block: '" << pred->getName()
+ << "' is not a child of function '" << fn->getName() << "'";
+ }
+
+ return success();
+ }
+ case Alloca: {
+ Type vty = inst.getType();
+ if (!vty.isPtr())
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "expected alloca to produce a pointer";
+
+ return success();
}
+ }
+
+ return success();
}
-// TODO: better naming?
-LogicalResult verifyBinaryInst(const Instruction &inst,
- DiagnosticEngine &diags) {
+LogicalResult verifyBinaryIntegerInst(const Instruction &inst,
+ DiagnosticEngine &diags) {
Type ty = inst.getType();
// TODO non scalars
@@ -159,9 +301,6 @@ LogicalResult verifyBinaryInst(const Instruction &inst,
auto *lhs = inst.getOperand(0);
auto *rhs = inst.getOperand(1);
- assert(lhs && "Binary op lacks LHS");
- assert(rhs && "Binary op needs RHS");
-
if (lhs->getType() != ty) {
return emit(diags, Severity::Error, inst.getLoc()) << std::format(
"expected operand type '{}' got '{}'", ty, lhs->getType());
@@ -171,6 +310,37 @@ LogicalResult verifyBinaryInst(const Instruction &inst,
return emit(diags, Severity::Error, inst.getLoc()) << std::format(
"expected operand type '{}' got '{}'", ty, rhs->getType());
}
+
+ return success();
+}
+
+LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx,
+ const Instruction &inst,
+ DiagnosticEngine &diags) {
+ if (failed(expectResultType(inst, diags, ctx.types().IntType(1))))
+ return failure();
+
+ if (failed(verifyNumOperands(inst, diags, 2)))
+ return failure();
+
+ const Value *lhs = inst.getOperand(0);
+ const Value *rhs = inst.getOperand(1);
+
+ Type lty = lhs->getType(), rty = rhs->getType();
+
+ if (!lty.isInt())
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "invalid operand type '{}': expected integral type", lty);
+
+ if (!rty.isInt())
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "invalid operand type '{}': expected integral type", rty);
+
+ if (lty != rty)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "mismatched operand types";
+
+ return success();
}
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
@@ -180,15 +350,60 @@ LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result";
}
+LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) {
+ if (!inst.hasName())
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result";
+}
+
LogicalResult verifyNumOperands(const Instruction &inst,
DiagnosticEngine &diags, std::size_t expected) {
std::size_t num_operands = inst.getNumOperands();
if (num_operands != expected)
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("'{}':, expected {} operands, got {}", inst.opcode(),
+ << std::format("wrong number of operands: expected {}, found {}",
expected, num_operands);
return success();
}
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, std::size_t opidx,
+ Type expected) {
+ auto *operand = inst.getOperand(opidx);
+
+ assert(operand && "expected operand");
+
+ if (operand->getType() == expected)
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc())
+ << std::format("expected operand #{} to be of type '{}', but got '{}'",
+ opidx, expected, operand->getType());
+}
+
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, const Value *operand,
+ Type expected) {
+ assert(operand && "expected operand");
+
+ auto ty = operand->getType();
+ if (ty == expected)
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "unexpected operand type '{}': expected '{}'", ty, expected);
+}
+
+LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags,
+ Type expected) {
+ auto ty = inst.getType();
+ if (ty == expected)
+ return success();
+
+ return emit(diags, Severity::Error) << std::format(
+ "unexpected result type: expected '{}', found '{}'", expected, ty);
+}
+
} // namespace willow
diff --git a/willow/tools/willowc/BUILD.bazel b/willow/tools/willowc/BUILD.bazel
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/willow/tools/willowc/BUILD.bazel