Using MLIR's Affine Dialect With FIR

Rajan Walia, Eric Schweitz, Steve Scalpone

Affine Dialect

Affine Dialect

The Good!

  • Dialect for polyhedral optimizations
  • Provides array, loop and condition operations.
  • Comes with optimizations for loop fusion, tiling, unrolling, vectorization, etc.

Affine Maps

#map0 = affine_map<()[s0, s1] -> (s0 - s1 + 1)>
#map1 = affine_map<(d0, d1)[s0, s1, s2, s3, s4, s5] -> 
           ((d1 - s3) * ((s1 - s0 + 1) * s2) + d0 - s0)>

...
 %5 = affine.apply #map1 (%arg3, %arg4)
                    [%c0, %c20, %c1, %c10, %c20, %c1]
...

Affine loops

  affine.for %i = 1 to 10 {
     ...
  }
  affine.for %arg4 = %22 to affine_map<()[s0] -> (s0 + 1)>()[%24] {
     ...
  }

Affine Constraints

  • Array index can only be specific values
  • No partial indexing for arrays
  • Every affine loop must finish to the end

Affine memory ops

   %12 = affine.load %11[%10] : memref<?xf32>
   affine.store %10, %11[%5] : memref<?xf32>

   

Affine Dialect

The Bad!

  • Heavily depended on memref type.

    • Array values used in affine operations must be of memref type.

Dummy converts

%7 = fir.convert %3 : (!fir.ref<!fir.array<?x?xf32>>) 
                      -> memref<?xf32>
fortran
optimizations for affine
affine promotion
fir
fir + affine
affine optimizations
affine demotion
fir
llvm ir

Changes to FIR

adding new operations

Operations

  • shape, shape_shift, slice
  • array_coor

shape operations

%sh = fir.shape %row_sz, %col_sz : 
	(index, index) -> !fir.shape<2>
%ss = fir.shape_shift %lo, %extent :
	(index, index) -> !fir.shapeshift<1>
%sl = fir.slice %lo, %hi, %step :
	(index, index, index) -> !fir.slice<1>

array coordinate

%s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
%1 = fir.array_coor %a(%s) %i, %j : 
	(!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) 
    	-> !fir.ref<f32>
      

Example

subroutine f1dc(a1,a2,ret)
  integer a1(0:4), a2(0:4), ret(0:4)
  integer t1(0:4)

  do i = 0,4
     t1(i) = a1(i) + a1(i)
  end do
  do i = 0,4
     ret(i) = t1(i) * a2(i)
  end do
end subroutine f1dc

