Skip to content

Commit 2459ecf

Browse files
committed
Update on "[ET-VK] Return fence after waiting is done."
This change returns a fence to fence pool after it has been waited on. Differential Revision: [D74484825](https://our.internmc.facebook.com/intern/diff/D74484825/) [ghstack-poisoned]
2 parents 0283b4c + ab22cbb commit 2459ecf

File tree

17 files changed

+273
-91
lines changed

17 files changed

+273
-91
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ jobs:
434434
output=$(ls -la cmake-out/test/size_test)
435435
arr=($output)
436436
size=${arr[4]}
437-
# threshold=48120 on devserver with gcc11.4
438-
# todo(lfq): update once binary size is below 50kb.
439-
threshold="47552"
437+
threshold="47560"
440438
if [[ "$size" -le "$threshold" ]]; then
441439
echo "Success $size <= $threshold"
442440
else

backends/arm/test/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,20 @@ def pytest_configure(config):
4444
if getattr(config.option, "fast_fvp", False):
4545
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
4646

47+
pytest._test_options["tosa_version"] = "0.80" # type: ignore[attr-defined]
4748
if config.option.arm_run_tosa_version:
4849
pytest._test_options["tosa_version"] = config.option.arm_run_tosa_version
4950

50-
pytest._test_options["tosa_ref_model"] = True # type: ignore[attr-defined]
51+
# Not all deployments of ET have the TOSA reference model available.
52+
# Make sure we don't try to use it if it's not available.
53+
try:
54+
if pytest._test_options["tosa_version"] == "0.80":
55+
import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model
56+
else:
57+
import tosa_tools.tosa_ref_model as tosa_reference_model
58+
except ImportError:
59+
pytest._test_options["tosa_ref_model"] = False # type: ignore[attr-defined]
60+
tosa_reference_model = None # noqa
5161

5262
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
5363

backends/cadence/aot/fuse_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886886
Fuse transpose or permute op pairs to a single view op.
887887
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888+
This happens when op2(op1) == identity, modulo unitary dimensions.
889+
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890+
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888891
"""
889892

890893
# A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
909912
return False
910913

911-
# checking that permut2(permut1(identify)) == identity
914+
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
912915
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913916
ident_dims = list(range(len(input_shape)))
914917
# this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921
}
919922
in_dims = f[producer.target](producer, ident_dims)
920923
out_dims = f[consumer.target](consumer, in_dims)
921-
return out_dims == ident_dims
924+
# Filtering out unitary dimensions
925+
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
926+
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
927+
return non_unit_out_dims == non_unit_ident_dims
922928

923929
def get_fused_node(
924930
self,
925931
producer: torch.fx.Node,
926932
consumer: torch.fx.Node,
927933
graph_module: torch.fx.GraphModule,
928934
) -> torch.fx.Node:
935+
# This step is important because of how we can fuse transpositions that are not perfectly
936+
# reverse one of another but will be fused if there are unitary dimensions.
937+
# The fused operation must have the same output shape as the consumer.
929938
output_shape = consumer.meta["val"].shape
930939
with graph_module.graph.inserting_after(consumer):
931940
view = graph_module.graph.call_function(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,28 @@ def _create_operator(
584584
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
585585
False,
586586
),
587+
# transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension
588+
# so it ends up being the same on memory => fuse
589+
(
590+
True,
591+
[0, 1],
592+
True,
593+
[0, 2],
594+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
595+
True,
596+
[5, 40, 1],
597+
),
598+
# transpose -> quant -> transpose is not the reverse, and unitary dimensions
599+
# don't help => don't fuse
600+
(
601+
True,
602+
[0, 1],
603+
True,
604+
[1, 3],
605+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
606+
False,
607+
[5, 40, 1, 4],
608+
),
587609
# permutation -> quant -> opposite permutation => fuse
588610
(
589611
False,
@@ -622,6 +644,28 @@ def _create_operator(
622644
False,
623645
[4, 4, 4],
624646
),
647+
# permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension
648+
# so it ends up being the same on memory => fuse
649+
(
650+
False,
651+
[1, 3, 2, 0],
652+
False,
653+
[3, 2, 1, 0],
654+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
655+
True,
656+
[3, 1, 8, 10],
657+
),
658+
# permutation -> quant -> a non reverse permutation, and unitary dimensions
659+
# don't help => don't fuse
660+
(
661+
False,
662+
[1, 3, 2, 0],
663+
False,
664+
[3, 1, 2, 0],
665+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
666+
False,
667+
[3, 1, 8, 10],
668+
),
625669
# transpose -> quant -> transpose as a permutation => fuse
626670
(
627671
True,

backends/cadence/hifi/operators/operators.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_(uint8_t, Byte) \
1313
_(int8_t, Char)
1414

15+
using ::executorch::aten::IntArrayRef;
1516
using ::executorch::aten::optional;
1617
using ::executorch::aten::ScalarType;
1718
using ::executorch::aten::Tensor;
@@ -67,6 +68,44 @@ void quantized_linear_per_tensor_out(
6768
__ET_UNUSED const optional<Tensor>& offset,
6869
Tensor& out);
6970

71+
void quantized_conv_out(
72+
__ET_UNUSED KernelRuntimeContext& ctx,
73+
const Tensor& input,
74+
const Tensor& weight,
75+
const Tensor& bias,
76+
IntArrayRef stride,
77+
IntArrayRef padding,
78+
IntArrayRef dilation,
79+
int64_t groups,
80+
int64_t in_zero_point,
81+
const Tensor& weight_zero_point,
82+
const Tensor& bias_scale,
83+
double output_scale,
84+
int64_t output_zero_point,
85+
__ET_UNUSED const Tensor& out_multiplier,
86+
__ET_UNUSED const Tensor& out_shift,
87+
bool channel_last,
88+
Tensor& out);
89+
90+
void quantized_conv_per_tensor_out(
91+
__ET_UNUSED KernelRuntimeContext& ctx,
92+
const Tensor& input,
93+
const Tensor& weight,
94+
const Tensor& bias,
95+
IntArrayRef stride,
96+
IntArrayRef padding,
97+
IntArrayRef dilation,
98+
int64_t groups,
99+
int64_t in_zero_point,
100+
int64_t weight_zero_point,
101+
double bias_scale,
102+
double output_scale,
103+
int64_t output_zero_point,
104+
__ET_UNUSED int64_t out_multiplier,
105+
__ET_UNUSED int64_t out_shift,
106+
bool channel_last,
107+
Tensor& out);
108+
70109
} // namespace native
71110
} // namespace HiFi
72111
} // namespace impl

backends/vulkan/runtime/api/Context.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ Context* context() {
235235
8u, // cmdPoolBatchSize
236236
};
237237

238+
const vkapi::DescriptorPoolConfig descriptor_pool_config{
239+
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorPoolMaxSets
240+
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorUniformBufferCount
241+
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageBufferCount
242+
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorCombinedSamplerCount
243+
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageImageCount
244+
32u, // descriptorPileSizes
245+
};
246+
238247
const vkapi::QueryPoolConfig query_pool_config{
239248
VULKAN_QUERY_POOL_SIZE, // maxQueryCount
240249
256u, // initialReserveSize
@@ -243,7 +252,7 @@ Context* context() {
243252
const ContextConfig config{
244253
cmd_submit_frequency,
245254
cmd_config,
246-
{},
255+
descriptor_pool_config,
247256
query_pool_config,
248257
};
249258

@@ -257,17 +266,6 @@ Context* context() {
257266
return context.get();
258267
}
259268

260-
vkapi::DescriptorPoolConfig default_descriptor_pool_config() {
261-
return {
262-
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorPoolMaxSets
263-
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorUniformBufferCount
264-
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageBufferCount
265-
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorCombinedSamplerCount
266-
VULKAN_DESCRIPTOR_POOL_SIZE, // descriptorStorageImageCount
267-
32u, // descriptorPileSizes
268-
};
269-
}
270-
271269
#ifdef VULKAN_DEBUG
272270

273271
#ifdef VK_KHR_pipeline_executable_properties

backends/vulkan/runtime/api/Context.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,6 @@ bool available();
267267
// a static local variable.
268268
Context* context();
269269

270-
vkapi::DescriptorPoolConfig default_descriptor_pool_config();
271-
272270
namespace detail {
273271

274272
inline void arg_is_empty(

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -594,13 +594,13 @@ void ComputeGraph::prepare() {
594594
prepack_descriptor_counts_.field) * \
595595
config_.descriptor_pool_safety_factor))
596596

597-
const uint32_t max_sets = MERGE_FIELD(descriptor_pool_max_sets);
598-
const vkapi::DescriptorPoolConfig config{
597+
uint32_t max_sets = MERGE_FIELD(descriptor_pool_max_sets);
598+
vkapi::DescriptorPoolConfig config{
599599
max_sets,
600-
MERGE_FIELD(descriptor_uniform_buffer_count),
601-
MERGE_FIELD(descriptor_storage_buffer_count),
602-
MERGE_FIELD(descriptor_combined_sampler_count),
603-
MERGE_FIELD(descriptor_storage_image_count),
600+
std::max(MERGE_FIELD(descriptor_uniform_buffer_count), max_sets),
601+
std::max(MERGE_FIELD(descriptor_storage_buffer_count), max_sets),
602+
std::max(MERGE_FIELD(descriptor_combined_sampler_count), max_sets),
603+
std::max(MERGE_FIELD(descriptor_storage_image_count), max_sets),
604604
1u,
605605
};
606606

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ DispatchNode::DispatchNode(
3333
spec_vars_(spec_vars),
3434
push_constants_(push_constants) {
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
36-
graph.context()->check_device_capabilities(shader_);
3736
}
3837

3938
void DispatchNode::encode(ComputeGraph* graph) {
@@ -43,6 +42,8 @@ void DispatchNode::encode(ComputeGraph* graph) {
4342
api::Context* const context = graph->context();
4443
vkapi::PipelineBarrier pipeline_barrier{};
4544

45+
context->check_device_capabilities(shader_);
46+
4647
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4748

4849
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ PrepackNode::PrepackNode(
4545
push_constants_(push_constants) {
4646
graph.update_descriptor_counts(shader, /*execute = */ false);
4747
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
48-
graph.context()->check_device_capabilities(shader_);
4948
}
5049

5150
api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
@@ -71,6 +70,8 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
7170
void PrepackNode::encode(ComputeGraph* graph) {
7271
api::Context* const context = graph->context();
7372

73+
context->check_device_capabilities(shader_);
74+
7475
vTensorPtr packed = graph->get_tensor(packed_);
7576
api::StagingBuffer staging = create_staging_buffer(graph);
7677

backends/vulkan/runtime/vk_api/Descriptor.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,11 @@ DescriptorPool::DescriptorPool(
269269
pool_(VK_NULL_HANDLE),
270270
config_(config),
271271
mutex_{},
272-
piles_{} {}
272+
piles_{} {
273+
if (config.descriptor_pool_max_sets > 0) {
274+
init(config);
275+
}
276+
}
273277

274278
DescriptorPool::~DescriptorPool() {
275279
if (pool_ == VK_NULL_HANDLE) {

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,6 @@ class VulkanComputeAPITest : public ::testing::Test {
8787
void SetUp() override {
8888
// Make sure we are starting with a clean slate
8989
EXPECT_TRUE(get_vma_allocation_count() == 0);
90-
if (!context()->descriptor_pool()) {
91-
context()->descriptor_pool().init(default_descriptor_pool_config());
92-
}
9390
}
9491

9592
void TearDown() override {

exir/passes/constant_prop_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def is_const(
6666
)
6767
elif isinstance(arg, _PRIMITIVE_TYPES):
6868
return True
69+
elif arg is None:
70+
return True
6971
elif not isinstance(arg, torch.fx.Node):
7072
return False
7173
elif arg in const_node_to_tensor:

exir/tests/test_passes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1823,3 +1823,34 @@ def _do_checks(
18231823
self.assertTrue(
18241824
torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
18251825
)
1826+
1827+
def test_constant_prop_pass_none(self) -> None:
1828+
"""
1829+
This checks that None arguments are treated as constants in constant_prop_pass.
1830+
"""
1831+
1832+
class M(torch.nn.Module):
1833+
def __init__(self):
1834+
super().__init__()
1835+
self.cst = torch.ones(3, 3, 3, dtype=torch.int8)
1836+
self.w = torch.ones(3, 3, 3, dtype=torch.int8)
1837+
1838+
def forward(self, x):
1839+
# Note: using e.g aten.linear would not work as None is not in the graph
1840+
a = torch.ops.aten.convolution.default(
1841+
self.cst, self.w, None, [1], [0], [1], False, [0], 1
1842+
)
1843+
return a + x
1844+
1845+
mod = M()
1846+
x = torch.randn([3, 3, 3])
1847+
mod(x)
1848+
edge = to_edge(
1849+
export(mod, (x,), strict=True),
1850+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1851+
)
1852+
# 2 constants: self.w and self.cst
1853+
self.assertEqual(2, len(edge.exported_program().constants))
1854+
pass_result = constant_prop_pass(edge.exported_program())
1855+
# 1 constant: a (= self.w @ self.cst)
1856+
self.assertEqual(1, len(pass_result.constants))

0 commit comments

Comments
 (0)