Skip to content

Incosistent Results for same circuit #43

@Arthaj-Octopus

Description

@Arthaj-Octopus

Hello, I'm currently trying to simulate a circuit. Initially, I was attempting to recreate results I got with another software, ANSYS INTERCONNECT, on an optical filter. The results were expected to be as follows:
icfilter
However, on attempting to recreate the circuit in SAX, I got varying results, even across two files that represent the same circuit.
out1
out2
Note that I have already confirmed the waveguide/coupler modelling is equivalent to ITNERCONNECT, So that is not the issue. Here are the scripts for each:

First plot

import numpy as np
import matplotlib.pyplot as plt
import sax
import networkx as nx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax
import math
#import meow
from typing import List
from scipy.constants import c
from sax.circuit import (
    _create_dag,
    _find_leaves,
    _find_root,
    _flat_circuit,
    _validate_models,
    draw_dag,
)

def swap_output_order(dictionary):
    swapped_dict = dictionary.copy()
    swapped_pairs = {value: key for key, value in dictionary.items() if key.startswith("out")}
    swapped_dict.update(swapped_pairs)
    for key in swapped_pairs.values():
        del swapped_dict[key]
    return swapped_dict
def create_graph(connections):
    connections=swap_output_order(connections)
    G = nx.DiGraph()

    # Add edges while treating components of the same name as one
    for source, target in connections.items():
        source_node = source.split(',')[0]
        target_node = target.split(',')[0]
        G.add_edge(source_node, target_node)


    pos = nx.nx_agraph.graphviz_layout(G, prog='dot', args='-Grankdir=LR')
    plt.figure(figsize=(20, 6))
    nx.draw(
        G, pos, with_labels=True, node_size=3000, 
        node_color="skyblue", font_size=10, font_weight="bold", arrowsize=20
    )
    plt.title("Left-to-Right Diagram of Components", fontsize=14)
    plt.show()

wli=1.5493214338870651
wlf=1.555747057602491
wl = np.linspace(wli, wlf, 100_000)


# # Component Definitions

def coupler(coupling=0.5)-> sax.SDict:
    kappa = coupling**0.5
    tau = (1-coupling)**0.5
    sdict = sax.reciprocal({
        ("in0", "out0"): tau,
        ("in0", "out1"): 1j*kappa,
        ("in1", "out0"): 1j*kappa,
        ("in1", "out1"): tau,
    })
    assert abs(tau**2 + kappa**2 - 1) < 1e-6, "Coupler is not normalized"
    return sdict

