Skip to content

Commit

Permalink
finish transitioning computing module
Browse files Browse the repository at this point in the history
  • Loading branch information
joksas committed Aug 9, 2023
1 parent 31c092b commit c041024
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 91 deletions.
146 changes: 74 additions & 72 deletions badcrossbar/computing/extract.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from collections import namedtuple

import numpy as np
import numpy.typing as npt
import jax.numpy as jnp
from jax import Array

from badcrossbar import utils
from badcrossbar.computing import solve

Expand All @@ -16,11 +17,7 @@


def solution(
resistances: npt.NDArray,
r_i_word_line: float,
r_i_bit_line: float,
applied_voltages: npt.NDArray,
**kwargs
resistances: Array, r_i_word_line: float, r_i_bit_line: float, applied_voltages: Array, **kwargs
) -> Solution:
"""Extracts branch currents and node voltages of a crossbar in a
convenient form.
Expand All @@ -37,7 +34,7 @@ def solution(
"""
r_i = Interconnect(r_i_word_line, r_i_bit_line)

if r_i.word_line == r_i.bit_line == np.inf:
if r_i.word_line == r_i.bit_line == jnp.inf:
return insulating_interconnect_solution(resistances, applied_voltages, **kwargs)

v = solve.v(resistances, r_i, applied_voltages)
Expand All @@ -52,9 +49,9 @@ def solution(

def currents(
extracted_voltages: Voltages,
resistances: npt.NDArray,
resistances: Array,
r_i: Interconnect,
applied_voltages: npt.NDArray,
applied_voltages: Array,
**kwargs
) -> Currents:
"""Extracts crossbar branch currents in a convenient format.
Expand Down Expand Up @@ -89,7 +86,7 @@ def currents(
return extracted_currents


def voltages(v: npt.NDArray, resistances: npt.NDArray, **kwargs) -> Voltages:
def voltages(v: Array, resistances: Array, **kwargs) -> Voltages:
"""Extracts crossbar node voltages in a convenient format.
Args:
Expand All @@ -108,7 +105,7 @@ def voltages(v: npt.NDArray, resistances: npt.NDArray, **kwargs) -> Voltages:
return extracted_voltages


def word_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
def word_line_voltages(v: Array, resistances: Array) -> Array:
"""Extracts voltages at the nodes on the word lines.
Args:
Expand All @@ -118,13 +115,11 @@ def word_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
Returns:
Voltages at the nodes on the word lines.
"""
v_domain = v[
: resistances.size,
]
v_domain = v[: resistances.size,]
return utils.distributed_array(v_domain, resistances)


def bit_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
def bit_line_voltages(v: Array, resistances: Array) -> Array:
"""Extracts voltages at the nodes on the bit lines.
Args:
Expand All @@ -134,15 +129,13 @@ def bit_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
Returns:
Voltages at the nodes on the bit lines.
"""
v_domain = v[
resistances.size :,
]
v_domain = v[resistances.size :,]
return utils.distributed_array(v_domain, resistances)


def output_currents(
extracted_voltages: Voltages, extracted_device_currents: npt.NDArray, r_i: Interconnect
) -> npt.NDArray:
extracted_voltages: Voltages, extracted_device_currents: Array, r_i: Interconnect
) -> Array:
"""Extracts output currents.
Args:
Expand All @@ -156,22 +149,17 @@ def output_currents(
Output currents.
"""
if r_i.bit_line > 0:
output_i = (
extracted_voltages.bit_line[
-1,
]
/ r_i.bit_line
)
output_i = extracted_voltages.bit_line[-1,] / r_i.bit_line
else:
output_i = np.sum(extracted_device_currents, axis=0)
output_i = jnp.sum(extracted_device_currents, axis=0)

output_i = np.transpose(output_i)
output_i = jnp.transpose(output_i)
if output_i.ndim == 1:
output_i = output_i.reshape(1, output_i.shape[0])
return output_i


def device_currents(extracted_voltages: Voltages, resistances: npt.NDArray):
def device_currents(extracted_voltages: Voltages, resistances: Array):
"""Extracts currents flowing through crossbar devices.
Args:
Expand All @@ -184,8 +172,8 @@ def device_currents(extracted_voltages: Voltages, resistances: npt.NDArray):
Currents flowing through crossbar devices.
"""
if extracted_voltages.word_line.ndim > 2:
resistances = np.repeat(
resistances[:, :, np.newaxis], extracted_voltages.word_line.shape[2], axis=2
resistances = jnp.repeat(
resistances[:, :, jnp.newaxis], extracted_voltages.word_line.shape[2], axis=2
)

v_diff = extracted_voltages.word_line - extracted_voltages.bit_line
Expand All @@ -196,10 +184,10 @@ def device_currents(extracted_voltages: Voltages, resistances: npt.NDArray):

def word_line_currents(
extracted_voltages: Voltages,
extracted_device_currents: npt.NDArray,
extracted_device_currents: Array,
r_i: Interconnect,
applied_voltages: npt.NDArray,
) -> npt.NDArray:
applied_voltages: Array,
) -> Array:
"""Extracts currents flowing through interconnect segments along the word
lines.
Expand All @@ -215,7 +203,7 @@ def word_line_currents(
Currents flowing through interconnect segments along the word lines.
"""
if r_i.word_line > 0:
word_line_i = np.zeros(extracted_device_currents.shape)
word_line_i = jnp.zeros(extracted_device_currents.shape)
if extracted_voltages.word_line.ndim > 2:
v_diff = (
applied_voltages
Expand All @@ -224,12 +212,13 @@ def word_line_currents(
0,
]
)
word_line_i[:, 0,] = (
v_diff / r_i.word_line
)
word_line_i = word_line_i.at[
:,
0,
].set(v_diff / r_i.word_line)
else:
v_diff = applied_voltages - extracted_voltages.word_line[:, [0]]
word_line_i[:, [0]] = v_diff / r_i.word_line
word_line_i = word_line_i.at[:, [0]].set(v_diff / r_i.word_line)

v_diff = (
extracted_voltages.word_line[
Expand All @@ -241,11 +230,12 @@ def word_line_currents(
1:,
]
)
word_line_i[:, 1:,] = (
v_diff / r_i.word_line
)
word_line_i = word_line_i.at[
:,
1:,
].set(v_diff / r_i.word_line)
else:
word_line_i = np.repeat(
word_line_i = jnp.repeat(
extracted_device_currents[
:,
-1:,
Expand All @@ -254,21 +244,26 @@ def word_line_currents(
axis=1,
)
for i in range(1, extracted_device_currents.shape[1]):
word_line_i[:, :-i,] += np.repeat(
extracted_device_currents[
:,
-(1 + i) : -i,
],
extracted_device_currents.shape[1] - i,
axis=1,
word_line_i = word_line_i.at[
:,
:-i,
].add(
jnp.repeat(
extracted_device_currents[
:,
-(1 + i) : -i,
],
extracted_device_currents.shape[1] - i,
axis=1,
)
)

return word_line_i


def bit_line_currents(
extracted_voltages: Voltages, extracted_device_currents: npt.NDArray, r_i: Interconnect
) -> npt.NDArray:
extracted_voltages: Voltages, extracted_device_currents: Array, r_i: Interconnect
) -> Array:
"""Extracts currents flowing through interconnect segments along the bit
lines.
Expand All @@ -283,7 +278,7 @@ def bit_line_currents(
Currents flowing through interconnect segments along the bit lines.
"""
if r_i.bit_line > 0:
bit_line_i = np.zeros(extracted_device_currents.shape)
bit_line_i = jnp.zeros(extracted_device_currents.shape)
v_diff = (
extracted_voltages.bit_line[
:-1,
Expand All @@ -294,37 +289,44 @@ def bit_line_currents(
:,
]
)
bit_line_i[:-1, :,] = (
v_diff / r_i.bit_line
)
bit_line_i = bit_line_i.at[
:-1,
:,
].set(v_diff / r_i.bit_line)
if extracted_voltages.bit_line.ndim > 2:
v_diff = extracted_voltages.bit_line[
-1,
:,
]
bit_line_i[-1, :,] = (
v_diff / r_i.bit_line
)
bit_line_i = bit_line_i.at[
-1,
:,
].set(v_diff / r_i.bit_line)
else:
v_diff = extracted_voltages.bit_line[[-1], :]
bit_line_i[[-1], :] = v_diff / r_i.bit_line
bit_line_i = bit_line_i.at[[-1], :].set(v_diff / r_i.bit_line)
else:
bit_line_i = np.zeros(extracted_device_currents.shape)
bit_line_i = jnp.zeros(extracted_device_currents.shape)
for i in range(extracted_device_currents.shape[0]):
bit_line_i[i:, :,] += np.repeat(
extracted_device_currents[
i : i + 1,
:,
],
extracted_device_currents.shape[0] - i,
axis=0,
bit_line_i = bit_line_i.at[
i:,
:,
].add(
jnp.repeat(
extracted_device_currents[
i : i + 1,
:,
],
extracted_device_currents.shape[0] - i,
axis=0,
)
)

return bit_line_i


def insulating_interconnect_solution(
resistances: npt.NDArray, applied_voltages: npt.NDArray, **kwargs
resistances: Array, applied_voltages: Array, **kwargs
) -> Solution:
"""Extracts solution when all interconnects are perfectly insulating.
Expand All @@ -343,9 +345,9 @@ def insulating_interconnect_solution(
"Warning: all interconnects are perfectly insulating! Node voltages are undefined!"
)

output_i = np.zeros((applied_voltages.shape[1], resistances.shape[1]))
output_i = jnp.zeros((applied_voltages.shape[1], resistances.shape[1]))
if kwargs.get("all_currents", True):
same_i = np.zeros((resistances.shape[0], resistances.shape[1], applied_voltages.shape[1]))
same_i = jnp.zeros((resistances.shape[0], resistances.shape[1], applied_voltages.shape[1]))
same_i = utils.squeeze_third_axis(same_i)
device_i = word_line_i = bit_line_i = same_i
logger.info("Extracted currents from all branches in the crossbar.")
Expand Down
29 changes: 10 additions & 19 deletions badcrossbar/computing/solve.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import logging

import jax
from jax import Array
import jax.numpy as jnp
import numpy as np
import numpy.typing as npt
from badcrossbar.computing import fill
from jax.scipy.sparse import linalg
from jax.experimental.sparse import BCOO

logger = logging.getLogger(__name__)


def v(resistances: npt.NDArray, r_i, applied_voltages: npt.NDArray):
def v(resistances: Array, r_i, applied_voltages: Array):
"""Solves matrix equation `gv = i`.
Args:
Expand All @@ -38,7 +37,7 @@ def v(resistances: npt.NDArray, r_i, applied_voltages: npt.NDArray):
v_matrix, _ = solver(g_matrix, i, tol=1e-12, atol=1e-12)
logger.info("Solved for v.")

v_matrix = np.array(v_matrix)
v_matrix = jnp.array(v_matrix)
# if `num_examples == 1`, it can result in 1D array.
if v_matrix.ndim == 1:
v_matrix = v_matrix.reshape(v_matrix.shape[0], 1)
Expand All @@ -47,26 +46,18 @@ def v(resistances: npt.NDArray, r_i, applied_voltages: npt.NDArray):
# matrix_v had to be solved. The other half can be filled without
# solving because the node voltages are known.
if r_i.word_line == 0:
new_v_matrix = np.zeros((2 * resistances.size, applied_voltages.shape[1]))
new_v_matrix[
: resistances.size,
] = np.repeat(applied_voltages, resistances.shape[1], axis=0)
new_v_matrix[
resistances.size :,
] = v_matrix
new_v_matrix = jnp.zeros((2 * resistances.size, applied_voltages.shape[1]))
new_v_matrix = new_v_matrix.at[: resistances.size, ].set(jnp.repeat(applied_voltages, resistances.shape[1], axis=0))
new_v_matrix = new_v_matrix.at[resistances.size :,].set(v_matrix)
v_matrix = new_v_matrix
if r_i.bit_line == 0:
new_v_matrix = np.zeros((2 * resistances.size, applied_voltages.shape[1]))
new_v_matrix[
: resistances.size,
] = v_matrix
new_v_matrix = jnp.zeros((2 * resistances.size, applied_voltages.shape[1]))
new_v_matrix = new_v_matrix.at[:resistances.size, ].set(v_matrix)
v_matrix = new_v_matrix
else:
# if both interconnect resistances are zero, all node voltages are
# known.
v_matrix = np.zeros((2 * resistances.size, applied_voltages.shape[1]))
v_matrix[
: resistances.size,
] = np.repeat(applied_voltages, resistances.shape[1], axis=0)
v_matrix = jnp.zeros((2 * resistances.size, applied_voltages.shape[1]))
v_matrix = v_matrix.at[:resistances.size, ].set(jnp.repeat(applied_voltages, resistances.shape[1], axis=0))

return v_matrix

0 comments on commit c041024

Please sign in to comment.