#include #include #include #include #include #include namespace willow { /// Verify that an instruction defines an SSA result LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags); /// Verify that an instruction does not define an ssa result LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags); /// Verify that an instruction has the expected number of operands LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected); /// Verify operand type LogicalResult expectOperandType(const Instruction &inst, DiagnosticEngine &diags, std::size_t opidx, Type expected); LogicalResult expectOperandType(const Instruction &inst, DiagnosticEngine &diags, const Value *operand, Type expected); LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, Type expected); LogicalResult verifyBinaryIntegerInst(const Instruction &, DiagnosticEngine &); LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &, DiagnosticEngine &); LogicalResult verifyModule(WillowContext &ctx, const Module &module, DiagnosticEngine &diags) { bool any_failure = false; for (auto &func : module.getFunctions()) { std::vector collected; DiagnosticEngine eng( [&](Diagnostic d) { collected.push_back(std::move(d)); }); auto r = verifyFunction(ctx, func, eng); if (succeeded(r)) continue; any_failure = true; auto diag = emit(diags, Severity::Error, std::nullopt); diag << "verification failed for function: '" << func.getName() << "'"; for (auto &d : collected) diag.note(std::move(d)); } return any_failure ? failure() : success(); } LogicalResult verifyFunction(WillowContext &ctx, const Function &function, DiagnosticEngine &diags) { if (function.empty()) return success(); bool has_failed = false; for (auto &block : function.getBlocks()) { if (failed(verifyBasicBlock(ctx, block, diags))) has_failed = true; } return has_failed ? failure() : success(); } LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB, DiagnosticEngine &diags) { if (BB.empty()) return emit(diags, Severity::Error, BB.getLoc()) << "Basic block '" << BB.getName() << "' has an empty body"; if (!BB.trailer()->isTerminator()) return emit(diags, Severity::Error, BB.getLoc()) << "Basic block '" << BB.getName() << "' does not end with a terminator"; bool has_failed = false; for (auto &inst : BB.getBody()) { // verify inst if (failed(verifyInst(ctx, inst, diags))) has_failed = true; if (&inst != BB.trailer() && inst.isTerminator()) return emit(diags, Severity::Error, BB.getLoc()) << "Illegal terminator in the middle of a block"; } return has_failed ? failure() : success(); } /// Verify an instruction. This will stop on the first invariant that fails to /// hold. LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst, DiagnosticEngine &diags) { const BasicBlock *BB = inst.getParent(); const Function *fn = BB ? BB->getParent() : nullptr; using enum Instruction::Opcode; switch (inst.opcode()) { case Add: case Mul: case Sub: case Div: case Mod: case Shl: case Shr: case Ashl: case Ashr: case And: case Or: return verifyBinaryIntegerInst(inst, diags); case Eq: case Lt: case Gt: case Le: case Ge: return verifyBinaryIntegerCmp(ctx, inst, diags); case Not: { Type ty = inst.getType(); if (failed(verifyResult(inst, diags))) return failure(); const Value *operand = inst.getOperand(0); if (!operand) return emit(diags, Severity::Error, inst.getLoc()) << "instruction 'not' requires 1 operand"; Type oty = operand->getType(); if (ty != oty) return emit(diags, Severity::Error, inst.getLoc()) << std::format("expected argument type '{}', got '{}'", ty, oty); return success(); } case Jmp: { if (failed(verifyNoResult(inst, diags))) return failure(); if (failed(verifyNumOperands(inst, diags, 1))) return failure(); const BasicBlock *dst = static_cast(inst.getOperand(0)); if (failed( expectOperandType(inst, diags, dst, ctx.types().BasicBlockType()))) return failure(); if (BB && fn) { if (dst->getParent() != fn) return emit(diags, Severity::Error, inst.getLoc()) << std::format( "trying to jump to a block outside of the current function"); } return success(); } case Br: { if (failed(verifyNoResult(inst, diags))) return failure(); if (failed(verifyNumOperands(inst, diags, 3))) return failure(); auto *cond = inst.getOperand(0); auto *truedst = static_cast(inst.getOperand(1)); auto *falsedst = static_cast(inst.getOperand(2)); if (failed(expectOperandType(inst, diags, cond, ctx.types().IntType(1)))) return failure(); if (failed(expectOperandType(inst, diags, truedst, ctx.types().BasicBlockType()))) return failure(); if (failed(expectOperandType(inst, diags, falsedst, ctx.types().BasicBlockType()))) return failure(); if (BB && fn && (fn != truedst->getParent() || fn != falsedst->getParent())) return emit(diags, Severity::Error, inst.getLoc()) << "branching to a basic block that does not belong to the " "current function"; return success(); } case Call: { auto &operands = inst.getOperands(); const Function *callee = static_cast(inst.getOperand(0)); Type rty = callee->getReturnType(); auto has_result = (rty != ctx.types().VoidType()); if (failed((has_result ? verifyResult : verifyNoResult)(inst, diags))) return failure(); if (failed(expectResultType(inst, diags, callee->getReturnType()))) return failure(); auto args = std::ranges::subrange(operands.begin() + 1, operands.end()); auto params = callee->getParams(); if (args.size() != params.size()) return emit(diags, Severity::Error, inst.getLoc()) << "expected " << params.size() << " operands to match the signature of function '" << callee->getName() << "', got " << operands.size(); for (const auto &&[arg, param] : std::views::zip(args, params)) { // TODO normalize interface auto aty = arg->getType(); auto pty = param.getType(); if (aty == pty) continue; DiagnosticBuilder d(diags, Severity::Error, inst.getLoc()); d << "invalid argument: expected '" << pty << "', got '" << aty << "'"; if (param.hasName()) d.note(Diagnostic{Severity::Remark, std::format("param name: {}", param.getName())}); return failure(); } return success(); } case Ret: { if (!BB || !fn) return success(); // not much we can say bool has_arg = (fn->getReturnType() != ctx.types().VoidType()); if (!has_arg) { if (failed(verifyNumOperands(inst, diags, 0))) return failure(); } else { if (failed(verifyNumOperands(inst, diags, 1))) return failure(); if (failed(expectOperandType(inst, diags, inst.getOperand(0), fn->getReturnType()))) return failure(); } break; } case Phi: { auto phi = static_cast(&inst); if (phi->getNumOperands() % 2) return emit(diags, Severity::Error, inst.getLoc()) << "Expected even number of arguments"; for (auto [pred, val] : phi->incomings()) { if (!pred->isBasicBlock()) return emit(diags, Severity::Error, inst.getLoc()) << "Expected basic block"; if (!BB->preds().contains(const_cast(pred))) return emit(diags, Severity::Error, inst.getLoc()) << "Incoming phi edge is not a predecessor"; if (BB && fn && (pred->getParent() != fn)) return emit(diags, Severity::Error, inst.getLoc()) << "basic block: '" << pred->getName() << "' is not a child of function '" << fn->getName() << "'"; } return success(); } case Alloca: { Type vty = inst.getType(); if (!vty.isPtr()) return emit(diags, Severity::Error, inst.getLoc()) << "expected alloca to produce a pointer"; return success(); } } return success(); } LogicalResult verifyBinaryIntegerInst(const Instruction &inst, DiagnosticEngine &diags) { Type ty = inst.getType(); // TODO non scalars if (!ty.isInt()) return emit(diags, Severity::Error, inst.getLoc()) << "invalid instruction '" << inst << "': " << "expected an integral type, got '" << ty << "'"; if (failed(verifyNumOperands(inst, diags, 2))) return failure(); auto *lhs = inst.getOperand(0); auto *rhs = inst.getOperand(1); if (lhs->getType() != ty) { return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, lhs->getType()); } if (rhs->getType() != ty) { return emit(diags, Severity::Error, inst.getLoc()) << std::format( "expected operand type '{}' got '{}'", ty, rhs->getType()); } return success(); } LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx, const Instruction &inst, DiagnosticEngine &diags) { if (failed(expectResultType(inst, diags, ctx.types().IntType(1)))) return failure(); if (failed(verifyNumOperands(inst, diags, 2))) return failure(); const Value *lhs = inst.getOperand(0); const Value *rhs = inst.getOperand(1); Type lty = lhs->getType(), rty = rhs->getType(); if (!lty.isInt()) return emit(diags, Severity::Error, inst.getLoc()) << std::format( "invalid operand type '{}': expected integral type", lty); if (!rty.isInt()) return emit(diags, Severity::Error, inst.getLoc()) << std::format( "invalid operand type '{}': expected integral type", rty); if (lty != rty) return emit(diags, Severity::Error, inst.getLoc()) << "mismatched operand types"; return success(); } LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) { if (inst.hasName()) return success(); return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result"; } LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) { if (!inst.hasName()) return success(); return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result"; } LogicalResult verifyNumOperands(const Instruction &inst, DiagnosticEngine &diags, std::size_t expected) { std::size_t num_operands = inst.getNumOperands(); if (num_operands != expected) return emit(diags, Severity::Error, inst.getLoc()) << std::format("wrong number of operands: expected {}, found {}", expected, num_operands); return success(); } LogicalResult expectOperandType(const Instruction &inst, DiagnosticEngine &diags, std::size_t opidx, Type expected) { auto *operand = inst.getOperand(opidx); assert(operand && "expected operand"); if (operand->getType() == expected) return success(); return emit(diags, Severity::Error, inst.getLoc()) << std::format("expected operand #{} to be of type '{}', but got '{}'", opidx, expected, operand->getType()); } LogicalResult expectOperandType(const Instruction &inst, DiagnosticEngine &diags, const Value *operand, Type expected) { assert(operand && "expected operand"); auto ty = operand->getType(); if (ty == expected) return success(); return emit(diags, Severity::Error, inst.getLoc()) << std::format( "unexpected operand type '{}': expected '{}'", ty, expected); } LogicalResult expectResultType(const Instruction &inst, DiagnosticEngine &diags, Type expected) { auto ty = inst.getType(); if (ty == expected) return success(); return emit(diags, Severity::Error) << std::format( "unexpected result type: expected '{}', found '{}'", expected, ty); } } // namespace willow