https://github.com/pytorch/ao/blob/42c23768d379e7d5acd8af0d84ee2a7672a66fcd/torchao/utils.py#L5 currently only supports cuda, we need to extend this to support cpu and mps as well