Skip to content

Commit

Permalink
Add plots to AlignmentJob (#509)
Browse files Browse the repository at this point in the history
Plots have only been added to `i6_core.corpus.filter.FilterSegmentsByAlignmentConfidenceJob`. However, it's useful to see the alignment score plot before actually filtering it. This PR adds optional (but recommended, given the default value of the plot_alignment_scores parameter) functionality to i6_core.mm.AlignmentJob.
  • Loading branch information
Icemole authored May 29, 2024
1 parent d782ce3 commit e84798f
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions mm/alignment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["AlignmentJob", "DumpAlignmentJob", "AMScoresFromAlignmentLogJob"]

import logging
import xml.etree.ElementTree as ET
import math
import os
Expand All @@ -15,6 +16,12 @@


class AlignmentJob(rasr.RasrCommand, Job):
"""
Align a dataset with the given feature scorer.
"""

__sis_hash_exclude__ = {"plot_alignment_scores": False}

def __init__(
self,
crp,
Expand All @@ -26,6 +33,7 @@ def __init__(
rtf=1.0,
extra_config=None,
extra_post_config=None,
plot_alignment_scores=False,
):
"""
:param rasr.crp.CommonRasrParameters crp:
Expand All @@ -37,6 +45,8 @@ def __init__(
:param float rtf:
:param extra_config:
:param extra_post_config:
:param plot_alignment_scores: Whether to plot the alignment scores (normalized over time) or not.
The recommended value is `True`. The default value is `False` for retrocompatibility purposes.
"""
assert isinstance(feature_scorer, rasr.FeatureScorer)

Expand All @@ -52,6 +62,7 @@ def __init__(
self.feature_scorer = feature_scorer
self.use_gpu = use_gpu
self.word_boundaries = word_boundaries
self.plot_alignment_scores = plot_alignment_scores

self.out_log_file = self.log_file_output_path("alignment", crp, True)
self.out_single_alignment_caches = dict(
Expand All @@ -75,6 +86,8 @@ def __init__(
cached=True,
)
self.out_word_boundary_bundle = self.output_path("word_boundary.cache.bundle", cached=True)
if self.plot_alignment_scores:
self.out_plot_avg = self.output_path("score.png")

self.rqmt = {
"time": max(rtf * crp.corpus_duration / crp.concurrent, 0.5),
Expand All @@ -91,6 +104,8 @@ def tasks(self):

yield Task("create_files", mini_task=True)
yield Task("run", resume="run", rqmt=rqmt, args=range(1, self.concurrent + 1))
if self.plot_alignment_scores:
yield Task("plot", resume="plot", rqmt=rqmt)

def create_files(self):
self.write_config(self.config, self.post_config, "alignment.config")
Expand Down Expand Up @@ -121,6 +136,39 @@ def run(self, task_id):
self.out_single_word_boundary_caches[task_id].get_path(),
)

def plot(self):
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Parse the files and search for the average alignment score values (normalized over time).
alignment_scores = []
for log_file in self.out_log_file.values():
logging.info("Reading: {}".format(log_file))
file_path = log_file.get_path()
document = ET.parse(util.uopen(file_path))
_seg_list = document.findall(".//segment")
for seg in _seg_list:
avg = seg.find(".//score/avg")
alignment_scores.append(float(avg.text))
del document

np_alignment_scores = np.asarray(alignment_scores)
higher_percentile = np.percentile(np_alignment_scores, 90) # There can be huge outliers.
logging.info(
f"Max {np_alignment_scores.max()}; min {np_alignment_scores.min()}; median {np.median(np_alignment_scores)}"
)
logging.info(f"Total number of segments: {np_alignment_scores.size}; 90-th percentile: {higher_percentile}")

# Plot the data.
matplotlib.use("Agg")
np.clip(np_alignment_scores, np_alignment_scores.min(), higher_percentile, out=np_alignment_scores)
plt.hist(np_alignment_scores, bins=100)
plt.xlabel("Average Maximum-Likelihood Score")
plt.ylabel("Number of Segments")
plt.title("Histogram of Alignment Scores")
plt.savefig(fname=self.out_plot_avg.get_path())

def cleanup_before_run(self, cmd, retry, task_id, *args):
util.backup_if_exists("alignment.log.%d" % task_id)
util.delete_if_exists("alignment.cache.%d" % task_id)
Expand Down

0 comments on commit e84798f

Please sign in to comment.