15 #include <catch2/catch.hpp>
19 struct ComputeTrailsToBalanceTestDescription {
20 int expected_number_of_noncyclic_trail_pairs;
21 int expected_number_of_cycle_paths;
22 std::string dfg_description;
23 std::function<
OpGraph()> opgraph_generator;
24 auto getExpectedNumberOfNoncyclicTrailPairs()
const {
return expected_number_of_noncyclic_trail_pairs; }
25 auto getExpectedNumberOfCyclePaths()
const {
return expected_number_of_cycle_paths; }
28 struct ComputeTrailsToBalanceTestDescriptionWithExpectedResult {
29 using EdgeSpec = std::pair<std::string, int>;
30 using Trail = std::vector<EdgeSpec>;
31 using TrailPair = std::pair<Trail,Trail>;
32 struct UnresolvedBalanceTrails {
33 std::vector<TrailPair> noncyclic_trail_pairs;
34 std::vector<Trail> cyclic_trails;
35 bool operator<(
const UnresolvedBalanceTrails& rhs)
const {
36 return std::tie(this->noncyclic_trail_pairs, this->cyclic_trails)
37 < std::tie( rhs.noncyclic_trail_pairs, rhs.cyclic_trails);
41 std::string dfg_description;
42 std::function<
OpGraph()> opgraph_generator;
62 std::set<UnresolvedBalanceTrails> expected_balance_trails;
63 auto getExpectedNumberOfNoncyclicTrailPairs()
const {
return expected_balance_trails.begin()->noncyclic_trail_pairs.size(); }
64 auto getExpectedNumberOfCyclePaths()
const {
return expected_balance_trails.begin()->cyclic_trails.size(); }
67 template<
typename TestDesc,
typename Continuation = DoNothing>
68 void runComputeTrailsToBalanceTest(
const TestDesc& test, Continuation&& continuation = {}) {
69 GIVEN(
"A " << test.dfg_description <<
" DFG") {
70 const auto opgraph = test.opgraph_generator();
71 WHEN(
"Computing paths to balance") {
76 opgraph.serialize(std::cout);
77 std::cout <<
'\n' << trails_to_balance <<
'\n';
79 THEN(
"There may be some paths to balance") {
84 CHECK(test.getExpectedNumberOfNoncyclicTrailPairs() == (std::ptrdiff_t)trails_to_balance.noncyclic_trail_pairs.size());
85 CHECK(test.getExpectedNumberOfCyclePaths() == (std::ptrdiff_t)trails_to_balance.cyclic_trails.size());
86 CHECK(test.getExpectedNumberOfNoncyclicTrailPairs() + test.getExpectedNumberOfCyclePaths()
87 == (std::ptrdiff_t)trails_to_balance.noncyclic_trail_pairs.size() + (std::ptrdiff_t)trails_to_balance.cyclic_trails.size());
89 continuation(test, opgraph, trails_to_balance);
95 template<
typename TestDesc>
96 void runComputeTrailsToBalanceTestWithExpectedResult(
const TestDesc& test) {
97 runComputeTrailsToBalanceTest(test, [](
const auto& test,
const auto& opgraph,
const auto& trails_to_balance) {
98 const auto resolve_names_into = [&](
const auto& val_name_and_out_indices,
auto& dest) {
99 for (
const auto& val_name_and_oindex : val_name_and_out_indices) {
101 const auto val = opgraph.getVal(val_name_and_oindex.first);
103 }
catch (
const std::exception& e) {
104 std::throw_with_nested(make_from_stream<std::logic_error>([&](
auto&& s) {
105 s <<
"problem while resolving val name " << val_name_and_oindex.first
106 <<
" does it exist in the OpGraph?";
112 AND_THEN(
"The paths to balance should match") {
113 std::vector<TrailsToBalance> expected_trails_to_balance;
115 for (
const auto & expected : test.expected_balance_trails) {
116 expected_trails_to_balance.emplace_back();
117 for (
const auto& val_name_and_oindex_trail_pair : expected.noncyclic_trail_pairs) {
119 resolve_names_into(val_name_and_oindex_trail_pair.first, trail_pair.first);
120 resolve_names_into(val_name_and_oindex_trail_pair.second, trail_pair.second);
121 expected_trails_to_balance.back().noncyclic_trail_pairs.insert(std::move(trail_pair));
123 for (
const auto& val_name_and_oindexs : expected.cyclic_trails) {
125 resolve_names_into(val_name_and_oindexs,cycle);
126 expected_trails_to_balance.back().cyclic_trails.insert(std::move(cycle));
130 using Catch::Matchers::VectorContains;
131 CHECK_THAT(expected_trails_to_balance, VectorContains(trails_to_balance));
138 SCENARIO (
"Computing Trails To Balance -- Single Solution, Small",
"") {
139 const auto& test = GENERATE(values<ComputeTrailsToBalanceTestDescriptionWithExpectedResult>({
145 { { {
"op_l_out",0} } } } } },
148 { { {
"op_l_out",1} } } } } },
151 { { {
"op_a_out",0}, {
"op_b_out",0} } } } } },
154 { { {
"op_a_out",0}, {
"op_b_out",0}, {
"op_c_out",0} } } } } },
157 { { {
"op_l_out",0}, {
"op_m_out",1} } } } } },
162 { { { { { {
"op_a_out",0}, {
"op_b_out", 0} }, { {
"op_a_out",1}, {
"op_b2_out",0} } } },
165 { { { { { {
"op_b_out",0}, {
"op_c1_out",0} }, { {
"op_b_out",1}, {
"op_c2_out",0} } } },
168 { { { { { {
"op_a_out",0}, }, { {
"op_a_out",1} } } },
172 runComputeTrailsToBalanceTestWithExpectedResult(test);
175 SCENARIO (
"Computing Trails To Balance -- Single Solution, Large",
"") {
176 const auto& test = GENERATE(values<ComputeTrailsToBalanceTestDescriptionWithExpectedResult>({
177 {
"large requiring only NN FUs", []{
return makeDFG_Large(
false); },
178 { { { { { {
"add0_out",0}, {
"mul2_out",0}, {
"load4_out",0}, {
"add10_out",0} }, { {
"add0_out",2}, {
"mul11_out",0}, {
"load13_out",0} } } },
179 { { {
"add0_out", 3} },
180 { {
"add16_out",0} } } } } },
183 runComputeTrailsToBalanceTestWithExpectedResult(test);
186 SCENARIO (
"Computing Trails To Balance -- Non-Unique Solution, Small",
"") {
187 const auto& test = GENERATE(values<ComputeTrailsToBalanceTestDescriptionWithExpectedResult>({
191 { { { {
"op_i_out",0} }, { {
"op_i_out",1}, {
"op_b_out",0} } } },
192 { { {
"op_a_out",0}, {
"op_b_out",0} } }
194 { { { {
"op_i_out",0}, {
"op_a_out",0} }, { {
"op_i_out",1} } } },
195 { { {
"op_a_out",0}, {
"op_b_out",0} } }
199 { { { { { {
"op_i1_out",0}, {
"op_a1_out",0} }, { {
"op_i1_out",1}, {
"op_a2_out",1} } },
200 { { {
"op_i1_out",0}, {
"op_a1_out",0}, {
"op_b1_out",0} }, { {
"op_i1_out",1}, {
"op_a2_out",0}, {
"op_b2_out",0} } } },
204 runComputeTrailsToBalanceTestWithExpectedResult(test);
207 SCENARIO (
"Computing Trails To Balance -- Non-Unique Solutions, Large",
"") {
209 const auto& test = GENERATE(values<ComputeTrailsToBalanceTestDescription>({
210 { 3, 2,
"large requiring non NN FUs", []{
return makeDFG_Large(
true); } },
211 { 5, 0,
"DeviceFiller 3c2x2 -- no loops", []{
return makeDFG_DeviceFiller({{{
"width",
"2"}, {
"height",
"2"}, {
"num_contexts",
"3"}}});; } },
212 { 5, 4,
"DeviceFiller 3c2x2 -- loops", []{
return makeDFG_DeviceFiller({{{
"width",
"2"}, {
"height",
"2"}, {
"num_contexts",
"3"}, {
"self_loop_via_regfile",
"yes"}}});; } },
213 { 5, 4,
"DeviceFiller 3c2x2 -- loops & consts", []{
return makeDFG_DeviceFiller({{{
"width",
"2"}, {
"height",
"2"}, {
"num_contexts",
"3"}, {
"self_loop_via_regfile",
"yes"}, {
"num_consts_in_pe",
"2"}}});; } },
218 runComputeTrailsToBalanceTest(test);