diff options
Diffstat (limited to 'willow/lib/IR')
| -rw-r--r-- | willow/lib/IR/Verifier.cpp | 323 |
1 files changed, 269 insertions, 54 deletions
diff --git a/willow/lib/IR/Verifier.cpp b/willow/lib/IR/Verifier.cpp index f016bbb..d19bc83 100644 --- a/willow/lib/IR/Verifier.cpp +++ b/willow/lib/IR/Verifier.cpp @@ -1,6 +1,7 @@ #include <willow/IR/BasicBlock.h> #include <willow/IR/Diagnostic.h> #include <willow/IR/DiagnosticEngine.h> +#include <willow/IR/Instructions.h> #include <willow/IR/Module.h> #include <willow/IR/Verifier.h> @@ -8,14 +9,30 @@ namespace willow { /// Verify that an instruction defines an SSA result LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags); +/// Verify that an instruction does not define an ssa result +LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags); /// Verify that an instruction has the expected number of operands LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected); -LogicalResult verifyBinaryInst(const Instruction &, DiagnosticEngine &); +/// Verify operand type +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, std::size_t opidx, + Type expected); +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, const Value *operand, + Type expected); -LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { +LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, + Type expected); + +LogicalResult verifyBinaryIntegerInst(const Instruction &, DiagnosticEngine &); +LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &, + DiagnosticEngine &); + +LogicalResult verifyModule(WillowContext &ctx, const Module &module, + DiagnosticEngine &diags) { bool any_failure = false; for (auto &func : module.getFunctions()) { @@ -23,7 +40,7 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { DiagnosticEngine eng( [&](Diagnostic d) { collected.push_back(std::move(d)); }); - auto r = verifyFunction(func, eng); + auto r = verifyFunction(ctx, func, eng); if (succeeded(r)) continue; @@ -40,12 +57,22 @@ LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) { return any_failure ? failure() : success(); } -// LogicalResult verifyFunction(const Function&function, DiagnosticEngine -// &diags) { -// bool any_failure = false; -// } +LogicalResult verifyFunction(WillowContext &ctx, const Function &function, + DiagnosticEngine &diags) { + if (function.empty()) + return success(); + + bool has_failed = false; + for (auto &block : function.getBlocks()) { + if (failed(verifyBasicBlock(ctx, block, diags))) + has_failed = true; + } + + return has_failed ? failure() : success(); +} -LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) { +LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB, + DiagnosticEngine &diags) { if (BB.empty()) return emit(diags, Severity::Error, BB.getLoc()) << "Basic block '" << BB.getName() << "' has an empty body"; @@ -55,19 +82,27 @@ LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) { << "Basic block '" << BB.getName() << "' does not end with a terminator"; + bool has_failed = false; for (auto &inst : BB.getBody()) { // verify inst - if (failed(verifyInst(inst, diags))) - return failure(); + if (failed(verifyInst(ctx, inst, diags))) + has_failed = true; if (&inst != BB.trailer() && inst.isTerminator()) return emit(diags, Severity::Error, BB.getLoc()) << "Illegal terminator in the middle of a block"; } + + return has_failed ? failure() : success(); } -// TODO: better instruction printing -LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) { +/// Verify an instruction. This will stop on the first invariant that fails to +/// hold. +LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst, + DiagnosticEngine &diags) { + const BasicBlock *BB = inst.getParent(); + const Function *fn = BB ? BB->getParent() : nullptr; + using enum Instruction::Opcode; switch (inst.opcode()) { case Add: @@ -81,70 +116,177 @@ LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) { case Ashr: case And: case Or: - return verifyBinaryInst(inst, diags); + return verifyBinaryIntegerInst(inst, diags); case Eq: case Lt: case Gt: case Le: - case Ge: { + case Ge: + return verifyBinaryIntegerCmp(ctx, inst, diags); + case Not: { Type ty = inst.getType(); - if (!ty.isInt() || ty.getNumBits() != 1) + + if (failed(verifyResult(inst, diags))) + return failure(); + + const Value *operand = inst.getOperand(0); + if (!operand) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("unexpected result type '{}': compare " - "instructions return i1", - ty); + << "instruction 'not' requires 1 operand"; - if (failed(verifyNumOperands(inst, diags, 2))) + Type oty = operand->getType(); + if (ty != oty) + return emit(diags, Severity::Error, inst.getLoc()) + << std::format("expected argument type '{}', got '{}'", ty, oty); + return success(); + } + case Jmp: { + if (failed(verifyNoResult(inst, diags))) return failure(); - const Value *lhs = inst.getOperand(0); - const Value *rhs = inst.getOperand(1); + if (failed(verifyNumOperands(inst, diags, 1))) + return failure(); - Type lty = lhs->getType(), rty = rhs->getType(); + const BasicBlock *dst = static_cast<const BasicBlock *>(inst.getOperand(0)); - if (!lty.isInt()) - return emit(diags, Severity::Error, inst.getLoc()) << std::format( - "invalid operand type '{}': expected integral type", lty); + if (failed( + expectOperandType(inst, diags, dst, ctx.types().BasicBlockType()))) + return failure(); - if (!rty.isInt()) - return emit(diags, Severity::Error, inst.getLoc()) << std::format( - "invalid operand type '{}': expected integral type", rty); + if (BB && fn) { + if (dst->getParent() != fn) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "trying to jump to a block outside of the current function"); + } + return success(); + } + case Br: { + if (failed(verifyNoResult(inst, diags))) + return failure(); + + if (failed(verifyNumOperands(inst, diags, 3))) + return failure(); + + auto *cond = inst.getOperand(0); + auto *truedst = static_cast<const BasicBlock *>(inst.getOperand(1)); + auto *falsedst = static_cast<const BasicBlock *>(inst.getOperand(2)); + + if (failed(expectOperandType(inst, diags, cond, ctx.types().IntType(1)))) + return failure(); + + if (failed(expectOperandType(inst, diags, truedst, + ctx.types().BasicBlockType()))) + return failure(); + + if (failed(expectOperandType(inst, diags, falsedst, + ctx.types().BasicBlockType()))) + return failure(); - if (lty != rty) + if (BB && fn && (fn != truedst->getParent() || fn != falsedst->getParent())) return emit(diags, Severity::Error, inst.getLoc()) - << "mismatched operand types"; + << "branching to a basic block that does not belong to the " + "current function"; - break; + return success(); } - case Not: { - Type ty = inst.getType(); + case Call: { + auto &operands = inst.getOperands(); + const Function *callee = static_cast<const Function *>(inst.getOperand(0)); - if (failed(verifyResult(inst, diags))) + Type rty = callee->getReturnType(); + auto has_result = (rty != ctx.types().VoidType()); + + if (failed((has_result ? verifyResult : verifyNoResult)(inst, diags))) return failure(); - const Value *operand = inst.getOperand(0); - if (!operand) - return emit(diags, Severity::Error, inst.getLoc()) - << "instruction 'not' requires 1 operand"; + if (failed(expectResultType(inst, diags, callee->getReturnType()))) + return failure(); - Type oty = operand->getType(); - if (ty != oty) + auto args = std::ranges::subrange(operands.begin() + 1, operands.end()); + auto params = callee->getParams(); + + if (args.size() != params.size()) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("expected argument type '{}', got '{}'", ty, oty); + << "expected " << params.size() + << " operands to match the signature of function '" + << callee->getName() << "', got " << operands.size(); + + for (const auto &&[arg, param] : std::views::zip(args, params)) { + // TODO normalize interface + auto aty = arg->getType(); + auto pty = param.getType(); + + if (aty == pty) + continue; + + DiagnosticBuilder d(diags, Severity::Error, inst.getLoc()); + d << "invalid argument: expected '" << pty << "', got '" << aty << "'"; + if (param.hasName()) + d.note(Diagnostic{Severity::Remark, + std::format("param name: {}", param.getName())}); + + return failure(); + } + + return success(); + } + case Ret: { + if (!BB || !fn) + return success(); // not much we can say + + bool has_arg = (fn->getReturnType() != ctx.types().VoidType()); + + if (!has_arg) { + if (failed(verifyNumOperands(inst, diags, 0))) + return failure(); + } else { + if (failed(verifyNumOperands(inst, diags, 1))) + return failure(); + + if (failed(expectOperandType(inst, diags, inst.getOperand(0), + fn->getReturnType()))) + return failure(); + } break; } - case Jmp: - case Br: - case Call: - case Ret: - case Phi: - case Alloca: + case Phi: { + auto phi = static_cast<const PhiInst *>(&inst); + if (phi->getNumOperands() % 2) + return emit(diags, Severity::Error, inst.getLoc()) + << "Expected even number of arguments"; + + for (auto [pred, val] : phi->incomings()) { + if (!pred->isBasicBlock()) + return emit(diags, Severity::Error, inst.getLoc()) + << "Expected basic block"; + + if (!BB->preds().contains(const_cast<BasicBlock *>(pred))) + return emit(diags, Severity::Error, inst.getLoc()) + << "Incoming phi edge is not a predecessor"; + + if (BB && fn && (pred->getParent() != fn)) + return emit(diags, Severity::Error, inst.getLoc()) + << "basic block: '" << pred->getName() + << "' is not a child of function '" << fn->getName() << "'"; + } + + return success(); + } + case Alloca: { + Type vty = inst.getType(); + if (!vty.isPtr()) + return emit(diags, Severity::Error, inst.getLoc()) + << "expected alloca to produce a pointer"; + + return success(); } + } + + return success(); } -// TODO: better naming? -LogicalResult verifyBinaryInst(const Instruction &inst, - DiagnosticEngine &diags) { +LogicalResult verifyBinaryIntegerInst(const Instruction &inst, + DiagnosticEngine &diags) { Type ty = inst.getType(); // TODO non scalars @@ -159,9 +301,6 @@ LogicalResult verifyBinaryInst(const Instruction &inst, auto *lhs = inst.getOperand(0); auto *rhs = inst.getOperand(1); - assert(lhs && "Binary op lacks LHS"); - assert(rhs && "Binary op needs RHS"); - if (lhs->getType() != ty) { return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, lhs->getType()); @@ -171,6 +310,37 @@ LogicalResult verifyBinaryInst(const Instruction &inst, return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, rhs->getType()); } + + return success(); +} + +LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, + const Instruction &inst, + DiagnosticEngine &diags) { + if (failed(expectResultType(inst, diags, ctx.types().IntType(1)))) + return failure(); + + if (failed(verifyNumOperands(inst, diags, 2))) + return failure(); + + const Value *lhs = inst.getOperand(0); + const Value *rhs = inst.getOperand(1); + + Type lty = lhs->getType(), rty = rhs->getType(); + + if (!lty.isInt()) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "invalid operand type '{}': expected integral type", lty); + + if (!rty.isInt()) + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "invalid operand type '{}': expected integral type", rty); + + if (lty != rty) + return emit(diags, Severity::Error, inst.getLoc()) + << "mismatched operand types"; + + return success(); } LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) { @@ -180,15 +350,60 @@ LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) { return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result"; } +LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) { + if (!inst.hasName()) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result"; +} + LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected) { std::size_t num_operands = inst.getNumOperands(); if (num_operands != expected) return emit(diags, Severity::Error, inst.getLoc()) - << std::format("'{}':, expected {} operands, got {}", inst.opcode(), + << std::format("wrong number of operands: expected {}, found {}", expected, num_operands); return success(); } +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, std::size_t opidx, + Type expected) { + auto *operand = inst.getOperand(opidx); + + assert(operand && "expected operand"); + + if (operand->getType() == expected) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) + << std::format("expected operand #{} to be of type '{}', but got '{}'", + opidx, expected, operand->getType()); +} + +LogicalResult expectOperandType(const Instruction &inst, + DiagnosticEngine &diags, const Value *operand, + Type expected) { + assert(operand && "expected operand"); + + auto ty = operand->getType(); + if (ty == expected) + return success(); + + return emit(diags, Severity::Error, inst.getLoc()) << std::format( + "unexpected operand type '{}': expected '{}'", ty, expected); +} + +LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, + Type expected) { + auto ty = inst.getType(); + if (ty == expected) + return success(); + + return emit(diags, Severity::Error) << std::format( + "unexpected result type: expected '{}', found '{}'", expected, ty); +} + } // namespace willow |