summaryrefslogtreecommitdiff
path: root/willow/lib/IR/Verifier.cpp
blob: b692b2fd620a09a3229a3eeeccc45541b7a15a32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#include <willow/IR/BasicBlock.h>
#include <willow/IR/Diagnostic.h>
#include <willow/IR/DiagnosticEngine.h>
#include <willow/IR/Module.h>
#include <willow/IR/Verifier.h>

namespace willow {

/// Verify that an instruction defines an SSA result
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags);

LogicalResult verifyNumOperands(const Instruction &inst,
                                DiagnosticEngine &diags, std::size_t expected);

LogicalResult verifyModule(const Module &module, DiagnosticEngine &diags) {
  bool any_failure = false;

  for (auto &func : module.getFunctions()) {
    std::vector<Diagnostic> collected;
    DiagnosticEngine eng(
        [&](Diagnostic d) { collected.push_back(std::move(d)); });

    auto r = verifyFunction(func, eng);

    if (succeeded(r))
      continue;

    any_failure = true;

    auto diag = emit(diags, Severity::Error, std::nullopt);
    diag << "verification failed for function: '" << func.getName() << "'";

    for (auto &d : collected)
      diag.note(std::move(d));
  }

  return any_failure ? failure() : success();
}

// LogicalResult verifyFunction(const Function&function, DiagnosticEngine
// &diags) {
//   bool any_failure = false;
// }

LogicalResult verifyBasicBlock(const BasicBlock &BB, DiagnosticEngine &diags) {
  bool any_failure = false;

  if (BB.empty())
    return emit(diags, Severity::Error, BB.getLoc())
           << "Basic block '" << BB.getName() << "' has an empty body";

  auto *trailer = BB.trailer();

  if (!trailer->isTerminator()) {
  }

  if (!BB.trailer()->isTerminator())
    // TODO: terminator check

    for (auto &inst : BB.getBody()) {
      // verify inst
    }
}

// TODO: better instruction printing
LogicalResult verifyInst(const Instruction &inst, DiagnosticEngine &diags) {
  using enum Instruction::Opcode;
  switch (inst.opcode()) {
  case Add:
  case Mul:
  case Sub:
  case Div:
  case Mod:
  case Shl:
  case Shr:
  case Ashl:
  case Ashr:
  case And:
  case Or:
    return verifyBinaryInst(inst, diags);
  case Eq:
  case Lt:
  case Gt:
  case Le:
  case Ge: {
    Type ty = inst.getType();
    if (!ty.isInt() || ty.getNumBits() != 1)
      return emit(diags, Severity::Error, inst.getLoc())
             << std::format("unexpected result type '{}': compare "
                            "instructions return i1",
                            ty);

    size_t num_operands = inst.getNumOperands();
    if (num_operands != 2)
      return emit(diags, Severity::Error, inst.getLoc())
             << std::format("expected 2 operands, found {}", num_operands);

    const Value *lhs = inst.getOperand(0);
    const Value *rhs = inst.getOperand(1);

    Type lty = lhs->getType(), rty = rhs->getType();

    if (!lty.isInt())
      return emit(diags, Severity::Error, inst.getLoc()) << std::format(
                 "invalid operand type '{}': expected integral type", lty);

    if (!rty.isInt())
      return emit(diags, Severity::Error, inst.getLoc()) << std::format(
                 "invalid operand type '{}': expected integral type", rty);

    if (lty != rty)
      return emit(diags, Severity::Error, inst.getLoc())
             << "mismatched operand types";
  }
  case Not: {
    Type ty = inst.getType();

    if (failed(verifyResult(inst, diags)))
      return failure();

    const Value *operand = inst.getOperand(0);
    if (!operand)
      return emit(diags, Severity::Error, inst.getLoc())
             << "instruction 'not' requires 1 operand";

    Type oty = operand->getType();
    if (ty != oty)
      return emit(diags, Severity::Error, inst.getLoc())
             << std::format("expected argument type '{}', got '{}'", ty, oty);
  }
  case Jmp:
  case Br:
  case Call:
  case Ret:
  case Phi:
  case Alloca:
  }
}

// TODO: better naming?
LogicalResult verifyBinaryInst(const Instruction &inst,
                               DiagnosticEngine &diags) {
  Type ty = inst.getType();

  // TODO non scalars
  if (!ty.isInt())
    return emit(diags, Severity::Error, inst.getLoc())
           << "invalid instruction '" << inst << "': "
           << "expected an integral type, got '" << ty << "'";

  auto *lhs = inst.getOperand(0);
  auto *rhs = inst.getOperand(1);

  assert(lhs && "Binary op lacks LHS");
  assert(rhs && "Binary op needs RHS");

  if (lhs->getType() != ty) {
    return emit(diags, Severity::Error, inst.getLoc()) << std::format(
               "expected operand type '{}' got '{}'", ty, lhs->getType());
  }

  if (rhs->getType() != ty) {
    return emit(diags, Severity::Error, inst.getLoc()) << std::format(
               "expected operand type '{}' got '{}'", ty, rhs->getType());
  }
}

LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
  if (inst.hasName())
    return success();

  return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result";
}

LogicalResult verifyNumOperands(const Instruction &inst,
                                DiagnosticEngine &diags, std::size_t expected) {
  if (inst.getNumOperands() != expected)
    return emitE
}

} // namespace willow