Skip to content

parallel computation of mask in constrained sampling #964

@mmoskal

Description

@mmoskal

Currently (as per #899) , the masks are always computed in sequence.

Are the logit Tensors typically async? That is not synchronized yet, so logits.to_vec1() in sample() is what takes almost all time?

If so, I guess we could stay with the current interface and just do tokio_rayon::spawn() for the mask as you do for the sampling.

Otherwise, it would be good to kick-off mask computation before starting the forward pass.

This also depends a little on what we do with #963

Metadata

Metadata

Assignees

No one assigned

    Labels

    new featureNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions