Skip to content

compute_at after fused & split,the result is wrong #3679

@Youjk

Description

@Youjk

when compute_at after fused & split, the IR is not as we want,the demo is like this:

import tvm

shape = [1024, 30522]
dtype = 'float16'
power_num = tvm.const(0.5, dtype = dtype)
data = tvm.placeholder(shape, name="data", dtype=dtype)

vlog_t = tvm.compute(shape, lambda *indice: tvm.log(data(*indice)), name = "vlog_t")
vmuls_t = tvm.compute(shape, lambda *indice: vlog_t(*indice) * power_num, name = "vmuls_t")
vexp_t = tvm.compute(shape, lambda *indice: tvm.exp(vmuls_t(*indice)), name = "vexp_t")

s = tvm.create_schedule(vexp_t.op)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

exp_fused_axis = s[vexp_t].fuse(*vexp_t.op.axis)
muls_fused_axis = s[vmuls_t].fuse(*vmuls_t.op.axis)
log_fused_axis = s[vlog_t].fuse(*vlog_t.op.axis)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

factor = 2048
xo, xi = s[vexp_t].split(exp_fused_axis, factor=factor)
mo, mi = s[vmuls_t].split(muls_fused_axis, factor)
lo, li = s[vlog_t].split(log_fused_axis, factor)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))

s[vmuls_t].compute_at(s[vexp_t], xo)
s[vlog_t].compute_at(s[vexp_t], xo)

print(tvm.lower(s, [data, vexp_t], simple_mode=True))
print("-----------------")

the origin IR:

/ attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vlog_t[((i0*30522) + i1)] = log(data[((i0*30522) + i1)])
    }
  }
}
produce vmuls_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vlog_t[((i0*30522) + i1)] = (vlog_t[((i0*30522) + i1)]*0.500000h)
    }
  }
}
produce vexp_t {
  for (i0, 0, 1024) {
    for (i1, 0, 30522) {
      vexp_t[((i0*30522) + i1)] = exp(vlog_t[((i0*30522) + i1)])
    }
  }
}

the fused IR:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0.i1.fused, 0, 31254528) {
    vlog_t[i0.i1.fused] = log(data[i0.i1.fused])
  }
}
produce vmuls_t {
  for (i0.i1.fused, 0, 31254528) {
    vlog_t[i0.i1.fused] = (vlog_t[i0.i1.fused]*0.500000h)
  }
}
produce vexp_t {
  for (i0.i1.fused, 0, 31254528) {
    vexp_t[i0.i1.fused] = exp(vlog_t[i0.i1.fused])
  }
}

the fuse and split IR:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vlog_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}
produce vmuls_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
    }
  }
}
produce vexp_t {
  for (i0.i1.fused.outer, 0, 15261) {
    for (i0.i1.fused.inner, 0, 2048) {
      vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}

the compute at after fused and split:

// attr [vlog_t] storage_scope = "global"
allocate vlog_t[float16 * 31254528]
produce vexp_t {
  for (i0.i1.fused.outer, 0, 15261) {
    produce vlog_t {
      for (i0.i1.fused.outer, 0, 15261) {
        for (i0.i1.fused.inner, 0, 2048) {
          vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = log(data[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
        }
      }
    }
    produce vmuls_t {
      for (i0.i1.fused.outer, 0, 15261) {
        for (i0.i1.fused.inner, 0, 2048) {
          vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = (vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)]*0.500000h)
        }
      }
    }
    for (i0.i1.fused.inner, 0, 2048) {
      vexp_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)] = exp(vlog_t[((i0.i1.fused.outer*2048) + i0.i1.fused.inner)])
    }
  }
}

we can see after compute _at, the axis is not combined, is it a bug?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions