Skip to content

Commit

Permalink
feat: complete Python interpolation implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mgreminger committed Jul 28, 2024
1 parent 731456b commit efc10de
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
80 changes: 72 additions & 8 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ class InterpolationFunction(TypedDict):
inputDims: list[float]
outputDims: list[float]
order: int
symbolic_function: NotRequired[UndefinedFunction] # this item is created in Python and doesn't exist in the incoming json

class CustomBaseUnits(TypedDict):
mass: str
Expand Down Expand Up @@ -1256,6 +1257,67 @@ def get_fluid_placeholder_map(fluid_functions: list[FluidFunction]) -> dict[Func
return new_map


NP = None

def load_numpy():
global NP
if NP is None:
NP = import_module('numpy')

def get_interpolation_wrapper(interpolation_function: InterpolationFunction):
global NP
if NP is None:
load_numpy()
NP = cast(Any, NP)

input_values = NP.array(interpolation_function["inputValues"])
output_values = NP.array(interpolation_function["outputValues"])

if not NP.all(NP.diff(input_values) > 0):
raise ValueError('The input values must be an increasing sequence for interpolation')

def interpolation_wrapper(input: Expr):
global NP
NP = cast(Any, NP)

if input.is_number:
float_input = float(input)

if float_input < input_values[0] or float_input > input_values[-1]:
raise ValueError('Attempt to extrapolate with an interpolation function')

return sympify(NP.interp(input, input_values, output_values))
else:
if "symbolic_function" not in interpolation_function:
custom_func = cast(Callable[[Expr], Expr], Function(interpolation_function["name"]))
custom_func = implemented_function(custom_func, lambda arg: cast(Any, NP).interp(float(arg), input_values, output_values) )
interpolation_function["symbolic_function"] = cast(UndefinedFunction, custom_func)

return interpolation_function["symbolic_function"](input)

def interpolation_dims_wrapper(input):
ensure_dims_all_compatible(get_dims(interpolation_function["inputDims"]), input)

return get_dims(interpolation_function["outputDims"])

return interpolation_wrapper, interpolation_dims_wrapper

def get_interpolation_placeholder_map(interpolation_functions: list[InterpolationFunction]) -> dict[Function, PlaceholderFunction]:
new_map = {}

for interpolation_function in interpolation_functions:
match interpolation_function["type"]:
case "interpolation":
sympy_func, dim_func = get_interpolation_wrapper(interpolation_function)
case _:
continue

new_map[Function(interpolation_function["name"])] = {"dim_func": dim_func,
"sympy_func": sympy_func}

return new_map


custom_data_table_id = Function('custom_data_table_id')

class DataTableSubs:
Expand Down Expand Up @@ -2801,14 +2863,16 @@ def solve_sheet(statements_and_systems):
"selectedSolution": selected_solution
})

# if there are fluid definitions, update placeholder functions
if len(fluid_definitions) > 0:
fluid_placeholder_map = get_fluid_placeholder_map(fluid_definitions)
placeholder_map = global_placeholder_map | fluid_placeholder_map
placeholder_set = set(placeholder_map.keys())
else:
placeholder_map = global_placeholder_map
placeholder_set = global_placeholder_set
fluid_placeholder_map = get_fluid_placeholder_map(fluid_definitions)

try:
interpolation_placeholder_map = get_interpolation_placeholder_map(interpolation_definitions)
except Exception as e:
error = f"Error generating interpolation or polyfit function: {e}"
return dumps(Results(error=error, results=[], systemResults=[]))

placeholder_map = global_placeholder_map | fluid_placeholder_map | interpolation_placeholder_map
placeholder_set = set(placeholder_map.keys())

custom_definition_names = [value["name"] for value in fluid_definitions]
custom_definition_names.extend( (value["name"] for value in interpolation_definitions) )
Expand Down
6 changes: 6 additions & 0 deletions public/webworker.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ let py_funcs;
let recursionError = false;
let pyodide;
let coolpropLoaded = false;
let numpyLoaded = false;

async function setup() {
try {
Expand Down Expand Up @@ -44,6 +45,11 @@ self.onmessage = async function(e){
coolpropLoaded = true;
}

if (e.data.needNumpy && !coolpropLoaded && !numpyLoaded) {
await pyodide.loadPackage("numpy");
numpyLoaded = true;
}

const result = py_funcs.solveSheet(e.data.data);

self.postMessage(JSON.parse(result));
Expand Down
8 changes: 5 additions & 3 deletions src/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,8 @@
refreshSheet(); // pushState does not trigger onpopstate event
}
function getResults(statementsAndSystems: string, myRefreshCount: BigInt, needCoolprop: Boolean) {
function getResults(statementsAndSystems: string, myRefreshCount: BigInt,
needCoolprop: Boolean, needNumpy: Boolean) {
return new Promise<Results>((resolve, reject) => {
function handleWorkerMessage(e) {
forcePyodidePromiseRejection = null;
Expand All @@ -808,7 +809,7 @@
} else {
forcePyodidePromiseRejection = () => reject("Restarting pyodide.")
pyodideWorker.onmessage = handleWorkerMessage;
pyodideWorker.postMessage({cmd: 'sheet_solve', data: statementsAndSystems, needCoolprop});
pyodideWorker.postMessage({cmd: 'sheet_solve', data: statementsAndSystems, needCoolprop, needNumpy});
}
});
}
Expand Down Expand Up @@ -964,7 +965,8 @@
}
pyodidePromise = getResults(statementsAndSystems,
myRefreshCount,
Boolean(statementsAndSystemsObject.fluidFunctions.length > 0))
Boolean(statementsAndSystemsObject.fluidFunctions.length > 0),
Boolean(statementsAndSystemsObject.interpolationFunctions.length > 0))
.then((data: Results) => {
$results = [];
$resultsInvalid = false;
Expand Down

0 comments on commit efc10de

Please sign in to comment.