add isl_schedule_node_band_tile
authorSven Verdoolaege <skimo@kotnet.org>
Thu, 1 Aug 2013 08:47:03 +0000 (1 10:47 +0200)
committerSven Verdoolaege <skimo@kotnet.org>
Wed, 11 Feb 2015 15:24:57 +0000 (11 16:24 +0100)
Signed-off-by: Sven Verdoolaege <skimo@kotnet.org>
doc/user.pod
include/isl/schedule_node.h
isl_schedule_band.c
isl_schedule_band.h
isl_schedule_node.c
isl_schedule_tree.c
isl_schedule_tree.h
isl_test.c

index e175fb5..58702c1 100644 (file)
@@ -7515,6 +7515,29 @@ two filter nodes are merged into one.
 These functions insert a new sequence or set node with the given
 filters as children.
 
+A band node can be tiled using the following function.
+
+       #include <isl/schedule_node.h>
+       __isl_give isl_schedule_node *isl_schedule_node_band_tile(
+               __isl_take isl_schedule_node *node,
+               __isl_take isl_multi_val *sizes);
+
+       int isl_options_set_tile_scale_tile_loops(isl_ctx *ctx,
+               int val);
+       int isl_options_get_tile_scale_tile_loops(isl_ctx *ctx);
+       int isl_options_set_tile_shift_point_loops(isl_ctx *ctx,
+               int val);
+       int isl_options_get_tile_shift_point_loops(isl_ctx *ctx);
+
+The C<isl_schedule_node_band_tile> function tiles
+the band using the given tile sizes inside its schedule.
+A new child band node is created to represent the point loops and it is
+inserted between the modified band and its children.
+The C<tile_scale_tile_loops> option specifies whether the tile
+loops iterators should be scaled by the tile sizes.
+If the C<tile_shift_point_loops> option is set, then the point loops
+are shifted to start at zero.
+
 A representation of the schedule node can be printed using
 
        #include <isl/schedule_node.h>
index 4ae0c67..1d2792e 100644 (file)
@@ -61,6 +61,14 @@ int isl_schedule_node_band_get_permutable(__isl_keep isl_schedule_node *node);
 __isl_give isl_schedule_node *isl_schedule_node_band_set_permutable(
        __isl_take isl_schedule_node *node, int permutable);
 
