Skip to content

Add docs about local variables lifetime #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion docs/user_guides/debugging/local_variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ When the debugger hits the last line of the kernel, ``info locals`` command retu

.. note::

The debugger can show the variable values, but these values may be equal to 0 after the variable is explicitly deleted or the function scope is ended. For more information, refer to `Numba variable policy <https://numba.pydata.org/numba-doc/latest/developer/live_variable_analysis.html?highlight=delete#live-variable-analysis>`_.
The debugger can show the variable values, but these values may be equal to
0 after the variable is explicitly deleted or the function scope is ended.
For more info see :ref:`local-variables-lifetime`.

When you use "O1 optimization" level ``NUMBA_OPT=1`` and run the ``info locals`` command, the output is as follows:

Expand Down Expand Up @@ -80,3 +82,178 @@ To print the type of a variable, run the ``ptype <variable>`` or ``whatis <varia
See also:

- `Local variables in GDB* <https://sourceware.org/gdb/current/onlinedocs/gdb/Frame-Info.html#Frame-Info>`_

.. _local-variables-lifetime:

Lifetime of local variables
---------------------------

Numba uses live variable analysis.
Lifetime of Python variables are different from lifetime of variables in
compiled code.

.. note::
For more information, refer to `Numba variable policy <https://numba.pydata.org/numba-doc/latest/developer/live_variable_analysis.html?highlight=delete#live-variable-analysis>`_.



It affects debugging experience in following way.

Consider Numba-dppy kernel code from :file:`sum_local_vars.py`:

.. literalinclude:: ../../../numba_dppy/examples/debug/sum_local_vars.py
:lines: 20-25
:linenos:
:lineno-match:

Run this code with environment variable :samp:`NUMBA_DUMP_ANNOTATION=1` and it
will show where numba inserts `del` for variables.

.. code-block::
:linenos:
:emphasize-lines: 28

-----------------------------------ANNOTATION-----------------------------------
# File: numba_dppy/examples/debug/sum_local_vars.py
# --- LINE 20 ---

@dppy.kernel(debug=True)

# --- LINE 21 ---

def data_parallel_sum(a, b, c):

# --- LINE 22 ---
# label 0
# a = arg(0, name=a) :: array(float32, 1d, C)
# b = arg(1, name=b) :: array(float32, 1d, C)
# c = arg(2, name=c) :: array(float32, 1d, C)
# $2load_global.0 = global(dppy: <module 'numba_dppy' from '.../numba-dppy/numba_dppy/__init__.py'>) :: Module(<module 'numba_dppy' from '.../numba-dppy/numba_dppy/__init__.py'>)
# $4load_method.1 = getattr(value=$2load_global.0, attr=get_global_id) :: Function(<function get_global_id at 0x7f82b8bae430>)
# del $2load_global.0
# $const6.2 = const(int, 0) :: Literal[int](0)
# i = call $4load_method.1($const6.2, func=$4load_method.1, args=[Var($const6.2, sum_local_vars.py:22)], kws=(), vararg=None, target=None) :: (uint32,) -> int64
# del $const6.2
# del $4load_method.1

i = dppy.get_global_id(0)

# --- LINE 23 ---
# $16binary_subscr.6 = getitem(value=a, index=i, fn=<built-in function getitem>) :: float32
# del a
# $const18.7 = const(float, 2.5) :: float64
# l1 = $16binary_subscr.6 + $const18.7 :: float64
# del $const18.7
# del $16binary_subscr.6

l1 = a[i] + 2.5

# --- LINE 24 ---
# $28binary_subscr.11 = getitem(value=b, index=i, fn=<built-in function getitem>) :: float32
# del b
# $const30.12 = const(float, 0.3) :: float64
# l2 = $28binary_subscr.11 * $const30.12 :: float64
# del $const30.12
# del $28binary_subscr.11

l2 = b[i] * 0.3

# --- LINE 25 ---
# $40binary_add.16 = l1 + l2 :: float64
# del l2
# del l1
# c[i] = $40binary_add.16 :: (array(float32, 1d, C), int64, float64) -> none
# del i
# del c
# del $40binary_add.16
# $const48.19 = const(NoneType, None) :: none
# $50return_value.20 = cast(value=$const48.19) :: none
# del $const48.19
# return $50return_value.20

c[i] = l1 + l2

I.e. in `LINE 23` variable `a` used the last time and numba inserts `del a` as
shown in annotated code in line 28. It means you will see value 0 for the
variable `a` when you set breakpoint at `LINE 24`.

As a workaround you can expand lifetime of the variable by using it (i.e.
passing to dummy function `revive()`) at the end of the function. So numba will
not insert `del a` until the end of the function.

.. literalinclude:: ../../../numba_dppy/examples/debug/sum_local_vars_revive.py
:lines: 20-31
:linenos:
:lineno-match:

.. code-block::
:linenos:
:emphasize-lines: 59

-----------------------------------ANNOTATION-----------------------------------
# File: numba_dppy/examples/debug/sum_local_vars_revive.py
# --- LINE 24 ---

@dppy.kernel(debug=True)

# --- LINE 25 ---

def data_parallel_sum(a, b, c):

# --- LINE 26 ---
# label 0
# a = arg(0, name=a) :: array(float32, 1d, C)
# b = arg(1, name=b) :: array(float32, 1d, C)
# c = arg(2, name=c) :: array(float32, 1d, C)
# $2load_global.0 = global(dppy: <module 'numba_dppy' from '.../numba-dppy/numba_dppy/__init__.py'>) :: Module(<module 'numba_dppy' from '.../numba-dppy/numba_dppy/__init__.py'>)
# $4load_method.1 = getattr(value=$2load_global.0, attr=get_global_id) :: Function(<function get_global_id at 0x7fcdf7e8c4c0>)
# del $2load_global.0
# $const6.2 = const(int, 0) :: Literal[int](0)
# i = call $4load_method.1($const6.2, func=$4load_method.1, args=[Var($const6.2, sum_local_vars_revive.py:26)], kws=(), vararg=None, target=None) :: (uint32,) -> int64
# del $const6.2
# del $4load_method.1

i = dppy.get_global_id(0)

# --- LINE 27 ---
# $16binary_subscr.6 = getitem(value=a, index=i, fn=<built-in function getitem>) :: float32
# $const18.7 = const(float, 2.5) :: float64
# l1 = $16binary_subscr.6 + $const18.7 :: float64
# del $const18.7
# del $16binary_subscr.6

l1 = a[i] + 2.5

# --- LINE 28 ---
# $28binary_subscr.11 = getitem(value=b, index=i, fn=<built-in function getitem>) :: float32
# del b
# $const30.12 = const(float, 0.3) :: float64
# l2 = $28binary_subscr.11 * $const30.12 :: float64
# del $const30.12
# del $28binary_subscr.11

l2 = b[i] * 0.3

# --- LINE 29 ---
# $40binary_add.16 = l1 + l2 :: float64
# del l2
# del l1
# c[i] = $40binary_add.16 :: (array(float32, 1d, C), int64, float64) -> none
# del i
# del c
# del $40binary_add.16

c[i] = l1 + l2

# --- LINE 30 ---
# $48load_global.19 = global(revive: <numba_dppy.compiler.DPPYFunctionTemplate object at 0x7fce12e5cc40>) :: Function(<numba_dppy.compiler.DPPYFunctionTemplate object at 0x7fce12e5cc40>)
# $52call_function.21 = call $48load_global.19(a, func=$48load_global.19, args=[Var(a, sum_local_vars_revive.py:26)], kws=(), vararg=None, target=None) :: (array(float32, 1d, C),) -> array(float32, 1d, C)
# del a
# del $52call_function.21
# del $48load_global.19
# $const56.22 = const(NoneType, None) :: none
# $58return_value.23 = cast(value=$const56.22) :: none
# del $const56.22
# return $58return_value.23

revive(a) # pass variable to dummy function
45 changes: 45 additions & 0 deletions numba_dppy/examples/debug/sum_local_vars_revive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2020, 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import numpy as np

import numpy as np
import numba_dppy as dppy
import dpctl


@dppy.func
def revive(x):
return x


@dppy.kernel(debug=True)
def data_parallel_sum(a, b, c):
i = dppy.get_global_id(0)
l1 = a[i] + 2.5
l2 = b[i] * 0.3
c[i] = l1 + l2
revive(a) # pass variable to dummy function


global_size = 10
N = global_size

a = np.array(np.random.random(N), dtype=np.float32)
b = np.array(np.random.random(N), dtype=np.float32)
c = np.ones_like(a)

device = dpctl.SyclDevice("opencl:gpu")
with dppy.offload_to_sycl_device(device):
data_parallel_sum[global_size, dppy.DEFAULT_LOCAL_SIZE](a, b, c)

print("Done...")