Skip to content

Commit

Permalink
[PyCDE] Restrict slicing index widths to clog2(len) (#7277)
Browse files Browse the repository at this point in the history
Previously, PyCDE was somewhat more loose on indexes into
arrays/bitvectors. Now that we have the pad_or_truncate convenience
method, lets be a bit more restrictive.
  • Loading branch information
teqdruid authored Jul 3, 2024
1 parent be44848 commit 296d283
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 38 deletions.
10 changes: 5 additions & 5 deletions frontends/PyCDE/src/pycde/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def _validate_idx(size: int, idx: Union[int, BitVectorSignal]):
if isinstance(idx, int):
if idx >= size:
raise ValueError("Subscript out-of-bounds")
elif isinstance(idx, BitVectorSignal):
if idx.type.width != (size - 1).bit_length():
raise ValueError("Index must be exactly clog2 of the size of the array")
else:
idx = support.get_value(idx)
if idx is None or not isinstance(support.type_to_pytype(idx.type),
ir.IntegerType):
raise TypeError("Subscript on array must be either int or int signal"
f" not {type(idx)}.")
raise TypeError("Subscript on array must be either int or int signal"
f" not {type(idx)}.")


def get_slice_bounds(size, idxOrSlice: Union[int, slice]):
Expand Down
38 changes: 5 additions & 33 deletions frontends/PyCDE/test/test_muxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
# CHECK: %c0_i3_2 = hw.constant 0 : i3
# CHECK: [[R8:%.+]] = hw.array_get %In[%c0_i3_2] {sv.namehint = "In__0"} : !hw.array<5xarray<4xi3>>, i3
# CHECK: [[R9:%.+]] = hw.array_get [[R8]][%c0_i2] {sv.namehint = "In__0__0"} : !hw.array<4xi3>
# CHECK: %c0_i2_3 = hw.constant 0 : i2
# CHECK: [[R10:%.+]] = comb.concat %c0_i2_3, %Sel {sv.namehint = "Sel_padto_3"} : i2, i1
# CHECK: %false = hw.constant false
# CHECK: [[RN9:%.+]] = comb.concat %false, %Sel {sv.namehint = "Sel_padto_2"} : i1, i1
# CHECK: %false_3 = hw.constant false
# CHECK: [[R10:%.+]] = comb.concat %false_3, [[RN9]] {sv.namehint = "Sel_padto_2_padto_3"} : i1, i2
# CHECK: [[R11:%.+]] = comb.shru bin [[R9]], [[R10]] : i3
# CHECK: [[R12:%.+]] = comb.extract [[R11]] from 0 : (i3) -> i1
# CHECK: hw.output [[R3]], [[R6]], [[R12]], [[R7]] : !hw.array<4xi3>, !hw.array<2xarray<4xi3>>, i1, !hw.array<3xarray<4xi3>>
Expand All @@ -49,41 +51,11 @@ def create(ports):
ports.OutArr = Signal.create([ports.In[0], ports.In[1]])
ports.OutSlice = ports.In[0:3]

ports.OutInt = ports.In[0][0][ports.Sel]
ports.OutInt = ports.In[0][0][ports.Sel.pad_or_truncate(2)]


# -----

# CHECK-LABEL: hw.module @Slicing(in %In : !hw.array<5xarray<4xi8>>, in %Sel8 : i8, in %Sel2 : i2, out OutIntSlice : i2, out OutArrSlice8 : !hw.array<2xarray<4xi8>>, out OutArrSlice2 : !hw.array<2xarray<4xi8>>)
# CHECK: [[R0:%.+]] = hw.array_get %In[%c0_i3] {sv.namehint = "In__0"} : !hw.array<5xarray<4xi8>>
# CHECK: [[R1:%.+]] = hw.array_get %0[%c0_i2] {sv.namehint = "In__0__0"} : !hw.array<4xi8>
# CHECK: [[R2:%.+]] = comb.concat %c0_i6, %Sel2 {sv.namehint = "Sel2_padto_8"} : i6, i2
# CHECK: [[R3:%.+]] = comb.shru bin [[R1]], [[R2]] : i8
# CHECK: [[R4:%.+]] = comb.extract [[R3]] from 0 : (i8) -> i2
# CHECK: [[R5:%.+]] = comb.concat %false, %Sel2 {sv.namehint = "Sel2_padto_3"} : i1, i2
# CHECK: [[R6:%.+]] = hw.array_slice %In[[[R5]]] : (!hw.array<5xarray<4xi8>>) -> !hw.array<2xarray<4xi8>>
# CHECK: [[R7:%.+]] = comb.extract %Sel8 from 0 : (i8) -> i3
# CHECK: [[R8:%.+]] = hw.array_slice %In[[[R7]]] : (!hw.array<5xarray<4xi8>>) -> !hw.array<2xarray<4xi8>>
# CHECK: hw.output %4, %8, %6 : i2, !hw.array<2xarray<4xi8>>, !hw.array<2xarray<4xi8>>


@unittestmodule()
class Slicing(Module):
In = Input(dim(8, 4, 5))
Sel8 = Input(types.i8)
Sel2 = Input(types.i2)

OutIntSlice = Output(types.i2)
OutArrSlice8 = Output(dim(8, 4, 2))
OutArrSlice2 = Output(dim(8, 4, 2))

@generator
def create(ports):
i = ports.In[0][0]
ports.OutIntSlice = i.slice(ports.Sel2, 2)
ports.OutArrSlice2 = ports.In.slice(ports.Sel2, 2)
ports.OutArrSlice8 = ports.In.slice(ports.Sel8, 2)


# CHECK-LABEL: hw.module @SimpleMux2(in %op : i1, in %a : i32, in %b : i32, out out : i32)
# CHECK-NEXT: [[r0:%.+]] = comb.mux bin %op, %b, %a
Expand Down

0 comments on commit 296d283

Please sign in to comment.