Original FIR

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, 
              %arg1: !fir.ref<!fir.array<5xi32>>, 
              %arg2: !fir.ref<!fir.array<5xi32>>) {
  %c0 = constant 0 : index
  %c1 = constant 1 : index
  %c4 = constant 4 : index
  %0 = fir.alloca i32 {name = "i"}
  %1 = fir.alloca !fir.array<5xi32> {name = "t1"}
  %4 = fir.do_loop %arg3 = %c0 to %c4 step %c1 
   iter_args(%arg4 = %2) -> (index) {
    %10 = fir.convert %arg3 : (index) -> i32
    fir.store %10 to %0 : !fir.ref<i32>
    %11 = fir.load %0 : !fir.ref<i32>
    ...
    %14 = subi %12, %13 : i64
    %15 = fir.coordinate_of %1, %14 : 
      		(!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %16 = fir.load %0 : !fir.ref<i32>
    ...
    %19 = subi %17, %18 : i64
    %20 = fir.coordinate_of %arg0, %19 : 
    		(!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %21 = fir.load %20 : !fir.ref<i32>
    %22 = fir.load %0 : !fir.ref<i32>
    %25 = subi %23, %24 : i64
    %26 = fir.coordinate_of %arg0, %25 : 
    		(!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %27 = fir.load %26 : !fir.ref<i32>
    %28 = addi %21, %27 : i32
    fir.store %28 to %15 : !fir.ref<i32>
    %29 = addi %arg3, %c1 : index
    fir.result %29 : index
  }
  %5 = fir.convert %4 : (index) -> i32
  fir.store %5 to %0 : !fir.ref<i32>
  %8 = fir.do_loop %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %6) -> (index) {
    %10 = fir.convert %arg3 : (index) -> i32
    fir.store %10 to %0 : !fir.ref<i32>
    %11 = fir.load %0 : !fir.ref<i32>
    %12 = fir.convert %11 : (i32) -> i64
    %13 = fir.convert %c0_1 : (index) -> i64
    %14 = subi %12, %13 : i64
    %15 = fir.coordinate_of %arg2, %14 : (!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %16 = fir.load %0 : !fir.ref<i32>
    %17 = fir.convert %16 : (i32) -> i64
    %18 = fir.convert %c0_2 : (index) -> i64
    %19 = subi %17, %18 : i64
    %20 = fir.coordinate_of %1, %19 : (!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %21 = fir.load %20 : !fir.ref<i32>
    %22 = fir.load %0 : !fir.ref<i32>
    %23 = fir.convert %22 : (i32) -> i64
    %24 = fir.convert %c0_0 : (index) -> i64
    %25 = subi %23, %24 : i64
    %26 = fir.coordinate_of %arg1, %25 : (!fir.ref<!fir.array<5xi32>>, i64) -> !fir.ref<i32>
    %27 = fir.load %26 : !fir.ref<i32>
    %28 = muli %21, %27 : i32
    fir.store %28 to %15 : !fir.ref<i32>
    %29 = addi %arg3, %c1_5 : index
    fir.result %29 : index
  }
  %9 = fir.convert %8 : (index) -> i32
  fir.store %9 to %0 : !fir.ref<i32>
  return
}

Using new operations

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, %arg1: !fir.ref<!fir.array<5xi32>>, %arg2: !fir.ref<!fir.array<5xi32>>) {
  ...
  %2 = fir.shape_shift %c0, %c5 : (index, index) -> !fir.shapeshift<1>
  %3 = fir.shape_shift %c0, %c5 : ...
  %4 = fir.shape_shift %c0, %c5 : ...
  fir.do_loop %arg3 = %c0 to %c4 step %c1 {
    %8 = fir.array_coor %1(%2) %arg3 : (..., !fir.shapeshift<1>, index) -> !fir.ref<i32>
    %9 = fir.array_coor %arg0(%3) %arg3 : ...
    %10 = fir.load %9 : !fir.ref<i32>
    %11 = fir.array_coor %arg0(%4) %arg3 : 
    		(!fir.ref<!fir.array<5xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
    %12 = fir.load %11 : !fir.ref<i32>
    %13 = addi %10, %12 : i32
    fir.store %13 to %8 : !fir.ref<i32>
  }
  %5 = fir.shape_shift %c0, %c5 : (index, index) -> !fir.shapeshift<1>
  %6 = fir.shape_shift %c0, %c5 : (index, index) -> !fir.shapeshift<1>
  %7 = fir.shape_shift %c0, %c5 : (index, index) -> !fir.shapeshift<1>
  fir.do_loop %arg3 = %c0 to %c4 step %c1 {
    %8 = fir.array_coor %arg2(%5) %arg3 : (!fir.ref<!fir.array<5xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
    %9 = fir.array_coor %1(%6) %arg3 : (!fir.ref<!fir.array<5xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
    %10 = fir.load %9 : !fir.ref<i32>
    %11 = fir.array_coor %arg1(%7) %arg3 : (!fir.ref<!fir.array<5xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
    %12 = fir.load %11 : !fir.ref<i32>
    %13 = muli %10, %12 : i32
    fir.store %13 to %8 : !fir.ref<i32>
  }
  return
}

Affine Promotion

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, %arg1: !fir.ref<!fir.array<5xi32>>, %arg2: !fir.ref<!fir.array<5xi32>>) {
  ...
  affine.for %arg3 = %c0 to affine_map<()[s0] -> (s0 + 1)>()[%c4] {
    %8 = fir.array_coor %1(%2) %arg3 : ...
    %9 = fir.array_coor %arg0(%3) %arg3 : ...
    %10 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_0]
    %11 = fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
    %12 = affine.load %11[%10] : memref<?xi32>
    %14 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_1]
    %15 = fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
    %16 = affine.load %15[%14] : memref<?xi32>
    %17 = addi %12, %16 : i32
    %18 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_2]
    affine.store %17, %19[%18] : memref<?xi32>
  }
  affine.for %arg3 = %c0 to affine_map<()[s0] -> (s0 + 1)>()[%c4] {
    ...
    %10 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_0]
    %11 = fir.convert %1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
    %12 = affine.load %11[%10] : memref<?xi32>
    %14 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_1]
    %15 = fir.convert %arg1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
    %16 = affine.load %15[%14] : memref<?xi32>
    %17 = muli %12, %16 : i32
    %18 = affine.apply affine_map<(d0)[s0, s1, s2] -> (d0 * s2 - s0)>(%arg3)[%c0, %c5, %c1_2]
    %19 = fir.convert %arg2 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
    affine.store %17, %19[%18] : memref<?xi32>
  }
  return
}

Basic Affine Optimizations

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, %arg1: !fir.ref<!fir.array<5xi32>>, %arg2: !fir.ref<!fir.array<5xi32>>) {
  %0 = fir.alloca i32 {name = "i"}
  %1 = fir.alloca !fir.array<5xi32> {name = "t1"}
  %2 = fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  %3 = fir.convert %1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  affine.for %arg3 = 0 to 5 {
    %6 = affine.load %2[%arg3] : memref<?xi32>
    %7 = affine.load %2[%arg3] : memref<?xi32>
    %8 = addi %6, %7 : i32
    affine.store %8, %3[%arg3] : memref<?xi32>
  }
  %4 = fir.convert %arg1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  %5 = fir.convert %arg2 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  affine.for %arg3 = 0 to 5 {
    %6 = affine.load %3[%arg3] : memref<?xi32>
    %7 = affine.load %4[%arg3] : memref<?xi32>
    %8 = muli %6, %7 : i32
    affine.store %8, %5[%arg3] : memref<?xi32>
  }
  return
}

Affine Loop Fusion

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, %arg1: !fir.ref<!fir.array<5xi32>>, %arg2: !fir.ref<!fir.array<5xi32>>) {
  %0 = fir.alloca i32 {name = "i"}
  %1 = fir.alloca !fir.array<5xi32> {name = "t1"}
  %2 = fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  %3 = fir.convert %1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  %4 = fir.convert %arg1 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  %5 = fir.convert %arg2 : (!fir.ref<!fir.array<5xi32>>) -> memref<?xi32>
  affine.for %arg3 = 0 to 5 {
    %6 = affine.load %2[%arg3] : memref<?xi32>
    %7 = affine.load %2[%arg3] : memref<?xi32>
    %8 = addi %6, %7 : i32
    %9 = affine.load %4[%arg3] : memref<?xi32>
    %10 = muli %8, %9 : i32
    affine.store %10, %5[%arg3] : memref<?xi32>
  }
  return
}

Affine Demotion

func @_QPf1dc(%arg0: !fir.ref<!fir.array<5xi32>>, %arg1: !fir.ref<!fir.array<5xi32>>, %arg2: !fir.ref<!fir.array<5xi32>>) {
  %0 = fir.alloca !fir.array<1xi32>
  %1 = fir.alloca i32 {name = "i"}
  %2 = fir.alloca !fir.array<5xi32> {name = "t1"}
  affine.for %arg3 = 0 to 5 {
    %7 = fir.coordinate_of %arg0, %arg3 : (!fir.ref<!fir.array<?xi32>>, index) -> !fir.ref<i32>
    %8 = fir.load %7 : !fir.ref<i32>
    %9 = fir.coordinate_of %arg0, %arg3 : ...
    %10 = fir.load %9 : !fir.ref<i32>
    %11 = addi %8, %10 : i32
    %12 = fir.coordinate_of %0, %c0 : ...
    fir.store %11 to %12 : !fir.ref<i32>
    %13 = fir.coordinate_of %0, %c0 : ...
    %14 = fir.load %13 : !fir.ref<i32>
    %15 = fir.coordinate_of %arg1, %arg3 : ...
    %16 = fir.load %15 : !fir.ref<i32>
    %17 = muli %14, %16 : i32
    %18 = fir.coordinate_of %arg2, %arg3 : ...
    fir.store %17 to %18 : !fir.ref<i32>
  }
  return
}
Made with Slides.com