diff --git a/swanlab/data/cloud/log_collector.py b/swanlab/data/cloud/log_collector.py index 50d259598..dbe1edd23 100644 --- a/swanlab/data/cloud/log_collector.py +++ b/swanlab/data/cloud/log_collector.py @@ -15,12 +15,17 @@ from swanlab.error import SyncError from .task_types import UploadType +from psutil import cpu_percent +from swanlab.data.run.main import get_run +from swanlab.data.modules import MediaType, DataWrapper, FloatConvertible, Line + class LogCollectorTask(ThreadTaskABC): """ 日志聚合器,负责收集所有线程注册的日志信息 并且定义日志上传接口 """ + UPLOAD_TIME = 1 """ 每隔多少秒上传一次日志 @@ -114,6 +119,21 @@ def task(self, u: ThreadUtil, *args): swanlog.error(f"upload error: {e}") self.lock = False + def hardware_task(self, u: ThreadUtil, *args): + """ + 定时任务,读取硬件信息添加到记录 + :param u: 线程工具类 + """ + cpu_percent_now = float(cpu_percent()) + run = get_run() + if run == None: + return + else: + exp = run.exp + print(f'My exp = {exp}') + exp.add(key="CPU", data=DataWrapper("CPU", [Line(cpu_percent_now)]), step=None) + print(f'Now CPU usage: {cpu_percent_now}%') + def callback(self, u: ThreadUtil, *args): """ 回调函数,用于结束时的回调 diff --git a/swanlab/data/cloud/start_thread.py b/swanlab/data/cloud/start_thread.py index 909d9ce31..497258979 100644 --- a/swanlab/data/cloud/start_thread.py +++ b/swanlab/data/cloud/start_thread.py @@ -46,20 +46,24 @@ def __init__(self, upload_sleep_time: float = None): args=(), name=self.UPLOAD_THREAD_NAME, sleep_time=upload_sleep_time, - callback=self.collector.callback + callback=self.collector.callback, ) + + # 生成硬件指标上传线程,此线程包含数据记录任务,作为和log子线程同级的子线程 + self.hardware_thread = self.create_thread( + target=self.collector.hardware_task, + args=(), + name="HardwareUploader", + sleep_time=upload_sleep_time, + ) + self.queue = LogQueue(queue=self.__queue, readable=False, writable=True) """ 一个线程安全的队列,用于主线程向数据上传线程通信 """ def create_thread( - self, - target: Callable, - args: Tuple = (), - name: str = None, - sleep_time: float = None, - callback: Callable = None + self, target: Callable, args: Tuple = (), name: str = None, sleep_time: float = None, callback: Callable = None ) -> threading.Thread: """ 创建一个线程 @@ -83,11 +87,7 @@ def create_thread( thread_util = ThreadUtil(q, name) callback = ThreadUtil.wrapper_callback(callback, (thread_util, *args)) if callback is not None else None task = self._create_loop(name, sleep_time, target, (thread_util, *args)) - thread = threading.Thread( - target=task, - daemon=True, - name=name - ) + thread = threading.Thread(target=task, daemon=True, name=name) self.thread_pool[name] = thread if callback is not None: self.__callbacks.append(callback) @@ -101,13 +101,7 @@ def sub_threads(self): """ return {name: thread for name, thread in self.thread_pool.items() if name != self.UPLOAD_THREAD_NAME} - def _create_loop( - self, - name: str, - sleep_time: float, - task: Callable, - args: Tuple[ThreadUtil, ...] - ) -> [Callable]: + def _create_loop(self, name: str, sleep_time: float, task: Callable, args: Tuple[ThreadUtil, ...]) -> [Callable]: """ 创建一个事件循环,循环执行传入线程池的任务 :param name: 线程名称 diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 4b2da4ea0..b6792c902 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -25,6 +25,7 @@ MAX_LIST_LENGTH = 108 + class SwanLabRunState(Enum): """SwanLabRunState is an enumeration class that represents the state of the experiment. We Recommend that you use this enumeration class to represent the state of the experiment. @@ -321,6 +322,7 @@ def log(self, data: dict, step: int = None): v = DataWrapper(k, [Line(v)]) # 数据类型的检查将在创建chart配置的时候完成,因为数据类型错误并不会影响实验进行 metric_info = self.__exp.add(key=k, data=v, step=step) + print(f'Corr exp = {self.__exp}') self.__operator.on_metric_create(metric_info) log_return[metric_info.key] = metric_info