def waveguide(wl=1.55, wl0=1.55, neff=8.06019, ng=8.05894, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission =  amplitude * jnp.exp(1j * phase)
    sdict = sax.reciprocal({("in", "out"): transmission})
    return sdict

neff=8.06019
ng=8.05894

netlist={
    "instances": {
        "wg_0": "waveguide",
        "wg_1": "waveguide",
        "null_wg_0": "waveguide",
        "null_wg_1": "waveguide",
        "dc_0": "dc",
        "dc_1": "dc",
        "dc_2": "dc",
    },
    "connections": {
        "dc_0,out0": "wg_0,in", #Top
        "wg_0,out": "dc_1,in0",
        "dc_1,out0": "wg_1,in",
        "wg_1,out": "dc_2,in0",
        #
        "dc_0,out1": "null_wg_0,in", #Bot
        "null_wg_0,out": "dc_1,in1",
        "dc_1,out1": "null_wg_1,in",
        "null_wg_1,out": "dc_2,in1",
    },
    "ports": {
        "in0": "dc_0,in0",
        "out1": "dc_2,out0",
        "out2": "dc_2,out1",
        "in1": "dc_0,in1"
    },
}

models={
    "waveguide": waveguide,
    "dc": coupler
    }
fi_1, info = sax.circuit(netlist=netlist,models=models, backend="klu")

map=netlist["connections"] | netlist["ports"]
create_graph(map)

sax.get_settings(fi_1)
dc_0_r=0.066393494773
dc_1_r=0.7499093045
S = fi_1(wl=wl,
         dc_0={"coupling" :dc_1_r},
         dc_1={"coupling" :dc_1_r},
         dc_2={"coupling":0.5},
         wg_0={"length": 619.9997725,"wl":wl},
         wg_1={"length": 309.9998863,"wl":wl},
         null_wg_0={"length": 0,"wl":wl},
         null_wg_1={"length": 0,"wl":wl},)


mag1 = jnp.abs(S["in0", "out1"])**2
mag2 = jnp.abs(S["in0", "out2"])**2
fig, axs = plt.subplots(1, 1, sharex=True)
axs.plot(wl,mag1,label="1")
axs.plot(wl,mag2,label="2")
axs.legend()
axs.set_ylabel("Transmission")
plt.suptitle("Filter Response")
plt.show()

Plot 2:

import sax
from sax.circuit import _create_dag, draw_dag
import jax.numpy as jnp
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

def swap_key_value(dictionary):
    swapped_dict = dictionary.copy()
    swapped_pairs = {value: key for key, value in dictionary.items() if key.startswith("out")}
    swapped_dict.update(swapped_pairs)
    for key in swapped_pairs.values():
        del swapped_dict[key]
    return swapped_dict
def create_graph(connections):
    connections=swap_key_value(connections)
    G = nx.DiGraph()

    # Add edges while treating components of the same name as one
    for source, target in connections.items():
        source_node = source.split(',')[0]  # Get component name (e.g., 'coupler_1')
        target_node = target.split(',')[0]  # Get component name (e.g., 'btm_1')
        G.add_edge(source_node, target_node)

    # Use a Graphviz layout for left-to-right direction
    pos = nx.nx_agraph.graphviz_layout(G, prog='dot', args='-Grankdir=LR')

    # Plot the graph
    plt.figure(figsize=(20, 6))
    nx.draw(
        G, pos, with_labels=True, node_size=3000, 
        node_color="skyblue", font_size=10, font_weight="bold", arrowsize=20
    )
    plt.title("Left-to-Right Diagram of Components", fontsize=14)
    plt.show()


def waveguide(wl=1.55, wl0=1.55, neff=8.06019, ng=8.05894, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    amplitude = jnp.asarray(10 ** (-loss * length / 20), dtype=complex)
    transmission =  amplitude * jnp.exp(1j * phase)
    sdict = sax.reciprocal({("in0", "out0"): transmission})
    return sdict

def coupler(coupling=0.5) -> sax.SDict:
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    coupler_dict = sax.reciprocal(
        {
            ("in0", "out0"): tau,
            ("in0", "out1"): 1j * kappa,
            ("in1", "out0"): 1j * kappa,
            ("in1", "out1"): tau,
        }
    )
    return coupler_dict


def cascaded_amzi_generator(n, backend="klu"):
    netlist = {
        "instances": {},
        "connections": {},
        "ports": {},
    }

    models = {
        "coupler": coupler,
        "waveguide": waveguide,
    }

    # Loop values to build network
    for i in range(1, n + 1):
        # Define instance names and initialize
        coupler_name = f"coupler_{i}"
        top_name = f"top_{i}"
        btm_name = f"btm_{i}"
        netlist["instances"][coupler_name] = "coupler"
        netlist["instances"][top_name] = "waveguide"
        netlist["instances"][btm_name] = "waveguide"

        # Interstage connections
        netlist["connections"][f"{coupler_name},out0"] = f"{btm_name},in0"
        netlist["connections"][f"{coupler_name},out1"] = f"{top_name},in0"
        # Connections between stages
        if i > 1:
            netlist["connections"][f"top_{i-1},out0"] = f"{coupler_name},in0"
            netlist["connections"][f"btm_{i-1},out0"] = f"{coupler_name},in1"
    netlist["instances"]["final"] = "coupler"
    netlist["connections"][f"{btm_name},out0"] = "final,in0"
    netlist["connections"][f"{top_name},out0"] = "final,in1"

    # External ports
    netlist["ports"]["in0"] = "coupler_1,in0"
    netlist["ports"]["in1"] = "coupler_1,in1"
    netlist["ports"]["out0"] = "final,out0"
    netlist["ports"]["out1"] = "final,out1"
    # Map for figures
    map=netlist["connections"] | netlist["ports"]
    # Create the circuit
    mzi_ideal, info = sax.circuit(netlist=netlist, models=models, backend=backend)

    return mzi_ideal, info,map


wli=1.5493214338870651
wlf=1.555747057602491
wavelengths = np.linspace(wli, wlf, 100_000)
n = 2
params = {"wl": wavelengths}
dc_array=np.array([0.7499093045,0.066393494773])
wg_array=np.array([619.9997725,309.9998863])

for i in range(1, n + 1):
    params[f"coupler_{i}"] = {"coupling": dc_array[i-1]}
    params[f"top_{i}"] = {"length": wg_array[i-1]}
    params[f"btm_{i}"] = {"length": 0}



mzi_ideal, info, map = cascaded_amzi_generator(n, backend="klu")
create_graph(map)



S = mzi_ideal(**params)
transmissions_klu = jnp.abs(S["in1", "out1"]) ** 2
transmissions_klu2 = jnp.abs(S["in1", "out0"]) ** 2



fig, axs = plt.subplots(figsize=(6, 4))

plt.plot(wavelengths, transmissions_klu, label="out1")
plt.plot(wavelengths, transmissions_klu2, label="out0")

plt.xlabel("wavelength (um)")
plt.ylabel("transmission (Normalized)")
plt.legend()
plt.grid()
plt.ylim(0, 1)
plt.show() 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions