summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Weigl-Bosker <stefan@s00.xyz>2026-02-19 13:13:41 -0500
committerGitHub <noreply@github.com>2026-02-19 13:13:41 -0500
commit1fd2d6d88f5f78d879bf38bb3fba7fa2e749d3b0 (patch)
treeeb5a0740956812678131970687377339fad5a541
parentadd95b14f74e6dbe04a6efe98ff0f20424930b73 (diff)
downloadcompiler-1fd2d6d88f5f78d879bf38bb3fba7fa2e749d3b0.tar.gz
[willow]: initial IRBuilder API (#9)
- add IRBuilder api - remove `name` field from `Value` - fix some bugs in IList interface - more verifier tests
-rw-r--r--willow/include/willow/ADT/IList.h202
-rw-r--r--willow/include/willow/IR/BasicBlock.h123
-rw-r--r--willow/include/willow/IR/ConstantPool.h6
-rw-r--r--willow/include/willow/IR/Function.h8
-rw-r--r--willow/include/willow/IR/IRBuilder.h76
-rw-r--r--willow/include/willow/IR/Instruction.h30
-rw-r--r--willow/include/willow/IR/Instructions.h124
-rw-r--r--willow/include/willow/IR/Module.h2
-rw-r--r--willow/include/willow/IR/TypeContext.h28
-rw-r--r--willow/include/willow/IR/Types.h4
-rw-r--r--willow/include/willow/IR/Value.h19
-rw-r--r--willow/include/willow/Util/LogicalResult.h2
-rw-r--r--willow/lib/IR/IRBuilder.cpp210
-rw-r--r--willow/lib/IR/TypeContext.cpp31
-rw-r--r--willow/lib/IR/Value.cpp54
-rw-r--r--willow/lib/IR/Verifier.cpp29
-rw-r--r--willow/unittest/BUILD.bazel2
-rw-r--r--willow/unittest/IR/BUILD.bazel (renamed from willow/unittest/ir/BUILD.bazel)0
-rw-r--r--willow/unittest/IR/VerifierTest.cpp82
-rw-r--r--willow/unittest/ir/VerifierTest.cpp51
20 files changed, 762 insertions, 321 deletions
diff --git a/willow/include/willow/ADT/IList.h b/willow/include/willow/ADT/IList.h
index 0c55ad5..ed9218e 100644
--- a/willow/include/willow/ADT/IList.h
+++ b/willow/include/willow/ADT/IList.h
@@ -1,7 +1,9 @@
#ifndef WILLOW_INCLUDE_ADT_ILIST_H
#define WILLOW_INCLUDE_ADT_ILIST_H
+#include <iterator>
#include <type_traits>
+#include <cassert>
namespace willow {
@@ -15,6 +17,7 @@ class IListTrait {
public:
IListTrait() = default;
+ IListTrait(IListTrait<T> *prev = nullptr, IListTrait<T> *next = nullptr) : prev{prev}, next{next} {}
IListTrait(const IListTrait &) = delete;
IListTrait &operator=(const IListTrait &) = delete;
IListTrait(IListTrait &&) = delete;
@@ -28,15 +31,17 @@ public:
void insertAfter(IListTrait<T> &node) {
auto old = this->next;
this->next = &node;
- node->prev = this;
- node->next = old;
+ node.prev = this;
+ node.next = old;
+ old->prev = &node;
};
void insertBefore(IListTrait<T> &node) {
auto old = this->prev;
- this->prev = node;
- node->next = this;
- node->prev = old;
+ this->prev = &node;
+ node.next = this;
+ node.prev = old;
+ old->next = &node;
};
/// Erase the node from the list. Return a pointer to the unliked node
@@ -48,6 +53,17 @@ public:
return static_cast<T *>(this);
}
+ [[maybe_unused]]
+ T *deleteFromList() {
+ prev->next = next;
+ next->prev = prev;
+ T *tmp = static_cast<T *>(next);
+ delete static_cast<T *>(this);
+ return tmp;
+ }
+
+ bool isLinked() const { return prev != nullptr || next != nullptr; }
+
private:
IListTrait<T> *prev = nullptr;
IListTrait<T> *next = nullptr;
@@ -58,17 +74,183 @@ template <typename T>
class IList {
static_assert(std::is_base_of_v<IListTrait<T>, T>,
"T must implement IListTrait");
- using node = IListTrait<T>;
+ using Node = IListTrait<T>;
public:
- IList() : dummyBegin{nullptr, &dummyEnd}, dummyEnd{&dummyBegin, dummyEnd} {}
+ IList() : dummyBegin{nullptr, static_cast<IListTrait<T>*>(&dummyEnd)}, dummyEnd{static_cast<IListTrait<T>*>(&dummyBegin), static_cast<IListTrait<T>*>(&dummyEnd)} {}
IList(const IList &) = delete;
IList(IList &&) = delete;
private:
- // yes we could save 16 bytes, no i dont really care
- node dummyBegin;
- node dummyEnd;
+ Node dummyBegin;
+ Node dummyEnd;
+
+public:
+ class Iterator {
+ friend class IList<T>;
+ explicit Iterator(Node *n) : cur(n) {}
+ explicit Iterator(Node &n) : cur(&n) {}
+ Node *cur = nullptr;
+
+ public:
+ using iterator_category = std::bidirectional_iterator_tag;
+ using value_type = T;
+ using difference_type = std::ptrdiff_t;
+ using pointer = T *;
+ using reference = T &;
+
+ Iterator() = default;
+
+ reference operator*() { return *static_cast<T *>(this->cur); }
+ pointer operator->() { return static_cast<T *>(this->cur); }
+
+ Iterator &operator++() {
+ cur = cur->next;
+ return *this;
+ }
+
+ Iterator operator++(int) {
+ Iterator tmp = *this;
+ ++(*this);
+ return tmp;
+ }
+
+ Iterator &operator--() {
+ cur = cur->prev;
+ return *this;
+ }
+
+ Iterator operator--(int) {
+ Iterator tmp = *this;
+ --(*this);
+ return tmp;
+ }
+
+ friend bool operator==(const Iterator &a, const Iterator &b) {
+ return a.cur == b.cur;
+ }
+
+ friend bool operator!=(const Iterator &a, const Iterator &b) {
+ return a.cur != b.cur;
+ }
+ };
+
+ class ConstIterator {
+ friend class IList<T>;
+ explicit ConstIterator(const Node *n) : cur(n) {}
+ explicit ConstIterator(const Node &n) : cur(&n) {}
+ const Node *cur = nullptr;
+
+ public:
+ using iterator_category = std::bidirectional_iterator_tag;
+ using value_type = const T;
+ using difference_type = std::ptrdiff_t;
+ using pointer = const T *;
+ using reference = const T &;
+
+ ConstIterator() = default;
+ ConstIterator(const Iterator &it) : cur(it.cur) {}
+
+ reference operator*() const { return *static_cast<const T *>(this->cur); }
+ pointer operator->() const { return static_cast<const T *>(this->cur); }
+
+ ConstIterator &operator++() {
+ cur = cur->next;
+ return *this;
+ }
+
+ ConstIterator operator++(int) {
+ Iterator tmp = *this;
+ ++(*this);
+ return tmp;
+ }
+
+ ConstIterator &operator--() {
+ cur = cur->prev;
+ return *this;
+ }
+
+ ConstIterator operator--(int) {
+ Iterator tmp = *this;
+ --(*this);
+ return tmp;
+ }
+
+ friend bool operator==(const ConstIterator &a, const ConstIterator &b) {
+ return a.cur == b.cur;
+ }
+
+ friend bool operator!=(const ConstIterator &a, const ConstIterator &b) {
+ return a.cur != b.cur;
+ }
+ };
+
+ bool empty() const { return dummyBegin.next == &dummyEnd; }
+
+ T &front() {
+ assert(!empty());
+ return *static_cast<T *>(dummyBegin.next);
+ }
+ const T &front() const {
+ assert(!empty());
+ return *static_cast<const T *>(dummyBegin.next);
+ }
+
+ T &back() {
+ assert(!empty());
+ return *static_cast<T *>(dummyEnd.prev);
+ }
+
+ const T &back() const {
+ assert(!empty());
+ return *static_cast<const T *>(dummyEnd.prev);
+ }
+
+ Iterator begin() { return Iterator{dummyBegin.next}; }
+ Iterator end() { return Iterator{&dummyEnd}; }
+
+ ConstIterator begin() const { return ConstIterator{dummyBegin.next}; }
+ ConstIterator end() const { return ConstIterator{&dummyEnd}; }
+
+ ConstIterator cbegin() const { return begin(); }
+ ConstIterator cend() const { return end(); }
+
+ Iterator insert(Iterator pos, T& node) { pos.cur->insertBefore(node); return Iterator(node); }
+
+ void push_front(T &node) {
+ dummyBegin.insertAfter(node);
+ }
+
+ void push_back(T &node) {
+ dummyEnd.insertBefore(node);
+ }
+
+ /// unlink the node and return it
+ T& remove(T &node) {
+ return static_cast<Node &>(node).removeFromList();
+ }
+
+ /// Unlink the node at \p pos
+ /// \returns Iterator to next node
+ Iterator erase(Iterator pos) {
+ assert(pos.cur != &dummyEnd);
+ Node *next = static_cast<Node *>(pos.cur)->next;
+ pos->removeFromList();
+ return Iterator(next);
+ }
+
+ /// Unlink all nodes
+ void clear() {
+ Node *n = dummyBegin.next;
+ while (n != &dummyEnd) {
+ Node *next = n->next;
+ n->prev = nullptr;
+ n->next = nullptr;
+ n = next;
+ }
+ dummyBegin.next = &dummyEnd;
+ dummyEnd.prev = &dummyBegin;
+ }
};
} // namespace willow
diff --git a/willow/include/willow/IR/BasicBlock.h b/willow/include/willow/IR/BasicBlock.h
index 5db8538..eb08c5d 100644
--- a/willow/include/willow/IR/BasicBlock.h
+++ b/willow/include/willow/IR/BasicBlock.h
@@ -1,74 +1,98 @@
#ifndef WILLOW_INCLUDE_IR_BASIC_BLOCK_H
#define WILLOW_INCLUDE_IR_BASIC_BLOCK_H
+#include <willow/ADT/IList.h>
#include <willow/IR/Instruction.h>
#include <willow/IR/Location.h>
#include <willow/IR/Value.h>
-#include <list>
#include <memory>
-#include <ranges>
-#include <string>
-#include <string_view>
-#include <utility>
-#include <vector>
namespace willow {
class Function;
/// A sequence of consecutively executed instructions, followed by a terminator.
-class BasicBlock : public Value {
+class BasicBlock final : public Value {
Function *parent = nullptr;
std::optional<Location> loc;
- std::list<std::unique_ptr<Instruction>> body;
+ IList<Instruction> body;
std::unordered_map<BasicBlock *, size_t> predecessors;
public:
- // ~BasicBlock() = TODO
+ using Iterator = IList<Instruction>::Iterator;
+ using ConstIterator = IList<Instruction>::ConstIterator;
+ ~BasicBlock() final {
+ assert(getUses().empty() && "Removing a basic block that still has uses");
+ while (!empty())
+ body.back().deleteFromList();
+ }
/// Create a basic block with a name and parent function.
- BasicBlock(std::string name, Function *parent, Type bbty,
- std::optional<Location> loc)
- : Value(ValueKind::BasicBlock, std::move(name), bbty), parent(parent),
- loc(loc) {}
+ BasicBlock(Function *parent, Type bbty,
+ std::optional<Location> loc = std::nullopt)
+ : Value(ValueKind::BasicBlock, bbty), parent(parent), loc(loc) {}
Function *getParent() const { return parent; }
bool hasParent() const { return parent; }
void setParent(Function *parent) { this->parent = parent; }
bool empty() const { return body.empty(); }
- std::size_t size() { return body.size(); }
- auto getBody() {
- return body |
- std::views::transform([](auto &p) -> Instruction & { return *p; });
+ Iterator begin() { return body.begin(); }
+ Iterator end() { return body.end(); }
+ ConstIterator begin() const { return body.begin(); }
+ ConstIterator end() const { return body.end(); }
+ ConstIterator cbegin() const { return body.cbegin(); }
+ ConstIterator cend() const { return body.cend(); }
+
+ Instruction &trailer() { return body.back(); }
+ Instruction &leader() { return body.front(); }
+ const Instruction &trailer() const { return body.front(); };
+ const Instruction &leader() const { return body.back(); };
+
+ Instruction *addInstruction(std::unique_ptr<Instruction> inst);
+
+ void push_back(Instruction &inst) {
+ assert(!inst.hasParent() && "Instruction is already parented");
+ body.push_back(inst);
+ inst.setParent(this);
}
- auto getBody() const {
- return body | std::views::transform(
- [](auto &p) -> const Instruction & { return *p; });
+ void push_front(Instruction &inst) {
+ assert(!inst.hasParent() && "Instruction is already parented");
+ body.push_front(inst);
+ inst.setParent(this);
}
- Instruction *trailer();
- Instruction *leader();
- const Instruction *trailer() const;
- const Instruction *leader() const;
+ Iterator insert(Iterator pos, Instruction &inst) {
+ assert(!inst.hasParent() && "Instruction is already parented");
+ auto it = body.insert(pos, inst);
+ inst.setParent(this);
+ return it;
+ }
- Instruction *addInstruction(std::unique_ptr<Instruction> inst);
+ Iterator erase(Iterator pos) {
+ Instruction &I = *pos;
+ I.setParent(nullptr);
+ return body.erase(pos);
+ }
- template <typename... Args>
- Instruction *emplaceInstruction(Args &&...args) {
- body.push_back(std::make_unique<Instruction>(std::forward(args)...));
- body.back()->setParent(this);
- return body.back().get();
+ Iterator eraseAndDelete(Iterator pos) {
+ Instruction &inst = *pos;
+ pos->setParent(nullptr);
+ auto it = body.erase(pos);
+ delete &inst;
+ return it;
}
std::optional<Location> getLoc() const { return loc; }
- std::unordered_map<BasicBlock*, size_t>& preds() { return predecessors; }
- const std::unordered_map<BasicBlock*, size_t>& preds() const { return predecessors; }
+ std::unordered_map<BasicBlock *, size_t> &preds() { return predecessors; }
+ const std::unordered_map<BasicBlock *, size_t> &preds() const {
+ return predecessors;
+ }
void addPred(BasicBlock *bb) {
auto [it, inserted] = predecessors.try_emplace(bb, 1);
@@ -87,41 +111,6 @@ public:
}
};
-inline Instruction *BasicBlock::trailer() {
- if (empty())
- return nullptr;
- else
- return body.back().get();
-}
-
-inline Instruction *BasicBlock::leader() {
- if (empty())
- return nullptr;
- else
- return body.back().get();
-}
-
-inline const Instruction *BasicBlock::trailer() const {
- if (empty())
- return nullptr;
- else
- return body.back().get();
-}
-
-inline const Instruction *BasicBlock::leader() const {
- if (empty())
- return nullptr;
- else
- return body.back().get();
-}
-
-inline Instruction *BasicBlock::addInstruction(std::unique_ptr<Instruction> inst) {
- Instruction *p = inst.get();
- p->setParent(this);
- body.push_back(std::move(inst));
- return p;
-}
-
} // namespace willow
#endif // WILLOW_INCLUDE_IR_BASIC_BLOCK_H
diff --git a/willow/include/willow/IR/ConstantPool.h b/willow/include/willow/IR/ConstantPool.h
index 2524c70..7d5347b 100644
--- a/willow/include/willow/IR/ConstantPool.h
+++ b/willow/include/willow/IR/ConstantPool.h
@@ -32,7 +32,7 @@ private:
std::unordered_map<TypeImpl *, std::unique_ptr<PoisonVal>> pcache;
};
-ConstantInt *ConstantPool::getInt(Type ty, uint64_t val) {
+inline ConstantInt *ConstantPool::getInt(Type ty, uint64_t val) {
assert(ty.isInt() && "Expected integer type");
ConstantInt::Key &&k{ty.getImpl(), ty.getNumBits()};
std::lock_guard<std::mutex> lock(int_mutex);
@@ -42,7 +42,7 @@ ConstantInt *ConstantPool::getInt(Type ty, uint64_t val) {
return it->second.get();
}
-UndefVal *ConstantPool::getUndefVal(Type ty) {
+inline UndefVal *ConstantPool::getUndefVal(Type ty) {
std::lock_guard<std::mutex> lock(undef_mutex);
auto [it, _] =
@@ -51,7 +51,7 @@ UndefVal *ConstantPool::getUndefVal(Type ty) {
return it->second.get();
}
-PoisonVal *ConstantPool::getPoisonVal(Type ty) {
+inline PoisonVal *ConstantPool::getPoisonVal(Type ty) {
std::lock_guard<std::mutex> lock(poison_mutex);
auto [it, _] =
diff --git a/willow/include/willow/IR/Function.h b/willow/include/willow/IR/Function.h
index 43a301b..bd6ef3d 100644
--- a/willow/include/willow/IR/Function.h
+++ b/willow/include/willow/IR/Function.h
@@ -19,6 +19,7 @@ class Module;
/// Groups basic blocks and the values they define.
class Function : public Value {
Module *parent = nullptr;
+ std::string name;
std::vector<std::unique_ptr<BasicBlock>> blocks;
std::vector<std::unique_ptr<Value>> values;
@@ -35,17 +36,20 @@ public:
/// fty.
Function(std::string name, Module *parent, Type fty,
std::vector<std::unique_ptr<Parameter>> params)
- : Value(ValueKind::Function, std::move(name), fty), parent(parent),
+ : Value(ValueKind::Function, fty), parent(parent), name(std::move(name)),
params(std::move(params)) {}
Function(std::string name, Module *parent, Type fty)
- : Value(ValueKind::Function, std::move(name), fty), parent(parent) {}
+ : Value(ValueKind::Function, fty), parent(parent), name(std::move(name)) {}
+
/// \return Parent module or nullptr.
Module *getParent() { return parent; }
const Module *getParent() const { return parent; }
void setParent(Module *parent) { this->parent = parent; }
+ const std::string& getName() const { return name; }
+
/// \return Entry block or nullptr.
BasicBlock *entryBlock();
const BasicBlock *entryBlock() const;
diff --git a/willow/include/willow/IR/IRBuilder.h b/willow/include/willow/IR/IRBuilder.h
index 6036e33..f2f36f2 100644
--- a/willow/include/willow/IR/IRBuilder.h
+++ b/willow/include/willow/IR/IRBuilder.h
@@ -1,23 +1,89 @@
#ifndef WILLOW_INCLUDE_IR_IR_BUILDER_H
#define WILLOW_INCLUDE_IR_IR_BUILDER_H
-#include <willow/IR/Context.h>
#include <willow/IR/BasicBlock.h>
+#include <willow/IR/Context.h>
#include <willow/IR/Function.h>
#include <willow/IR/Instruction.h>
+#include <willow/IR/Instructions.h>
+
+#include <cassert>
namespace willow {
/// Helper for constructing and modifiying IR.
class IRBuilder {
+ BasicBlock::Iterator insert_point;
WillowContext &ctx;
- const Module &mod;
- BasicBlock *block = nullptr;
public:
- explicit IRBuilder(BasicBlock *block) : block(block) {}
+ explicit IRBuilder(WillowContext &ctx, BasicBlock &bb,
+ BasicBlock::Iterator insertion_point)
+ : insert_point(insertion_point), ctx(ctx) {}
+
+ void SetInsertPoint(BasicBlock::Iterator point) { insert_point = point; }
+
+ AddInst *BuildAdd(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ MulInst *BuildMul(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ SubInst *BuildSub(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ DivInst *BuildDiv(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ ModInst *BuildMod(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ ShlInst *BuildShl(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ ShrInst *BuildShr(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ AshlInst *BuildAshl(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ AshrInst *BuildAshr(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+
+ EqInst *BuildEq(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ LtInst *BuildLt(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ GtInst *BuildGt(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ LeInst *BuildLe(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ GeInst *BuildGe(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+
+ AndInst *BuildAnd(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ OrInst *BuildOr(Value *lhs, Value *rhs,
+ std::optional<Location> loc = std::nullopt);
+ NotInst *BuildNot(Value *val, std::optional<Location> loc = std::nullopt);
+
+ JmpInst *BuildJmp(BasicBlock *dst,
+ std::optional<Location> loc = std::nullopt);
+ BrInst *BuildBr(Value *predicate, BasicBlock *truedst, BasicBlock *falsedst,
+ std::optional<Location> loc = std::nullopt);
+ template <typename... Args>
+ CallInst *BuildCall(Type rty, Function *func,
+ std::optional<Location> loc = std::nullopt,
+ Args &&...args) {
+ auto *inst = new CallInst(rty, func, std::forward<Args>(args)...);
+ assert(inst);
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+ }
+ RetInst *BuildRet(Value *val, std::optional<Location> loc = std::nullopt);
+ RetInst *BuildRet(std::optional<Location> loc = std::nullopt);
+
+ PhiInst *
+ BuildPhi(Type ty,
+ std::initializer_list<std::pair<BasicBlock *, Value *>> args,
+ std::optional<Location> loc = std::nullopt);
+ AllocaInst *BuildAlloca(Type ty, std::optional<Location> loc = std::nullopt);
};
} // namespace willow
-#endif // WILLOW_INCLUDE_IR_IR_BUILDER_H
+#endif // WILLOW_INCLUDE_IR_IR_BUILDER_H*
diff --git a/willow/include/willow/IR/Instruction.h b/willow/include/willow/IR/Instruction.h
index 1198546..e02e308 100644
--- a/willow/include/willow/IR/Instruction.h
+++ b/willow/include/willow/IR/Instruction.h
@@ -1,6 +1,7 @@
#ifndef WILLOW_INCLUDE_IR_INSTRUCTION_H
#define WILLOW_INCLUDE_IR_INSTRUCTION_H
+#include <willow/ADT/IList.h>
#include <willow/IR/Location.h>
#include <willow/IR/Types.h>
#include <willow/IR/Value.h>
@@ -42,7 +43,7 @@ public:
};
/// Defines an IR instruction.
-class Instruction : public Value {
+class Instruction : public Value, public IListTrait<Instruction> {
public:
enum class Opcode {
Add, ///< Int addition
@@ -98,19 +99,18 @@ public:
/// \param op Opcode for this instruction.
/// \param type Type of the result of this instruction.
/// \param loc Source location of this instruction.
- Instruction(Opcode op, Type type, std::optional<Location> loc = std::nullopt)
- : Value(ValueKind::Instruction, type), op(op), loc(loc) {}
-
- /// \param name Name of the ssa value produced by the instruction.
- /// \param op Opcode for the instruction.
- /// \param type Type of the result of the instruction.
- /// \param loc Source location of this instruction.
- Instruction(std::string name, Opcode op, Type type,
- std::optional<Location> loc = std::nullopt)
- : Value(ValueKind::Instruction, std::move(name), type), op(op), loc(loc) {
- }
+ /// \param prev Previous instruction in the block.
+ /// \param next Next instruction in the block.
+ Instruction(Opcode op, Type type, std::optional<Location> loc = std::nullopt,
+ Instruction *prev = nullptr, Instruction *next = nullptr)
+ : Value(ValueKind::Instruction, type),
+ IListTrait<Instruction>{prev, next}, op(op), loc(loc) {}
~Instruction() override {
+ assert(!IListTrait<Instruction>::isLinked() &&
+ "Instruction is destroyed before it is unlinked");
+ assert(getUses().empty() && "Removing an instruction that still has uses");
+ // TODO prob need a use wrapper, for op %1 %1
for (Value *operand : operands)
operand->delUse(this);
}
@@ -218,11 +218,13 @@ struct std::formatter<willow::Instruction::Opcode> {
}
};
-inline std::ostream& operator<<(std::ostream &os, const willow::Instruction::Opcode op) {
+inline std::ostream &operator<<(std::ostream &os,
+ const willow::Instruction::Opcode op) {
return os << willow::Instruction::opcodeName(op);
}
-inline std::ostream& operator<<(std::ostream &os, const willow::Instruction &inst) {
+inline std::ostream &operator<<(std::ostream &os,
+ const willow::Instruction &inst) {
auto vty = inst.getType();
os << vty << " " << willow::Instruction::opcodeName(inst.opcode());
diff --git a/willow/include/willow/IR/Instructions.h b/willow/include/willow/IR/Instructions.h
index 569c372..753ea70 100644
--- a/willow/include/willow/IR/Instructions.h
+++ b/willow/include/willow/IR/Instructions.h
@@ -10,11 +10,6 @@ namespace willow {
/// The base class for unary instructions
class UnaryInst : public Instruction {
public:
- UnaryInst(std::string name, Opcode op, Type type, Value *value,
- std::optional<Location> loc = std::nullopt)
- : Instruction(std::move(name), op, type, loc) {
- addOperand(value);
- }
UnaryInst(Opcode op, Type type, Value *value,
std::optional<Location> loc = std::nullopt)
: Instruction(op, type, loc) {
@@ -28,9 +23,9 @@ public:
/// The base class for binary instructions
class BinaryInst : public Instruction {
public:
- BinaryInst(std::string name, Opcode op, Type type, Value *lhs, Value *rhs,
+ BinaryInst(Opcode op, Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : Instruction(std::move(name), op, type, loc) {
+ : Instruction(op, type, loc) {
addOperand(lhs);
addOperand(rhs);
}
@@ -45,205 +40,188 @@ public:
/// Add two integers.
class AddInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The first integer
/// \p lhs The second integer
- AddInst(std::string name, Type type, Value *lhs, Value *rhs,
+ AddInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Add, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Add, type, lhs, rhs, loc) {}
};
/// Multiply two integers.
class MulInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The first integer
/// \p lhs The second integer
- MulInst(std::string name, Type type, Value *lhs, Value *rhs,
+ MulInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Mul, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Mul, type, lhs, rhs, loc) {}
};
/// Subtract two integers.
class SubInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The first integer
/// \p lhs The second integer
- SubInst(std::string name, Type type, Value *lhs, Value *rhs,
+ SubInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Sub, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Sub, type, lhs, rhs, loc) {}
};
/// Divide two integers.
class DivInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The dividend
/// \p lhs The divisor
- DivInst(std::string name, Type type, Value *lhs, Value *rhs,
+ DivInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Div, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Div, type, lhs, rhs, loc) {}
};
/// Compute the modulus of two integers.
class ModInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The dividend
/// \p lhs The divisor
- ModInst(std::string name, Type type, Value *lhs, Value *rhs,
+ ModInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Mod, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Mod, type, lhs, rhs, loc) {}
};
/// Compute the bitwise left shift of an integer
class ShlInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The integer to be shifted
/// \p lhs The shift amount
- ShlInst(std::string name, Type type, Value *lhs, Value *rhs,
+ ShlInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Shl, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Shl, type, lhs, rhs, loc) {}
};
/// Compute the bitwise right shift of an integer
class ShrInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The integer to be shifted
/// \p lhs The shift amount
- ShrInst(std::string name, Type type, Value *lhs, Value *rhs,
+ ShrInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Shr, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Shr, type, lhs, rhs, loc) {}
};
/// Compute the arithmetic left shift of an integer (preserves sign)
class AshlInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The integer to be shifted
/// \p lhs The shift amount
- AshlInst(std::string name, Type type, Value *lhs, Value *rhs,
+ AshlInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Ashl, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Ashl, type, lhs, rhs, loc) {}
};
/// Compute the arithmetic right shift of an integer (preserves sign)
class AshrInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an integer type.
/// \p lhs The integer to be shifted
/// \p lhs The shift amount
- AshrInst(std::string name, Type type, Value *lhs, Value *rhs,
+ AshrInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Ashr, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Ashr, type, lhs, rhs, loc) {}
};
/// Test two integers for equality.
class EqInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first integer
/// \p lhs The second integer
- EqInst(std::string name, Type type, Value *lhs, Value *rhs,
+ EqInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Eq, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Eq, type, lhs, rhs, loc) {}
};
/// Test if one integer is less than another.
class LtInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first integer
/// \p lhs The second integer
- LtInst(std::string name, Type type, Value *lhs, Value *rhs,
+ LtInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Lt, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Lt, type, lhs, rhs, loc) {}
};
// Test if one integer is greater than another.
class GtInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first integer
/// \p lhs The second integer
- GtInst(std::string name, Type type, Value *lhs, Value *rhs,
+ GtInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Gt, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Gt, type, lhs, rhs, loc) {}
};
// Test if one integer is less than or equal to another.
class LeInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first integer
/// \p lhs The second integer
- LeInst(std::string name, Type type, Value *lhs, Value *rhs,
+ LeInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Le, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Le, type, lhs, rhs, loc) {}
};
// Test if one integer is greater than or equal to another.
class GeInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first integer
/// \p lhs The second integer
- GeInst(std::string name, Type type, Value *lhs, Value *rhs,
+ GeInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Ge, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Ge, type, lhs, rhs, loc) {}
};
/// Compute the logical and of two boolean values.
class AndInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first value
/// \p lhs The second value
- AndInst(std::string name, Type type, Value *lhs, Value *rhs,
+ AndInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::And, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::And, type, lhs, rhs, loc) {}
};
/// Compute the logical and of two boolean values.
class OrInst : public BinaryInst {
public:
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first value
/// \p lhs The second value
- OrInst(std::string name, Type type, Value *lhs, Value *rhs,
+ OrInst(Type type, Value *lhs, Value *rhs,
std::optional<Location> loc = std::nullopt)
- : BinaryInst(std::move(name), Opcode::Or, type, lhs, rhs, loc) {}
+ : BinaryInst(Opcode::Or, type, lhs, rhs, loc) {}
};
/// Compute the logical negation of a boolean value.
class NotInst : public UnaryInst {
- /// \p name The name of the ssa result produced by the instruction.
/// \p type Type of the result and operands. Must be an i1.
/// \p lhs The first value
/// \p lhs The second value
public:
- explicit NotInst(std::string name, Type type, Value *value,
+ explicit NotInst(Type type, Value *value,
std::optional<Location> loc = std::nullopt)
- : UnaryInst(std::move(name), Opcode::Not, type, value, loc) {}
+ : UnaryInst(Opcode::Not, type, value, loc) {}
};
/// Jump to another basic block.
@@ -298,26 +276,16 @@ public:
/// Call another function
class CallInst : public Instruction {
public:
- /// \p optional name name of the ssa result produced by the function call
/// \p rty Type of the ssa result produced by the function. Must match the
/// signature of the function
/// \p func The function to be called
/// \p args List of arguments passed to the function. Must match the signature
/// of the function
- CallInst(std::string name, Type rty, Function *func,
- std::initializer_list<Value *> args,
- std::optional<Location> loc = std::nullopt)
- : Instruction(std::move(name), Opcode::Call, rty, loc) {
- addOperand(func);
- for (auto a : args)
- addOperand(a);
- }
- CallInst(Function *func, Type voidty, std::initializer_list<Value *> args,
- std::optional<Location> loc = std::nullopt)
- : Instruction(Opcode::Call, voidty, loc) {
- addOperand(func);
- for (auto a : args)
- addOperand(a);
+ template <typename... Args>
+ CallInst(Type rty, Function *func, std::optional<Location> loc = std::nullopt,
+ Args &&...args)
+ : Instruction(Opcode::Call, rty, loc) {
+ (addOperand(args), ...);
}
};
@@ -335,13 +303,12 @@ public:
/// Select a control-flow variant value
class PhiInst : public Instruction {
public:
- /// \p name name of the ssa result produced by the instruction.
/// \p type Type of the result and operands.
/// \p args List of (predecessor block, value to take) pairs
- PhiInst(std::string name, Type type,
+ PhiInst(Type type,
std::initializer_list<std::pair<BasicBlock *, Value *>> args,
std::optional<Location> loc)
- : Instruction(std::move(name), Opcode::Phi, type, loc) {
+ : Instruction(Opcode::Phi, type, loc) {
for (auto [bb, v] : args) {
addOperand(static_cast<Value *>(bb));
addOperand(v);
@@ -376,10 +343,9 @@ public:
/// Allocate stack space of a value.
class AllocaInst : public Instruction {
public:
- /// \p name Name of the ssa result.
/// \p type The pointer type to allocate.
- AllocaInst(std::string name, Type type, std::optional<Location> loc)
- : Instruction(std::move(name), Opcode::Alloca, type, loc) {
+ AllocaInst(Type type, std::optional<Location> loc)
+ : Instruction(Opcode::Alloca, type, loc) {
assert(type.isPtr() && "alloca must return a pointer");
}
};
diff --git a/willow/include/willow/IR/Module.h b/willow/include/willow/IR/Module.h
index 59fb6a4..d854f2d 100644
--- a/willow/include/willow/IR/Module.h
+++ b/willow/include/willow/IR/Module.h
@@ -42,7 +42,7 @@ public:
}
};
-Function *Module::addFunction(std::unique_ptr<Function> fn) {
+inline Function *Module::addFunction(std::unique_ptr<Function> fn) {
auto p = fn.get();
functions.push_back(std::move(fn));
p->setParent(this);
diff --git a/willow/include/willow/IR/TypeContext.h b/willow/include/willow/IR/TypeContext.h
index 5bdcf51..e7b0712 100644
--- a/willow/include/willow/IR/TypeContext.h
+++ b/willow/include/willow/IR/TypeContext.h
@@ -24,7 +24,7 @@ public:
/// \param pointee Type of the pointee.
/// \param size Size in bits of the pointer representation.
/// \return Uniqued pointer type.
- Type PtrType(Type pointee, std::size_t size);
+ Type PtrType(Type pointee);
/// \return Uniqued void type.
Type VoidType();
@@ -50,32 +50,6 @@ private:
std::unique_ptr<BasicBlockTypeImpl> labelty;
};
-Type TypeContext::IntType(std::size_t width) {
- auto [it, _] = icache.try_emplace(IntTypeImpl::Key{width},
- std::make_unique<IntTypeImpl>(width));
-
- return Type(it->second.get());
-}
-
-Type TypeContext::PtrType(Type pointee, std::size_t size) {
- auto [it, _] = pcache.try_emplace(
- PtrTypeImpl::Key{pointee}, std::make_unique<PtrTypeImpl>(pointee, size));
-
- return Type(it->second.get());
-}
-
-Type TypeContext::VoidType() { return Type(voidty.get()); }
-
-Type TypeContext::BasicBlockType() { return Type{labelty.get()}; }
-
-Type TypeContext::FunctionType(Type ret, std::initializer_list<Type> params) {
- auto [it, _] =
- fncache.try_emplace(FunctionTypeImpl::Key{ret, params},
- std::make_unique<FunctionTypeImpl>(ret, params));
-
- return Type(it->second.get());
-};
-
} // namespace willow
#endif // WILLOW_INCLUDE_IR_TYPE_CONTEXT_H
diff --git a/willow/include/willow/IR/Types.h b/willow/include/willow/IR/Types.h
index 40bb098..0e4a02d 100644
--- a/willow/include/willow/IR/Types.h
+++ b/willow/include/willow/IR/Types.h
@@ -94,8 +94,8 @@ class PtrTypeImpl : public TypeImpl {
Type pointee;
public:
- PtrTypeImpl(Type pointee_type, std::size_t size)
- : TypeImpl{TypeID::Ptr, size}, pointee{pointee_type.getImpl()} {}
+ PtrTypeImpl(Type pointee_type)
+ : TypeImpl{TypeID::Ptr}, pointee{pointee_type.getImpl()} {}
Type getPointee() const { return Type{pointee.getImpl()}; }
diff --git a/willow/include/willow/IR/Value.h b/willow/include/willow/IR/Value.h
index 650e773..e252fa2 100644
--- a/willow/include/willow/IR/Value.h
+++ b/willow/include/willow/IR/Value.h
@@ -4,8 +4,6 @@
#include <willow/IR/Types.h>
#include <cassert>
-#include <string>
-#include <string_view>
#include <unordered_map>
#include <ostream>
@@ -26,7 +24,6 @@ enum class ValueKind {
/// An SSA value that may be used.
class Value {
ValueKind valuekind;
- std::string name;
Type type;
// Instructions that use this value
@@ -35,14 +32,8 @@ class Value {
public:
virtual ~Value() = default;
- Value(ValueKind valuekind, std::string name, Type type)
- : valuekind(valuekind), name(std::move(name)), type(type) {}
- Value(ValueKind valuekind, Type type) : valuekind(valuekind), type(type) {}
-
- bool hasName() const { return !name.empty(); }
-
- std::string_view getName() const { return name; }
- void setName(std::string name) { this->name = std::move(name); }
+ Value(ValueKind valuekind, Type type)
+ : valuekind(valuekind), type(type) {}
Type getType() const { return type; }
ValueKind getValueKind() const { return valuekind; }
@@ -78,12 +69,10 @@ public:
class Parameter : public Value {
public:
- Parameter(std::string name, Type type)
- : Value(ValueKind::Parameter, std::move(name), type) {}
+ Parameter(Type type)
+ : Value(ValueKind::Parameter, type) {}
};
} // namespace willow
-inline std::ostream &operator<<(std::ostream &os, const willow::Value &v);
-
#endif // WILLOW_INCLUDE_IR_VALUE_H
diff --git a/willow/include/willow/Util/LogicalResult.h b/willow/include/willow/Util/LogicalResult.h
index c09310f..c1d582d 100644
--- a/willow/include/willow/Util/LogicalResult.h
+++ b/willow/include/willow/Util/LogicalResult.h
@@ -24,7 +24,7 @@ inline LogicalResult success(bool is_success = true) {
return LogicalResult::success(is_success);
}
inline LogicalResult failure(bool is_failure = true) {
- return LogicalResult::success(is_failure);
+ return LogicalResult::failure(is_failure);
}
inline bool succeeded(LogicalResult r) { return r.succeeded(); }
inline bool failed(LogicalResult r) { return r.failed(); }
diff --git a/willow/lib/IR/IRBuilder.cpp b/willow/lib/IR/IRBuilder.cpp
new file mode 100644
index 0000000..62e7a98
--- /dev/null
+++ b/willow/lib/IR/IRBuilder.cpp
@@ -0,0 +1,210 @@
+#include <willow/IR/IRBuilder.h>
+
+namespace willow {
+
+AddInst *IRBuilder::BuildAdd(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new AddInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+MulInst *IRBuilder::BuildMul(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new MulInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+SubInst *IRBuilder::BuildSub(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new SubInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+DivInst *IRBuilder::BuildDiv(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new DivInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+ModInst *IRBuilder::BuildMod(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new ModInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+ShlInst *IRBuilder::BuildShl(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new ShlInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+ShrInst *IRBuilder::BuildShr(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new ShrInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+AshlInst *IRBuilder::BuildAshl(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new AshlInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+AshrInst *IRBuilder::BuildAshr(Type type, Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new AshrInst{type, lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+EqInst *IRBuilder::BuildEq(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new EqInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+LtInst *IRBuilder::BuildLt(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new LtInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+GtInst *IRBuilder::BuildGt(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new GtInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+LeInst *IRBuilder::BuildLe(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new LeInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+GeInst *IRBuilder::BuildGe(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new GeInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+AndInst *IRBuilder::BuildAnd(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new AndInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+OrInst *IRBuilder::BuildOr(Value *lhs, Value *rhs,
+ std::optional<Location> loc) {
+ auto *inst = new OrInst{ctx.types().IntType(1), lhs, rhs, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+NotInst *IRBuilder::BuildNot(Value *val, std::optional<Location> loc) {
+ auto *inst = new NotInst{ctx.types().IntType(1), val, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+JmpInst *IRBuilder::BuildJmp(BasicBlock *dst, std::optional<Location> loc) {
+
+ auto *inst = new JmpInst{ctx.types().VoidType(), dst, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+BrInst *IRBuilder::BuildBr(Value *predicate, BasicBlock *truedst,
+ BasicBlock *falsedst, std::optional<Location> loc) {
+ auto *inst =
+ new BrInst{ctx.types().VoidType(), predicate, truedst, falsedst, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+RetInst *IRBuilder::BuildRet(Value *val, std::optional<Location> loc) {
+ auto *inst = new RetInst{ctx.types().VoidType(), val, loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+RetInst *IRBuilder::BuildRet(std::optional<Location> loc) {
+ auto *inst = new RetInst{ctx.types().VoidType(), loc};
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+PhiInst *IRBuilder::BuildPhi(
+ Type ty, std::initializer_list<std::pair<BasicBlock *, Value *>> args,
+ std::optional<Location> loc) {
+ auto *inst = new PhiInst(ty, args, loc);
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+AllocaInst *IRBuilder::BuildAlloca(Type ty, std::optional<Location> loc) {
+ Type pty = ctx.types().PtrType(ty);
+ auto *inst = new AllocaInst(pty, loc);
+ insert_point->insertAfter(*inst);
+ insert_point++;
+
+ return inst;
+}
+
+}; // namespace willow
diff --git a/willow/lib/IR/TypeContext.cpp b/willow/lib/IR/TypeContext.cpp
index e69de29..21e5173 100644
--- a/willow/lib/IR/TypeContext.cpp
+++ b/willow/lib/IR/TypeContext.cpp
@@ -0,0 +1,31 @@
+#include <willow/IR/TypeContext.h>
+
+namespace willow {
+
+Type TypeContext::IntType(std::size_t width) {
+ auto [it, _] = icache.try_emplace(IntTypeImpl::Key{width},
+ std::make_unique<IntTypeImpl>(width));
+
+ return Type(it->second.get());
+}
+
+Type TypeContext::PtrType(Type pointee) {
+ auto [it, _] = pcache.try_emplace(
+ PtrTypeImpl::Key{pointee}, std::make_unique<PtrTypeImpl>(pointee));
+
+ return Type(it->second.get());
+}
+
+Type TypeContext::VoidType() { return Type(voidty.get()); }
+
+Type TypeContext::BasicBlockType() { return Type{labelty.get()}; }
+
+Type TypeContext::FunctionType(Type ret, std::initializer_list<Type> params) {
+ auto [it, _] =
+ fncache.try_emplace(FunctionTypeImpl::Key{ret, params},
+ std::make_unique<FunctionTypeImpl>(ret, params));
+
+ return Type(it->second.get());
+};
+
+};
diff --git a/willow/lib/IR/Value.cpp b/willow/lib/IR/Value.cpp
index 13e029f..0bd1079 100644
--- a/willow/lib/IR/Value.cpp
+++ b/willow/lib/IR/Value.cpp
@@ -1,27 +1,27 @@
-#include <willow/IR/Value.h>
-#include <willow/IR/Constant.h>
-#include <ostream>
-
-std::ostream &operator<<(std::ostream &os, const willow::Value &v) {
- using willow::ValueKind;
- auto ty = v.getType();
- if (!v.isVoid())
- os << ty << " ";
-
- switch (v.getValueKind()) {
- case ValueKind::Parameter:
- [[fallthrough]];
- case ValueKind::Instruction: {
- return os << "%" << v.getName();
- }
- case ValueKind::BasicBlock: {
- return os << "^" << v.getName();
- }
- case ValueKind::Function: {
- return os << "@" << v.getName();
- }
- case ValueKind::Constant: {
- return os << *static_cast<const willow::Constant*>(&v);
- }
- }
-}
+// #include <willow/IR/Value.h>
+// #include <willow/IR/Constant.h>
+// #include <ostream>
+//
+// std::ostream &operator<<(std::ostream &os, const willow::Value &v) {
+// using willow::ValueKind;
+// auto ty = v.getType();
+// if (!v.isVoid())
+// os << ty << " ";
+//
+// switch (v.getValueKind()) {
+// case ValueKind::Parameter:
+// [[fallthrough]];
+// case ValueKind::Instruction: {
+// return os << "%" << v.getName();
+// }
+// case ValueKind::BasicBlock: {
+// return os << "^" << v.getName();
+// }
+// case ValueKind::Function: {
+// return os << "@" << v.getName();
+// }
+// case ValueKind::Constant: {
+// return os << *static_cast<const willow::Constant*>(&v);
+// }
+// }
+// }
diff --git a/willow/lib/IR/Verifier.cpp b/willow/lib/IR/Verifier.cpp
index d19bc83..b622f10 100644
--- a/willow/lib/IR/Verifier.cpp
+++ b/willow/lib/IR/Verifier.cpp
@@ -1,3 +1,4 @@
+#include <iostream>
#include <willow/IR/BasicBlock.h>
#include <willow/IR/Diagnostic.h>
#include <willow/IR/DiagnosticEngine.h>
@@ -75,20 +76,19 @@ LogicalResult verifyBasicBlock(WillowContext &ctx, const BasicBlock &BB,
DiagnosticEngine &diags) {
if (BB.empty())
return emit(diags, Severity::Error, BB.getLoc())
- << "Basic block '" << BB.getName() << "' has an empty body";
+ << "Basic block has empty body";
- if (!BB.trailer()->isTerminator())
+ if (!BB.trailer().isTerminator())
return emit(diags, Severity::Error, BB.getLoc())
- << "Basic block '" << BB.getName()
- << "' does not end with a terminator";
+ << "Basic block does not end with a terminator";
bool has_failed = false;
- for (auto &inst : BB.getBody()) {
+ for (auto& inst : BB) {
// verify inst
if (failed(verifyInst(ctx, inst, diags)))
has_failed = true;
- if (&inst != BB.trailer() && inst.isTerminator())
+ if (&inst != &BB.trailer() && inst.isTerminator())
return emit(diags, Severity::Error, BB.getLoc())
<< "Illegal terminator in the middle of a block";
}
@@ -219,11 +219,9 @@ LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst,
if (aty == pty)
continue;
- DiagnosticBuilder d(diags, Severity::Error, inst.getLoc());
- d << "invalid argument: expected '" << pty << "', got '" << aty << "'";
- if (param.hasName())
- d.note(Diagnostic{Severity::Remark,
- std::format("param name: {}", param.getName())});
+ return emit(diags, Severity::Error, inst.getLoc())
+ << "invalid argument: expected '" << pty << "', got '" << aty
+ << "'";
return failure();
}
@@ -266,8 +264,7 @@ LogicalResult verifyInst(WillowContext &ctx, const Instruction &inst,
if (BB && fn && (pred->getParent() != fn))
return emit(diags, Severity::Error, inst.getLoc())
- << "basic block: '" << pred->getName()
- << "' is not a child of function '" << fn->getName() << "'";
+ << "invalid predecessor";
}
return success();
@@ -344,14 +341,14 @@ LogicalResult verifyBinaryIntegerCmp(WillowContext &ctx,
}
LogicalResult verifyResult(const Instruction &inst, DiagnosticEngine &diags) {
- if (inst.hasName())
- return success();
+ if (!inst.getType().isVoid())
+ return success();
return emit(diags, Severity::Error, inst.getLoc()) << "expected ssa result";
}
LogicalResult verifyNoResult(const Instruction &inst, DiagnosticEngine &diags) {
- if (!inst.hasName())
+ if (inst.getType().isVoid())
return success();
return emit(diags, Severity::Error, inst.getLoc()) << "unexpected ssa result";
diff --git a/willow/unittest/BUILD.bazel b/willow/unittest/BUILD.bazel
index 8f1c55d..bfd0430 100644
--- a/willow/unittest/BUILD.bazel
+++ b/willow/unittest/BUILD.bazel
@@ -1,6 +1,6 @@
test_suite(
name = "unittest",
tests = [
- "//willow/unittest/ir:ir_tests",
+ "//willow/unittest/IR:ir_tests",
]
)
diff --git a/willow/unittest/ir/BUILD.bazel b/willow/unittest/IR/BUILD.bazel
index b41dfcd..b41dfcd 100644
--- a/willow/unittest/ir/BUILD.bazel
+++ b/willow/unittest/IR/BUILD.bazel
diff --git a/willow/unittest/IR/VerifierTest.cpp b/willow/unittest/IR/VerifierTest.cpp
new file mode 100644
index 0000000..efe34db
--- /dev/null
+++ b/willow/unittest/IR/VerifierTest.cpp
@@ -0,0 +1,82 @@
+#include <catch2/catch_test_macros.hpp>
+
+#include <willow/IR/Context.h>
+#include <willow/IR/Diagnostic.h>
+#include <willow/IR/DiagnosticEngine.h>
+#include <willow/IR/Function.h>
+#include <willow/IR/IRBuilder.h>
+#include <willow/IR/Module.h>
+#include <willow/IR/Verifier.h>
+
+using namespace willow;
+
+TEST_CASE("valid modules", "[verifier]") {
+ WillowContext ctx;
+ std::vector<Diagnostic> diags;
+ DiagnosticEngine eng([&](Diagnostic d) { diags.push_back(std::move(d)); });
+
+ auto &m = *ctx.addModule("test");
+ SECTION("empty module") {
+ REQUIRE(succeeded(verifyModule(ctx, m, eng)));
+ REQUIRE(diags.empty());
+ }
+}
+
+TEST_CASE("valid function", "[verifier]") {
+ WillowContext ctx;
+ std::vector<Diagnostic> diags;
+ DiagnosticEngine eng([&](Diagnostic d) { diags.push_back(std::move(d)); });
+
+ auto &m = *ctx.addModule("test");
+
+ Type fty = ctx.types().FunctionType(ctx.types().VoidType(), {});
+ auto &fn = *m.emplaceFunction("fn", &m, fty);
+
+ REQUIRE(succeeded(verifyFunction(ctx, fn, eng)));
+ REQUIRE(diags.empty());
+}
+
+TEST_CASE("invalid basic block", "[verifier]") {
+ WillowContext ctx;
+ std::vector<Diagnostic> diags;
+ DiagnosticEngine eng([&](Diagnostic d) { diags.push_back(std::move(d)); });
+
+ Type i64Ty = ctx.types().IntType(64);
+ Type voidTy = ctx.types().VoidType();
+ auto *one = ctx.constants().getInt(i64Ty, 1);
+
+ auto &m = *ctx.addModule("test");
+ Type fty = ctx.types().FunctionType(ctx.types().VoidType(), {});
+ Type ifty = ctx.types().FunctionType(i64Ty, {});
+ auto &fn = *m.emplaceFunction("fn", &m, fty);
+ auto &fn2 = *m.emplaceFunction("fn2", &m, ifty);
+ auto *bb = fn.addBlock(
+ std::make_unique<BasicBlock>(&fn, ctx.types().BasicBlockType()));
+ IRBuilder builder{ctx, *bb, bb->end()};
+
+ SECTION("Empty basic block") {
+ REQUIRE(bb->empty());
+ REQUIRE(failed(verifyBasicBlock(ctx, *bb, eng)));
+ }
+
+ SECTION("Basic block with no terminator") {
+ builder.BuildAdd(i64Ty, one, one);
+ REQUIRE(failed(verifyBasicBlock(ctx, *bb, eng)));
+ }
+
+ SECTION("Teminator must be the last instruction in a basic block") {
+ builder.BuildCall(i64Ty, &fn2);
+ builder.BuildAdd(i64Ty, one, one);
+ REQUIRE(failed(verifyBasicBlock(ctx, *bb, eng)));
+ }
+
+ SECTION("Basic block with invalid instruction") {
+ auto *bb2 = fn.addBlock(
+ std::make_unique<BasicBlock>(&fn, ctx.types().BasicBlockType()));
+ builder.SetInsertPoint(bb2->end());
+ builder.BuildAdd(voidTy, one, one);
+ builder.BuildRet();
+
+ REQUIRE(failed(verifyBasicBlock(ctx, *bb, eng)));
+ }
+}
diff --git a/willow/unittest/ir/VerifierTest.cpp b/willow/unittest/ir/VerifierTest.cpp
deleted file mode 100644
index 959d72a..0000000
--- a/willow/unittest/ir/VerifierTest.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-#include <catch2/catch_test_macros.hpp>
-
-#include <willow/IR/Context.h>
-#include <willow/IR/Module.h>
-#include <willow/IR/Verifier.h>
-#include <willow/IR/Diagnostic.h>
-#include <willow/IR/DiagnosticEngine.h>
-#include <willow/IR/Function.h>
-
-using namespace willow;
-
-TEST_CASE("valid modules", "[verifier]") {
- WillowContext ctx;
- std::vector<Diagnostic> diags;
- DiagnosticEngine eng(
- [&](Diagnostic d) { diags.push_back(std::move(d)); });
-
- auto &m = *ctx.addModule("test");
- SECTION("empty module") {
- REQUIRE(succeeded(verifyModule(ctx, m, eng)));
- REQUIRE(diags.empty());
- }
-}
-
-TEST_CASE("valid function", "[verifier]") {
- WillowContext ctx;
- std::vector<Diagnostic> diags;
- DiagnosticEngine eng(
- [&](Diagnostic d) { diags.push_back(std::move(d)); });
-
- auto &m = *ctx.addModule("test");
-
- Type fty = ctx.types().FunctionType(ctx.types().VoidType(), {});
- auto &fn = *m.emplaceFunction("fn", &m, fty);
-
- REQUIRE(succeeded(verifyFunction(ctx, fn, eng)));
- REQUIRE(diags.empty());
-}
-
-TEST_CASE("invalid basic block", "[verifier]") {
- WillowContext ctx;
- std::vector<Diagnostic> diags;
- DiagnosticEngine eng(
- [&](Diagnostic d) { diags.push_back(std::move(d)); });
-
- auto &m = *ctx.addModule("test");
-
- Type fty = ctx.types().FunctionType(ctx.types().VoidType(), {});
- auto &fn = *m.emplaceFunction("fn", &m, fty);
- // TODO
-}