+int isl_options_set_tile_scale_tile_loops(isl_ctx *ctx, int val);
+int isl_options_get_tile_scale_tile_loops(isl_ctx *ctx);
+int isl_options_set_tile_shift_point_loops(isl_ctx *ctx, int val);
+int isl_options_get_tile_shift_point_loops(isl_ctx *ctx);
+
+__isl_give isl_schedule_node *isl_schedule_node_band_tile(
+       __isl_take isl_schedule_node *node, __isl_take isl_multi_val *sizes);
+
 __isl_give isl_union_set *isl_schedule_node_domain_get_domain(
        __isl_keep isl_schedule_node *node);
 __isl_give isl_union_set *isl_schedule_node_filter_get_filter(
index 12d8a01..0dca714 100644 (file)
@@ -7,6 +7,7 @@
  * Ecole Normale Superieure, 45 rue d'Ulm, 75230 Paris, France
  */
 
+#include <isl/schedule_node.h>
 #include <isl_schedule_band.h>
 #include <isl_schedule_private.h>
 
@@ -232,3 +233,110 @@ __isl_give isl_multi_union_pw_aff *isl_schedule_band_get_partial_schedule(
 {
        return band ? isl_multi_union_pw_aff_copy(band->mupa) : NULL;
 }
+
+/* Given the schedule of a band, construct the corresponding
+ * schedule for the tile loops based on the given tile sizes
+ * and return the result.
+ *
+ * If the scale tile loops options is set, then the tile loops
+ * are scaled by the tile sizes.
+ *
+ * That is replace each schedule dimension "i" by either
+ * "floor(i/s)" or "s * floor(i/s)".
+ */
+static isl_multi_union_pw_aff *isl_multi_union_pw_aff_tile(
+       __isl_take isl_multi_union_pw_aff *sched,
+       __isl_take isl_multi_val *sizes)
+{
+       isl_ctx *ctx;
+       int i, n;
+       isl_val *v;
+       int scale;
+
+       ctx = isl_multi_val_get_ctx(sizes);
+       scale = isl_options_get_tile_scale_tile_loops(ctx);
+
+       n = isl_multi_union_pw_aff_dim(sched, isl_dim_set);
+       for (i = 0; i < n; ++i) {
+               isl_union_pw_aff *upa;
+
+               upa = isl_multi_union_pw_aff_get_union_pw_aff(sched, i);
+               v = isl_multi_val_get_val(sizes, i);
+
+               upa = isl_union_pw_aff_scale_down_val(upa, isl_val_copy(v));
+               upa = isl_union_pw_aff_floor(upa);
+               if (scale)
+                       upa = isl_union_pw_aff_scale_val(upa, isl_val_copy(v));
+               isl_val_free(v);
+
+               sched = isl_multi_union_pw_aff_set_union_pw_aff(sched, i, upa);
+       }
+
+       isl_multi_val_free(sizes);
+       return sched;
+}
+
+/* Replace "band" by a band corresponding to the tile loops of a tiling
+ * with the given tile sizes.
+ */
+__isl_give isl_schedule_band *isl_schedule_band_tile(
+       __isl_take isl_schedule_band *band, __isl_take isl_multi_val *sizes)
+{
+       band = isl_schedule_band_cow(band);
+       if (!band || !sizes)
+               goto error;
+       band->mupa = isl_multi_union_pw_aff_tile(band->mupa, sizes);
+       if (!band->mupa)
+               return isl_schedule_band_free(band);
+       return band;
+error:
+       isl_schedule_band_free(band);
+       isl_multi_val_free(sizes);
+       return NULL;
+}
+
+/* Replace "band" by a band corresponding to the point loops of a tiling
+ * with the given tile sizes.
+ * "tile" is the corresponding tile loop band.
+ *
+ * If the shift point loops option is set, then the point loops
+ * are shifted to start at zero.  That is, each schedule dimension "i"
+ * is replaced by "i - s * floor(i/s)".
+ * The expression "floor(i/s)" (or "s * floor(i/s)") is extracted from
+ * the tile band.
+ *
+ * Otherwise, the band is left untouched.
+ */
+__isl_give isl_schedule_band *isl_schedule_band_point(
+       __isl_take isl_schedule_band *band, __isl_keep isl_schedule_band *tile,
+       __isl_take isl_multi_val *sizes)
+{
+       isl_ctx *ctx;
+       isl_multi_union_pw_aff *scaled;
+
+       if (!band || !sizes)
+               goto error;
+
+       ctx = isl_schedule_band_get_ctx(band);
+       if (!isl_options_get_tile_shift_point_loops(ctx)) {
+               isl_multi_val_free(sizes);
+               return band;
+       }
+       band = isl_schedule_band_cow(band);
+       if (!band)
+               goto error;
+
+       scaled = isl_schedule_band_get_partial_schedule(tile);
+       if (!isl_options_get_tile_scale_tile_loops(ctx))
+               scaled = isl_multi_union_pw_aff_scale_multi_val(scaled, sizes);
+       else
+               isl_multi_val_free(sizes);
+       band->mupa = isl_multi_union_pw_aff_sub(band->mupa, scaled);
+       if (!band->mupa)
+               return isl_schedule_band_free(band);
+       return band;
+error:
+       isl_schedule_band_free(band);
+       isl_multi_val_free(sizes);
+       return NULL;
+}
index 8078bb5..40abf62 100644 (file)
@@ -48,4 +48,10 @@ int isl_schedule_band_get_permutable(__isl_keep isl_schedule_band *band);
 __isl_give isl_schedule_band *isl_schedule_band_set_permutable(
        __isl_take isl_schedule_band *band, int permutable);
 
+__isl_give isl_schedule_band *isl_schedule_band_tile(
+       __isl_take isl_schedule_band *band, __isl_take isl_multi_val *sizes);
+__isl_give isl_schedule_band *isl_schedule_band_point(
+       __isl_take isl_schedule_band *band, __isl_keep isl_schedule_band *tile,
+       __isl_take isl_multi_val *sizes);
+
 #endif
index 7fd5c56..949cb04 100644 (file)
@@ -919,6 +919,70 @@ __isl_give isl_union_map *isl_schedule_node_band_get_partial_schedule_union_map(
        return isl_union_map_from_multi_union_pw_aff(mupa);
 }
 
+/* Make sure that that spaces of "node" and "mv" are the same.
+ * Return -1 on error, reporting the error to the user.
+ */
+static int check_space_multi_val(__isl_keep isl_schedule_node *node,
+       __isl_keep isl_multi_val *mv)
+{
+       isl_space *node_space, *mv_space;
+       int equal;
+
+       node_space = isl_schedule_node_band_get_space(node);
+       mv_space = isl_multi_val_get_space(mv);
+       equal = isl_space_tuple_is_equal(node_space, isl_dim_set,
+                                       mv_space, isl_dim_set);
+       isl_space_free(mv_space);
+       isl_space_free(node_space);
+       if (equal < 0)
+               return -1;
+       if (!equal)
+               isl_die(isl_schedule_node_get_ctx(node), isl_error_invalid,
+                       "spaces don't match", return -1);
+
+       return 0;
+}
+
+/* Tile "node" with tile sizes "sizes".
+ *
+ * The current node is replaced by two nested nodes corresponding
+ * to the tile dimensions and the point dimensions.
+ *
+ * Return a pointer to the outer (tile) node.
+ *
+ * If the scale tile loops option is set, then the tile loops
+ * are scaled by the tile sizes.  If the shift point loops option is set,
+ * then the point loops are shifted to start at zero.
+ * In particular, these options affect the tile and point loop schedules
+ * as follows
+ *
+ *     scale   shift   original        tile            point
+ *
+ *     0       0       i               floor(i/s)      i
+ *     1       0       i               s * floor(i/s)  i
+ *     0       1       i               floor(i/s)      i - s * floor(i/s)
+ *     1       1       i               s * floor(i/s)  i - s * floor(i/s)
+ */
+__isl_give isl_schedule_node *isl_schedule_node_band_tile(
+       __isl_take isl_schedule_node *node, __isl_take isl_multi_val *sizes)
+{
+       isl_schedule_tree *tree;
+
+       if (!node || !sizes)
+               goto error;
+
+       if (check_space_multi_val(node, sizes) < 0)
+               goto error;
+
+       tree = isl_schedule_node_get_tree(node);
+       tree = isl_schedule_tree_band_tile(tree, sizes);
+       return isl_schedule_node_graft_tree(node, tree);
+error:
+       isl_multi_val_free(sizes);
+       isl_schedule_node_free(node);
+       return NULL;
+}
+
 /* Return the domain of the domain node "node".
  */
 __isl_give isl_union_set *isl_schedule_node_domain_get_domain(
index 393f188..034f884 100644 (file)
@@ -1041,6 +1041,47 @@ __isl_give isl_union_map *isl_schedule_tree_get_subtree_schedule_union_map(
        return subtree_schedule_extend(tree, umap);
 }
 
+/* Tile the band root node of "tree" with tile sizes "sizes".
+ *
+ * We duplicate the band node, change the schedule of one of them
+ * to the tile schedule and the other to the point schedule and then
+ * attach the point band as a child to the tile band.
+ */
+__isl_give isl_schedule_tree *isl_schedule_tree_band_tile(
+       __isl_take isl_schedule_tree *tree, __isl_take isl_multi_val *sizes)
+{
+       isl_schedule_tree *child = NULL;
+
+       if (!tree || !sizes)
+               goto error;
+       if (tree->type != isl_schedule_node_band)
+               isl_die(isl_schedule_tree_get_ctx(tree), isl_error_invalid,
+                       "not a band node", goto error);
+
+       child = isl_schedule_tree_copy(tree);
+       tree = isl_schedule_tree_cow(tree);
+       child = isl_schedule_tree_cow(child);
+       if (!tree || !child)
+               goto error;
+
+       tree->band = isl_schedule_band_tile(tree->band,
+                                           isl_multi_val_copy(sizes));
+       if (!tree->band)
+               goto error;
+       child->band = isl_schedule_band_point(child->band, tree->band, sizes);
+       if (!child->band)
+               child = isl_schedule_tree_free(child);
+
+       tree = isl_schedule_tree_replace_child(tree, 0, child);
+
+       return tree;
+error:
+       isl_schedule_tree_free(child);
+       isl_schedule_tree_free(tree);
+       isl_multi_val_free(sizes);
+       return NULL;
+}
+
 /* Are any members in "band" marked coincident?
  */
 static int any_coincident(__isl_keep isl_schedule_band *band)
index 72a72cd..107f0c5 100644 (file)
@@ -104,6 +104,9 @@ __isl_give isl_schedule_tree *isl_schedule_tree_insert_domain(
 __isl_give isl_schedule_tree *isl_schedule_tree_insert_filter(
        __isl_take isl_schedule_tree *tree, __isl_take isl_union_set *filter);
 
+__isl_give isl_schedule_tree *isl_schedule_tree_band_tile(
+       __isl_take isl_schedule_tree *tree, __isl_take isl_multi_val *sizes);
+
 __isl_give isl_schedule_tree *isl_schedule_tree_child(
        __isl_take isl_schedule_tree *tree, int pos);
 __isl_give isl_schedule_tree *isl_schedule_tree_reset_children(
index 6f39400..1a4b7b0 100644 (file)
@@ -5241,6 +5241,112 @@ static int test_dual(isl_ctx *ctx)
 }
 
 struct {
+       int scale_tile;
+       int shift_point;
+       const char *domain;
+       const char *schedule;
+       const char *sizes;
+       const char *tile;
+       const char *point;
+} tile_tests[] = {
+       { 0, 0, "[n] -> { S[i,j] : 0 <= i,j < n }",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+         "{ [32,32] }",
+         "[{ S[i,j] -> [floor(i/32)] }, { S[i,j] -> [floor(j/32)] }]",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+       },
+       { 1, 0, "[n] -> { S[i,j] : 0 <= i,j < n }",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+         "{ [32,32] }",
+         "[{ S[i,j] -> [32*floor(i/32)] }, { S[i,j] -> [32*floor(j/32)] }]",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+       },
+       { 0, 1, "[n] -> { S[i,j] : 0 <= i,j < n }",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+         "{ [32,32] }",
+         "[{ S[i,j] -> [floor(i/32)] }, { S[i,j] -> [floor(j/32)] }]",
+         "[{ S[i,j] -> [i%32] }, { S[i,j] -> [j%32] }]",
+       },
+       { 1, 1, "[n] -> { S[i,j] : 0 <= i,j < n }",
+         "[{ S[i,j] -> [i] }, { S[i,j] -> [j] }]",
+         "{ [32,32] }",
+         "[{ S[i,j] -> [32*floor(i/32)] }, { S[i,j] -> [32*floor(j/32)] }]",
+         "[{ S[i,j] -> [i%32] }, { S[i,j] -> [j%32] }]",
+       },
+};
+
+/* Basic tiling tests.  Create a schedule tree with a domain and a band node,
+ * tile the band and then check if the tile and point bands have the
+ * expected partial schedule.
+ */
+static int test_tile(isl_ctx *ctx)
+{
+       int i;
+       int scale;
+       int shift;
+
+       scale = isl_options_get_tile_scale_tile_loops(ctx);
+       shift = isl_options_get_tile_shift_point_loops(ctx);
+
+       for (i = 0; i < ARRAY_SIZE(tile_tests); ++i) {
+               int opt;
+               int equal;
+               const char *str;
+               isl_union_set *domain;
+               isl_multi_union_pw_aff *mupa, *mupa2;
+               isl_schedule_node *node;
+               isl_multi_val *sizes;
+
+               opt = tile_tests[i].scale_tile;
+               isl_options_set_tile_scale_tile_loops(ctx, opt);
+               opt = tile_tests[i].shift_point;
+               isl_options_set_tile_shift_point_loops(ctx, opt);
+
+               str = tile_tests[i].domain;
+               domain = isl_union_set_read_from_str(ctx, str);
+               node = isl_schedule_node_from_domain(domain);
+               node = isl_schedule_node_child(node, 0);
+               str = tile_tests[i].schedule;
+               mupa = isl_multi_union_pw_aff_read_from_str(ctx, str);
+               node = isl_schedule_node_insert_partial_schedule(node, mupa);
+               str = tile_tests[i].sizes;
+               sizes = isl_multi_val_read_from_str(ctx, str);
+               node = isl_schedule_node_band_tile(node, sizes);
+
+               str = tile_tests[i].tile;
+               mupa = isl_multi_union_pw_aff_read_from_str(ctx, str);
+               mupa2 = isl_schedule_node_band_get_partial_schedule(node);
+               equal = isl_multi_union_pw_aff_plain_is_equal(mupa, mupa2);
+               isl_multi_union_pw_aff_free(mupa);
+               isl_multi_union_pw_aff_free(mupa2);
+
+               node = isl_schedule_node_child(node, 0);
+
+               str = tile_tests[i].point;
+               mupa = isl_multi_union_pw_aff_read_from_str(ctx, str);
+               mupa2 = isl_schedule_node_band_get_partial_schedule(node);
+               if (equal >= 0 && equal)
+                       equal = isl_multi_union_pw_aff_plain_is_equal(mupa,
+                                                                       mupa2);
+               isl_multi_union_pw_aff_free(mupa);
+               isl_multi_union_pw_aff_free(mupa2);
+
+               isl_schedule_node_free(node);
+
+               if (equal < 0)
+                       return -1;
+               if (!equal)
+                       isl_die(ctx, isl_error_unknown,
+                               "unexpected result", return -1);
+       }
+
+       isl_options_set_tile_scale_tile_loops(ctx, scale);
+       isl_options_set_tile_shift_point_loops(ctx, shift);
+
+       return 0;
+}
+
+struct {
        const char *name;
        int (*fn)(isl_ctx *ctx);
 } tests [] = {
@@ -5277,6 +5383,7 @@ struct {
        { "affine", &test_aff },
        { "injective", &test_injective },
        { "schedule", &test_schedule },
+       { "tile", &test_tile },
        { "union_pw", &test_union_pw },
        { "parse", &test_parse },
        { "single-valued", &test_sv },