- 
                Notifications
    
You must be signed in to change notification settings  - Fork 609
 
Open
Description
Found this issue while working on #2969
Using AtenStackOp on list constructed using ListConstruct with a ValueRange does NOT fail:
module {
  func.func @test_stack() -> !torch.vtensor<[3,2,3],f32> {
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int6 = torch.constant.int 6
    %float0 = torch.constant.float 0.0
    %none = torch.constant.none
    %shape = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %2 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %3 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %4 = torch.prim.ListConstruct %1, %2, %3 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list<vtensor<[2,3],f32>>
    %5 = torch.aten.stack %4, %int0 : !torch.list<vtensor<[2,3],f32>>, !torch.int -> !torch.vtensor<[3,2,3],f32>
    return %5 : !torch.vtensor<[3,2,3],f32>
  }
}
Minimal replicating example:
module {
  func.func @test_stack() -> !torch.vtensor<[3,2,3],f32> {
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int6 = torch.constant.int 6
    %float0 = torch.constant.float 0.0
    %none = torch.constant.none
    %shape = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %2 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %3 = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
    %list_of_tensors = torch.prim.ListConstruct  : () -> !torch.list<vtensor<[2,3],f32>>
    torch.aten.append.t %list_of_tensors, %1 : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
    torch.aten.append.t %list_of_tensors, %2 : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
    torch.aten.append.t %list_of_tensors, %3 : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
    %5 = torch.aten.stack %list_of_tensors, %int0 : !torch.list<vtensor<[2,3],f32>>, !torch.int -> !torch.vtensor<[3,2,3],f32>
    return %5 : !torch.vtensor<[3,2,3],f32>
  }
}Replicating example with loop:
module {
  func.func @test_stack() -> !torch.vtensor<[3,2,3],f32> {
    %int0 = torch.constant.int 0
    %int2 = torch.constant.int 2
    %int3 = torch.constant.int 3
    %int6 = torch.constant.int 6
    %float0 = torch.constant.float 0.0
    %none = torch.constant.none
    %shape = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
    %true = torch.constant.bool true
    %list_of_tensors = torch.prim.ListConstruct  : () -> !torch.list<vtensor<[2,3],f32>>
    
    %loop_iters_max = torch.constant.int 3
    torch.prim.Loop %loop_iters_max, %true, init() {
      ^bb0(%iter_index: !torch.int):
      %tensor = torch.aten.full %shape, %float0, %int6, %none, %none, %none : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
      %discard_append_result = torch.aten.append.t %list_of_tensors, %tensor : !torch.list<vtensor<[2,3],f32>>, !torch.vtensor<[2,3],f32> -> !torch.list<vtensor<[2,3],f32>>
      %continue = torch.constant.bool true
      torch.prim.Loop.condition %true, iter()
    } : (!torch.int, !torch.bool) -> ()
    %5 = torch.aten.stack %list_of_tensors, %int0 : !torch.list<vtensor<[2,3],f32>>, !torch.int -> !torch.vtensor<[3,2,3],f32>
    return %5 : !torch.vtensor<[3,2,3],f32>
  }
}
Metadata
Metadata
Assignees
Labels
No labels