Skip to content

Commit

Permalink
unit test mid point
Browse files Browse the repository at this point in the history
Signed-off-by: Trevor Grant <[email protected]>
  • Loading branch information
ibm-peach-fish committed Oct 27, 2023
1 parent ac3a237 commit af417d8
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
24 changes: 14 additions & 10 deletions caikit_ray_backend/ray_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# Standard
from time import sleep
import base64
import json
import os
Expand Down Expand Up @@ -77,25 +78,28 @@ def main():
if model_path:
error.type_check("<RYT70238308E>", str, model_path=model_path)

timeout = 60
timeout = 3
if get_config().training_timeout:
try:
timeout = int(get_config().training_timeout)
timeout = float(get_config().training_timeout)
except ValueError:
log.warn(
f"training_timeout: '{get_config().training_timeout}' cannot be converted to int, ignoring"
)

# Finally kick off training
with alog.ContextTimer(log.debug, "Done training %s in: ", module_class):
ray.wait(
[
ray_training_tasks.train_and_save.options(
num_cpus=num_cpus, num_gpus=num_gpus
).remote(module_class, model_path, *args, **kwargs)
],
timeout=timeout
)
task = ray_training_tasks.train_and_save.options(
num_cpus=num_cpus, num_gpus=num_gpus
).remote(module_class, model_path, *args, **kwargs)

ready, _ = ray.wait([task], timeout=timeout)

if ready:
ray.get(task)
else:
ray.cancel(task)
log.error("Task did not complete before time out.")


if __name__ == "__main__":
Expand Down
21 changes: 21 additions & 0 deletions tests/test_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ def test_cancel(mock_ray_cluster, jsonl_file_data_stream):
assert status == TrainingStatus.CANCELED


def test_timeout(mock_ray_cluster, jsonl_file_data_stream):
config = {
"connection": {"address": mock_ray_cluster.address},
"training_timeout": 3,
}
trainer = RayJobTrainModule(config, "ray_backend")

args = [jsonl_file_data_stream]
model_future = trainer.train(
SampleModule,
*args,
save_path="/tmp",
)

time.sleep(5)

status = model_future.get_info().status
print("Final status was", status)
assert status == TrainingStatus.CANCELED


## Test Ray Backend


Expand Down

0 comments on commit af417d8

Please sign in to comment.