summaryrefslogtreecommitdiff
path: root/willow/lib/IR/Verifier.cpp
diff options
context:
space:
mode:
authorStefan Weigl-Bosker <stefan@s00.xyz>2026-01-20 11:10:38 -0500
committerGitHub <noreply@github.com>2026-01-20 11:10:38 -0500
commitc5b2905c5a64433f8519531a77d3acc42d881f17 (patch)
tree8d4d555c057b2ca00adab68797a9b814ad5c8891 /willow/lib/IR/Verifier.cpp
parent8d40f659fabdba2d6a17228f76168e7bdbf5c955 (diff)
downloadcompiler-c5b2905c5a64433f8519531a77d3acc42d881f17.tar.gz
[willow]: finish verifier (#7)
Diffstat (limited to 'willow/lib/IR/Verifier.cpp')
-rw-r--r--willow/lib/IR/Verifier.cpp323
1 files changed, 269 insertions, 54 deletions
diff --git a/willow/lib/IR/Verifier.cpp b/willow/lib/IR/Verifier.cpp
index f016bbb..d19bc83 100644
--- a/willow/lib/IR/Verifier.cpp
+++ b/willow/lib/IR/Verifier.cpp
@@ -1,6 +1,7 @@
#include <willow/IR/BasicBlock.h>
#include <willow/IR/Diagnostic.h>
#include <willow/IR/DiagnosticEngine.h>
+#include <willow/IR/Instructions.h>
#include <willow/IR/Module.h>
#include <willow/IR/Verifier.h>
@@ -8,14 +9,30 @@ namespace willow {
/// Verify that an instruction defines an SSA result
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags);
+/// Verify that an instruction does not define an ssa result
+LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags);
/// Verify that an instruction has the expected number of operands
LogicalResult verifyNumOperands(const Instruction &inst,
DiagnosticEngine &diags, std::size_t expected);
-LogicalResult verifyBinaryInst(const Instruction &, DiagnosticEngine &);
+/// Verify operand type
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, std::size_t opidx,
+ Type expected);
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, const Value *operand,
+ Type expected);
-LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
+LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags,
+ Type expected);
+
+LogicalResult verifyBinaryIntegerInst(const Instruction &, DiagnosticEngine &);
+LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &,
+ DiagnosticEngine &);
+
+LogicalResult verifyModule(WillowContext &ctx, const Module &module,
+ DiagnosticEngine &diags) {
bool any_failure = false;
for (auto &func : module.getFunctions()) {
@@ -23,7 +40,7 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
DiagnosticEngine eng(
[&](Diagnostic d) { collected.push_back(std::move(d)); });
- auto r = verifyFunction(func, eng);
+ auto r = verifyFunction(ctx, func, eng);
if (succeeded(r))
continue;
@@ -40,12 +57,22 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
return any_failure ? failure() : success();
}
-// LogicalResult verifyFunction(const Function&function, DiagnosticEngine
-// &diags) {
-// bool any_failure = false;
-// }
+LogicalResult verifyFunction(WillowContext &ctx, const Function &function,
+ DiagnosticEngine &diags) {
+ if (function.empty())
+ return success();
+
+ bool has_failed = false;
+ for (auto &block : function.getBlocks()) {
+ if (failed(verifyBasicBlock(ctx, block, diags)))
+ has_failed = true;
+ }
+
+ return has_failed ? failure() : success();
+}
-LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) {
+LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB,
+ DiagnosticEngine &diags) {
if (BB.empty())
return emit(diags, Severity::Error, BB.getLoc())
<< "Basic block '" << BB.getName() << "' has an empty body";
@@ -55,19 +82,27 @@ LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) {
<< "Basic block '" << BB.getName()
<< "' does not end with a terminator";
+ bool has_failed = false;
for (auto &inst : BB.getBody()) {
// verify inst
- if (failed(verifyInst(inst, diags)))
- return failure();
+ if (failed(verifyInst(ctx, inst, diags)))
+ has_failed = true;
if (&inst != BB.trailer() && inst.isTerminator())
return emit(diags, Severity::Error, BB.getLoc())
<< "Illegal terminator in the middle of a block";
}
+
+ return has_failed ? failure() : success();
}
-// TODO: better instruction printing
-LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) {
+/// Verify an instruction. This will stop on the first invariant that fails to
+/// hold.
+LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst,
+ DiagnosticEngine &diags) {
+ const BasicBlock *BB = inst.getParent();
+ const Function *fn = BB ? BB->getParent() : nullptr;
+
using enum Instruction::Opcode;
switch (inst.opcode()) {
case Add:
@@ -81,70 +116,177 @@ LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) {
case Ashr:
case And:
case Or:
- return verifyBinaryInst(inst, diags);
+ return verifyBinaryIntegerInst(inst, diags);
case Eq:
case Lt:
case Gt:
case Le:
- case Ge: {
+ case Ge:
+ return verifyBinaryIntegerCmp(ctx, inst, diags);
+ case Not: {
Type ty = inst.getType();
- if (!ty.isInt() || ty.getNumBits() != 1)
+
+ if (failed(verifyResult(inst, diags)))
+ return failure();
+
+ const Value *operand = inst.getOperand(0);
+ if (!operand)
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("unexpected result type '{}': compare "
- "instructions return i1",
- ty);
+ << "instruction 'not' requires 1 operand";
- if (failed(verifyNumOperands(inst, diags, 2)))
+ Type oty = operand->getType();
+ if (ty != oty)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << std::format("expected argument type '{}', got '{}'", ty, oty);
+ return success();
+ }
+ case Jmp: {
+ if (failed(verifyNoResult(inst, diags)))
return failure();
- const Value *lhs = inst.getOperand(0);
- const Value *rhs = inst.getOperand(1);
+ if (failed(verifyNumOperands(inst, diags, 1)))
+ return failure();
- Type lty = lhs->getType(), rty = rhs->getType();
+ const BasicBlock *dst = static_cast<const BasicBlock *>(inst.getOperand(0));
- if (!lty.isInt())
- return emit(diags, Severity::Error, inst.getLoc()) << std::format(
- "invalid operand type '{}': expected integral type", lty);
+ if (failed(
+ expectOperandType(inst, diags, dst, ctx.types().BasicBlockType())))
+ return failure();
- if (!rty.isInt())
- return emit(diags, Severity::Error, inst.getLoc()) << std::format(
- "invalid operand type '{}': expected integral type", rty);
+ if (BB && fn) {
+ if (dst->getParent() != fn)
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "trying to jump to a block outside of the current function");
+ }
+ return success();
+ }
+ case Br: {
+ if (failed(verifyNoResult(inst, diags)))
+ return failure();
+
+ if (failed(verifyNumOperands(inst, diags, 3)))
+ return failure();
+
+ auto *cond = inst.getOperand(0);
+ auto *truedst = static_cast<const BasicBlock *>(inst.getOperand(1));
+ auto *falsedst = static_cast<const BasicBlock *>(inst.getOperand(2));
+
+ if (failed(expectOperandType(inst, diags, cond, ctx.types().IntType(1))))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, truedst,
+ ctx.types().BasicBlockType())))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, falsedst,
+ ctx.types().BasicBlockType())))
+ return failure();
- if (lty != rty)
+ if (BB && fn && (fn != truedst->getParent() || fn != falsedst->getParent()))
return emit(diags, Severity::Error, inst.getLoc())
- << "mismatched operand types";
+ << "branching to a basic block that does not belong to the "
+ "current function";
- break;
+ return success();
}
- case Not: {
- Type ty = inst.getType();
+ case Call: {
+ auto &operands = inst.getOperands();
+ const Function *callee = static_cast<const Function *>(inst.getOperand(0));
- if (failed(verifyResult(inst, diags)))
+ Type rty = callee->getReturnType();
+ auto has_result = (rty != ctx.types().VoidType());
+
+ if (failed((has_result ? verifyResult : verifyNoResult)(inst, diags)))
return failure();
- const Value *operand = inst.getOperand(0);
- if (!operand)
- return emit(diags, Severity::Error, inst.getLoc())
- << "instruction 'not' requires 1 operand";
+ if (failed(expectResultType(inst, diags, callee->getReturnType())))
+ return failure();
- Type oty = operand->getType();
- if (ty != oty)
+ auto args = std::ranges::subrange(operands.begin() + 1, operands.end());
+ auto params = callee->getParams();
+
+ if (args.size() != params.size())
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("expected argument type '{}', got '{}'", ty, oty);
+ << "expected " << params.size()
+ << " operands to match the signature of function '"
+ << callee->getName() << "', got " << operands.size();
+
+ for (const auto &&[arg, param] : std::views::zip(args, params)) {
+ // TODO normalize interface
+ auto aty = arg->getType();
+ auto pty = param.getType();
+
+ if (aty == pty)
+ continue;
+
+ DiagnosticBuilder d(diags, Severity::Error, inst.getLoc());
+ d << "invalid argument: expected '" << pty << "', got '" << aty << "'";
+ if (param.hasName())
+ d.note(Diagnostic{Severity::Remark,
+ std::format("param name: {}", param.getName())});
+
+ return failure();
+ }
+
+ return success();
+ }
+ case Ret: {
+ if (!BB || !fn)
+ return success(); // not much we can say
+
+ bool has_arg = (fn->getReturnType() != ctx.types().VoidType());
+
+ if (!has_arg) {
+ if (failed(verifyNumOperands(inst, diags, 0)))
+ return failure();
+ } else {
+ if (failed(verifyNumOperands(inst, diags, 1)))
+ return failure();
+
+ if (failed(expectOperandType(inst, diags, inst.getOperand(0),
+ fn->getReturnType())))
+ return failure();
+ }
break;
}
- case Jmp:
- case Br:
- case Call:
- case Ret:
- case Phi:
- case Alloca:
+ case Phi: {
+ auto phi = static_cast<const PhiInst *>(&inst);
+ if (phi->getNumOperands() % 2)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Expected even number of arguments";
+
+ for (auto [pred, val] : phi->incomings()) {
+ if (!pred->isBasicBlock())
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Expected basic block";
+
+ if (!BB->preds().contains(const_cast<BasicBlock *>(pred)))
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "Incoming phi edge is not a predecessor";
+
+ if (BB && fn && (pred->getParent() != fn))
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "basic block: '" << pred->getName()
+ << "' is not a child of function '" << fn->getName() << "'";
+ }
+
+ return success();
+ }
+ case Alloca: {
+ Type vty = inst.getType();
+ if (!vty.isPtr())
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "expected alloca to produce a pointer";
+
+ return success();
}
+ }
+
+ return success();
}
-// TODO: better naming?
-LogicalResult verifyBinaryInst(const Instruction &inst,
- DiagnosticEngine &diags) {
+LogicalResult verifyBinaryIntegerInst(const Instruction &inst,
+ DiagnosticEngine &diags) {
Type ty = inst.getType();
// TODO non scalars
@@ -159,9 +301,6 @@ LogicalResult verifyBinaryInst(const Instruction &inst,
auto *lhs = inst.getOperand(0);
auto *rhs = inst.getOperand(1);
- assert(lhs && "Binary op lacks LHS");
- assert(rhs && "Binary op needs RHS");
-
if (lhs->getType() != ty) {
return emit(diags, Severity::Error, inst.getLoc()) << std::format(
"expected operand type '{}' got '{}'", ty, lhs->getType());
@@ -171,6 +310,37 @@ LogicalResult verifyBinaryInst(const Instruction &inst,
return emit(diags, Severity::Error, inst.getLoc()) << std::format(
"expected operand type '{}' got '{}'", ty, rhs->getType());
}
+
+ return success();
+}
+
+LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx,
+ const Instruction &inst,
+ DiagnosticEngine &diags) {
+ if (failed(expectResultType(inst, diags, ctx.types().IntType(1))))
+ return failure();
+
+ if (failed(verifyNumOperands(inst, diags, 2)))
+ return failure();
+
+ const Value *lhs = inst.getOperand(0);
+ const Value *rhs = inst.getOperand(1);
+
+ Type lty = lhs->getType(), rty = rhs->getType();
+
+ if (!lty.isInt())
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "invalid operand type '{}': expected integral type", lty);
+
+ if (!rty.isInt())
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "invalid operand type '{}': expected integral type", rty);
+
+ if (lty != rty)
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "mismatched operand types";
+
+ return success();
}
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
@@ -180,15 +350,60 @@ LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result";
}
+LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) {
+ if (!inst.hasName())
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result";
+}
+
LogicalResult verifyNumOperands(const Instruction &inst,
DiagnosticEngine &diags, std::size_t expected) {
std::size_t num_operands = inst.getNumOperands();
if (num_operands != expected)
return emit(diags, Severity::Error, inst.getLoc())
- << std::format("'{}':, expected {} operands, got {}", inst.opcode(),
+ << std::format("wrong number of operands: expected {}, found {}",
expected, num_operands);
return success();
}
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, std::size_t opidx,
+ Type expected) {
+ auto *operand = inst.getOperand(opidx);
+
+ assert(operand && "expected operand");
+
+ if (operand->getType() == expected)
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc())
+ << std::format("expected operand #{} to be of type '{}', but got '{}'",
+ opidx, expected, operand->getType());
+}
+
+LogicalResult expectOperandType(const Instruction &inst,
+ DiagnosticEngine &diags, const Value *operand,
+ Type expected) {
+ assert(operand && "expected operand");
+
+ auto ty = operand->getType();
+ if (ty == expected)
+ return success();
+
+ return emit(diags, Severity::Error, inst.getLoc()) << std::format(
+ "unexpected operand type '{}': expected '{}'", ty, expected);
+}
+
+LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags,
+ Type expected) {
+ auto ty = inst.getType();
+ if (ty == expected)
+ return success();
+
+ return emit(diags, Severity::Error) << std::format(
+ "unexpected result type: expected '{}', found '{}'", expected, ty);
+}
+
} // namespace willow