32 const auto opcode_names = []{
33 std::vector<std::pair<std::string, OpGraphOpCode>> result {
83 std::sort(result.begin(), result.end());
90 const auto opcode_to_name = []{
91 std::map<OpGraphOpCode, std::string> result;
92 for (
const auto& name_and_code : opcode_names) {
93 result.emplace(name_and_code.second, name_and_code.first);
101 const auto lookup_result = opcode_to_name.find(opcode);
102 if (lookup_result != opcode_to_name.end()) {
103 return os << lookup_result->second;
106 throw make_from_stream<cgrame_error>([&,fname=__func__](
auto&& s) {
107 s << fname <<
"can't print opcode with value " << +opcode;
114 const auto name_comp_less = [](
auto&& lhs_elem,
auto& rhs_val) {
115 const auto comp = [](
auto&& lhs,
auto& rhs) {
return std::tolower(lhs) < std::tolower(rhs); };
116 auto first1 = lhs_elem.first.begin();
117 auto first2 = rhs_val.begin();
118 const auto last1 = lhs_elem.first.end();
119 const auto last2 = rhs_val.end();
122 for ( ; (first1 != last1) && (first2 != last2); ++first1, (void) ++first2 ) {
123 if (comp(*first1, *first2))
return true;
124 if (comp(*first2, *first1))
return false;
126 return (first1 == last1) && (first2 != last2);
129 const auto search_result = std::lower_bound(opcode_names.begin(), opcode_names.end(), str, name_comp_less);
131 if (search_result == opcode_names.end() || name_comp_less(std::make_pair(str,0), search_result->first)) {
132 throw make_from_stream<std::logic_error>([&,fname=__func__](
auto&& s) {
133 s << fname <<
" can't convert '" << str <<
"' into an opcode";
136 return search_result->second;
161 make_and_throw<std::logic_error>([&](
auto& s) {
162 s <<
"Trying to create a op graph const without initializing a value";
166 make_and_throw<std::logic_error>([&](
auto& s) {
167 s <<
"Trying to create a op graph cmp without initializing a mode";
174 , bitwidth(_bitwidth)
176 , bitConfig(_bitConfig)
180 make_and_throw<std::logic_error>([&](
auto& s) {
181 s <<
"Trying to create a op graph const without initializing a value";
185 make_and_throw<std::logic_error>([&](
auto& s) {
186 s <<
"Trying to create a op graph cmp without initializing a mode";
199 make_and_throw<std::logic_error>([&](
auto& s) {
200 s <<
"Trying to assign a value to an op that is not const";
217 make_and_throw<std::logic_error>([&](
auto& s) {
218 s <<
"Trying to assign a value to an op that is not cmp or memory access";
236 if (*(
const OpGraphNode*)
this != rhs) {
return false; }
263 this->
input = input_op;
271 const auto ival_output_pos = std::find(
output.begin(),
output.end(), op);
272 const auto dist = std::distance(
output.begin(), ival_output_pos);
277 const auto ival_output_pos = std::find(
output.begin(),
output.end(), op);
278 const auto dist = std::distance(
output.begin(), ival_output_pos);
321 std::map<OpDescriptor, OpDescriptor> src_to_this;
322 for (
const auto& src_op : src.
opNodes()) {
325 for (
const auto& src_op : src.
opNodes()) {
326 const auto src_val = src.
outputVal(src_op);
327 int output_index = -1;
328 for (
const auto& sink_op : src.
outputOps(src_val)) {
330 link(src_to_this.at(src_op), src_to_this.at(sink_op),
331 src.
getNodeRef(src_val).output_operand.at(output_index),
342 make_and_throw<std::logic_error>([&](
auto& s) {
343 s <<
"Ops should have unique names. An op with the name " << op->
getName() <<
" already exists";
354 std::string operand_group,
358 auto fanout_op =
insert(std::move(fanout_to_move_in));
359 auto val =
link(driver, fanout_op, operand_group, bitwidth, dist, kind);
361 return { val, fanout_op };
366 std::string operand_group,
385 make_and_throw<std::logic_error>([&](
auto& s) {
386 s <<
"Vals should have unique names. A val with the name " << val->
getName() <<
" already exists";
399 driver_val_ref.output_operand.emplace_back(operand_group);
400 driver_val_ref.output_predicate.emplace_back(predicate);
412 driver_val_ref.output_operand.emplace_back(operand_group);
419 return link(driver, fanout, getOperandTag(base), getBitwidth(base), getDist(base), getKind(base));
425 const auto ival_output_pos = std::find(ival.output.begin(), ival.output.end(),
fanout);
426 if (ival_output_pos == ival.output.end()) {
427 throw make_from_stream<std::logic_error>([&](
auto&& s) {
428 s <<
"trying to unlink " << driver_val <<
" -> "
429 <<
fanout <<
" but the latter op is not a fanout of the former val";
434 auto& fo_ins = oval.input;
435 const auto oval_input_pos = std::find(fo_ins.begin(), fo_ins.end(), driver_val);
437 fo_ins.erase(oval_input_pos);
440 const auto dist = std::distance(ival.output.begin(), ival_output_pos);
441 ival.output.erase(ival_output_pos);
442 ival.output_operand.erase(std::next(ival.output_operand.begin(), dist));
444 if (ival.output.empty()) {
446 const auto driver = ival.input;
454 const auto pos_in_op_nodes = std::find(
op_nodes.begin(),
op_nodes.end(), op);
455 if (pos_in_op_nodes ==
op_nodes.end()) {
456 throw make_from_stream<std::logic_error>([&](
auto&& s) {
457 s <<
"trying to erase op " << op <<
" which is not part of the OpGraph " <<
this;
461 const auto input_vals_copy =
inputVals(op);
462 for (
auto& ival : input_vals_copy) {
466 const auto output_ops_copy =
outputOps(op);
467 for (
auto& oop : output_ops_copy) {
477 if (not op_desc) {
return null_val; }
479 if (op.output) {
return op.output; }
484 if (not op_desc) {
return {}; }
498 std::vector<OpDescriptor> result;
499 const auto& ivals = inputVals(op);
500 std::transform(ivals.begin(), ivals.end(), std::back_inserter(result), [
this](
const auto& ival) { return this->inputOp(ival); });
505 if (not val) {
return null_op; }
506 return getNodeRef(val).
input;
526 if (not ed) {
return null_op; }
527 return getNodeRef(ed.val).
output.at(ed.output_index);
531 std::vector<EdgeDescriptor> result;
532 for (
const auto& out_val : outputVals(op)) {
533 for (
int i = 0; i < (std::ptrdiff_t)outputOps(out_val).size(); ++i) {
534 result.push_back({out_val, i});
541 std::vector<EdgeDescriptor> result;
542 for (
const auto& target_ival : inputVals(op)) {
543 for (
const auto& edge_to_target : outEdges(inputOp(target_ival))) {
544 if (targetOfEdge(edge_to_target) != op) {
continue; }
545 result.push_back(edge_to_target);
578 unsigned int num_scheduled = 0;
582 if(v->input.size() == 0)
584 schedule.
cycle[v] = 0;
590 schedule.
cycle[v] = -1;
596 while(num_scheduled < V.size())
602 for(
int i = 0; v && i < (std::ptrdiff_t)v->input.size(); i++)
605 std::cout <<
"Path latency (" << *(in->
input) <<
", " << *v <<
"): " <<
"\n";
606 max = std::max(max, schedule.
cycle.at(in->
input) );
616 schedule.
cycle.at(v) = max;
619 max_latency = std::max(max, max_latency);
624 schedule.
latency = max_latency;
635 unsigned int num_scheduled = 0;
640 if(v->output ==
nullptr)
642 schedule.
cycle[v] = max_cycles;
648 schedule.
cycle[v] = -1;
652 while(num_scheduled < V.size())
658 for(
int i = 0; v && i < (std::ptrdiff_t)v->output->output.size(); i++)
663 std::cout <<
"Path latency (" << *v <<
", " << *succ <<
"): " <<
"\n";
664 min = std::min(min, schedule.
cycle.at(succ) );
666 if(schedule.
cycle.at(succ) == -1)
675 schedule.
cycle.at(v) = min;
686 std::vector<VerifyMessage> result;
693 (void)add_warning; (void)add_info;
696 for(
auto op : this->
opNodes()) {
698 add_error([&](
auto&& s) { s <<
"found null op\n"; });
703 add_error([&](
auto&& s) { s <<
"found null val\n"; });
705 const auto expected_size = val->output.size();
706 if (val->output_operand.size() != expected_size) {
707 add_error([&](
auto&& s) { s <<
"operand index list size does not match number of outputs\n"; });
712 for(
auto temp_val : this->
valNodes())
714 if (not temp_val) {
continue; }
716 auto& output = temp_val->output;
717 for(
auto temp_op : output)
719 if (not temp_op) {
continue; }
721 auto temp_it = std::find(temp_op->input.begin(), temp_op->input.end(), temp_val);
722 if (temp_it == temp_op->input.end()) {
723 add_error([&](
auto&& s) { s <<
"expected to find " << *temp_val <<
" in input list of " << *temp_op; });
728 for(
auto temp_op : this->
opNodes())
730 if (not temp_op) {
continue; }
732 auto& input = temp_op->input;
733 for(
auto temp_val : input)
735 if (not temp_val) {
continue; }
737 auto temp_it = std::find(temp_val->output.begin(), temp_val->output.end(), temp_op);
738 if (temp_it == temp_val->output.end()) {
739 add_error([&](
auto&& s) { s <<
"expected to find " << *temp_op <<
" in output list of " << *temp_val; });
752 throw cgrame_error (
"Operation not found within the opgraph");
776 throw cgrame_error (
"Index is larger than the number of operations");
784 throw cgrame_error (
"Index is larger than the number of operations");
791 const auto verify_output = opgraph.
verify();
793 if (throw_if_errors && found_errors) {
794 make_and_throw<cgrame_error>([&](
auto&& s) { s <<
"OpGraph verify check failed! check stdout for results"; });
796 return not found_errors;
803 const auto contains_error = std::any_of(
begin(messages),
end(messages), [](
auto&& msg) {
return msg.type == Type::Error; });
804 const auto has_messages = !messages.empty();
805 const auto print_all_messages = contains_error || not silent_on_no_errors;
807 if (print_all_messages) {
808 if (contains_error) {
809 os <<
"OpGraph verify FAILED";
811 os <<
"OpGraph verify passed";
815 os <<
". Begin report:\n";
817 for (
auto& msg : messages) {
818 os << msg.type <<
": " << msg.message <<
'\n';
821 os <<
"End OpGraph verify report\n";
823 os <<
", and nothing to report.\n";
828 return contains_error;
834 static void dfs_visit(
int time ,
int & longest_cycle, std::map<OpGraphOp*, int> & dfs_colour, std::map<OpGraphOp*, int> & dfs_timestamp,
OpGraphOp* op)
839 dfs_timestamp[op] = time;
851 if(dfs_colour[n] == 0)
853 dfs_visit(time, longest_cycle, dfs_colour, dfs_timestamp, n);
855 else if(dfs_colour[n] == 1)
857 int size = dfs_timestamp[op] - dfs_timestamp[n] + 1;
859 if(size > longest_cycle)
860 longest_cycle = size;
871 std::map<OpGraphOp*, int> dfs_colour;
872 std::map<OpGraphOp*, int> dfs_timestamp;
884 if(dfs_colour[op] == 0)
887 dfs_visit(0, result, dfs_colour, dfs_timestamp, op);
895 std::reverse(cyclic_trail.begin(), cyclic_trail.end());
896 auto new_begin = std::min_element(cyclic_trail.begin(), cyclic_trail.end(), [&](
const auto& lhs,
const auto& rhs) {
897 return opgraph.getNodeRef(lhs.val).getName() < opgraph.getNodeRef(rhs.val).getName();
899 std::rotate(cyclic_trail.begin(), new_begin, cyclic_trail.end());
908 std::map<EdgeDescriptor,int> result;
909 for (
const auto& op : this->
opNodes()) {
910 for (
const auto& e : this->
outEdges(op)) {
915 std::set<std::vector<EdgeDescriptor>> cycles_found;
918 std::set<OpDescriptor> ops_visited;
922 std::set<EdgeDescriptor> nte_and_convergence_edges_found;
931 if (ops_visited.find(op) != ops_visited.end()) {
continue; }
935 {std::set<OpDescriptor> ancestors;
938 if (
fanin.empty() || not ancestors.insert(op).second) {
break; }
943 const OpGraph* opgraph =
nullptr;
944 std::set<OpDescriptor> examined_op_targets = {};
945 std::set<EdgeDescriptor> non_tree_edges = {};
949 if (not examined_op_targets.insert(opgraph->
targetOfEdge(e)).second) {
954 visitor.opgraph =
this;
955 visitor.examined_op_targets.insert(op);
963 for (
const auto nte : visitor.non_tree_edges) {
965 if (not nte_and_convergence_edges_found.insert(nte).second) {
continue; }
969 const auto reverse_search_tree = galgos.wavedBreadthFirstVisit(
970 makeReversedGraphFromFaninLists<EdgeDescriptor>(&search_tree), {nte},
974 if (found) { cycle_edge = e; }
979 auto traceback = galgos.singleTraceback(
singleItemSet(nte), cycle_edge, reverse_search_tree);
980 if (not traceback.success) {
throw make_from_stream<cgrame_error>([fname=__func__](
auto&& s) {
981 s << fname <<
" couldn't find original non-tree edge when searching a reverse search tree";
983 if (nte == cycle_edge) {
984 traceback.path.pop_back();
987 if (cycles_found.insert(reverse_and_canonicalize(std::move(traceback.path), *
this)).second) {
994 std::copy(visitor.examined_op_targets.begin(), visitor.examined_op_targets.end(), std::inserter(ops_visited, ops_visited.end()));
1003 unsigned int counter = 0;
1004 s <<
"digraph G {\n";
1008 std::map<OpGraphOp*, std::string> opnode_map;
1013 s << opnode_map[(*it)] <<
"[opcode=" << (*it)->opcode;
1015 s <<
", constVal=" << (*it)->constVal;
1017 s <<
", cmpMode=" << (*it)->cmpMode;
1019 s <<
", memName=" << (*it)->memName;
1028 std::string inputnode = opnode_map[(*it)->input];
1029 for(
unsigned int o = 0; o < (*it)->output.size(); o++)
1033 s << inputnode <<
"->" << opnode_map[op] <<
"[operand=" << operand <<
"]; ";
1034 s <<
"//" << (*it)->input->name <<
"->" << op->
name <<
"\n";
1042 unsigned int counter = 0;
1043 s <<
"digraph G {\n";
1047 std::map<OpGraphOp*, std::string> opnode_map;
1052 s << opnode_map[(*it)] <<
"[opcode=" << (*it)->opcode <<
"];\n";
1058 std::string inputnode = opnode_map[(*it)->input];
1059 for(
unsigned int o = 0; o < (*it)->output.size(); o++)
1063 s << inputnode <<
"->" << opnode_map[op] <<
"[operand=" << operand <<
"]; ";
1064 s <<
"//" << (*it)->input->name <<
"->" << op->
name <<
"\n";
1095 const auto lookup = op_print_ranking.find(op);
1096 if (lookup == op_print_ranking.end()) {
1097 return std::numeric_limits<int>::max();
1099 return lookup->second;
1104 auto sorted_op_nodes =
opNodes();
1105 std::stable_sort(sorted_op_nodes.begin(), sorted_op_nodes.end(), [&,
this](
const auto& lhs,
const auto& rhs) {
1106 return get_ranking(lhs) < get_ranking(rhs);
1109 s <<
"digraph G {\n";
1111 for (
const auto& op_desc : sorted_op_nodes) {
1123 for (
const auto& src_op_desc : sorted_op_nodes) {
1124 const auto& src_op =
getNodeRef(src_op_desc);
1127 auto sorted_out_edges =
outEdges(src_op_desc);
1128 std::stable_sort(sorted_out_edges.begin(), sorted_out_edges.end(), [&,
this](
const auto& lhs,
const auto& rhs) {
1129 const auto& lhs_rank = get_ranking(this->targetOfEdge(lhs));
1130 const auto& rhs_rank = get_ranking(this->targetOfEdge(rhs));
1131 if (lhs_rank == rhs_rank) {
1132 return lhs.output_index < rhs.output_index;
1134 return lhs_rank < rhs_rank;
1138 for (
const auto& out_edge : sorted_out_edges) {
1141 const auto& operand_tag = val.output_operand.at(out_edge.output_index);
1142 const bool has_operand_tag = not operand_tag.empty();
1143 const bool has_attributes = has_operand_tag;
1145 if (has_attributes) {
1147 if (has_operand_tag) {
1160 if (&lhs == &rhs) {
return true; }
1164 for (
const auto& lhs_name_and_opdesc : lhs.
ops_by_name) {
1166 const auto rhs_search_result = rhs.
ops_by_name.find(lhs_name_and_opdesc.first);
1167 if (rhs_search_result == rhs.
ops_by_name.end()) {
return false; }
1170 const auto& lhs_node = lhs.
getNodeRef(lhs_name_and_opdesc.second);
1171 const auto& rhs_node = rhs.
getNodeRef(rhs_search_result->second);
1172 if (not (lhs_node == rhs_node)) {
return false; }
1175 const auto& lhs_out_edges = lhs.
outEdges(lhs_name_and_opdesc.second);
1176 const auto& rhs_out_edges = rhs.
outEdges(rhs_search_result->second);
1177 const auto& mismatch_result = std::mismatch(
1178 lhs_out_edges.begin(), lhs_out_edges.end(),
1179 rhs_out_edges.begin(), rhs_out_edges.end(),
1180 [&](
const auto& lhs_edge,
const auto& rhs_edge) {
1182 if (not (bool)lhs_edge || not (bool)rhs_edge) { return false; }
1186 const auto& lhs_val = lhs.
getNodeRef(lhs_edge.val);
1187 const auto& rhs_val_lookup = rhs.
vals_by_name.find(lhs_val.name);
1189 const auto& rhs_val = rhs.
getNodeRef(rhs_edge.val);
1190 return lhs_val == rhs_val && (
1191 lhs_val.output_operand.at(lhs_edge.output_index)
1192 == rhs_val.output_operand.at(rhs_edge.output_index)
1196 if (mismatch_result != std::make_pair(lhs_out_edges.end(), rhs_out_edges.end())) {
return false; }
1200 if (lhs.vals_by_name.size() != rhs.vals_by_name.size()) {
return false; }
1201 for (
const auto& lhs_name_and_valdesc : lhs.vals_by_name) {
1202 const auto rhs_search_result = rhs.vals_by_name.find(lhs_name_and_valdesc.first);
1203 if (rhs_search_result == rhs.vals_by_name.end()) {
return false; }
1204 const auto& lhs_node = lhs.getNodeRef(lhs_name_and_valdesc.second);
1205 const auto& rhs_node = rhs.getNodeRef(rhs_search_result->second);
1206 if (not (lhs_node == rhs_node)) {
return false; }
1214 const std::set<OpGraph::OpDescriptor>& allowed_ops
1219 for (
const auto& src_op : src.
opNodes()) {
1220 if (allowed_ops.end() == allowed_ops.find(src_op)) {
continue; }
1227 for (
const auto& src_op : src.
opNodes()) {
1228 if (allowed_ops.end() == allowed_ops.find(src_op)) {
continue; }
1231 int output_idex = 0;
1232 for (
const auto& src_oop : src.
outputOps(src_op)) {
1233 if (allowed_ops.end() == allowed_ops.find(src_oop)) {
continue; }
1236 src_val.output_operand.at(output_idex), src_val.bitwidth, src_val.dist, src_val.kind);
1248 const std::vector<OpGraphOp*>& starting_points,
1249 const std::ptrdiff_t n_ops
1255 std::ptrdiff_t n_ops = -1;
1256 std::set<ODesc> ops = {};
1257 bool have_enough =
false;
1259 void onExamine(
const ODesc& op) {
if (!have_enough) { ops.insert(op); } }
1260 void onWaveEnd() { have_enough = (std::ptrdiff_t)ops.size() > n_ops; }
1265 std::set<OpGraph::OpDescriptor> starting_points_as_set(starting_points.begin(), starting_points.end());
1266 g_algos.wavedBreadthFirstVisit(opgraph, starting_points_as_set,
visitAllVertices(), v);
1268 return std::move(v.ops);
1279 if (ed) { os << ed.
val->
name; }
else { os <<
"nullptr"; }
1285 case Type::Info: os <<
"Info";
break;
1286 case Type::Warning: os <<
"Warning";
break;
1287 case Type::Error: os <<
"Error";
break;
1288 default: os <<
"OpGraphVMTNotImplementedByPrinter" << +vm_type;
break;