Skip to content
This repository has been archived by the owner on Mar 26, 2022. It is now read-only.

Commit

Permalink
Fix import bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
BreezeWhite committed Oct 6, 2021
1 parent 1765967 commit 0a7a09b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
53 changes: 49 additions & 4 deletions oemer/dewarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import scipy.ndimage
from scipy.interpolate import interp1d, griddata
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt

from .morph import morph_open
from .utils import get_logger
from oemer.morph import morph_open
from oemer.utils import get_logger


logger = get_logger(__name__)
Expand Down Expand Up @@ -288,8 +289,8 @@ def dewarp(img, coords_x, coords_y):
#f_name = "tabi"
img_path = f"../test_imgs/{f_name}.jpg"

#img_path = "../test_imgs/Chihiro/7.jpg"
img_path = "../test_imgs/Gym/2.jpg"
img_path = "../test_imgs/Chihiro/7.jpg"
#img_path = "../test_imgs/Gym/2.jpg"

ori_img = cv2.imread(img_path)
f_name, ext = os.path.splitext(os.path.basename(img_path))
Expand Down Expand Up @@ -324,3 +325,47 @@ def dewarp(img, coords_x, coords_y):
out[..., i] = cv2.remap(out[..., i].astype(np.float32), grid_y.astype(np.float32), mapping.astype(np.float32), cv2.INTER_CUBIC)

mix = np.hstack([ori_img, out])


import random
def teaser():
plt.clf()
plt.rcParams['axes.titlesize'] = 'medium'
plt.subplot(231)
plt.title("Predict")
plt.axis('off')
plt.imshow(st_pred, cmap="Greys")

plt.subplot(232)
plt.title("Morph")
plt.axis('off')
plt.imshow(pred, cmap='Greys')

plt.subplot(233)
plt.title("Quantize")
plt.axis('off')
plt.imshow(grid_map>0, cmap='Greys')

plt.subplot(234)
plt.title("Group")
plt.axis('off')
ggs = set(np.unique(gg_map))
ggs.remove(-1)
_gg_map = np.ones(gg_map.shape+(3,), dtype=np.uint8) * 255
for i in ggs:
ys, xs = np.where(gg_map==i)
for c in range(3):
v = random.randint(0, 255)
_gg_map[ys, xs, c] = v
plt.imshow(_gg_map)

plt.subplot(235)
plt.title("Connect")
plt.axis('off')
plt.imshow(new_gg_map>0, cmap='Greys')

plt.subplot(236)
plt.title("Dewarp")
plt.axis('off')
plt.imshow(out)

6 changes: 6 additions & 0 deletions oemer/ete.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def extract(args):
stems_rests = pred["stems_rests"]
else:
# Make predictions
if args.use_tf:
ori_inf_type = os.environ.get("INFERENCE_WITH_TF", None)
os.environ["INFERENCE_WITH_TF"] = "true"
staff, symbols, stems_rests, notehead, clefs_keys = generate_pred(str(img_path))
if args.use_tf:
os.environ["INFERENCE_WITH_TF"] = ori_inf_type
if args.save_cache:
data = {
'staff': staff,
Expand Down Expand Up @@ -197,6 +202,7 @@ def get_parser():
parser = argparse.ArgumentParser("Oemer", description="End-to-end OMR")
parser.add_argument("img_path", help="Path to the image.", type=str)
parser.add_argument("-o", "--output-path", help="Path to output the result file", type=str, default="./")
parser.add_argument("--use-tf", help="Use Tensorflow for model inference. Default is to use Onnxruntime.", action="store_true")
parser.add_argument(
"--save-cache",
help="Save the model predictions and the next time won't need to predict again.",
Expand Down
14 changes: 7 additions & 7 deletions oemer/staffline_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from sklearn.cluster import KMeans
from sklearn.linear_model import LinearRegression

from . import layers
from . import exceptions as E
from .utils import get_logger
from .bbox import find_lines, get_bbox, get_center
from oemer import layers
from oemer import exceptions as E
from oemer.utils import get_logger
from oemer.bbox import find_lines, get_bbox, get_center


logger = get_logger(__name__)
Expand Down Expand Up @@ -692,15 +692,15 @@ def dist(st):

if __name__ == "__main__":
f_name = "last"
#f_name = "tabi"
f_name = "tabi"
#f_name = "tabi_page2"
#f_name = "PXL2"
#f_name = "girl"
#f_name = "1"

pred = pickle.load(open(f"{f_name}.pkl", "rb"))['staff']
pred = pickle.load(open(f"../test_imgs/{f_name}.pkl", "rb"))['staff']
layers.register_layer("staff_pred", pred)
rr = range(10, 193)
rr = range(1130, 1400)
#staffs, zones = extract()
#staffs = extract_part(pred[..., rr], 0)
lines, norm = extract_line(pred[..., rr], 0)
Expand Down

0 comments on commit 0a7a09b

Please sign in to comment.