Skip to content

Commit

Permalink
refactor(remote-model): typing + minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseppeambrosio97 committed Dec 10, 2024
1 parent f49ed48 commit d265a02
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions focoos/remote_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(
instance_type: TrainInstance = TrainInstance.ML_G4DN_XLARGE,
volume_size: int = 50,
max_runtime_in_seconds: int = 36000,
):
) -> dict | None:
"""
Initiate the training of a remote model on the Focoos platform.
Expand Down Expand Up @@ -144,13 +144,12 @@ def train(
"hyperparameters": hyperparameters.model_dump(),
},
)
if res.status_code == 200:
return res.json()
else:
if res.status_code != 200:
logger.warning(f"Failed to train model: {res.status_code} {res.text}")
return None
return res.json()

def train_status(self):
def train_status(self) -> dict | None:
"""
Retrieve the current status of the model training.
Expand All @@ -163,13 +162,12 @@ def train_status(self):
ValueError: If the request to get training status fails.
"""
res = self.http_client.get(f"models/{self.model_ref}/train/status")
if res.status_code == 200:
return res.json()
else:
if res.status_code != 200:
logger.error(f"Failed to get train status: {res.status_code} {res.text}")
raise ValueError(
f"Failed to get train status: {res.status_code} {res.text}"
)
return res.json()

def train_logs(self) -> list[str]:
"""
Expand Down Expand Up @@ -297,7 +295,7 @@ def infer(
logger.error(f"Failed to infer: {res.status_code} {res.text}")
raise ValueError(f"Failed to infer: {res.status_code} {res.text}")

def train_metrics(self, period=60) -> Optional[dict]:
def train_metrics(self, period=60) -> dict | None:
"""
Retrieve training metrics for the model over a specified period.
Expand All @@ -310,18 +308,14 @@ def train_metrics(self, period=60) -> Optional[dict]:
Returns:
Optional[dict]: A dictionary containing the training metrics if the request is successful,
or None if the request fails.
Raises:
None explicitly, but may log warnings if the request fails.
"""
res = self.http_client.get(
f"models/{self.model_ref}/train/all-metrics?period={period}&aggregation_type=Average"
)
if res.status_code == 200:
return res.json()
else:
if res.status_code != 200:
logger.warning(f"Failed to get train logs: {res.status_code} {res.text}")
return None
return res.json()

def _log_metrics(self):
"""
Expand Down Expand Up @@ -373,7 +367,7 @@ def _log_metrics(self):
f"Iter {iter:.0f}: Loss {total_loss:.2f}, {eval_metric} {accuracy}"
)

def monitor_train(self, update_period=30):
def monitor_train(self, update_period=30) -> None:
"""
Monitor the training process of the model and log its status periodically.
Expand Down Expand Up @@ -430,7 +424,7 @@ def monitor_train(self, update_period=30):
logger.info(f"Model is not training, status: {status['main_status']}")
return

def stop_training(self):
def stop_training(self) -> None:
"""
Stop the training process of the model.
Expand All @@ -453,7 +447,7 @@ def stop_training(self):
f"Failed to get stop training: {res.status_code} {res.text}"
)

def delete_model(self):
def delete_model(self) -> None:
"""
Delete the model from the system.
Expand Down

0 comments on commit d265a02

Please sign in to comment.