diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index f8487d104..d2a668d33 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.55 ] + branches: [ master, V0.9.56 ] pull_request: branches: [ master ] diff --git a/czsc/__init__.py b/czsc/__init__.py index ea996b816..ca3caa06e 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -192,10 +192,10 @@ ) -__version__ = "0.9.55" +__version__ = "0.9.56" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20240706" +__date__ = "20240714" def welcome(): diff --git a/czsc/connectors/cooperation.py b/czsc/connectors/cooperation.py index f5f0f0075..6ad83d234 100644 --- a/czsc/connectors/cooperation.py +++ b/czsc/connectors/cooperation.py @@ -9,9 +9,10 @@ """ import os import czsc +import requests +import loguru import pandas as pd from tqdm import tqdm -from loguru import logger from datetime import datetime from czsc import RawBar, Freq @@ -74,6 +75,8 @@ def get_symbols(name, **kwargs): def get_min_future_klines(code, sdt, edt, freq="1m", **kwargs): """分段获取期货1分钟K线后合并""" + logger = kwargs.pop("logger", loguru.logger) + sdt = pd.to_datetime(sdt).strftime("%Y%m%d") edt = pd.to_datetime(edt).strftime("%Y%m%d") # dates = pd.date_range(start=sdt, end=edt, freq='1M') @@ -133,6 +136,8 @@ def get_raw_bars(symbol, freq, sdt, edt, fq="前复权", **kwargs): >>> from czsc.connectors import cooperation as coo >>> df = coo.get_raw_bars(symbol="000001.SH#INDEX", freq="日线", sdt="2001-01-01", edt="2021-12-31", fq='后复权', raw_bars=False) """ + logger = kwargs.pop("logger", loguru.logger) + freq = czsc.Freq(freq) raw_bars = kwargs.get("raw_bars", True) ttl = kwargs.get("ttl", -1) @@ -226,3 +231,63 @@ def stocks_daily_klines(sdt="20170101", edt="20240101", **kwargs): if nxb: dfk = czsc.update_nxb(dfk, nseq=nxb) return dfk + + +def upload_strategy(df, meta, token=None, **kwargs): + """上传策略数据 + + :param df: pd.DataFrame, 策略持仓权重数据,至少包含 dt, symbol, weight 三列, 例如: + + =================== ======== ======== + dt symbol weight + =================== ======== ======== + 2017-01-03 09:01:00 ZZSF9001 0 + 2017-01-03 09:01:00 DLj9001 0 + 2017-01-03 09:01:00 SQag9001 0 + 2017-01-03 09:06:00 ZZSF9001 0.136364 + 2017-01-03 09:06:00 SQag9001 1 + =================== ======== ======== + + :param meta: dict, 策略元数据 + + 至少包含 name, description, base_freq, author, outsample_sdt 字段, 例如: + + {'name': 'TS001_3', + 'description': '测试策略:仅用于读写redis测试', + 'base_freq': '1分钟', + 'author': 'ZB', + 'outsample_sdt': '20220101'} + + :param token: str, 上传凭证码;如果不提供,将从环境变量 CZSC_TOKEN 中获取 + :param kwargs: dict, 其他参数 + + - logger: loguru.logger, 日志记录器 + :return dict + """ + logger = kwargs.pop("logger", loguru.logger) + df = df.copy() + df["dt"] = pd.to_datetime(df["dt"]) + logger.info(f"输入数据中有 {len(df)} 条权重信号") + + # 去除单个品种下相邻时间权重相同的数据 + _res = [] + for _, dfg in df.groupby("symbol"): + dfg = dfg.sort_values("dt", ascending=True).reset_index(drop=True) + dfg = dfg[dfg["weight"].diff().fillna(1) != 0].copy() + _res.append(dfg) + df = pd.concat(_res, ignore_index=True) + df = df.sort_values(["dt"]).reset_index(drop=True) + df["dt"] = df["dt"].dt.strftime("%Y-%m-%d %H:%M:%S") + + logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号") + + data = { + "weights": df[["dt", "symbol", "weight"]].to_json(orient="split"), + "token": token or os.getenv("CZSC_TOKEN"), + "strategy_name": meta.get("name"), + "meta": meta, + } + response = requests.post("http://zbczsc.com:9106/upload_strategy", json=data) + + logger.info(f"上传策略接口返回: {response.json()}") + return response.json() diff --git a/czsc/connectors/tq_connector.py b/czsc/connectors/tq_connector.py index 91883b10d..782cf4452 100644 --- a/czsc/connectors/tq_connector.py +++ b/czsc/connectors/tq_connector.py @@ -412,6 +412,12 @@ def adjust_portfolio(api: TqApi, portfolio, account=None, **kwargs): target_pos.set_target_volume(int(lots)) symbol_infos[symbol] = {"quote": quote, "target_pos": target_pos, "lots": lots} + if not symbol_infos: + logger.warning(f"没有需要调仓的品种,跳过调仓") + return api + else: + logger.info(f"开始调仓:{[x for x in symbol_infos.keys()]}") + while True: api.wait_update() @@ -439,6 +445,10 @@ def adjust_portfolio(api: TqApi, portfolio, account=None, **kwargs): if (datetime.now() - start_time).seconds > timeout: logger.error(f"调仓超时,已运行 {timeout} 秒") + for symbol, info in symbol_infos.items(): + target_pos: TargetPosTask = info["target_pos"] + target_pos.cancel() + logger.info(f"取消调仓:{symbol}") break return api diff --git a/czsc/signals/__init__.py b/czsc/signals/__init__.py index c01814e67..4bbd960fd 100644 --- a/czsc/signals/__init__.py +++ b/czsc/signals/__init__.py @@ -234,6 +234,7 @@ pos_stop_V240331, pos_stop_V240608, pos_stop_V240614, + pos_stop_V240717, ) diff --git a/czsc/signals/pos.py b/czsc/signals/pos.py index 67bff7ac5..5d0b954ac 100644 --- a/czsc/signals/pos.py +++ b/czsc/signals/pos.py @@ -20,10 +20,10 @@ def pos_ma_V230414(cat: CzscTrader, **kwargs) -> OrderedDict: **信号逻辑:** - 多头止损逻辑如下,反之为空头止损逻辑: + 多头持有状态如下,反之为空头持有状态: - 1. 从多头开仓点开始,在给定对的K线周期 freq1 上向前找 N 个底分型,记为 F1 - 2. 将这 N 个底分型的最低点,记为 L1,如果 L1 的价格低于开仓点的价格,则止损 + 1. 如果持有多头,且开仓后有价格升破MA均线,则为多头升破均线; + 2. 如果持有空头,且开仓后有价格跌破MA均线,则为空头跌破均线。 **信号列表:** @@ -55,7 +55,6 @@ def pos_ma_V230414(cat: CzscTrader, **kwargs) -> OrderedDict: c = cat.kas[freq1] op = pos.operates[-1] - # 多头止损逻辑 if op["op"] == Operate.LO: bars = [x for x in c.bars_raw[-100:] if x.dt > op["dt"]] for x in bars: @@ -63,7 +62,6 @@ def pos_ma_V230414(cat: CzscTrader, **kwargs) -> OrderedDict: v1, v2 = "多头", "升破均线" break - # 空头止损逻辑 if op["op"] == Operate.SO: bars = [x for x in c.bars_raw[-100:] if x.dt > op["dt"]] for x in bars: @@ -941,3 +939,69 @@ def pos_stop_V240614(cat: CzscTrader, **kwargs) -> OrderedDict: v1 = "空头止损" return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + +def pos_stop_V240717(cat: CzscTrader, **kwargs) -> OrderedDict: + """止损:多头开仓后,有超过N根K线的最低价在成本价-ATR*0.67下方,提示止损;空头反之。贡献者:谢磊 + + 参数模板:"{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717" + + **信号逻辑:** + + 以多头止损为例,计算过程如下: + + 1. 从多头开仓点开始,在给定的K线周期 freq1 上获取开仓后的所有K线,记为 bars; + 2. 计算 bars 中的最低价小于(开仓价-ATR*0.67)的数量,记为 C; + 3. ATR的参数为默认参数,可以调整; + 3. 如果 C >= N,则提示多头止损信号。 + + 空头止损逻辑同理。 + + **信号列表:** + + - Signal('SMA5多头_15分钟N3T20_止损V240614_多头止损_任意_任意_0') + - Signal('SMA5空头_15分钟N3T20_止损V240614_空头止损_任意_任意_0') + + :param cat: CzscTrader对象 + :param kwargs: 参数字典 + + - pos_name: str,开仓信号的名称 + - freq1: str,给定的K线周期 + - n: int,最低价下方N个价位,默认为 3 + + :return: OrderedDict + """ + from czsc.signals.tas import update_atr_cache + + pos_name = kwargs["pos_name"] + freq1 = kwargs["freq1"] + n = int(kwargs.get("n", 10)) # N根K线 + timeperiod = int(kwargs.get("timeperiod", 20)) # ATR参数 + + c = cat.kas[freq1] + cache_key = update_atr_cache(c, timeperiod=timeperiod) + + k1, k2, k3 = f"{pos_name}_{freq1}N{n}T{timeperiod}_止损V240717".split("_") + v1 = "其他" + + # 如果没有持仓策略,则不产生信号 + if not cat.kas or not hasattr(cat, "positions"): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + pos = [x for x in cat.positions if x.name == pos_name][0] + if len(pos.operates) == 0 or pos.operates[-1]["op"] in [Operate.SE, Operate.LE]: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + op = pos.operates[-1] + atr = [x.cache[cache_key] if x.cache.get(cache_key) is not None else 0 for x in c.bars_raw if x.dt == op["dt"]] + + # 开仓后的K线 + a_bars = [x for x in c.bars_raw if x.dt >= op["dt"]] + + if op["op"] == Operate.LO and len([x for x in a_bars if x.low < op["price"] - atr[0] * 0.67]) >= n: + v1 = "多头止损" + + if op["op"] == Operate.SO and len([x for x in a_bars if x.high > op["price"] + atr[0] * 0.67]) >= n: + v1 = "空头止损" + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/czsc/traders/rwc.py b/czsc/traders/rwc.py index 205cf0579..4247256ce 100644 --- a/czsc/traders/rwc.py +++ b/czsc/traders/rwc.py @@ -79,42 +79,50 @@ def __init__(self, strategy_name, redis_url=None, connection_pool=None, send_hea def set_metadata(self, base_freq, description, author, outsample_sdt, **kwargs): """设置策略元数据""" - key = f'{self.key_prefix}:META:{self.strategy_name}' - overwrite = kwargs.pop('overwrite', False) + key = f"{self.key_prefix}:META:{self.strategy_name}" + overwrite = kwargs.pop("overwrite", False) if self.r.exists(key): if not overwrite: - logger.warning(f'已存在 {self.strategy_name} 的元数据,如需覆盖请设置 overwrite=True') + logger.warning(f"已存在 {self.strategy_name} 的元数据,如需覆盖请设置 overwrite=True") return else: self.r.delete(key) - logger.warning(f'删除 {self.strategy_name} 的元数据,重新写入') - - outsample_sdt = pd.to_datetime(outsample_sdt).strftime('%Y%m%d') - meta = {'name': self.strategy_name, 'base_freq': base_freq, 'key_prefix': self.key_prefix, - 'description': description, 'author': author, 'outsample_sdt': outsample_sdt, - 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'kwargs': json.dumps(kwargs)} + logger.warning(f"删除 {self.strategy_name} 的元数据,重新写入") + + outsample_sdt = pd.to_datetime(outsample_sdt).strftime("%Y%m%d") + meta = { + "name": self.strategy_name, + "base_freq": base_freq, + "key_prefix": self.key_prefix, + "description": description, + "author": author, + "outsample_sdt": outsample_sdt, + "update_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "kwargs": json.dumps(kwargs), + } self.r.hset(key, mapping=meta) def update_last(self, **kwargs): """设置策略最近一次更新时间,以及更新参数【可选】""" - key = f'{self.key_prefix}:LAST:{self.strategy_name}' - last = {'name': self.strategy_name, - 'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'kwargs': json.dumps(kwargs)} + key = f"{self.key_prefix}:LAST:{self.strategy_name}" + last = { + "name": self.strategy_name, + "time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "kwargs": json.dumps(kwargs), + } self.r.hset(key, mapping=last) logger.info(f"更新 {key} 的 last 时间") @property def metadata(self): """获取策略元数据""" - key = f'{self.key_prefix}:META:{self.strategy_name}' + key = f"{self.key_prefix}:META:{self.strategy_name}" return self.r.hgetall(key) @property def heartbeat_time(self): """获取策略的最近一次心跳时间""" - key = f'{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}' + key = f"{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}" return pd.to_datetime(self.r.get(key)) def get_last_times(self, symbols=None): @@ -124,15 +132,15 @@ def get_last_times(self, symbols=None): :return: dict, {symbol: datetime},如{'SFIF9001': datetime(2021, 9, 24, 15, 19, 0)} """ if isinstance(symbols, str): - row = self.r.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbols}:LAST') - return pd.to_datetime(row['dt']) if row else None # type: ignore + row = self.r.hgetall(f"{self.key_prefix}:{self.strategy_name}:{symbols}:LAST") + return pd.to_datetime(row["dt"]) if row else None # type: ignore symbols = symbols if symbols else self.get_symbols() with self.r.pipeline() as pipe: for symbol in symbols: - pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST') + pipe.hgetall(f"{self.key_prefix}:{self.strategy_name}:{symbol}:LAST") rows = pipe.execute() - return {x['symbol']: pd.to_datetime(x['dt']) for x in rows} + return {x["symbol"]: pd.to_datetime(x["dt"]) for x in rows} def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False): """发布单个策略持仓权重 @@ -150,13 +158,13 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False): if not overwrite: last_dt = self.get_last_times(symbol) - if last_dt is not None and dt <= last_dt: # type: ignore + if last_dt is not None and dt <= last_dt: # type: ignore logger.warning(f"不允许重复写入,已过滤 {symbol} {dt} 的重复信号") return 0 - udt = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + udt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") key = f'{self.key_prefix}:{self.strategy_name}:{symbol}:{dt.strftime("%Y%m%d%H%M%S")}' - ref = ref if ref else '{}' + ref = ref if ref else "{}" ref_str = json.dumps(ref) if isinstance(ref, dict) else ref return self.lua_publish(keys=[key], args=[1 if overwrite else 0, udt, weight, price, ref_str]) @@ -169,38 +177,38 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000): :return: 成功发布信号的条数 """ df = df.copy() - df['dt'] = pd.to_datetime(df['dt']) + df["dt"] = pd.to_datetime(df["dt"]) logger.info(f"输入数据中有 {len(df)} 条权重信号") # 去除单个品种下相邻时间权重相同的数据 _res = [] - for _, dfg in df.groupby('symbol'): - dfg = dfg.sort_values('dt', ascending=True).reset_index(drop=True) - dfg = dfg[dfg['weight'].diff().fillna(1) != 0].copy() + for _, dfg in df.groupby("symbol"): + dfg = dfg.sort_values("dt", ascending=True).reset_index(drop=True) + dfg = dfg[dfg["weight"].diff().fillna(1) != 0].copy() _res.append(dfg) df = pd.concat(_res, ignore_index=True) - df = df.sort_values(['dt']).reset_index(drop=True) + df = df.sort_values(["dt"]).reset_index(drop=True) logger.info(f"去除单个品种下相邻时间权重相同的数据后,剩余 {len(df)} 条权重信号") - if 'price' not in df.columns: - df['price'] = 0 - if 'ref' not in df.columns: - df['ref'] = '{}' + if "price" not in df.columns: + df["price"] = 0 + if "ref" not in df.columns: + df["ref"] = "{}" if not overwrite: raw_count = len(df) _time = self.get_last_times() _data = [] - for symbol, dfg in df.groupby('symbol'): + for symbol, dfg in df.groupby("symbol"): last_dt = _time.get(symbol) if last_dt is not None: - dfg = dfg[dfg['dt'] > last_dt] + dfg = dfg[dfg["dt"] > last_dt] _data.append(dfg) df = pd.concat(_data, ignore_index=True) logger.info(f"不允许重复写入,已过滤 {raw_count - len(df)} 条重复信号") keys, args = [], [] - for row in df[['symbol', 'dt', 'weight', 'price', 'ref']].to_numpy(): + for row in df[["symbol", "dt", "weight", "price", "ref"]].to_numpy(): key = f'{self.key_prefix}:{self.strategy_name}:{row[0]}:{row[1].strftime("%Y%m%d%H%M%S")}' keys.append(key) @@ -209,18 +217,18 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000): ref = row[4] args.append(json.dumps(ref) if isinstance(ref, dict) else ref) - udt = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + udt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") overwrite = 1 if overwrite else 0 pub_cnt = 0 len_keys = len(keys) for i in range(0, len_keys, batch_size): if i + batch_size < len_keys: - tmp_keys = keys[i: i + batch_size] - tmp_args = [overwrite, udt] + args[3 * i: 3 * (i + batch_size)] + tmp_keys = keys[i : i + batch_size] + tmp_args = [overwrite, udt] + args[3 * i : 3 * (i + batch_size)] else: - tmp_keys = keys[i: len_keys] - tmp_args = [overwrite, udt] + args[3 * i: 3 * len_keys] + tmp_keys = keys[i:len_keys] + tmp_args = [overwrite, udt] + args[3 * i : 3 * len_keys] logger.info(f"索引 {i},即将发布 {len(tmp_keys)} 条权重信号") pub_cnt += self.lua_publish(keys=tmp_keys, args=tmp_args) logger.info(f"已完成 {pub_cnt} 次发布") @@ -230,24 +238,24 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000): def __heartbeat(self): while True: - key = f'{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}' + key = f"{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}" try: - self.heartbeat_client.set(key, datetime.now().strftime('%Y-%m-%d %H:%M:%S')) - except Exception: - continue + self.heartbeat_client.set(key, datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + except Exception as e: + logger.error(f"心跳发送失败:{e}") time.sleep(15) def get_keys(self, pattern) -> list: """获取 redis 中指定 pattern 的 keys""" k = self.r.keys(pattern) - return k if k else [] # type: ignore + return k if k else [] # type: ignore def clear_all(self, with_human=True): """删除该策略所有记录""" - keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*') - keys.append(f'{self.key_prefix}:META:{self.strategy_name}') - keys.append(f'{self.key_prefix}:LAST:{self.strategy_name}') - keys.append(f'{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}') + keys = self.get_keys(f"{self.key_prefix}:{self.strategy_name}*") + keys.append(f"{self.key_prefix}:META:{self.strategy_name}") + keys.append(f"{self.key_prefix}:LAST:{self.strategy_name}") + keys.append(f"{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}") if len(keys) == 0: logger.warning(f"{self.strategy_name} 没有记录") @@ -255,16 +263,16 @@ def clear_all(self, with_human=True): if with_human: human = input(f"{self.strategy_name} 即将删除 {len(keys)} 条记录,是否确认?(y/n):") - if human.lower() != 'y': + if human.lower() != "y": logger.warning(f"{self.strategy_name} 删除操作已取消") return - self.r.delete(*keys) # type: ignore + self.r.delete(*keys) # type: ignore logger.info(f"{self.strategy_name} 删除了 {len(keys)} 条记录") @staticmethod def register_lua_publish(client): - lua_body = ''' + lua_body = """ local overwrite = ARGV[1] local update_time = ARGV[2] local cnt = 0 @@ -304,13 +312,13 @@ def register_lua_publish(client): end end return cnt -''' +""" return client.register_script(lua_body) def get_symbols(self): """获取策略交易的品种列表""" - keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*') - symbols = {x.split(":")[2] for x in keys} # type: ignore + keys = self.get_keys(f"{self.key_prefix}:{self.strategy_name}*") + symbols = {x.split(":")[2] for x in keys} # type: ignore return list(symbols) def get_last_weights(self, symbols=None, ignore_zero=True, lua=True): @@ -331,25 +339,25 @@ def get_last_weights(self, symbols=None, ignore_zero=True, lua=True): end return results """ - key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:LAST' + key_pattern = self.key_prefix + ":" + self.strategy_name + ":*:LAST" results = self.r.eval(lua_script, 0, key_pattern) - rows = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore + rows = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore if symbols: - rows = [r for r in rows if r['symbol'] in symbols] + rows = [r for r in rows if r["symbol"] in symbols] else: symbols = symbols if symbols else self.get_symbols() with self.r.pipeline() as pipe: for symbol in symbols: - pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST') + pipe.hgetall(f"{self.key_prefix}:{self.strategy_name}:{symbol}:LAST") rows = pipe.execute() dfw = pd.DataFrame(rows) - dfw['weight'] = dfw['weight'].astype(float) - dfw['dt'] = pd.to_datetime(dfw['dt']) + dfw["weight"] = dfw["weight"].astype(float) + dfw["dt"] = pd.to_datetime(dfw["dt"]) if ignore_zero: - dfw = dfw[dfw['weight'] != 0].copy().reset_index(drop=True) - dfw = dfw.sort_values(['dt', 'symbol']).reset_index(drop=True) + dfw = dfw[dfw["weight"] != 0].copy().reset_index(drop=True) + dfw = dfw.sort_values(["dt", "symbol"]).reset_index(drop=True) return dfw def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame: @@ -360,18 +368,18 @@ def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame: :param edt: str, 结束时间, eg: 20220924 10:19:00 :return: pd.DataFrame """ - start_score = pd.to_datetime(sdt).strftime('%Y%m%d%H%M%S') - end_score = pd.to_datetime(edt).strftime('%Y%m%d%H%M%S') - model_key = f'{self.key_prefix}:{self.strategy_name}:{symbol}' + start_score = pd.to_datetime(sdt).strftime("%Y%m%d%H%M%S") + end_score = pd.to_datetime(edt).strftime("%Y%m%d%H%M%S") + model_key = f"{self.key_prefix}:{self.strategy_name}:{symbol}" key_list = self.r.zrangebyscore(model_key, start_score, end_score) if len(key_list) == 0: - logger.warning(f'no history weights: {symbol} - {sdt} - {edt}') + logger.warning(f"no history weights: {symbol} - {sdt} - {edt}") return pd.DataFrame() with self.r.pipeline() as pipe: for key in key_list: - pipe.hmget(key, 'weight', 'price', 'ref') + pipe.hmget(key, "weight", "price", "ref") rows = pipe.execute() weights = [] @@ -386,8 +394,8 @@ def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame: ref = ref weights.append((self.strategy_name, symbol, dt, weight, price, ref)) - dfw = pd.DataFrame(weights, columns=['strategy_name', 'symbol', 'dt', 'weight', 'price', 'ref']) - dfw = dfw.sort_values('dt').reset_index(drop=True) + dfw = pd.DataFrame(weights, columns=["strategy_name", "symbol", "dt", "weight", "price", "ref"]) + dfw = dfw.sort_values("dt").reset_index(drop=True) return dfw def get_all_weights(self, sdt=None, edt=None, **kwargs) -> pd.DataFrame: @@ -408,30 +416,30 @@ def get_all_weights(self, sdt=None, edt=None, **kwargs) -> pd.DataFrame: end return results """ - key_pattern = self.key_prefix + ':' + self.strategy_name + ':*:*' + key_pattern = self.key_prefix + ":" + self.strategy_name + ":*:*" results = self.r.eval(lua_script, 0, key_pattern) - results = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore + results = [dict(zip(r[::2], r[1::2])) for r in results] # type: ignore df = pd.DataFrame(results) - df['dt'] = pd.to_datetime(df['dt']) - df['weight'] = df['weight'].astype(float) - df = df.sort_values(['dt', 'symbol']).reset_index(drop=True) + df["dt"] = pd.to_datetime(df["dt"]) + df["weight"] = df["weight"].astype(float) + df = df.sort_values(["dt", "symbol"]).reset_index(drop=True) # df 中的columns:['symbol', 'weight', 'dt', 'update_time', 'price', 'ref'] - df1 = pd.pivot_table(df, index='dt', columns='symbol', values='weight').sort_index().ffill().fillna(0) - df1 = pd.melt(df1.reset_index(), id_vars='dt', value_vars=df1.columns, value_name='weight') # type: ignore + df1 = pd.pivot_table(df, index="dt", columns="symbol", values="weight").sort_index().ffill().fillna(0) + df1 = pd.melt(df1.reset_index(), id_vars="dt", value_vars=df1.columns, value_name="weight") # type: ignore # 加上 df 中的 update_time 信息 - df1 = df1.merge(df[['dt', 'symbol', 'update_time']], on=['dt', 'symbol'], how='left') - df1 = df1.sort_values(['symbol', 'dt']).reset_index(drop=True) - for _, dfg in df1.groupby('symbol'): - df1.loc[dfg.index, 'update_time'] = dfg['update_time'].ffill().bfill() + df1 = df1.merge(df[["dt", "symbol", "update_time"]], on=["dt", "symbol"], how="left") + df1 = df1.sort_values(["symbol", "dt"]).reset_index(drop=True) + for _, dfg in df1.groupby("symbol"): + df1.loc[dfg.index, "update_time"] = dfg["update_time"].ffill().bfill() if sdt: - df1 = df1[df1['dt'] >= pd.to_datetime(sdt)].reset_index(drop=True) + df1 = df1[df1["dt"] >= pd.to_datetime(sdt)].reset_index(drop=True) if edt: - df1 = df1[df1['dt'] <= pd.to_datetime(edt)].reset_index(drop=True) - df1 = df1.sort_values(['dt', 'symbol']).reset_index(drop=True) + df1 = df1[df1["dt"] <= pd.to_datetime(edt)].reset_index(drop=True) + df1 = df1.sort_values(["dt", "symbol"]).reset_index(drop=True) return df1 @@ -445,8 +453,14 @@ def clear_strategy(strategy_name, redis_url=None, connection_pool=None, key_pref :param kwargs: dict, 其他参数 """ with_human = kwargs.pop("with_human", True) - rwc = RedisWeightsClient(strategy_name, redis_url=redis_url, connection_pool=connection_pool, - key_prefix=key_prefix, send_heartbeat=False, **kwargs) + rwc = RedisWeightsClient( + strategy_name, + redis_url=redis_url, + connection_pool=connection_pool, + key_prefix=key_prefix, + send_heartbeat=False, + **kwargs, + ) rwc.clear_all(with_human) @@ -467,8 +481,14 @@ def get_strategy_weights(strategy_name, redis_url=None, connection_pool=None, ke :return: pd.DataFrame """ kwargs.pop("send_heartbeat", None) - rwc = RedisWeightsClient(strategy_name, redis_url=redis_url, connection_pool=connection_pool, - key_prefix=key_prefix, send_heartbeat=False, **kwargs) + rwc = RedisWeightsClient( + strategy_name, + redis_url=redis_url, + connection_pool=connection_pool, + key_prefix=key_prefix, + send_heartbeat=False, + **kwargs, + ) sdt = kwargs.get("sdt") edt = kwargs.get("edt") symbols = kwargs.get("symbols") @@ -482,11 +502,11 @@ def get_strategy_weights(strategy_name, redis_url=None, connection_pool=None, ke df = rwc.get_all_weights(sdt=sdt, edt=edt) if symbols: # 保留指定品种的权重 - not_in = [x for x in symbols if x not in df['symbol'].unique()] + not_in = [x for x in symbols if x not in df["symbol"].unique()] if not_in: logger.warning(f"{strategy_name} 中没有 {not_in} 的权重记录") - df = df[df['symbol'].isin(symbols)].reset_index(drop=True) + df = df[df["symbol"].isin(symbols)].reset_index(drop=True) return df @@ -509,13 +529,13 @@ def get_strategy_mates(redis_url=None, connection_pool=None, key_pattern="Weight r = redis.Redis.from_url(redis_url, decode_responses=True) rows = [] - for key in r.keys(key_pattern): # type: ignore + for key in r.keys(key_pattern): # type: ignore meta = r.hgetall(key) if not meta: logger.warning(f"{key} 没有策略元数据") continue - meta['heartbeat_time'] = r.get(f"{meta['key_prefix']}:{heartbeat_prefix}:{meta['name']}") # type: ignore + meta["heartbeat_time"] = r.get(f"{meta['key_prefix']}:{heartbeat_prefix}:{meta['name']}") # type: ignore rows.append(meta) if len(rows) == 0: @@ -523,9 +543,9 @@ def get_strategy_mates(redis_url=None, connection_pool=None, key_pattern="Weight return pd.DataFrame() df = pd.DataFrame(rows) - df['update_time'] = pd.to_datetime(df['update_time']) - df['heartbeat_time'] = pd.to_datetime(df['heartbeat_time']) - df = df.sort_values('name').reset_index(drop=True) + df["update_time"] = pd.to_datetime(df["update_time"]) + df["heartbeat_time"] = pd.to_datetime(df["heartbeat_time"]) + df = df.sort_values("name").reset_index(drop=True) r.close() return df @@ -551,18 +571,20 @@ def get_heartbeat_time(strategy_name=None, redis_url=None, connection_pool=None, r = redis.Redis.from_url(redis_url, decode_responses=True) if not strategy_name: - dfm = get_strategy_mates(redis_url=redis_url, connection_pool=connection_pool, key_pattern=f"{key_prefix}:META:*") + dfm = get_strategy_mates( + redis_url=redis_url, connection_pool=connection_pool, key_pattern=f"{key_prefix}:META:*" + ) if len(dfm) == 0: logger.warning(f"{key_prefix} 下没有策略元数据") return None - strategy_names = dfm['name'].unique().tolist() + strategy_names = dfm["name"].unique().tolist() else: strategy_names = [strategy_name] heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat") res = {} for sn in strategy_names: - hdt = r.get(f'{key_prefix}:{heartbeat_prefix}:{sn}') + hdt = r.get(f"{key_prefix}:{heartbeat_prefix}:{sn}") if hdt: res[sn] = pd.to_datetime(hdt) else: diff --git a/czsc/utils/data_client.py b/czsc/utils/data_client.py index 3d1138a34..86b2178c5 100644 --- a/czsc/utils/data_client.py +++ b/czsc/utils/data_client.py @@ -1,20 +1,21 @@ import os import shutil +import loguru import hashlib import requests import pandas as pd from time import time from pathlib import Path -from loguru import logger from functools import partial -def set_url_token(token, url): +def set_url_token(token, url, **kwargs): """设置指定 URL 数据接口的凭证码,通常一台机器只需要设置一次即可 :param token: 凭证码 :param url: 数据接口地址 """ + logger = kwargs.get("logger", loguru.logger) hash_key = hashlib.md5(str(url).encode("utf-8")).hexdigest() file_token = Path("~").expanduser() / f"{hash_key}.txt" with open(file_token, "w", encoding="utf-8") as f: @@ -22,8 +23,9 @@ def set_url_token(token, url): logger.info(f"{url} 数据访问凭证码已保存到 {file_token}") -def get_url_token(url): +def get_url_token(url, **kwargs): """获取指定 URL 数据接口的凭证码""" + logger = kwargs.get("logger", loguru.logger) hash_key = hashlib.md5(str(url).encode("utf-8")).hexdigest() file_token = Path("~").expanduser() / f"{hash_key}.txt" if file_token.exists(): @@ -62,15 +64,17 @@ def __init__(self, token=None, url="http://api.tushare.pro", timeout=300, **kwar assert self.__token, "请设置czsc_token凭证码,如果没有请联系管理员申请" self.cache_path = Path(kwargs.get("cache_path", os.path.expanduser("~/.quant_data_cache"))) self.cache_path.mkdir(exist_ok=True, parents=True) - + logger = kwargs.pop("logger", loguru.logger) logger.info( f"数据URL: {url} 数据缓存路径:{self.cache_path} 占用磁盘空间:{get_dir_size(self.cache_path) / 1024 / 1024:.2f} MB" ) if kwargs.get("clear_cache", False): self.clear_cache() - def clear_cache(self): + def clear_cache(self, **kwargs): """清空缓存""" + logger = kwargs.pop("logger", loguru.logger) + shutil.rmtree(self.cache_path) logger.info(f"{self.cache_path} 路径下的数据缓存已清空") self.cache_path.mkdir(exist_ok=True, parents=True) @@ -83,9 +87,12 @@ def post_request(self, api_name, fields="", **kwargs): :param kwargs: dict, 查询参数 - ttl: int, 缓存有效期,单位秒,-1表示不过期 + - logger: loguru.logger, 日志记录器 :return: pd.DataFrame """ + logger = kwargs.pop("logger", loguru.logger) + stime = time() if api_name in ["__getstate__", "__setstate__"]: return pd.DataFrame() diff --git "a/examples/\346\234\237\350\264\247\345\245\227\345\210\251\345\233\240\345\255\220\346\240\267\344\276\213.py" "b/examples/\346\234\237\350\264\247\345\245\227\345\210\251\345\233\240\345\255\220\346\240\267\344\276\213.py" new file mode 100644 index 000000000..b4f3814cb --- /dev/null +++ "b/examples/\346\234\237\350\264\247\345\245\227\345\210\251\345\233\240\345\255\220\346\240\267\344\276\213.py" @@ -0,0 +1,87 @@ +import inspect +import pandas as pd +import streamlit as st + +st.set_page_config(layout="wide") + + +def ARB001(df: pd.DataFrame, **kwargs): + """期货套利因子样例 + + ARB 是 arbitrage 的缩写,意为套利。套利因子是指用于判断套利机会的指标,通常是两个或多个标的之间的价格差异。 + + 动态监控豆油/棕榈油的价格比值的阈值; + 当比值 >1.2, 且比值是下降趋势,即比值的当前值与比值序列的MA20 形成一个死叉时,做多豆油,做空棕榈油;比值到1.15左右平仓 + + :param df: pd.DataFrame, 包含 DLy9001 和 DLp9001 两个品种的行情数据;至少包含以下列:dt, symbol, close; + 数据样例: + + =================== ======== ======= + dt symbol close + =================== ======== ======= + 2017-01-03 00:00:00 DLy9001 5975.07 + 2017-01-04 00:00:00 DLy9001 5914.89 + 2017-01-05 00:00:00 DLy9001 5909.73 + 2017-01-03 00:00:00 DLp9001 5975.07 + 2017-01-04 00:00:00 DLp9001 5914.89 + 2017-01-05 00:00:00 DLp9001 5909.73 + =================== ======== ======= + + :param kwargs: dict, 其他参数 + + - window: int, 默认20, 计算均线的窗口 + - tag: str, 默认"DEFAULT", 因子标签 + + """ + window = kwargs.get("window", 20) + tag = kwargs.get("tag", "DEFAULT") + + # 获取函数名构建因子列名 + factor_name = inspect.currentframe().f_code.co_name + factor_col = f"F#{factor_name}#{tag}" + + # 计算套利因子 + dfp = pd.pivot_table(df, index="dt", columns="symbol", values="close") + dfp["ratio"] = dfp["DLy9001"] / dfp["DLp9001"] + dfp["ratio_ma20"] = dfp["ratio"].rolling(window).mean() + dfp["ratio_diff"] = dfp["ratio"] - dfp["ratio_ma20"] + + # 构建套利组合 + dfp["y_weight"] = 0 + dfp["p_weight"] = 0 + for i, row in dfp.iterrows(): + if row["ratio"] > 1.2 and row["ratio_diff"] < 0: + dfp.loc[i, "y_weight"] = 1 + dfp.loc[i, "p_weight"] = -1 + if row["ratio"] < 1.15: + dfp.loc[i, "y_weight"] = 0 + dfp.loc[i, "p_weight"] = 0 + + # 合并到原始数据 + df.loc[df["symbol"] == "DLy9001", factor_col] = df.loc[df["symbol"] == "DLy9001", "dt"].map(dfp["y_weight"]) + df.loc[df["symbol"] == "DLp9001", factor_col] = df.loc[df["symbol"] == "DLp9001", "dt"].map(dfp["p_weight"]) + return df + + +def main(): + import czsc + from czsc.connectors import cooperation as coo + + df1 = coo.get_raw_bars(symbol="DLy9001", freq="日线", sdt="20170101", edt="20221231", raw_bars=False, fq="后复权") + df2 = coo.get_raw_bars(symbol="DLp9001", freq="日线", sdt="20170101", edt="20221231", raw_bars=False, fq="后复权") + df = pd.concat([df1, df2], axis=0) + df = ARB001(df) + factor = [x for x in df.columns if x.startswith("F#ARB001")][0] + df["weight"] = df[factor].fillna(0) + df["price"] = df["close"] + + dfw = df[["dt", "symbol", "price", "weight"]].copy() + st.title("期货套利研究") + czsc.show_weight_backtest( + dfw, fee_rate=0.0002, show_drawdowns=True, show_yearly_stats=True, show_monthly_return=True + ) + + +if __name__ == "__main__": + main() + # 启动说明:streamlit run examples/期货套利因子样例.py --theme.base=dark