Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mark group convolution always cost viable to convert #958

Merged
merged 1 commit into from
Jan 7, 2025

Conversation

anakinxc
Copy link
Collaborator

@anakinxc anakinxc commented Jan 6, 2025

Pull Request

Should fixed #955 @w-gc feel free to run some testing.

What problem does this PR solve?

Issue Number: Fixed #955

Possible side effects?

  • Performance: n/a

  • Backward compatibility: n/a

@anakinxc anakinxc requested review from w-gc and rivertalk January 6, 2025 10:19
@anakinxc anakinxc merged commit 5ecbead into main Jan 7, 2025
11 of 13 checks passed
@anakinxc anakinxc deleted the fix/group_conv branch January 7, 2025 06:02
@github-actions github-actions bot locked and limited conversation to collaborators Jan 7, 2025
@w-gc
Copy link
Collaborator

w-gc commented Jan 7, 2025

It works on test case:

import jax.lax as jlax
import jax.numpy as jnp
import numpy as np

import spu.spu_pb2 as spu_pb2
import spu.utils.simulation as ppsim

if __name__ == "__main__":
    """
    You can modify the code below for debug purpose only.
    Please DONT commit it unless it will cause build break.
    """

    sim = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64)
    copts = spu_pb2.CompilerOptions()
    # Tweak compiler options
    copts.disable_div_sqrt_rewrite = True

    x = np.random.randn(4, 64, 64)
    y = np.random.randn(2, 8, 8)
    dn = jlax.conv_dimension_numbers(x.shape, y.shape,
                                ('NWC', 'WIO', 'NWC'))
    fn = lambda x, y: jlax.conv_general_dilated(x, y, window_strides=[2], padding='SAME', lhs_dilation=None, rhs_dilation=None, dimension_numbers=dn, feature_group_count=8, batch_group_count=1, precision=None, preferred_element_type=None)
    spu_fn = ppsim.sim_jax(sim, fn, copts=copts)
    z = spu_fn(x, y)

    print(spu_fn.pphlo)

    print(f"spu out = {z}")
    print(f"cpu out = {fn(x, y)}")

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
3 participants