Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
METH_VARARGS | METH_KEYWORDS, ""},
{"insertion_sort", (PyCFunction) insertion_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"insertion_sort_llvm", (PyCFunction)insertion_sort_llvm,
METH_VARARGS | METH_KEYWORDS, ""},
{"is_ordered", (PyCFunction) is_ordered,
METH_VARARGS | METH_KEYWORDS, ""},
{"linear_search", (PyCFunction) linear_search,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,122 @@ def _materialize(dtype: str) -> int:

except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")


def get_insertion_sort_ptr(dtype: str) -> int:
"""Get function pointer for insertion sort with specified dtype."""
dtype = dtype.lower().strip()
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

return _materialize_insertion(dtype)


def _build_insertion_sort_ir(dtype: str) -> str:
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

T, _ = _SUPPORTED[dtype]
i32 = ir.IntType(32)
i64 = ir.IntType(64)

mod = ir.Module(name=f"insertion_sort_{dtype}_module")
fn_name = f"insertion_sort_{dtype}"

fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
fn = ir.Function(mod, fn_ty, name=fn_name)

arr, n = fn.args
arr.name, n.name = "arr", "n"

b_entry = fn.append_basic_block("entry")
b_outer = fn.append_basic_block("outer")
b_inner = fn.append_basic_block("inner")
b_inner_latch = fn.append_basic_block("inner.latch")
b_exit = fn.append_basic_block("exit")

b = ir.IRBuilder(b_entry)
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
b.cbranch(cond_trivial, b_exit, b_outer)

b.position_at_end(b_outer)
i_phi = b.phi(i32, name="i")
i_phi.add_incoming(ir.Constant(i32, 1), b_entry) # start at 1

cond_outer = b.icmp_signed("<", i_phi, n)
b.cbranch(cond_outer, b_inner, b_exit)

b.position_at_end(b_inner)
# key = arr[i]
i64_idx = b.sext(i_phi, i64)
key_ptr = b.gep(arr, [i64_idx], inbounds=True)
key_val = b.load(key_ptr)

# j = i - 1
j = b.sub(i_phi, ir.Constant(i32, 1))
j64 = b.sext(j, i64)

b_inner_loop = fn.append_basic_block("inner.loop")
b.position_at_end(b_inner)
b.branch(b_inner_loop)

b.position_at_end(b_inner_loop)
cond_j = b.icmp_signed(">=", j, ir.Constant(i32, 0))
b.cbranch(cond_j, b_inner_latch, b_outer_latch)

b.position_at_end(b_inner_latch)
j64 = b.sext(j, i64)
arr_j_ptr = b.gep(arr, [j64], inbounds=True)
arr_j_val = b.load(arr_j_ptr)

# compare arr[j] > key
if isinstance(T, ir.IntType):
cmp = b.icmp_signed(">", arr_j_val, key_val)
else:
cmp = b.fcmp_ordered(">", arr_j_val, key_val)

b.cbranch(cmp, b_inner_latch, b_outer_latch)

# swap/move
b.store(arr_j_val, b.gep(arr, [j64 + ir.Constant(i64, 1)], inbounds=True))
j = b.sub(j, ir.Constant(i32, 1))
b.branch(b_inner_loop)

b.position_at_end(b_outer_latch)
b.store(key_val, b.gep(arr, [b.sext(j+ir.Constant(i32,1), i64)], inbounds=True))

i_next = b.add(i_phi, ir.Constant(i32, 1))
i_phi.add_incoming(i_next, b_outer_latch)
b.branch(b_outer)

b.position_at_end(b_exit)
b.ret_void()

return str(mod)


def _materialize_insertion(dtype: str) -> int:
_ensure_target_machine()

name = f"insertion_sort_{dtype}"
if dtype in _fn_ptr_cache:
return _fn_ptr_cache[dtype]

try:
llvm_ir = _build_insertion_sort_ir(dtype)
mod = binding.parse_assembly(llvm_ir)
mod.verify()

engine = binding.create_mcjit_compiler(mod, _target_machine)
engine.finalize_object()
engine.run_static_constructors()

addr = engine.get_function_address(name)
if not addr:
raise RuntimeError(f"Failed to get address for {name}")

_fn_ptr_cache[dtype] = addr
_engines[dtype] = engine
return addr
except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")
Loading
Loading