* Copyright 2008-2009 Katholieke Universiteit Leuven
* Copyright 2010 INRIA Saclay
* Copyright 2011 Sven Verdoolaege
+ * Copyright 2023 Cerebras Systems
*
* Use of this software is governed by the MIT license
*
* Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
* and INRIA Saclay - Ile-de-France, Parc Club Orsay Universite,
* ZAC des vignes, 4 rue Jacques Monod, 91893 Orsay, France
+ * and Cerebras Systems, 1237 E Arques Ave, Sunnyvale, CA, USA
*/
#define xSF(TYPE,SUFFIX) TYPE ## SUFFIX
@@ -137,10 +139,55 @@ error:
return NULL;
}
+/* Given that the output dimension of "bmap" at position "d" is equal to "aff",
+ * exploit this information to reduce the effective dimensionality of "bmap" and
+ * then call basic_map_partial_lexopt recursively.
+ *
+ * In particular, introduce a dimension in the context "dom" (and the domain
+ * of "bmap") that is equal to "aff" and equate output dimension "d"
+ * to this new input dimension.
+ * This essentially moves the output dimension to the input, but
+ * leaves a placeholder so that the value "aff" can easily be plugged
+ * into the result of the recursive call.
+ */
+static __isl_give TYPE *SF(basic_map_partial_lexopt_plugin,SUFFIX)(
+ __isl_take isl_basic_map *bmap, __isl_take isl_basic_set *dom,
+ __isl_give isl_set **empty, int max, int d, __isl_take isl_aff *aff)
+{
+ isl_size n_in;
+ isl_multi_aff *ma;
+ isl_basic_map *insert;
+ TYPE *res;
+
+ n_in = isl_aff_dim(aff, isl_dim_in);
+ if (n_in < 0)
+ bmap = isl_basic_map_free(bmap);
+
+ ma = isl_aff_as_domain_extension(aff);
+ insert = isl_basic_map_from_multi_aff2(isl_multi_aff_copy(ma), 0);
+
+ bmap = isl_basic_map_apply_domain(bmap, isl_basic_map_copy(insert));
+ dom = isl_basic_set_apply(dom, insert);
+ bmap = isl_basic_map_equate(bmap, isl_dim_in, n_in, isl_dim_out, d);
+
+ res = SF(basic_map_partial_lexopt,SUFFIX)(bmap, dom, empty, max);
+ if (empty)
+ *empty = isl_set_preimage_multi_aff(*empty,
+ isl_multi_aff_copy(ma));
+ res = FN(TYPE,pullback_multi_aff)(res, ma);
+
+ return res;
+}
+
/* Recursive part of isl_tab_basic_map_partial_lexopt*, after detecting
* equalities and removing redundant constraints.
*
- * We first check if there are any parallel constraints (left).
+ * First check if some combination of constraints can be found that force
+ * a given dimension to be equal to the floor or modulo
+ * of some affine combination of the input dimensions.
+ * If so, plug in this expression and continue.
+ *
+ * Otherwise, check if there are any parallel constraints (left).
* If not, we are in the base case.
* If there are parallel constraints, we replace them by a single
* constraint in basic_map_partial_lexopt_symm_pma and then call
@@ -150,9 +197,18 @@ static __isl_give TYPE *SF(basic_map_partial_lexopt,SUFFIX)(
__isl_take isl_basic_map *bmap, __isl_take isl_basic_set *dom,
__isl_give isl_set **empty, int max)
{
+ int d;
isl_bool par = isl_bool_false;
int first, second;
isl_ctx *ctx;
+ isl_maybe_isl_aff div_mod;
+
+ div_mod = isl_basic_map_try_find_any_output_div_mod(bmap, &d);
+ if (div_mod.valid < 0)
+ goto error;
+ if (div_mod.valid)
+ return SF(basic_map_partial_lexopt_plugin,SUFFIX)(bmap, dom,
+ empty, max, d, div_mod.value);
if (!bmap)
goto error;
@@ -457,7 +457,7 @@ static void test_lexmin(isl::ctx ctx)
{ "{ [a=0:11] -> [b=0:3] : -1 + b <= 2*floor((a)/6) <= b }", true },
{ "{ [a = 0:2, b = 0:1] -> [c = 0:9, d = (-a + b) mod 3] : "
- "10a + 5b - 3c <= 5d <= 12 + 10a + 5b - 3c }", false },
+ "10a + 5b - 3c <= 5d <= 12 + 10a + 5b - 3c }", true },
});
C(&isl::map::lexmin_pw_multi_aff, {
@@ -466,19 +466,18 @@ static void test_lexmin(isl::ctx ctx)
* The lexicographic minimum of both should consist of a single cell.
*/
{ "{ [a=0:3] -> [b=a//2] : 0 <= b <= 1 }",
- "{ [a=0:3] -> [(a - floor((1 + a)/2))] }" },
+ "{ [a=0:3] -> [(floor((a)/2))] }" },
{ "{ [a] -> [b=a//2] : 0 <= b <= 1 }",
- "{ [a=0:1] -> [(0)]; [a=2:3] -> [(1)] }" },
+ "{ [a=0:3] -> [(floor((a)/2))] }" },
{ "{ [a = 0:2, b = 0:1] -> [c = 0:9, d = (-a + b) mod 3] : "
"10a + 5b - 3c <= 5d <= 12 + 10a + 5b - 3c }",
- "{ [a = 0:2, b = 0:1] -> [(0), (2a + b)] : b <= 2 - 2a; "
- "[a = 0:2, b = 0:1] -> [(5), (-3 + 2a + b)] : 3 - 2a <= b }" },
+ "{ [a = 0:2, b = 0:1] -> [5*(2a + b)//3, (2a + b) mod 3] }" },
});
C(&isl::set::lexmin_pw_multi_aff, {
{ "[a] -> { [b=a//2] : 0 <= b <= 1 }",
- "[a] -> { [(0)] : 0 <= a <= 1; [(1)] : 2 <= a <= 3 }" },
+ "[a=0:3] -> { [(floor((a)/2))] }" },
});
}