11 #include "llvm/IR/Function.h"
12 #include "llvm/Pass.h"
13 #include "llvm/Support/raw_ostream.h"
14 #include "llvm/Support/FileSystem.h"
15 #include "llvm/Support/CommandLine.h"
16 #include "llvm/Analysis/LoopPass.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/GetElementPtrTypeIterator.h"
20 #include "llvm/IR/GlobalVariable.h"
21 #include "llvm/IR/Module.h"
23 #if LLVM_VERSION_MAJOR >= 14
24 #include "llvm/Passes/PassPlugin.h"
25 #include "llvm/Passes/PassBuilder.h"
47 cl::opt<std::string> inputTagPairs(
"in-tag-pairs", cl::Positional, cl::desc(
"<input file that contains tag number and string pairs>"));
49 cl::opt<std::string> loopTags(
"loop-tags", cl::desc(
"Input a list of loop tag names to generate DFG for"));
59 case Instruction::FAdd : errs() <<
"Note: converting FAdd to integer op\n";
return OpCode::ADD;
61 case Instruction::FSub : errs() <<
"Note: converting FSub to integer op\n";
return OpCode::SUB;
70 case Instruction::FRem : errs() <<
"Note: converting FRem to integer op\n";
return OpCode::DIV;
73 case Instruction::FDiv : errs() <<
"Note: converting FDiv to integer op\n";
return OpCode::DIV;
75 case Instruction::FMul : errs() <<
"Note: converting FMul to integer op\n";
return OpCode::MUL;
78 case Instruction::GetElementPtr :
return OpCode::GEP;
95 default: errs() <<
"could not look up:" << I <<
"\n"; std::abort();
112 std::map<const Value*,std::string> instruction_unique_part = {};
113 std::map<std::string, int> other_category_counts = {};
115 OpNameMaker(
const Loop* L) {
117 for (
const auto& inst : *bb) {
119 std::string name =
static_cast<std::string
>(inst.getName());
120 if (name ==
"") { name =
'i' +
std::to_string(instruction_unique_part.size()) +
'_'; }
121 instruction_unique_part.emplace(&inst,std::move(name));
127 std::string operator()(
const Value* val, std::string s_arg =
"") {
130 const auto lookup_result = instruction_unique_part.find(val);
131 if (lookup_result != instruction_unique_part.end()) {
132 return lookup_result->second + std::move(s_arg);
136 std::string result =
static_cast<std::string
>(val->getName());
138 if (result.empty()) {
139 std::string category_name =
"input";
142 if (isa<GlobalValue>(val)) { category_name =
"global"; }
143 else if (isa<Constant>(val)) { category_name =
"const"; }
144 else if (isa<BasicBlock>(val)) { category_name =
"bb"; }
146 auto& category_count = other_category_counts[category_name];
150 return result + s_arg;
154 std::string operator()(
const Value* inst,
const char* s_arg) {
155 return (*
this)(inst, std::string(s_arg));
159 template<
typename Arg1,
typename... Args>
160 std::string operator()(
const Value* inst,
const std::string& s_arg, Arg1& arg1, Args&... args) {
165 template<
typename Arg1,
typename... Args>
166 std::string operator()(
const Value* inst,
const char* s_arg, Arg1& arg1, Args&... args) {
171 template<
typename Arg1,
typename... Args>
172 std::string operator()(
const Value* inst, Arg1& arg1, Args&... args) {
177 std::string operandTypeFor(
const OpGraphOp* op,
int op_num) {
178 if (commutative_ops.find(op->
opcode) != commutative_ops.end()) {
188 else if (op_num == 1) {
197 case 0:
return "branch_cond";
198 case 1:
return "branch_true";
199 case 2:
return "branch_false";
202 else if (op_num == 0) {
205 else if (op_num == 1) {
209 errs() <<
"Unhandled case for operand setting"; std::abort();
216 std::pair<bool, uint64_t> try_extract_integral_constant(
const Value& val) {
217 if(
const auto& as_constant = dyn_cast<Constant>(&val)) {
218 const auto& stripped_const = as_constant->stripPointerCasts();
220 if (
const auto& as_const_int = dyn_cast<ConstantInt>(stripped_const)) {
221 return {
true, as_const_int->getZExtValue()};
223 }
else if (
const auto& as_global_var = dyn_cast<GlobalVariable>(stripped_const)) {
224 if (as_global_var->hasInitializer()) {
225 try_extract_integral_constant(*as_global_var->getInitializer());
227 }
else if (
const auto& as_constant_expr = dyn_cast<ConstantExpr>(stripped_const)) {
228 if (as_constant_expr->isCast()) {
229 return try_extract_integral_constant(*stripped_const->getOperand(0));
231 errs() <<
"Warning: unable to extract a value from a ConstantExpr: " << *stripped_const <<
"\n";
234 errs() <<
"Warning: unable to extract a value from a Constant: " << *stripped_const <<
"\n";
241 std::vector<OpGraph*> graphs = {};
242 std::map<unsigned int, std::string> tag_pairs = {};
243 std::vector<std::string> loop_tags = {};
246 if(!inputTagPairs.empty())
248 std::ifstream in(inputTagPairs);
249 for(std::string line; std::getline(in, line); )
251 std::stringstream temp_sstream(line);
253 temp_sstream >> temp_tag_num;
254 std::string temp_tag_string;
255 temp_sstream >> temp_tag_string;
256 tag_pairs.emplace(temp_tag_num, std::move(temp_tag_string));
258 std::stringstream temp_sstream(loopTags);
259 std::string temp_tag_string;
260 while(temp_sstream >> temp_tag_string)
261 loop_tags.push_back(temp_tag_string);
264 errs() <<
"Warning: No tag pair is provided as input, no DFG will be generated." <<
"\n";
267 bool runOnLoop(Loop* L)
270 int found_tag_num = 0;
271 Instruction* the_loop_tag =
nullptr;
272 for(
auto it = L->getHeader()->begin(); it != L->getHeader()->end(); ++it)
274 if(isa<CallInst>(it))
276 Function * func = dyn_cast<CallInst>(it)->getCalledFunction();
277 if((func !=
nullptr) && (func->getName() ==
"DFGLOOP_TAG")) {
278 found_tag_num = cast<ConstantInt>(dyn_cast<CallInst>(it)->getArgOperand(0))->getValue().getZExtValue();
286 std::string tag_name;
288 auto tag_pairs_it = tag_pairs.find(found_tag_num);
289 if(tag_pairs_it == tag_pairs.end())
291 errs() <<
"Error: Tag could not be found from the generated script, ignoring this loop." <<
"\n";
296 auto tag_string_it = std::find(loop_tags.begin(), loop_tags.end(), tag_pairs_it->second);
297 if(tag_string_it == loop_tags.end())
302 tag_name = *tag_string_it;
305 if (!L->getSubLoops().empty())
308 unsigned int bb_count = L->getBlocks().size();
311 errs() <<
"Error: Loop with tag: " << tag_name <<
" is not supported. This loop is ignored." <<
"\n";
316 OpNameMaker makeOpName {L};
323 auto llvm_value_to_dfg_op = [&]() {
324 std::map<const Value*, OpGraph::OpDescriptor> result;
326 for(
const auto& bb : L->blocks()) {
327 for(
const auto& inst : *bb) {
328 if (&inst == the_loop_tag) {
continue; }
331 result[&inst] = og.
emplace(makeOpName(&inst, inst.getOpcodeName()), 32, LLVMtoOp(inst));
334 for (
const auto& operand : inst.operands()) {
335 const auto& operand_as_value = *operand;
338 if (result.find(&operand_as_value) != result.end()) {
continue; }
341 const auto* operand_as_inst_ptr = dyn_cast<Instruction>(&operand_as_value);
342 if (operand_as_inst_ptr && L->contains(operand_as_inst_ptr)) {
continue; }
347 const auto op_name = makeOpName(operand);
350 const auto success_and_value = try_extract_integral_constant(*operand);
351 if (success_and_value.first) {
352 return {std::move(op_name), 32,
OpCode::CONST, (std::int64_t)success_and_value.second};
365 for(
const auto& bb : L->blocks()) {
366 for(
const auto& inst : *bb) {
367 if (&inst == the_loop_tag) {
continue; }
368 const auto& op_for_inst = llvm_value_to_dfg_op.at(&inst);
370 int operand_num = -1;
371 for (
const auto& operand : inst.operands()) {
373 const auto& operand_as_value = *operand;
374 const auto& operand_op = llvm_value_to_dfg_op.at(&operand_as_value);
375 og.
link(operand_op, op_for_inst, operandTypeFor(op_for_inst, operand_num));
380 if(std::any_of(inst.user_begin(), inst.user_end(), [&](
const auto& user) {
381 const auto* user_as_instruction_ptr = dyn_cast<Instruction>(&*user);
382 return not user_as_instruction_ptr || not L->contains(user_as_instruction_ptr);
393 auto dfg_op_to_llvm_value = [&]() {
394 std::map<OpGraph::OpDescriptor, const Value*> result;
395 for (
const auto& value_and_op : llvm_value_to_dfg_op) {
396 result.emplace(value_and_op.second, value_and_op.first);
402 const auto orig_geps =
filter_collection(og.
opNodes(), [&og](
auto&& op) { return og.getNodeRef(op).getOpCode() == OpGraphOpCode::GEP; });
403 for (
const auto& gep_op : orig_geps) {
404 const auto inst_ptr = dyn_cast<GetElementPtrInst>(dfg_op_to_llvm_value.at(gep_op));
406 errs() <<
"Warning: Instruction " << *dfg_op_to_llvm_value.at(gep_op)
407 <<
" for GEP op node " << og.
getNodeRef(gep_op).
getName() <<
" is not a GEP. Will not lower";
410 const auto& inst = *inst_ptr;
416 {gep_type_iterator GTI = gep_type_begin(&inst);
417 int gep_operand_num = 1;
418 for(
auto operand_it = std::next(inst.op_begin()); operand_it != inst.op_end(); ++operand_it, ++GTI, ++gep_operand_num) {
419 const auto data_size = inst.getParent()->getModule()->getDataLayout().getTypeAllocSize(GTI.getIndexedType());
420 const auto data_size_op = og.
emplace(makeOpName(&inst,
"data_size", gep_operand_num), 32,
OpCode::CONST, data_size);
421 const auto mult_by_size = og.
emplace(makeOpName(&inst,
"mul", gep_operand_num), 32,
OpCode::MUL);
422 const auto add_to_prev = og.
emplace(makeOpName(&inst,
"add", gep_operand_num), 32,
OpCode::ADD);
431 for (
const auto& edge : og.
outEdges(gep_op)) {
436 llvm_value_to_dfg_op[inst_ptr] = tip;
437 dfg_op_to_llvm_value[tip] = inst_ptr;
454 const auto op_print_ranking = [&]() {
455 std::map<OpGraph::OpDescriptor, int> result;
457 for (
const auto& inst : *bb) {
458 const auto lookup_result = llvm_value_to_dfg_op.find(&inst);
459 if (lookup_result == llvm_value_to_dfg_op.end()) {
continue; }
462 std::vector<OpGraph::OpDescriptor> to_visit = og.
inputOps(lookup_result->second);
463 std::set<OpGraph::OpDescriptor> visited;
464 std::vector<OpGraph::OpDescriptor> dfs_order;
465 while (not to_visit.empty()) {
466 const auto curr = to_visit.back();
468 if (not visited.insert(curr).second) {
continue; }
469 const auto& inst_lookup = dfg_op_to_llvm_value.find(curr);
470 if (inst_lookup != dfg_op_to_llvm_value.end()) {
471 if (
const auto inst = dyn_cast<Instruction>(inst_lookup->second)) {
472 if (L->contains(inst)) {
continue; }
475 const auto& inputs = og.
inputOps(curr);
476 std::copy(inputs.begin(), inputs.end(), std::back_inserter(to_visit));
477 dfs_order.push_back(curr);
481 for (
auto it = dfs_order.rbegin(); it != dfs_order.rend(); ++it) {
482 result.insert({*it,
static_cast<int>(result.size())});
486 result.insert({lookup_result->second,
static_cast<int>(result.size())});
494 std::ofstream f(
"graph_" + tag_name +
".dot", std::ios::out);
501 struct LegacyDFGOut :
public LoopPass
504 DfgOutImpl impl = {};
506 LegacyDFGOut() : LoopPass(ID) {}
508 virtual bool runOnLoop(Loop* L, LPPassManager&) {
509 return impl.runOnLoop(L);
512 virtual bool doFinalization()
521 char LegacyDFGOut::ID = 0;
522 static RegisterPass<LegacyDFGOut>
X(
"dfg-out",
"DFG(Data Flow Graph) Output Pass",
527 #if LLVM_VERSION_MAJOR >= 14
529 struct DfgOut : PassInfoMixin<DfgOut> {
530 DfgOutImpl impl = {};
531 PreservedAnalyses run(Loop& L, LoopAnalysisManager& , LoopStandardAnalysisResults& , LPMUpdater&) {
532 return impl.runOnLoop(&L) ? PreservedAnalyses::none() : PreservedAnalyses::all();
537 extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() {
539 LLVM_PLUGIN_API_VERSION,
542 [](PassBuilder& PB) {
543 PB.registerPipelineParsingCallback([](StringRef Name, LoopPassManager& PM, ArrayRef<PassBuilder::PipelineElement>) {
544 if (Name ==
"dfg-out") {
545 PM.addPass(DfgOut());