Skip to content

Commit

Permalink
Rebase to axlearn main
Browse files Browse the repository at this point in the history
  • Loading branch information
jiya-zhang committed Jan 14, 2025
2 parents 89ac1ea + feb8357 commit d63dd6a
Show file tree
Hide file tree
Showing 82 changed files with 5,561 additions and 1,986 deletions.
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* @ruomingp @markblee
* @ruomingp @markblee @apple/axlearn-admins
4 changes: 2 additions & 2 deletions axlearn/audio/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class Config(BaseLayer.Config):

# Number of output channels.
output_dim: Required[int] = REQUIRED
# Number of filters/bands in the output spectrogram.
num_filters: Required[int] = REQUIRED
# Number of input samples per second, e.g., 24000 for 24KHz inputs.
sample_rate: Required[int] = REQUIRED
# Size of each frame in ms.
Expand Down Expand Up @@ -132,6 +130,8 @@ class LogMelFrontend(BaseFrontend):
class Config(BaseFrontend.Config):
"""Configures LogMelFrontend."""

# Number of filters/bands in the output spectrogram.
num_filters: Required[int] = REQUIRED
# Number of output channels. Should always be 1.
output_dim: int = 1
# Optional output transformation. See `normalize_by_mean_std` for an example.
Expand Down
9 changes: 6 additions & 3 deletions axlearn/audio/frontend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def pre_emphasis(x: Tensor, *, coeff: Tensor) -> Tensor:
return x[..., 1:] - coeff * x[..., :-1]


def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor:
"""Applies windowing to the input frames of shape `[..., num_windows, window_size]`."""
window_size = x.shape[-1]
def window_coffs(window_size: int, *, window_type: WindowType, periodic: bool = True) -> Tensor:
is_even = (1 - window_size % 2) * periodic

if window_type == WindowType.HANN:
Expand All @@ -261,7 +259,12 @@ def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> T
coeffs = jnp.hamming(window_size + is_even)[:window_size]
else:
raise NotImplementedError(f"Unrecognized window_type {window_type}.")
return coeffs


def windowing(x: Tensor, *, window_type: WindowType, periodic: bool = True) -> Tensor:
"""Applies windowing to the input frames of shape `[..., num_windows, window_size]`."""
coeffs = window_coffs(x.shape[-1], window_type=window_type, periodic=periodic)
return (x * coeffs).astype(x.dtype)


Expand Down
40 changes: 33 additions & 7 deletions axlearn/audio/frontend_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
next_power_of_2,
pre_emphasis,
sharded_fft,
window_coffs,
windowing,
)
from axlearn.audio.test_utils import fake_audio
Expand Down Expand Up @@ -160,6 +161,20 @@ def test_window(self, input_shape, window_type: WindowType, periodic: bool):
windowing(inputs, window_type=window_type, periodic=periodic),
)

@parameterized.product(
window_size=[400, 401],
window_type=list(WindowType),
periodic=[True, False],
)
def test_window_coffs(self, window_size, window_type: WindowType, periodic: bool):
ref_coffs = _ref_window_coffs(
window_size=window_size, window_type=window_type, periodic=periodic
)
test_coeffs = window_coffs(
window_size=window_size, window_type=window_type, periodic=periodic
)
self.assertAllClose(ref_coffs, test_coeffs)


class SpectrogramTest(parameterized.TestCase, tf.test.TestCase):
"""Tests spectrograms."""
Expand Down Expand Up @@ -296,19 +311,30 @@ def _ref_pre_emphasis(*, inputs: ArrayLike, coeff: float):
return inputs[:, :, 1:] - coeff * inputs[:, :, :-1]


def _ref_window(*, inputs: ArrayLike, window_type: WindowType, **kwargs):
def _ref_window_coffs(
*, window_size: int, window_type: WindowType, periodic: bool = True, dtype=tf.float32
):
if window_type == WindowType.HANN:
tf_window = tf.signal.hann_window(window_size, periodic=periodic, dtype=dtype)
elif window_type == WindowType.HAMMING:
tf_window = tf.signal.hamming_window(window_size, periodic=periodic, dtype=dtype)
else:
raise NotImplementedError(f"Unrecognized window type: {window_type}")
return tf_window


def _ref_window(
*, inputs: ArrayLike, window_type: WindowType, periodic: bool = True, dtype=tf.float32
):
"""Lingvo window.
Reference:
https://github.com/tensorflow/lingvo/blob/4a9097a212622d99d7f8e2379804dbffdc44a97f/lingvo/tasks/asr/frontend.py#L244
"""
frame_size = inputs.shape[-1]
if window_type == WindowType.HANN:
tf_window = tf.signal.hann_window(frame_size, **kwargs)
elif window_type == WindowType.HAMMING:
tf_window = tf.signal.hamming_window(frame_size, **kwargs)
else:
raise NotImplementedError(f"Unrecognized window type: {window_type}")
tf_window = _ref_window_coffs(
window_size=frame_size, window_type=window_type, periodic=periodic, dtype=dtype
)
return inputs * tf_window


Expand Down
115 changes: 63 additions & 52 deletions axlearn/cloud/common/bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,8 @@ def __init__(self, cfg: Config):
tf_io.gfile.makedirs(self._job_history_dir)
self._project_history_dir = os.path.join(self._output_dir, "history", "projects")
tf_io.gfile.makedirs(self._project_history_dir)
self._scheduler_history_dir = os.path.join(self._output_dir, "history", "scheduler")
tf_io.gfile.makedirs(self._scheduler_history_dir)
# Mapping from project_id to previous job verdicts.
self._project_history_previous_verdicts = {}
# Jobs that have fully completed.
Expand Down Expand Up @@ -745,12 +747,22 @@ def _append_to_job_history(self, job: Job, *, msg: str, state: JobLifecycleState
)
)

def _append_to_project_history(
def _append_to_history(
self, jobs: dict[str, JobMetadata], schedule_results: BaseScheduler.ScheduleResults
):
now = datetime.now(timezone.utc)
with tf_io.gfile.GFile(
os.path.join(self._scheduler_history_dir, now.strftime("%Y%m%d-%H%M%S")), "a"
) as f:
for job_id, verdict in schedule_results.job_verdicts.items():
job_metadata = jobs[job_id]
f.write(f"{job_id} [{job_metadata}] {verdict}\n")
for project_id, limits in schedule_results.project_limits.items():
job_verdicts = schedule_results.job_verdicts.get(project_id, {})
job_verdicts = {
job_id: verdict
for job_id, verdict in schedule_results.job_verdicts.items()
if jobs[job_id].project_id == project_id
}
verdicts = []
for job_id, verdict in job_verdicts.items():
verdicts.append((job_id, verdict.should_run(), verdict.metadata))
Expand Down Expand Up @@ -1075,57 +1087,56 @@ def _update_jobs(self):
dry_run=schedule_options["dry_run"],
verbosity=schedule_options["verbosity"],
)
self._append_to_project_history(schedulable_jobs, schedule_results)
for verdicts in schedule_results.job_verdicts.values():
for job_name, verdict in verdicts.items():
job = self._active_jobs[job_name]
assert job.state.status in {JobStatus.PENDING, JobStatus.ACTIVE}

if verdict:
old_tier = job.state.metadata.get("tier")
new_tier = verdict.metadata.get("tier")
changed_tiers = old_tier != new_tier

jobspec_changed = job.state.metadata.get("updated")

# Jobspec changed, trigger a restart of the runner.
if jobspec_changed:
self._append_to_job_history(
job,
msg="UPDATING: Detected updated jobspec. Will restart the runner "
"by sending to PENDING state",
state=JobLifecycleState.UPDATING,
)
job.state.status = JobStatus.PENDING
elif job.state.status == JobStatus.PENDING or not changed_tiers:
# Resume if not running, or keep running if scheduling tier did not change.
job.state.status = JobStatus.ACTIVE
else:
# Job changed scheduling tiers, and must be restarted on the new tier.
# NOTE: this can possibly lead to thrashing of jobs that frequently switch
# tiers. One option is track per-job tier changes and hold off on promoting
# low priority to high priority if it was demoted recently.
# TODO(markblee): Add instrumentation to track frequency of tier changes to
# see whether this is necessary.
assert job.state.status == JobStatus.ACTIVE and changed_tiers
self._append_to_job_history(
job,
msg=f"Rescheduling at a different tier from {old_tier} to {new_tier}",
state=JobLifecycleState.RESCHEDULING,
)
job.state.status = JobStatus.PENDING
self._append_to_history(schedulable_jobs, schedule_results)
for job_name, verdict in schedule_results.job_verdicts.items():
job = self._active_jobs[job_name]
assert job.state.status in {JobStatus.PENDING, JobStatus.ACTIVE}

if verdict:
old_tier = job.state.metadata.get("tier")
new_tier = verdict.metadata.get("tier")
changed_tiers = old_tier != new_tier

jobspec_changed = job.state.metadata.get("updated")

# Jobspec changed, trigger a restart of the runner.
if jobspec_changed:
self._append_to_job_history(
job,
msg="UPDATING: Detected updated jobspec. Will restart the runner "
"by sending to PENDING state",
state=JobLifecycleState.UPDATING,
)
job.state.status = JobStatus.PENDING
elif job.state.status == JobStatus.PENDING or not changed_tiers:
# Resume if not running, or keep running if scheduling tier did not change.
job.state.status = JobStatus.ACTIVE
else:
# Pre-empt/stay queued.
if job.command_proc is not None and _is_proc_complete(job.command_proc):
# As a slight optimization, we avoid pre-empting ACTIVE jobs that are
# complete, since we can directly transition to CLEANING.
job.state.status = JobStatus.ACTIVE
else:
job.state.status = JobStatus.PENDING
# Pending jobs which are not rescheduled should have no tier information.
verdict.metadata.pop("tier", None)

job.state.metadata = verdict.metadata
# Job changed scheduling tiers, and must be restarted on the new tier.
# NOTE: this can possibly lead to thrashing of jobs that frequently switch
# tiers. One option is track per-job tier changes and hold off on promoting
# low priority to high priority if it was demoted recently.
# TODO(markblee): Add instrumentation to track frequency of tier changes to
# see whether this is necessary.
assert job.state.status == JobStatus.ACTIVE and changed_tiers
self._append_to_job_history(
job,
msg=f"Rescheduling at a different tier from {old_tier} to {new_tier}",
state=JobLifecycleState.RESCHEDULING,
)
job.state.status = JobStatus.PENDING
else:
# Pre-empt/stay queued.
if job.command_proc is not None and _is_proc_complete(job.command_proc):
# As a slight optimization, we avoid pre-empting ACTIVE jobs that are
# complete, since we can directly transition to CLEANING.
job.state.status = JobStatus.ACTIVE
else:
job.state.status = JobStatus.PENDING
# Pending jobs which are not rescheduled should have no tier information.
verdict.metadata.pop("tier", None)

job.state.metadata = verdict.metadata

# TODO(markblee): Parallelize this.
for job_name, job in self._active_jobs.items():
Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/common/bastion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ def test_update_scheduler(
):
with self._patch_bastion() as mock_bastion:
patch_update = mock.patch.object(mock_bastion, "_update_single_job")
patch_history = mock.patch.object(mock_bastion, "_append_to_project_history")
patch_history = mock.patch.object(mock_bastion, "_append_to_history")
patch_scheduler = mock.patch.object(mock_bastion, "_scheduler")

with patch_update, patch_history, patch_scheduler as mock_scheduler:
Expand Down
2 changes: 1 addition & 1 deletion axlearn/cloud/common/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,6 @@ def sweep(self, jobs: dict[str, JobSpec]) -> Sequence[str]:
schedule_result = scheduler.schedule(
dict(my_job=job_spec.metadata),
)
if schedule_result.job_verdicts[job_spec.metadata.project_id]["my_job"].over_limits:
if schedule_result.job_verdicts["my_job"].over_limits:
result.append(job_name)
return result
30 changes: 12 additions & 18 deletions axlearn/cloud/common/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,15 @@ class ScheduleResults:
Attributes:
project_limits: The effective resource limits.
project_usages: The resource usages.
job_verdicts: A mapping of project_id -> (job_id -> run_or_not).
job_verdicts: A mapping of job_id -> run_or_not.
The entries will be ordered by descending scheduling priorities (not necessarily
JobMetadata.priority), where the higher priority jobs will be scheduled before
lower priority ones. The jobs not getting scheduled will also be ordered.
"""

project_limits: ProjectResourceMap[int]
project_usages: ProjectResourceMap[int]
job_verdicts: dict[str, dict[str, JobVerdict]]
job_verdicts: dict[str, JobVerdict]

def schedule(
self,
Expand Down Expand Up @@ -345,8 +348,8 @@ def traverse_tiers(
break
return tier_usages

job_verdicts = collections.defaultdict(dict)
while not project_queue.empty() and remaining_limits:
job_verdicts = {}
while not project_queue.empty():
project_usage_ratio, _, project_id = project_queue.get()
job_id, job_metadata = project_jobs[project_id].popleft()

Expand Down Expand Up @@ -381,17 +384,10 @@ def traverse_tiers(
"Schedule %s(%s)/%s: %s", project_id, project_usage_ratio, job_id, verdict
)

job_verdicts[project_id][job_id] = verdict
job_verdicts[job_id] = verdict
if project_jobs[project_id]:
project_queue.put(project_queue_item(project_id))

# Remaining jobs are rejected.
for project_id, job_queue in project_jobs.items():
for job_id, job_metadata in job_queue:
job_verdicts[project_id][job_id] = JobVerdict(
over_limits=set(job_metadata.resources.keys())
)

return BaseScheduler.ScheduleResults(
# Treat the usages as the limits.
project_limits=_recursively_to_dict(project_usages),
Expand Down Expand Up @@ -472,14 +468,15 @@ def schedule(
logging.info("")
logging.info("==Begin scheduling report")
logging.info("Total resource limits: %s", resource_limits)
for project_id, project_verdicts in schedule_results.job_verdicts.items():
for project_id, project_job_queue in project_jobs.items():
logging.info(
"Verdicts for Project [%s] Quota [%s] Effective limits [%s]:",
project_id,
project_quotas.get(project_id, {}),
schedule_results.project_limits.get(project_id, {}),
)
for job_name, job_verdict in project_verdicts.items():
for job_name, job_metadata in project_job_queue:
job_verdict = schedule_results.job_verdicts[job_name]
logging.info(
"Job %s: Resources [%s] Over limits [%s] Should Run? [%s] Metadata [%s]",
job_name,
Expand All @@ -500,9 +497,6 @@ def schedule(
schedule_results = BaseScheduler.ScheduleResults(
project_limits=schedule_results.project_limits,
project_usages=project_usages,
job_verdicts={
project_id: {job_name: JobVerdict() for job_name in project_verdicts}
for project_id, project_verdicts in schedule_results.job_verdicts.items()
},
job_verdicts={job_name: JobVerdict() for job_name in schedule_results.job_verdicts},
)
return schedule_results
Loading

0 comments on commit d63dd6a

Please sign in to comment.