Skip to content

Commit

Permalink
improve obj_size
Browse files Browse the repository at this point in the history
  • Loading branch information
kongdd committed Oct 7, 2023
1 parent 76a5a43 commit d0293b3
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,33 +51,52 @@ seq_len(n) = 1:n
r_range(x) = [minimum(x), maximum(x)]


"""
obj_size(x)
obj_size(dims, T)
# Examples
```julia
dims = (100, 100, 200)
x = zeros(Float32, dims)
obj_size(x)
obj_size(dims, Float32)
```
"""
function obj_size(x)
ans = Base.summarysize(x) / 1024^2
ans = round(ans, digits=2)
print(typeof(x), " | ", size(x), " | ")
printstyled("$ans Mb\n"; color=:blue, bold=true, underline=true)
end

function obj_size(dims, T)
ans = Base.summarysize(T(0)) * prod(dims) / 1024^2
ans = round(ans, digits=2)
print(T, " | ", dims, " | ")
printstyled("$ans Mb\n"; color=:blue, bold=true, underline=true)
end


function squeeze_TailOrHead(A::AbstractArray; type="tail")
dims = size(A)
n = length(dims)

dims_drop = []
dims_drop = []
if type == "head"
inds = 1:n
elseif type == "tail"
inds = n:-1:1
end

for i = inds
if dims[i] == 1
push!(dims_drop, i)
else
break
end
end

if !isempty(dims_drop)
dropdims(A, dims=tuple(dims_drop...))
else
Expand All @@ -97,7 +116,7 @@ function zip_continue(x::AbstractVector{<:Integer})
flag = cumsum([true; diff(x) .!= 1])
grps = unique(flag)
n = grps[end]

res = []
for i = 1:n
inds = findall(flag .== grps[i])
Expand Down Expand Up @@ -144,9 +163,9 @@ export which_isna, which_notna,
mean, weighted_mean, weighted_sum,
seq_along, seq_len,
r_range,
nth,
nth,
selectdim_deep,
length_unique, unique_sort,
length_unique, unique_sort,
squeeze, squeeze_tail, squeeze_head,
abind,
set_seed;
Expand Down

0 comments on commit d0293b3

Please sign in to comment.