diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 9c31e2f66..fb8c75f26 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -53,10 +53,12 @@ jobs: uses: actions/checkout@v3 - name: Get Image Tag Name + env: + GITHUB_REF_NAME_ENV: ${{ github.ref_name }} run: | REGEX="(.*)v(.*)\.(.*)\.(.*)" IMAGE_TAG="nightly" - if [[ "${{ github.ref_name }}" =~ $REGEX ]]; then + if [[ "${GITHUB_REF_NAME_ENV}" =~ $REGEX ]]; then IMAGE_TAG="${GITHUB_REF_NAME##*/}" fi echo "IMAGE_TAG=$IMAGE_TAG" >> $GITHUB_ENV diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6e0831d66..1192bd0d7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,6 +47,7 @@ jobs: MILABENCH_ARGS: "" MILABENCH_GPU_ARCH: "${{ matrix.arch }}" MILABENCH_DASH: "no" + MILABENCH_EXCLUDE: "${{ matrix.exclude }}" steps: - uses: actions/checkout@v3 @@ -60,7 +61,7 @@ jobs: - name: Pytorch Sanity run: | - if [[ "${{ matrix.arch }}" == "rocm" ]]; then + if [[ "${MILABENCH_GPU_ARCH}" == "rocm" ]]; then groups /opt/rocm/bin/rocminfo fi @@ -96,16 +97,16 @@ jobs: - name: install benchmarks run: | - milabench install --exclude "${{ matrix.exclude }}" + milabench install --exclude "${MILABENCH_EXCLUDE}" - name: prepare benchmarks run: | - milabench prepare --exclude "${{ matrix.exclude }}" + milabench prepare --exclude "${MILABENCH_EXCLUDE}" - name: run benchmarks run: | export PATH="/opt/rocm/bin:$PATH" - milabench run --validations all --exclude "${{ matrix.exclude }}" + milabench run --validations all --exclude "${MILABENCH_EXCLUDE}" - name: Summary run: | diff --git a/benchmarks/dlrm/voirfile.py b/benchmarks/dlrm/voirfile.py index 7a489ecaa..e18491fe3 100644 --- a/benchmarks/dlrm/voirfile.py +++ b/benchmarks/dlrm/voirfile.py @@ -47,12 +47,7 @@ def instrument_main(ov, options: Config): yield ov.phases.load_script # Loss - ( - ov.probe("//run > L") - .throttle(1)["L"] - .map(float) - .give("loss") - ) + (ov.probe("//run > L").throttle(1)["L"].map(float).give("loss")) # Compute Start & End + Batch ov.probe( diff --git a/benchmarks/flops/benchfile.py b/benchmarks/flops/benchfile.py index b00415f0f..0bb601d67 100644 --- a/benchmarks/flops/benchfile.py +++ b/benchmarks/flops/benchfile.py @@ -5,15 +5,15 @@ class FlopsBenchmarch(Package): base_requirements = "requirements.in" prepare_script = "prepare.py" main_script = "main.py" - + def build_run_plan(self) -> "execs.Executor": import milabench.executors as execs - + main = self.dirs.code / self.main_script pack = execs.PackExecutor(self, *self.argv, lazy=True) # pack = execs.VoirExecutor(pack, cwd=main.parent) pack = execs.ActivatorExecutor(pack, use_stdout=True) return pack - + __pack__ = FlopsBenchmarch diff --git a/benchmarks/flops/main.py b/benchmarks/flops/main.py index d72bf7186..5d2aa20cb 100755 --- a/benchmarks/flops/main.py +++ b/benchmarks/flops/main.py @@ -22,34 +22,37 @@ def _worker(state, queue, func, delay): import time - while state['running']: + while state["running"]: queue.put(func()) time.sleep(delay) - + + class Monitor: def __init__(self, delay, func): self.manager = multiprocessing.Manager() self.state = self.manager.dict() - self.state['running'] = True + self.state["running"] = True self.results = multiprocessing.Queue() self.process = multiprocessing.Process( - target=_worker, + target=_worker, args=(self.state, self.results, func, delay), ) - + def start(self): self.process.start() - + def stop(self): - self.state['running'] = False + self.state["running"] = False self.process.join() -def modelflops(model: torch.nn.Module, shape, repeat=10, dtype=torch.float32, unit=TERA): +def modelflops( + model: torch.nn.Module, shape, repeat=10, dtype=torch.float32, unit=TERA +): # Not sure how much thop is correct in its computation # it says it return MAC but I feel its methods is wrong from thop import profile - + # MAC: Multiply–accumulate operation batch = torch.randn(*shape, dtype=dtype, device="cuda:0") @@ -77,7 +80,6 @@ def modelflops(model: torch.nn.Module, shape, repeat=10, dtype=torch.float32, un return (flops * repeat) / (end - start) / unit - def f(N, R=30, m=5000000, n=256, unit=TERA, dtype=torch.float32, log=None): torch.cuda.empty_cache() a = torch.eye(n, dtype=dtype, device="cuda:0") @@ -85,26 +87,22 @@ def f(N, R=30, m=5000000, n=256, unit=TERA, dtype=torch.float32, log=None): y = torch.zeros_like(x) F = N * (2 * m * n * n + 2 * m * n * n) - - for i in range(R): + + for i in range(R): torch.cuda.synchronize() ts = -time.time() - + for _ in range(N): # No allocation in main loop using dual-out strategy y = torch.mm(x, a, out=y) x = torch.mm(y, a, out=x) - + torch.cuda.synchronize() ts += time.time() - + if log is not None: - log({ - "task": "train", - "rate": F / ts / unit, - "units": "Tflops" - }) - + log({"task": "train", "rate": F / ts / unit, "units": "Tflops"}) + torch.cuda.empty_cache() @@ -112,73 +110,62 @@ def setupvoir(): # wtf this do data_file = SmuggleWriter(sys.stdout) # data_file = sys.stdout - + def log(data): if data_file is not None: data["t"] = time.time() print(json.dumps(data), file=data_file) - + while not monitor.results.empty(): print(json.dumps(monitor.results.get()), file=data_file) - + def monitor_fn(): data = { gpu["device"]: { "memory": [ - gpu["memory"]["used"], + gpu["memory"]["used"], gpu["memory"]["total"], ], "load": gpu["utilization"]["compute"], "temperature": gpu["temperature"], - "power": gpu["power"] + "power": gpu["power"], } for gpu in get_gpu_info()["gpus"].values() } return {"task": "main", "gpudata": data, "t": time.time()} - + monitor = Monitor(0.5, monitor_fn) monitor.start() return log, monitor - def main(): dtypes = { - 'bf16': torch.bfloat16, - 'fp16': torch.float16, - 'fp32': torch.float32, + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, } - + parser = ArgumentParser() - parser.add_argument('--repeat', type=int, default=100) - parser.add_argument('--number', type=int, default=100) - parser.add_argument('--m', type=int, default=256) - parser.add_argument('--n', type=int, default=256) - parser.add_argument('--dtype', type=str, default='fp32', choices=dtypes.keys()) - parser.add_argument('--tf32', action='store_true', default=False) - + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--number", type=int, default=100) + parser.add_argument("--m", type=int, default=256) + parser.add_argument("--n", type=int, default=256) + parser.add_argument("--dtype", type=str, default="fp32", choices=dtypes.keys()) + parser.add_argument("--tf32", action="store_true", default=False) + args = parser.parse_args() torch.backends.cuda.matmul.allow_tf32 = False if args.tf32: torch.backends.cuda.matmul.allow_tf32 = True - + log, monitor = setupvoir() - f( - args.number, - args.repeat, - args.m, - args.n, - TERA, - dtypes[args.dtype], - log - ) + f(args.number, args.repeat, args.m, args.n, TERA, dtypes[args.dtype], log) monitor.stop() - -if __name__ == "__main__": - main() - +if __name__ == "__main__": + main() diff --git a/benchmarks/llama/benchfile.py b/benchmarks/llama/benchfile.py index 2213b1657..8b253bc92 100644 --- a/benchmarks/llama/benchfile.py +++ b/benchmarks/llama/benchfile.py @@ -11,12 +11,12 @@ class LLAMA(Package): def make_env(self): return { **super().make_env(), - "OMP_NUM_THREADS": str(self.config.get("cpus_per_gpu", 8)) + "OMP_NUM_THREADS": str(self.config.get("cpus_per_gpu", 8)), } - + async def install(self): await super().install() - + def build_prepare_plan(self): return CmdExecutor( self, @@ -36,7 +36,8 @@ def build_run_plan(self): *self.argv, "--cache", str(self.dirs.cache), - use_stdout=True + use_stdout=True, ) + __pack__ = LLAMA diff --git a/benchmarks/llama/main.py b/benchmarks/llama/main.py index ddf76243c..5bb20164e 100755 --- a/benchmarks/llama/main.py +++ b/benchmarks/llama/main.py @@ -19,37 +19,38 @@ def available_models(): models = dict() for size in ("7b", "13b", "70b"): - models[f'llama2-{size}'] = { + models[f"llama2-{size}"] = { "name": f"meta-llama/Llama-2-{size}-chat-hf", - "config": f"llama2_{size}_chat_hf.config" + "config": f"llama2_{size}_chat_hf.config", } - + return models def _worker(state, queue, func, delay): import time - while state['running']: + while state["running"]: queue.put(func()) time.sleep(delay) - + + class Monitor: def __init__(self, delay, func): self.manager = multiprocessing.Manager() self.state = self.manager.dict() - self.state['running'] = True + self.state["running"] = True self.results = multiprocessing.Queue() self.process = multiprocessing.Process( - target=_worker, + target=_worker, args=(self.state, self.results, func, delay), ) - + def start(self): self.process.start() - + def stop(self): - self.state['running'] = False + self.state["running"] = False self.process.join() @@ -57,30 +58,30 @@ def setupvoir(): # wtf this do data_file = SmuggleWriter(sys.stdout) # data_file = sys.stdout - + def log(data): if data_file is not None: data["t"] = time.time() print(json.dumps(data), file=data_file) - + while not monitor.results.empty(): print(json.dumps(monitor.results.get()), file=data_file) - + def monitor_fn(): data = { gpu["device"]: { "memory": [ - gpu["memory"]["used"], + gpu["memory"]["used"], gpu["memory"]["total"], ], "load": gpu["utilization"]["compute"], "temperature": gpu["temperature"], - "power": gpu["power"] + "power": gpu["power"], } for gpu in get_gpu_info()["gpus"].values() } return {"task": "main", "gpudata": data, "t": time.time()} - + monitor = Monitor(0.5, monitor_fn) monitor.start() return log, monitor @@ -93,11 +94,11 @@ def __init__(self, tokenizer): def __call__(self, *args, **kwargs): input_ids = self.tokenizer(*args, **kwargs) - + self.count = 1 for c in input_ids["input_ids"].shape: self.count *= c - + return input_ids def __getattr__(self, attr): @@ -105,7 +106,9 @@ def __getattr__(self, attr): method = getattr(self.tokenizer, attr) return method else: - raise AttributeError(f"'{type(self.tokenizer).__name__}' object has no attribute '{attr}'") + raise AttributeError( + f"'{type(self.tokenizer).__name__}' object has no attribute '{attr}'" + ) def println(*args, **kwargs): @@ -119,14 +122,11 @@ def huggingface_main(args, model, config): from transformers.models.llama.configuration_llama import LlamaConfig from datasets import load_dataset - + # Dataset here println("Dataset") - dataset = load_dataset( - "wikitext", - "wikitext-103-v1" - ) - + dataset = load_dataset("wikitext", "wikitext-103-v1") + println("Tokenizer") # LLAMA tokenizer official tokenizer is hidden behind a login tokenizer = WrappedTokenizer( @@ -136,12 +136,12 @@ def huggingface_main(args, model, config): # Prepare is done if args.prepare: return 0 - + # We do not download LLAMA because it takes too long # we just instantiate an untrained one println("Model") model = LlamaForCausalLM(LlamaConfig.from_dict(config)).cuda() - + println("Pipeline") pipeline = transformers.pipeline( "text-generation", @@ -149,25 +149,25 @@ def huggingface_main(args, model, config): torch_dtype=torch.float16, # device_map="cuda", tokenizer=tokenizer, - device=torch.device("cuda") + device=torch.device("cuda"), ) - + in_token_count = 0 out_token_count = 0 - + start = time.time() - + log, monitor = setupvoir() - + println("Starting") count = 0 for entry in dataset["train"]: text = entry["text"].strip() - + # Titles if text == "" or text.startswith(" = ") or len(text) < 10: continue - + count += 1 sequences = pipeline( text, @@ -177,58 +177,55 @@ def huggingface_main(args, model, config): eos_token_id=tokenizer.eos_token_id, max_length=400, ) - + for seq in sequences: out_token_count += len(seq["generated_text"]) in_token_count += tokenizer.count total = out_token_count + in_token_count - + elapsed = time.time() - start - println(f"{elapsed =}, {total / elapsed =} {in_token_count =} {out_token_count =}") - + println( + f"{elapsed =}, {total / elapsed =} {in_token_count =} {out_token_count =}" + ) + if total > 30: out_token_count = 0 in_token_count = 0 start = time.time() - + if log is not None: - log({ - "task": "train", - "rate": total / elapsed, - "units": "Tok/s" - }) - + log({"task": "train", "rate": total / elapsed, "units": "Tok/s"}) + if count > 40: break - monitor.stop() + def main(): import torch - + models = available_models() - + parser = argparse.ArgumentParser() parser.add_argument("--model", default="llama2-7b", choices=models.keys()) parser.add_argument("--prepare", action="store_true") parser.add_argument("--cache", required=True, type=str) - + # args = parser.parse_args() os.environ["XDG_CACHE_HOME"] = str(args.cache) - + settings = models[args.model] model, config = settings["name"], settings["config"] - - with open(os.path.join(root, 'config', config), 'r') as file: + + with open(os.path.join(root, "config", config), "r") as file: config = json.load(file) with torch.no_grad(): return huggingface_main(args, model, config) - if __name__ == "__main__": main() diff --git a/benchmarks/rwkv/prepare.py b/benchmarks/rwkv/prepare.py index 1e51bb2b1..992e6c099 100755 --- a/benchmarks/rwkv/prepare.py +++ b/benchmarks/rwkv/prepare.py @@ -24,9 +24,7 @@ print("This will compile the appropriate torch extensions.") print("=" * 80) result = subprocess.run( - ["voir", - "--no-dash", "--interval", "1", "--stop", "1", - "train.py", *argv] + ["voir", "--no-dash", "--interval", "1", "--stop", "1", "train.py", *argv] ) print("=" * 80) print("Done") diff --git a/benchmarks/rwkv/rwkv-v4neo/chat.py b/benchmarks/rwkv/rwkv-v4neo/chat.py index d214ba281..19e2b36f9 100644 --- a/benchmarks/rwkv/rwkv-v4neo/chat.py +++ b/benchmarks/rwkv/rwkv-v4neo/chat.py @@ -2,12 +2,13 @@ # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## -print('Loading...') +print("Loading...") from src.model_run import RWKV_RNN import numpy as np import os, copy, types, gc, sys import torch from src.utils import TOKENIZER + try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: @@ -17,7 +18,7 @@ torch.backends.cuda.matmul.allow_tf32 = True np.set_printoptions(precision=4, suppress=True, linewidth=200) -CHAT_LANG = 'English' # English Chinese +CHAT_LANG = "English" # English Chinese WORD_NAME = [ "20B_tokenizer.json", @@ -28,14 +29,16 @@ args = types.SimpleNamespace() args.RUN_DEVICE = "cuda" # 'cpu' (already very fast) // 'cuda' -args.FLOAT_MODE = "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +args.FLOAT_MODE = ( + "fp16" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate) +) args.vocab_size = 50277 args.head_qk = 0 args.pre_ffn = 0 args.grad_cp = 0 args.my_pos_emb = 0 -args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170' +args.MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230108-5170" args.n_layer = 40 args.n_embd = 5120 args.ctx_len = 1024 @@ -50,7 +53,7 @@ # args.n_embd = 2560 # args.ctx_len = 1024 -if CHAT_LANG == 'English': +if CHAT_LANG == "English": user = "User" bot = "Bot" interface = ":" @@ -58,7 +61,7 @@ # The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. # The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. - init_prompt = f''' + init_prompt = f""" The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. {user}{interface} french revolution what year @@ -81,8 +84,8 @@ {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. -''' - HELP_MSG = '''Commands: +""" + HELP_MSG = """Commands: say something --> chat with bot. use \\n for new line. +alt --> alternate chat reply +reset --> reset chat @@ -94,9 +97,9 @@ Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. -''' -elif CHAT_LANG == 'Chinese': - args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293' +""" +elif CHAT_LANG == "Chinese": + args.MODEL_NAME = "/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run3z/rwkv-293" args.n_layer = 32 args.n_embd = 4096 args.ctx_len = 1024 @@ -105,7 +108,7 @@ bot = "A" interface = ":" - init_prompt = ''' + init_prompt = """ Q: 企鹅会飞吗? A: 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 @@ -114,8 +117,8 @@ A: 西瓜是一种常见的水果,是一种多年生蔓生藤本植物。西瓜的果实呈圆形或卵形,通常是绿色的,里面有红色或黄色的肉和很多的籽。西瓜味甜,多吃可以增加水分,是夏季非常受欢迎的水果之一。 -''' - HELP_MSG = '''指令: +""" + HELP_MSG = """指令: 直接输入内容 --> 和机器人聊天,用\\n代表换行 +alt --> 让机器人换个回答 +reset --> 重置对话 @@ -126,14 +129,14 @@ +retry --> 换个 +gen / +qa 的回答 现在可以输入内容和机器人聊天(注意它不怎么懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 -''' +""" # Load Model os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE MODEL_NAME = args.MODEL_NAME -print(f'loading... {MODEL_NAME}') +print(f"loading... {MODEL_NAME}") model = RWKV_RNN(args) model_tokens = [] @@ -142,15 +145,18 @@ ######################################################################################################## -def run_rnn(tokens, newline_adj = 0): + +def run_rnn(tokens, newline_adj=0): global model_tokens, current_state for i in range(len(tokens)): model_tokens += [int(tokens[i])] if i == len(tokens) - 1: out, current_state = model.forward(model_tokens, current_state) else: - current_state = model.forward(model_tokens, current_state, preprocess_only = True) - + current_state = model.forward( + model_tokens, current_state, preprocess_only=True + ) + # print(f'### model ###\n[{tokenizer.tokenizer.decode(model_tokens)}]') out[0] = -999999999 # disable <|endoftext|> @@ -159,60 +165,67 @@ def run_rnn(tokens, newline_adj = 0): # out[15] += newline_adj / 2 # '.' return out + all_state = {} + + def save_all_stat(srv, name, last_out): - n = f'{name}_{srv}' + n = f"{name}_{srv}" all_state[n] = {} - all_state[n]['out'] = last_out - all_state[n]['rnn'] = copy.deepcopy(current_state) - all_state[n]['token'] = copy.deepcopy(model_tokens) + all_state[n]["out"] = last_out + all_state[n]["rnn"] = copy.deepcopy(current_state) + all_state[n]["token"] = copy.deepcopy(model_tokens) + def load_all_stat(srv, name): global model_tokens, current_state - n = f'{name}_{srv}' - current_state = copy.deepcopy(all_state[n]['rnn']) - model_tokens = copy.deepcopy(all_state[n]['token']) - return all_state[n]['out'] + n = f"{name}_{srv}" + current_state = copy.deepcopy(all_state[n]["rnn"]) + model_tokens = copy.deepcopy(all_state[n]["token"]) + return all_state[n]["out"] + ######################################################################################################## # Run inference -print(f'\nRun prompt...') +print(f"\nRun prompt...") out = run_rnn(tokenizer.tokenizer.encode(init_prompt)) gc.collect() torch.cuda.empty_cache() -save_all_stat('', 'chat_init', out) +save_all_stat("", "chat_init", out) -srv_list = ['dummy_server'] +srv_list = ["dummy_server"] for s in srv_list: - save_all_stat(s, 'chat', out) + save_all_stat(s, "chat", out) + +print(f"### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n") -print(f'### prompt ###\n[{tokenizer.tokenizer.decode(model_tokens)}]\n') def reply_msg(msg): - print(f'{bot}{interface} {msg}\n') + print(f"{bot}{interface} {msg}\n") + def on_message(message): global model_tokens, current_state - srv = 'dummy_server' + srv = "dummy_server" - msg = message.replace('\\n','\n').strip() + msg = message.replace("\\n", "\n").strip() if len(msg) > 1000: - reply_msg('your message is too long (max 1000 tokens)') + reply_msg("your message is too long (max 1000 tokens)") return x_temp = 1.0 x_top_p = 0.85 - if ("-temp=" in msg): + if "-temp=" in msg: x_temp = float(msg.split("-temp=")[1].split(" ")[0]) - msg = msg.replace("-temp="+f'{x_temp:g}', "") + msg = msg.replace("-temp=" + f"{x_temp:g}", "") # print(f"temp: {x_temp}") - if ("-top_p=" in msg): + if "-top_p=" in msg: x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) - msg = msg.replace("-top_p="+f'{x_top_p:g}', "") + msg = msg.replace("-top_p=" + f"{x_top_p:g}", "") # print(f"top_p: {x_top_p}") if x_temp <= 0.2: x_temp = 0.2 @@ -220,31 +233,35 @@ def on_message(message): x_temp = 5 if x_top_p <= 0: x_top_p = 0 - - if msg == '+reset': - out = load_all_stat('', 'chat_init') - save_all_stat(srv, 'chat', out) + + if msg == "+reset": + out = load_all_stat("", "chat_init") + save_all_stat(srv, "chat", out) reply_msg("Chat reset.") return - elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg.lower() == '+more' or msg.lower() == '+retry': - - if msg[:5].lower() == '+gen ': - new = '\n' + msg[5:].strip() + elif ( + msg[:5].lower() == "+gen " + or msg[:4].lower() == "+qa " + or msg.lower() == "+more" + or msg.lower() == "+retry" + ): + if msg[:5].lower() == "+gen ": + new = "\n" + msg[5:].strip() # print(f'### prompt ###\n[{new}]') current_state = None out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) + save_all_stat(srv, "gen_0", out) - elif msg[:4].lower() == '+qa ': - out = load_all_stat('', 'chat_init') + elif msg[:4].lower() == "+qa ": + out = load_all_stat("", "chat_init") real_msg = msg[4:].strip() new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" # print(f'### qa ###\n[{new}]') - + out = run_rnn(tokenizer.tokenizer.encode(new)) - save_all_stat(srv, 'gen_0', out) + save_all_stat(srv, "gen_0", out) # new = f"\nThe following is an excellent Q&A session consists of detailed and factual information.\n\nQ: What is 3+5?\nA: The answer is 8.\n\nQ: {msg[9:].strip()}\nA:" # print(f'### prompt ###\n[{new}]') @@ -252,16 +269,16 @@ def on_message(message): # out = run_rnn(tokenizer.tokenizer.encode(new)) # save_all_stat(srv, 'gen_0', out) - elif msg.lower() == '+more': + elif msg.lower() == "+more": try: - out = load_all_stat(srv, 'gen_1') - save_all_stat(srv, 'gen_0', out) + out = load_all_stat(srv, "gen_1") + save_all_stat(srv, "gen_0", out) except: return - elif msg.lower() == '+retry': + elif msg.lower() == "+retry": try: - out = load_all_stat(srv, 'gen_0') + out = load_all_stat(srv, "gen_0") except: return @@ -276,37 +293,37 @@ def on_message(message): top_p_usual=x_top_p, top_p_newline=x_top_p, ) - if msg[:4].lower() == '+qa ': + if msg[:4].lower() == "+qa ": out = run_rnn([token], newline_adj=-1) else: out = run_rnn([token]) - + xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) + if "\ufffd" not in xxx: + print(xxx, end="", flush=True) out_last = begin + i + 1 - print('\n') + print("\n") # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # print(f'### send ###\n[{send_msg}]') # reply_msg(send_msg) - save_all_stat(srv, 'gen_1', out) + save_all_stat(srv, "gen_1", out) else: - if msg.lower() == '+alt': + if msg.lower() == "+alt": try: - out = load_all_stat(srv, 'chat_pre') + out = load_all_stat(srv, "chat_pre") except: return else: - out = load_all_stat(srv, 'chat') + out = load_all_stat(srv, "chat") new = f"{user}{interface} {msg}\n\n{bot}{interface}" # print(f'### add ###\n[{new}]') out = run_rnn(tokenizer.tokenizer.encode(new), newline_adj=-999999999) - save_all_stat(srv, 'chat_pre', out) + save_all_stat(srv, "chat_pre", out) begin = len(model_tokens) out_last = begin - print(f'{bot}{interface}', end='', flush=True) + print(f"{bot}{interface}", end="", flush=True) for i in range(999): if i <= 0: newline_adj = -999999999 @@ -315,7 +332,7 @@ def on_message(message): elif i <= 130: newline_adj = 0 else: - newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION + newline_adj = (i - 130) * 0.25 # MUST END THE GENERATION token = tokenizer.sample_logits( out, model_tokens, @@ -327,15 +344,15 @@ def on_message(message): out = run_rnn([token], newline_adj=newline_adj) xxx = tokenizer.tokenizer.decode(model_tokens[out_last:]) - if '\ufffd' not in xxx: - print(xxx, end='', flush=True) + if "\ufffd" not in xxx: + print(xxx, end="", flush=True) out_last = begin + i + 1 - + send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]) - if '\n\n' in send_msg: + if "\n\n" in send_msg: send_msg = send_msg.strip() break - + # send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! # send_msg = send_msg[:-len(f'{user}{interface}')].strip() @@ -349,13 +366,14 @@ def on_message(message): # print(f'### send ###\n[{send_msg}]') # reply_msg(send_msg) - save_all_stat(srv, 'chat', out) + save_all_stat(srv, "chat", out) + print(HELP_MSG) while True: - msg = input(f'{user}{interface} ') + msg = input(f"{user}{interface} ") if len(msg.strip()) > 0: on_message(msg) else: - print('Erorr: please say something') + print("Erorr: please say something") diff --git a/benchmarks/rwkv/rwkv-v4neo/img_demoAE.py b/benchmarks/rwkv/rwkv-v4neo/img_demoAE.py index ab0d4edd6..43c0c3cf3 100644 --- a/benchmarks/rwkv/rwkv-v4neo/img_demoAE.py +++ b/benchmarks/rwkv/rwkv-v4neo/img_demoAE.py @@ -9,55 +9,58 @@ from torch.nn import functional as F import torchvision as vision import torchvision.transforms as transforms + np.set_printoptions(precision=4, suppress=True, linewidth=200) -print(f'loading...') +print(f"loading...") ######################################################################################################## -model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201' -input_img = 'test/img_ae_test/test0.png' +model_prefix = "test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201" +input_img = "test/img_ae_test/test0.png" ######################################################################################################## + class ToBinary(torch.autograd.Function): @staticmethod def forward(ctx, x): - return torch.floor(x + 0.5) # no need for noise when we have plenty of data + return torch.floor(x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): - return grad_output.clone() # pass-through + return grad_output.clone() # pass-through + class R_ENCODER(nn.Module): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.Bxx = nn.BatchNorm2d(dd*64) + self.Bxx = nn.BatchNorm2d(dd * 64) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*4) - self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 4) + self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) - self.B20 = nn.BatchNorm2d(dd*64) - self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B20 = nn.BatchNorm2d(dd * 64) + self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) - self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd * 64, args.my_img_bit, kernel_size=3, padding=1) def forward(self, img): ACT = F.mish @@ -81,30 +84,31 @@ def forward(self, img): x = self.COUT(x + xx) return torch.sigmoid(x) + class R_DECODER(nn.Module): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) - - self.B00 = nn.BatchNorm2d(dd*64) - self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(dd*4) - self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.CIN = nn.Conv2d(args.my_img_bit, dd * 64, kernel_size=3, padding=1) + + self.B00 = nn.BatchNorm2d(dd * 64) + self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd * 4) + self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) @@ -128,30 +132,33 @@ def forward(self, code): x = x + self.Cx1(ACT(self.Cx0(x))) x = self.COUT(x) - + return torch.sigmoid(x) + ######################################################################################################## -print(f'building model...') +print(f"building model...") args = types.SimpleNamespace() args.my_img_bit = 13 encoder = R_ENCODER(args).eval().cuda() decoder = R_DECODER(args).eval().cuda() -zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long() +zpow = torch.tensor([2**i for i in range(0, 13)]).reshape(13, 1, 1).cuda().long() -encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth')) -decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth')) +encoder.load_state_dict(torch.load(f"{model_prefix}-E.pth")) +decoder.load_state_dict(torch.load(f"{model_prefix}-D.pth")) ######################################################################################################## -print(f'test image...') -img_transform = transforms.Compose([ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Resize((224, 224)) -]) +print(f"test image...") +img_transform = transforms.Compose( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Resize((224, 224)), + ] +) with torch.no_grad(): img = img_transform(Image.open(input_img)).unsqueeze(0).cuda() @@ -159,7 +166,7 @@ def forward(self, code): z = ToBinary.apply(z) zz = torch.sum(z.squeeze().long() * zpow, dim=0) - print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n') - + print(f"Code shape = {zz.shape}\n{zz.cpu().numpy()}\n") + out = decoder(z) vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg") diff --git a/benchmarks/rwkv/rwkv-v4neo/run.py b/benchmarks/rwkv/rwkv-v4neo/run.py index f13e97f08..eb7109cb6 100644 --- a/benchmarks/rwkv/rwkv-v4neo/run.py +++ b/benchmarks/rwkv/rwkv-v4neo/run.py @@ -6,6 +6,7 @@ import math, os, sys, types, time, gc import torch from src.utils import TOKENIZER + try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] except: @@ -20,12 +21,14 @@ # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible) ######################################################################################################## -args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) -args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) +args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) +args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) # if args.RUN_DEVICE == "cuda": # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output -os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! +os.environ[ + "RWKV_JIT_ON" +] = "1" # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! TOKEN_MODE = "pile" WORD_NAME = [ @@ -58,7 +61,7 @@ # n_embd = 2560 # ctx_len = 1024 -MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' +MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047" n_layer = 32 n_embd = 4096 ctx_len = 1024 @@ -129,12 +132,12 @@ ######################################################################################################## -print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...') +print(f"\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...") from src.model_run import RWKV_RNN model = RWKV_RNN(args) -print(f'\nOptimizing speed...') +print(f"\nOptimizing speed...") out, _ = model.forward([187], None) # print(out) gc.collect() @@ -142,10 +145,10 @@ # input(0) -print(f'\nLoading tokenizer {WORD_NAME}...') +print(f"\nLoading tokenizer {WORD_NAME}...") tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) if TOKEN_MODE == "pile": - assert tokenizer.tokenizer.decode([187]) == '\n' + assert tokenizer.tokenizer.decode([187]) == "\n" ######################################################################################################## @@ -165,6 +168,7 @@ time_slot = {} time_ref = time.time_ns() + def record_time(name): if name not in time_slot: time_slot[name] = 1e20 @@ -172,13 +176,14 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt + init_state = None init_out = None state = None out = None for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): - print(("-" * 50) + '\n' + context, end="") + print(("-" * 50) + "\n" + context, end="") time_ref = time.time_ns() ctx = src_ctx.copy() @@ -193,7 +198,7 @@ def record_time(name): gc.collect() torch.cuda.empty_cache() - record_time('preprocess') + record_time("preprocess") out_last = src_len for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): x = ctx[: i + 1] @@ -205,7 +210,14 @@ def record_time(name): else: out, state = model.forward(x, state) if DEBUG_DEBUG: - print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) + print( + "model", + np.array(x), + "==>", + np.array(out), + np.max(out.cpu().numpy()), + np.min(out.cpu().numpy()), + ) if TOKEN_MODE == "pile": out[0] = -999999999 # disable <|endoftext|> @@ -224,14 +236,15 @@ def record_time(name): print(char, end="", flush=True) else: char = tokenizer.tokenizer.decode(ctx[out_last:]) - if '\ufffd' not in char: # is valid utf8 string? + if "\ufffd" not in char: # is valid utf8 string? print(char, end="", flush=True) - out_last = i+1 + out_last = i + 1 - record_time('total') + record_time("total") # print(f'\n\n{time_slot}\n\n') print( - f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' + f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", + end="", ) -print(("-" * 50) + '\n') +print(("-" * 50) + "\n") diff --git a/benchmarks/rwkv/rwkv-v4neo/src/binidx.py b/benchmarks/rwkv/rwkv-v4neo/src/binidx.py index 369081ad4..8d5b40bfe 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/binidx.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/binidx.py @@ -7,6 +7,7 @@ from functools import lru_cache from itertools import accumulate + def print_rank_0(*message): pass # """If distributed is initialized print only on rank 0.""" @@ -16,12 +17,14 @@ def print_rank_0(*message): # else: # print(*message, flush=True) + def _warmup_mmap_file(path): pass # with open(path, "rb") as stream: # while stream.read(100 * 1024 * 1024): # pass + dtypes = { 1: np.uint8, 2: np.int8, @@ -33,18 +36,22 @@ def _warmup_mmap_file(path): 8: np.uint16, } + def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) + def index_file_path(prefix_path): return prefix_path + ".idx" + def data_file_path(prefix_path): return prefix_path + ".bin" + class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" @@ -100,7 +107,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._file.close() return _Writer() - + def __init__(self, path, skip_warmup=False): with open(path, "rb") as stream: magic_test = stream.read(9) @@ -217,8 +224,7 @@ def __getitem__(self, idx): elif isinstance(idx, slice): start, stop, step = idx.indices(len(self)) if step != 1: - raise ValueError( - "Slices into indexed_dataset must be contiguous") + raise ValueError("Slices into indexed_dataset must be contiguous") ptr = self._index._pointers[start] sizes = self._index._sizes[idx] offsets = list(accumulate(sizes)) diff --git a/benchmarks/rwkv/rwkv-v4neo/src/dataset.py b/benchmarks/rwkv/rwkv-v4neo/src/dataset.py index 71cbb1a57..6ddafb90b 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/dataset.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/dataset.py @@ -17,15 +17,24 @@ def __init__(self, args): if args.data_type == "binidx": self.vocab_size = args.vocab_size - rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") + rank_zero_info( + f"Current vocab size = {self.vocab_size} (make sure it's correct)" + ) if args.my_pile_version == 1: self.data = MMapIndexedDataset(args.data_file) - self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size + self.data_size = ( + len(self.data._bin_buffer) // self.data._index._dtype_size + ) rank_zero_info(f"Data has {self.data_size} tokens.") else: - data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n') - data_list = [i.strip().split(' ') for i in data_list] + data_list = ( + open(args.data_file, "r", encoding="utf-8") + .read() + .strip() + .split("\n") + ) + data_list = [i.strip().split(" ") for i in data_list] self.data = [] self.data_size = int(data_list[-1][-1]) rank_zero_info(f"Data has {self.data_size} chunks.") @@ -37,29 +46,46 @@ def __init__(self, args): # rank_zero_info(self.data) if args.my_qa_mask > 0: - self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') - self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size + self.data_pile = MMapIndexedDataset( + "/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document" + ) + self.data_pile_size = ( + len(self.data_pile._bin_buffer) // self.data._index._dtype_size + ) if args.my_pile_stage > 0: # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 - rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") + rank_zero_info( + f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########" + ) dataset_slot = self.data_size // args.ctx_len if args.my_pile_stage != 4: assert MaybeIsPrime(args.magic_prime) assert args.magic_prime % 3 == 2 - assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 + assert ( + args.magic_prime / dataset_slot > 0.99 + and args.magic_prime / dataset_slot <= 1 + ) elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = len(self.data) rank_zero_info(f"Data has {self.data_size} tokens.") elif args.data_type == "uint16": - self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) + self.data = ( + np.fromfile(args.data_file, dtype=np.uint16) + .astype("int32") + .reshape(-1, args.my_sample_len) + ) self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info( + "Current vocab size =", self.vocab_size, "(make sure it's correct)" + ) self.data_size = self.data.shape[0] rank_zero_info(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": @@ -92,10 +118,14 @@ def __init__(self, args): for u in unique: xxObj[xx] = u xx += 1 - with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: + with open( + f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le" + ) as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") + rank_zero_info( + f"Data has {self.data_size} tokens, {self.vocab_size} vocab size." + ) self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} @@ -110,36 +140,53 @@ def __getitem__(self, idx): # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") if args.data_type == "wds_img": + def init_wds(self, bias=0): def identity(x): - return x + return x + import webdataset as wds import torchvision.transforms as transforms + # img_transform = transforms.Compose( # [transforms.CenterCrop(256)] # ) - img_transform = transforms.Compose([ - transforms.CenterCrop(512), - transforms.Resize((args.my_img_size)) - ]) - self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) + img_transform = transforms.Compose( + [transforms.CenterCrop(512), transforms.Resize((args.my_img_size))] + ) + self.data_raw = ( + wds.WebDataset(args.data_file, resampled=True) + .shuffle( + 10000, + initial=1000, + rng=random.Random(epoch * 100000 + rank + bias * 1e9), + ) + .decode("torchrgb") + .to_tuple("jpg", "json", "txt") + .map_tuple(img_transform, identity, identity) + ) for pp in self.data_raw.pipeline: - if 'Resampled' in str(pp): + if "Resampled" in str(pp): pp.deterministic = True + def worker_seed(): - return rank*100000+epoch+bias*1e9 + return rank * 100000 + epoch + bias * 1e9 + pp.worker_seed = worker_seed self.data = iter(self.data_raw) # print(f"WebDataset loaded for rank {rank} epoch {epoch}") + if self.data == None: init_wds(self) trial = 0 while trial < 10: try: - dd = next(self.data) # jpg, json, txt + dd = next(self.data) # jpg, json, txt break except: - print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') + print( + f"[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]" + ) self.error_count += 1 init_wds(self, self.error_count) trial += 1 @@ -150,7 +197,7 @@ def worker_seed(): return dd[0], dd[2] else: if args.data_type == "uint16": - i = np.random.randint(0, self.data_size-1) + i = np.random.randint(0, self.data_size - 1) dix = self.data[i] x = torch.tensor(dix[:-1], dtype=torch.long) y = torch.tensor(dix[1:], dtype=torch.long) @@ -203,8 +250,12 @@ def worker_seed(): for j in range(len(data)): if i < data[j][0]: ii = i - i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1] - dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int) + i = (i - (data[j - 1][0] if j > 0 else 0)) % data[j][1] + dix = ( + data[j][2] + .get(idx=0, offset=i, length=req_len) + .astype(int) + ) # print(ii, j, i) break elif args.data_type == "numpy": @@ -220,7 +271,12 @@ def worker_seed(): z_sum = 0 isGood = False for i in range(3, ctx_len): - if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: + if ( + dix[i] == 27 + and dix[i - 1] == 34 + and dix[i - 2] == 187 + and dix[i - 3] == 187 + ): isGood = True if dix[i] == 0: isGood = False @@ -230,7 +286,9 @@ def worker_seed(): if z_sum == 0: z = [1] * ctx_len i = np.random.randint(0, self.data_pile_size - req_len) - dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) + dix = self.data_pile.get( + idx=0, offset=i, length=req_len + ).astype(int) z = torch.tensor(z, dtype=torch.bfloat16) x = torch.tensor(dix[:-1], dtype=torch.long) diff --git a/benchmarks/rwkv/rwkv-v4neo/src/model.py b/benchmarks/rwkv/rwkv-v4neo/src/model.py index b79f96d26..0914c160e 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/model.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/model.py @@ -4,6 +4,7 @@ import os, math, gc, importlib import torch + # torch._C._jit_set_profiling_executor(True) # torch._C._jit_set_profiling_mode(True) import torch.nn as nn @@ -11,16 +12,18 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from pytorch_lightning.strategies import DeepSpeedStrategy -if importlib.util.find_spec('deepspeed'): + +if importlib.util.find_spec("deepspeed"): import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam try: - print('RWKV_MY_TESTING', os.environ["RWKV_MY_TESTING"]) + print("RWKV_MY_TESTING", os.environ["RWKV_MY_TESTING"]) except: - os.environ["RWKV_MY_TESTING"] = '' + os.environ["RWKV_MY_TESTING"] = "" + def __nop(ob): return ob @@ -43,7 +46,23 @@ def __nop(ob): from torch.utils.cpp_extension import load if os.environ["RWKV_FLOAT_MODE"] == "bf16": - wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-t 4", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}_bf16", + sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], + verbose=True, + extra_cuda_cflags=[ + "-t 4", + "-std=c++17", + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -56,10 +75,16 @@ def forward(ctx, B, T, C, w, u, k, v): u = u.contiguous() k = k.contiguous() v = v.contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + y = torch.empty( + (B, T, C), + device=w.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) return y + @staticmethod def backward(ctx, gy): B = ctx.B @@ -68,16 +93,51 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16) + gw = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gu = torch.empty( + (B, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gk = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) + gv = torch.empty( + (B, T, C), + device=gy.device, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16, + ) wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) return (None, None, None, gw, gu, gk, gv) + else: - wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) + wkv_cuda = load( + name=f"wkv_{T_MAX}", + sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={T_MAX}", + ], + ) + class WKV(torch.autograd.Function): @staticmethod def forward(ctx, B, T, C, w, u, k, v): @@ -96,7 +156,9 @@ def forward(ctx, B, T, C, w, u, k, v): u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() - y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) + y = torch.empty( + (B, T, C), device=w.device, memory_format=torch.contiguous_format + ) wkv_cuda.forward(B, T, C, w, u, k, v, y) ctx.save_for_backward(w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -105,6 +167,7 @@ def forward(ctx, B, T, C, w, u, k, v): return y.half() elif os.environ["RWKV_FLOAT_MODE"] == "bf16": return y.bfloat16() + @staticmethod def backward(ctx, gy): B = ctx.B @@ -113,14 +176,26 @@ def backward(ctx, gy): assert T <= T_MAX assert B * C % min(C, 32) == 0 w, u, k, v, y = ctx.saved_tensors - gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) - gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) - gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gw = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gu = torch.empty( + (B, C), device=gy.device, memory_format=torch.contiguous_format + ) + gk = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) + gv = torch.empty( + (B, T, C), device=gy.device, memory_format=torch.contiguous_format + ) if "32" in os.environ["RWKV_FLOAT_MODE"]: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv + ) else: - wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward( + B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv + ) gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: @@ -128,7 +203,15 @@ def backward(ctx, gy): elif os.environ["RWKV_FLOAT_MODE"] == "fp16": return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) elif os.environ["RWKV_FLOAT_MODE"] == "bf16": - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + return ( + None, + None, + None, + gw.bfloat16(), + gu.bfloat16(), + gk.bfloat16(), + gv.bfloat16(), + ) def RUN_CUDA(B, T, C, w, u, k, v): @@ -154,21 +237,27 @@ def __init__(self, args, layer_id): ddd = torch.ones(1, 1, args.n_embd) for i in range(args.n_embd): ddd[0, 0, i] = i / args.n_embd - + # fancy time_decay decay_speed = torch.ones(args.dim_att) for h in range(args.dim_att): - decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + decay_speed[h] = -5 + 8 * (h / (args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) self.time_decay = nn.Parameter(decay_speed) # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) # fancy time_first zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(args.dim_att)]) * 0.5 - self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag) + self.time_first = nn.Parameter( + torch.ones(args.dim_att) * math.log(0.3) + zigzag + ) # fancy time_mix self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_v = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0)) self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) @@ -177,8 +266,10 @@ def __init__(self, args, layer_id): self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) - if 'a' in os.environ["RWKV_MY_TESTING"]: - self.register_buffer("att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + if "a" in os.environ["RWKV_MY_TESTING"]: + self.register_buffer( + "att_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) d_qkv = args.n_embd // 16 self.qq = nn.Linear(args.n_embd, d_qkv, bias=False) self.kk = nn.Linear(args.n_embd, d_qkv, bias=False) @@ -187,12 +278,17 @@ def __init__(self, args, layer_id): with torch.no_grad(): self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_vv = nn.Parameter( + torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + ) + + if "a" not in os.environ["RWKV_MY_TESTING"]: - if 'a' not in os.environ["RWKV_MY_TESTING"]: @MyFunction def jit_func(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -205,21 +301,26 @@ def jit_func(self, x): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v = self.jit_func(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) return self.output(rwkv) - if 'a' in os.environ["RWKV_MY_TESTING"]: + if "a" in os.environ["RWKV_MY_TESTING"]: + @MyFunction def QKV(self, q, k, v): att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.att_mask == 0, float('-inf')) - att = F.softmax(att, dim = -1) + att = att.masked_fill(self.att_mask == 0, float("-inf")) + att = F.softmax(att, dim=-1) x = att @ v return x @MyFunction def jit_funcQKV(self, x): - xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift( + x + ) # Mix x with the previous timestep to produce xk, xv, xr xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) @@ -238,12 +339,16 @@ def jit_funcQKV(self, x): def forward(self, x): B, T, C = x.size() # x = (Batch,Time,Channel) sr, k, v, qq, kk, vv = self.jit_funcQKV(x) - rwkv = sr * RUN_CUDA(B, T, self.args.dim_att, self.time_decay, self.time_first, k, v) + rwkv = sr * RUN_CUDA( + B, T, self.args.dim_att, self.time_decay, self.time_first, k, v + ) rwkv = self.output(rwkv) + self.oo(self.QKV(qq, kk, vv)) return rwkv + ######################################################################################################## + class RWKV_ChannelMix(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -258,7 +363,7 @@ def __init__(self, args, layer_id): ddd[0, 0, i] = i / args.n_embd self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0)) - + self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) @@ -273,6 +378,7 @@ def forward(self, x): kv = self.value(k) return torch.sigmoid(self.receptance(xr)) * kv + class MishGLU(MyModule): def __init__(self, args, layer_id): super().__init__() @@ -302,6 +408,7 @@ def forward(self, x): b = self.bb(xb) return self.value(a * F.mish(b)) + ######################################################################################################## # The RWKV Model with our blocks ######################################################################################################## @@ -319,25 +426,31 @@ def __init__(self, args, layer_id): if self.layer_id == 0: self.ln0 = nn.LayerNorm(args.n_embd) if args.my_pos_emb > 0: - self.pos_emb_x = nn.Parameter(torch.zeros((1,args.my_pos_emb,args.n_embd))) - self.pos_emb_y = nn.Parameter(torch.zeros((args.my_pos_emb,1,args.n_embd))) + self.pos_emb_x = nn.Parameter( + torch.zeros((1, args.my_pos_emb, args.n_embd)) + ) + self.pos_emb_y = nn.Parameter( + torch.zeros((args.my_pos_emb, 1, args.n_embd)) + ) if self.layer_id == 0 and self.args.pre_ffn > 0: self.ffnPre = RWKV_ChannelMix(args, 0) else: self.att = RWKV_TimeMix(args, layer_id) - if 'g' in os.environ["RWKV_MY_TESTING"]: + if "g" in os.environ["RWKV_MY_TESTING"]: self.ffn = MishGLU(args, layer_id) else: self.ffn = RWKV_ChannelMix(args, layer_id) - + if args.tiny_att_dim > 0 and self.layer_id == args.tiny_att_layer: self.tiny_ln = nn.LayerNorm(args.n_embd) self.tiny_q = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_k = nn.Linear(args.n_embd, args.tiny_att_dim, bias=False) self.tiny_v = nn.Linear(args.n_embd, args.n_embd, bias=False) - self.register_buffer("tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "tiny_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def forward(self, x, x_emb=None): args = self.args @@ -345,7 +458,7 @@ def forward(self, x, x_emb=None): if self.layer_id == 0: x = self.ln0(x) if args.my_pos_emb > 0: - pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T+1, -1)[:-1,:] + pos_emb = (self.pos_emb_x + self.pos_emb_y).reshape(T + 1, -1)[:-1, :] x = x + pos_emb if self.layer_id == 0 and args.pre_ffn > 0: @@ -385,13 +498,13 @@ class RWKV(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args - if not hasattr(args, 'dim_att'): + if not hasattr(args, "dim_att"): args.dim_att = args.n_embd - if not hasattr(args, 'dim_ffn'): + if not hasattr(args, "dim_ffn"): args.dim_ffn = args.n_embd * 4 - if not hasattr(args, 'tiny_att_layer'): + if not hasattr(args, "tiny_att_layer"): args.tiny_att_layer = -1 - if not hasattr(args, 'tiny_att_dim'): + if not hasattr(args, "tiny_att_dim"): args.tiny_att_dim = -1 self.emb = nn.Embedding(args.vocab_size, args.n_embd) @@ -404,7 +517,9 @@ def __init__(self, args): if args.head_qk > 0: self.head_q = nn.Linear(args.n_embd, args.head_qk, bias=False) self.head_k = nn.Linear(args.n_embd, args.head_qk, bias=False) - self.register_buffer("copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len))) + self.register_buffer( + "copy_mask", torch.tril(torch.ones(args.ctx_len, args.ctx_len)) + ) def configure_optimizers(self): args = self.args @@ -436,24 +551,69 @@ def configure_optimizers(self): param_dict = {n: p for n, p in self.named_parameters()} if args.my_pile_stage == 2: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 2e-3 / args.lr_init}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 5.0},# test: 3e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 2e-3 / args.lr_init}, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 5.0, + }, # test: 3e-3 / args.lr_init}, ] else: optim_groups = [ - {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "my_lr_scale": 1.0}, - {"params": [param_dict[n] for n in lr_2x], "weight_decay": 0.0, "my_lr_scale": 2.0}, - {"params": [param_dict[n] for n in lr_3x], "weight_decay": 0.0, "my_lr_scale": 3.0}, + { + "params": [param_dict[n] for n in lr_1x], + "weight_decay": 0.0, + "my_lr_scale": 1.0, + }, + { + "params": [param_dict[n] for n in lr_2x], + "weight_decay": 0.0, + "my_lr_scale": 2.0, + }, + { + "params": [param_dict[n] for n in lr_3x], + "weight_decay": 0.0, + "my_lr_scale": 3.0, + }, ] else: optim_groups = [ - {"params": [p for n, p in self.named_parameters()], "weight_decay": 0.0}, + { + "params": [p for n, p in self.named_parameters()], + "weight_decay": 0.0, + }, ] if self.deepspeed_offload: - return DeepSpeedCPUAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adamw_mode=False, weight_decay=0, amsgrad=False) - return FusedAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + return DeepSpeedCPUAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adamw_mode=False, + weight_decay=0, + amsgrad=False, + ) + return FusedAdam( + optim_groups, + lr=self.args.lr_init, + betas=self.args.betas, + eps=self.args.adam_eps, + bias_correction=True, + adam_w_mode=False, + weight_decay=0, + amsgrad=False, + ) # return ZeroOneAdam(optim_groups, lr=self.args.lr_init, betas=self.args.betas, eps=self.args.adam_eps, bias_correction=True, weight_decay=0, amsgrad=False, cuda_aware=False) @property @@ -521,10 +681,14 @@ def training_step(self, batch, batch_idx): logits = self(idx) if sum_mask == mask.shape[0]: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1) + ) # print('rank', self.global_rank, 'loss', loss.item()) else: - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none') + loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), reduction="none" + ) # loss_raw = loss loss = torch.sum(loss * mask) / sum_mask @@ -564,7 +728,14 @@ def generate_init_weight(self): gain = 1.0 scale = 1.0 - if "ln_" in n or ".ln" in n or "time_" in n or "_mask" in n or "pos_emb" in n or '.mask.' in n: + if ( + "ln_" in n + or ".ln" in n + or "time_" in n + or "_mask" in n + or "pos_emb" in n + or ".mask." in n + ): m[n] = p else: if n == "emb.weight": @@ -572,7 +743,19 @@ def generate_init_weight(self): else: if shape[0] > shape[1]: gain = math.sqrt(shape[0] / shape[1]) - for kk in [".att.key.", ".att.receptance.", ".att.output.", ".att.key.", ".ffn.value.", ".ffn.receptance.", ".ffnPre.value.", ".ffnPre.receptance.", "head_q.", '.oo.', '.rr.']: + for kk in [ + ".att.key.", + ".att.receptance.", + ".att.output.", + ".att.key.", + ".ffn.value.", + ".ffn.receptance.", + ".ffnPre.value.", + ".ffnPre.receptance.", + "head_q.", + ".oo.", + ".rr.", + ]: if kk in n: scale = 0 if n == "head.weight": @@ -582,7 +765,9 @@ def generate_init_weight(self): if "head_q." in n: scale = 0 - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}") + print( + f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {n}" + ) if self.args.accelerator.upper() == "GPU": m[n] = torch.empty((shape[0], shape[1]), device="cuda") diff --git a/benchmarks/rwkv/rwkv-v4neo/src/model_img.py b/benchmarks/rwkv/rwkv-v4neo/src/model_img.py index 24337236b..3a9bceb4e 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/model_img.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/model_img.py @@ -13,10 +13,14 @@ from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam + # from pytorch_msssim import MS_SSIM + def __nop(ob): return ob + + MyModule = torch.jit.ScriptModule # MyFunction = __nop MyFunction = torch.jit.script_method @@ -24,6 +28,7 @@ def __nop(ob): import clip from transformers import CLIPModel + class L2pooling(nn.Module): def __init__(self, filter_size=5, stride=2, channels=None, pad_off=0): super(L2pooling, self).__init__() @@ -149,55 +154,57 @@ def forward(self, x, y, require_grad=False, batch_average=False): class ToBinary(torch.autograd.Function): @staticmethod - def forward(ctx, x):#, noise_scale): + def forward(ctx, x): # , noise_scale): # if noise_scale > 0: # noise_min = 0.5 - noise_scale / 2 # noise_max = 0.5 + noise_scale / 2 # return torch.floor(x + torch.empty_like(x).uniform_(noise_min, noise_max)) # else: - return torch.floor(x + 0.5) # no need for noise when we have plenty of data + return torch.floor(x + 0.5) # no need for noise when we have plenty of data @staticmethod def backward(ctx, grad_output): - return grad_output.clone()#, None + return grad_output.clone() # , None + ######################################################################################################## + class R_ENCODER(MyModule): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.Bxx = nn.BatchNorm2d(dd*64) + self.Bxx = nn.BatchNorm2d(dd * 64) self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*4) - self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - - self.B20 = nn.BatchNorm2d(dd*64) - self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 4) + self.C00 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + + self.B20 = nn.BatchNorm2d(dd * 64) + self.C20 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) # self.B21 = nn.BatchNorm2d(dd*64) # self.C24 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C25 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) # self.C26 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C27 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) + self.COUT = nn.Conv2d(dd * 64, args.my_img_bit, kernel_size=3, padding=1) @MyFunction def forward(self, img): @@ -224,37 +231,39 @@ def forward(self, img): x = self.COUT(x + xx) return torch.sigmoid(x) + ######################################################################################################## + class R_DECODER(MyModule): def __init__(self, args): super().__init__() self.args = args dd = 8 - self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) + self.CIN = nn.Conv2d(args.my_img_bit, dd * 64, kernel_size=3, padding=1) - self.B00 = nn.BatchNorm2d(dd*64) - self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) - self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) + self.B00 = nn.BatchNorm2d(dd * 64) + self.C00 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C01 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) + self.C02 = nn.Conv2d(dd * 64, 256, kernel_size=3, padding=1) + self.C03 = nn.Conv2d(256, dd * 64, kernel_size=3, padding=1) # self.B01 = nn.BatchNorm2d(dd*64) # self.C04 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C05 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) # self.C06 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) # self.C07 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) - self.B10 = nn.BatchNorm2d(dd*16) - self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) - self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) - self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) + self.B10 = nn.BatchNorm2d(dd * 16) + self.C10 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C11 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) + self.C12 = nn.Conv2d(dd * 16, 256, kernel_size=3, padding=1) + self.C13 = nn.Conv2d(256, dd * 16, kernel_size=3, padding=1) - self.B20 = nn.BatchNorm2d(dd*4) - self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) - self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) - self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) + self.B20 = nn.BatchNorm2d(dd * 4) + self.C20 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C21 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) + self.C22 = nn.Conv2d(dd * 4, 256, kernel_size=3, padding=1) + self.C23 = nn.Conv2d(256, dd * 4, kernel_size=3, padding=1) self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) @@ -281,47 +290,52 @@ def forward(self, code): x = x + self.Cx1(ACT(self.Cx0(x))) x = self.COUT(x) - + return torch.sigmoid(x) + ########################################################################################################` + def cosine_loss(x, y): x = F.normalize(x, dim=-1) y = F.normalize(y, dim=-1) - return 1 - torch.einsum('ij,ij->i',[x,y]) + return 1 - torch.einsum("ij,ij->i", [x, y]) + class RWKV_IMG(pl.LightningModule): def __init__(self, args): super().__init__() self.args = args - + self.encoder = R_ENCODER(args) self.decoder = R_DECODER(args) self.clip_model = None clip_name = args.my_img_clip - if clip_name == 'B32': - clip_name = 'ViT-B/32' - elif clip_name == 'B16': - clip_name = 'ViT-B/16' - elif clip_name == 'L14': - clip_name = 'ViT-L/14' - elif clip_name == 'OB32': + if clip_name == "B32": + clip_name = "ViT-B/32" + elif clip_name == "B16": + clip_name = "ViT-B/16" + elif clip_name == "L14": + clip_name = "ViT-L/14" + elif clip_name == "OB32": clip_name = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" self.clip_model = CLIPModel.from_pretrained(clip_name) self.clip_model.encode_image = self.clip_model.get_image_features if self.clip_model == None: - self.clip_model, _ = clip.load(clip_name, jit = True) + self.clip_model, _ = clip.load(clip_name, jit=True) self.register_buffer( - "clip_mean", torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1) + "clip_mean", + torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1), ) self.register_buffer( - "clip_std", torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1) + "clip_std", + torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1), ) for n, p in self.named_parameters(): - if 'clip_model' in n: + if "clip_model" in n: p.requires_grad = False self.loss_dists = DISTS() @@ -365,7 +379,7 @@ def deepspeed_offload(self) -> bool: def forward(self, img): z = self.encoder(img) - z = ToBinary.apply(z)#, self.args.my_img_noise_scale) + z = ToBinary.apply(z) # , self.args.my_img_noise_scale) out = self.decoder(z) return out @@ -379,10 +393,12 @@ def training_step(self, batch, batch_idx): if not os.path.exists(img_dir): os.makedirs(img_dir) vision.utils.save_image( - img[:4], f"{img_dir}/{self.trainer.global_step}-src.jpg"#, padding=0 + img[:4], + f"{img_dir}/{self.trainer.global_step}-src.jpg", # , padding=0 ) vision.utils.save_image( - out[:4], f"{img_dir}/{self.trainer.global_step}-out.jpg"#, padding=0 + out[:4], + f"{img_dir}/{self.trainer.global_step}-out.jpg", # , padding=0 ) # loss_ssim = 1 - self.loss_ssim(out, img) @@ -394,7 +410,11 @@ def training_step(self, batch, batch_idx): if args.my_img_l1_scale > 0: loss_l1 = F.l1_loss(out, img) - return loss_dists + loss_clip * args.my_img_clip_scale + loss_l1 * args.my_img_l1_scale + return ( + loss_dists + + loss_clip * args.my_img_clip_scale + + loss_l1 * args.my_img_l1_scale + ) else: return loss_dists + loss_clip * args.my_img_clip_scale @@ -418,7 +438,7 @@ def generate_init_weight(self): scale = 1 p = self.state_dict()[n] shape = p.shape - ss = n.split('.') + ss = n.split(".") # if ss[0] in ['encoder', 'decoder']: # if ss[2] == 'bias': diff --git a/benchmarks/rwkv/rwkv-v4neo/src/model_run.py b/benchmarks/rwkv/rwkv-v4neo/src/model_run.py index 2516e508c..184a35cfa 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/model_run.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/model_run.py @@ -10,8 +10,12 @@ from typing import List, Dict MyModule = nn.Module + + def __nop(ob): return ob + + MyFunction = __nop # # try torchdynamo @@ -24,14 +28,17 @@ def __nop(ob): MyFunction = torch.jit.script_method RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') +print( + f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n' +) -DEBUG_TIME = False # True False - show trained time-coeffs +DEBUG_TIME = False # True False - show trained time-coeffs -RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer +RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer ############################################################################################################ + class RWKV_RNN(MyModule): def __init__(self, args): super().__init__() @@ -41,30 +48,32 @@ def __init__(self, args): self.RUN_DEVICE = args.RUN_DEVICE with torch.no_grad(): - w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') + w = torch.load(args.MODEL_NAME + ".pth", map_location="cpu") # refine weights and send to correct device keys = list(w.keys()) - if 'pos_emb_x' in keys: - w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:] + if "pos_emb_x" in keys: + w["pos_emb"] = (w["pos_emb_x"] + w["pos_emb_y"]).reshape( + args.ctx_len + 1, -1 + )[:-1, :] keys = list(w.keys()) print_need_newline = False for x in keys: block_id = 0 - if 'blocks.' in x: - block_id = int(x.split('.')[1]) - if 'att.output.weight' in x: + if "blocks." in x: + block_id = int(x.split(".")[1]) + if "att.output.weight" in x: w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - if 'ffn.value.weight' in x: + if "ffn.value.weight" in x: w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) - - if '.time_' in x: + + if ".time_" in x: w[x] = w[x].squeeze() if DEBUG_TIME: print(x, w[x].numpy()) - if '.time_decay' in x: + if ".time_decay" in x: w[x] = w[x].float() w[x] = -torch.exp(w[x]) - elif '.time_first' in x: + elif ".time_first" in x: w[x] = w[x].float() else: if self.FLOAT_MODE == "fp32": @@ -75,23 +84,27 @@ def __init__(self, args): w[x] = w[x].half() w[x].requires_grad = False - if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': + if args.RUN_DEVICE == "cuda" and x != "emb.weight": w[x] = w[x].cuda() - if ('blocks.' not in x) or ('blocks.0.' in x): + if ("blocks." not in x) or ("blocks.0." in x): if print_need_newline: - print('\n', end = '') + print("\n", end="") print_need_newline = False - print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) + print( + x.ljust(40), + str(w[x].dtype).replace("torch.", "").ljust(10), + w[x].device, + ) else: print_need_newline = True - print('.', end = '', flush = True) + print(".", end="", flush=True) # store weights in self.w keys = list(w.keys()) self.w = types.SimpleNamespace() for x in keys: - xx = x.split('.') + xx = x.split(".") here = self.w for i in range(len(xx)): if xx[i].isdigit(): @@ -103,7 +116,7 @@ def __init__(self, args): if i == len(xx) - 1: setattr(here, xx[i], w[x]) elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): + if xx[i + 1].isdigit(): setattr(here, xx[i], {}) else: setattr(here, xx[i], types.SimpleNamespace()) @@ -119,19 +132,23 @@ def LN(self, x, w): # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp @MyFunction - def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): + def FF(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+0] = x.float() + xk = x * time_mix_k + state[5 * i + 0].type(torch.bfloat16) * ( + 1 - time_mix_k + ) + xr = x * time_mix_r + state[5 * i + 0].type(torch.bfloat16) * ( + 1 - time_mix_r + ) + state[5 * i + 0] = x.float() elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r) - state[5*i+0] = x.float() + xk = x * time_mix_k + state[5 * i + 0].half() * (1 - time_mix_k) + xr = x * time_mix_r + state[5 * i + 0].half() * (1 - time_mix_r) + state[5 * i + 0] = x.float() else: - xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) - xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) - state[5*i+0] = x + xk = x * time_mix_k + state[5 * i + 0] * (1 - time_mix_k) + xr = x * time_mix_r + state[5 * i + 0] * (1 - time_mix_r) + state[5 * i + 0] = x r = torch.sigmoid(rw @ xr) k = torch.square(torch.relu(kw @ xk)) @@ -140,36 +157,56 @@ def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): return r * kv @MyFunction - def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): + def SA( + self, + x, + state, + i: int, + time_mix_k, + time_mix_v, + time_mix_r, + time_first, + time_decay, + kw, + vw, + rw, + ow, + ): if self.FLOAT_MODE == "bf16": - xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) - state[5*i+1] = x.float() + xk = x * time_mix_k + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_k + ) + xv = x * time_mix_v + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_v + ) + xr = x * time_mix_r + state[5 * i + 1].type(torch.bfloat16) * ( + 1 - time_mix_r + ) + state[5 * i + 1] = x.float() elif self.FLOAT_MODE == "fp16": - xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) - state[5*i+1] = x.float() + xk = x * time_mix_k + state[5 * i + 1].half() * (1 - time_mix_k) + xv = x * time_mix_v + state[5 * i + 1].half() * (1 - time_mix_v) + xr = x * time_mix_r + state[5 * i + 1].half() * (1 - time_mix_r) + state[5 * i + 1] = x.float() else: - xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) - xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) - xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) - state[5*i+1] = x + xk = x * time_mix_k + state[5 * i + 1] * (1 - time_mix_k) + xv = x * time_mix_v + state[5 * i + 1] * (1 - time_mix_v) + xr = x * time_mix_r + state[5 * i + 1] * (1 - time_mix_r) + state[5 * i + 1] = x r = torch.sigmoid(rw @ xr) k = kw @ xk v = vw @ xv - if '16' in self.FLOAT_MODE: + if "16" in self.FLOAT_MODE: kk = k.float() vv = v.float() else: kk = k vv = v - aa = state[5*i+2] - bb = state[5*i+3] - pp = state[5*i+4] + aa = state[5 * i + 2] + bb = state[5 * i + 3] + pp = state[5 * i + 4] ww = time_first + kk p = torch.maximum(pp, ww) e1 = torch.exp(pp - p) @@ -180,52 +217,72 @@ def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, ti p = torch.maximum(ww, kk) e1 = torch.exp(ww - p) e2 = torch.exp(kk - p) - state[5*i+2] = e1 * aa + e2 * vv - state[5*i+3] = e1 * bb + e2 - state[5*i+4] = p + state[5 * i + 2] = e1 * aa + e2 * vv + state[5 * i + 3] = e1 * bb + e2 + state[5 * i + 4] = p if self.FLOAT_MODE == "bf16": wkv = (a / b).type(torch.bfloat16) elif self.FLOAT_MODE == "fp16": wkv = (a / b).half() else: wkv = a / b - + return ow @ (r * wkv) - def forward(self, ctx, state, preprocess_only = False): + def forward(self, ctx, state, preprocess_only=False): with torch.no_grad(): w = self.w args = self.args x = w.emb.weight[ctx[-1]] - if self.RUN_DEVICE == 'cuda': + if self.RUN_DEVICE == "cuda": x = x.cuda() try: - pos_emb = w.pos_emb[len(ctx)-1] + pos_emb = w.pos_emb[len(ctx) - 1] x = x + pos_emb except: - pass + pass if state == None: - state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) + state = torch.zeros( + args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE + ) for i in range(args.n_layer): - state[5*i+4] -= 1e30 + state[5 * i + 4] -= 1e30 for i in range(args.n_layer): if i == 0: x = self.LN(x, w.blocks[i].ln0) - + ww = w.blocks[i].att - x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, - ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, - ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) - + x = x + self.SA( + self.LN(x, w.blocks[i].ln1), + state, + i, + ww.time_mix_k, + ww.time_mix_v, + ww.time_mix_r, + ww.time_first, + ww.time_decay, + ww.key.weight, + ww.value.weight, + ww.receptance.weight, + ww.output.weight, + ) + ww = w.blocks[i].ffn - x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, - ww.time_mix_k, ww.time_mix_r, - ww.key.weight, ww.value.weight, ww.receptance.weight) - - if (i+1) % RWKV_RESCALE_LAYER == 0: + x = x + self.FF( + self.LN(x, w.blocks[i].ln2), + state, + i, + ww.time_mix_k, + ww.time_mix_r, + ww.key.weight, + ww.value.weight, + ww.receptance.weight, + ) + + if (i + 1) % RWKV_RESCALE_LAYER == 0: x = x / 2 if preprocess_only: diff --git a/benchmarks/rwkv/rwkv-v4neo/src/trainer.py b/benchmarks/rwkv/rwkv-v4neo/src/trainer.py index 9791ea524..98f229c40 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/trainer.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/trainer.py @@ -5,6 +5,7 @@ from pytorch_lightning.utilities import rank_zero_info, rank_zero_only from giving import give + def my_save(dd, ff): pass # if '14b-run1' not in ff: @@ -15,6 +16,7 @@ def my_save(dd, ff): # torch.save(dd, fff) # subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) + class train_callback(pl.Callback): def __init__(self, args): super().__init__() @@ -39,7 +41,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if args.lr_final == 0 or args.lr_init == 0: # linear decay lr = args.lr_init + (args.lr_final - args.lr_init) * progress else: # exp decay - lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + lr = args.lr_init * math.exp( + math.log(args.lr_final / args.lr_init) * pow(progress, 1) + ) if trainer.global_step < w_step: lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) @@ -61,7 +65,9 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): trainer.my_loss_sum = 0 trainer.my_loss_count = 0 trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") - trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + trainer.my_log.write( + f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n" + ) try: print(f"\n{trainer.strategy.config}\n") trainer.my_log.write(f"{trainer.strategy.config}\n") @@ -71,6 +77,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): if len(args.wandb) > 0: print("Login to wandb...") import wandb + wandb.init( project=args.wandb, name=args.run_name + " " + args.my_timestamp, @@ -105,19 +112,25 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): # self.log("s", real_step, prog_bar=True, on_step=True) if len(args.wandb) > 0: - lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + lll = { + "loss": trainer.my_loss, + "lr": trainer.my_lr, + "Gtokens": real_step * token_per_step / 1e9, + } if kt_s > 0: lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) if args.magic_prime > 0: expand_factor = 2 if args.my_qa_mask > 0 else 1 - if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: + if ( + int(real_step) + == int(args.magic_prime * expand_factor // args.real_bsz) - 1 + ): to_save_dict = pl_module.state_dict() my_save( to_save_dict, f"{args.proj_dir}/rwkv-final.pth", ) - def on_train_epoch_start(self, trainer, pl_module): args = self.args @@ -147,7 +160,9 @@ def on_train_epoch_end(self, trainer, pl_module): # ) # except Exception as e: # print('Error\n\n', e, '\n\n') - trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.write( + f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n" + ) trainer.my_log.flush() trainer.my_loss_sum = 0 @@ -169,22 +184,22 @@ def generate_init_weight(model, init_weight_name): mm[k] = src.reshape(mm[k].shape) except: tmp = mm[k].squeeze().clone() - print(k, src.shape, '-->', mm[k].shape) + print(k, src.shape, "-->", mm[k].shape) ss = src.shape[0] dd = tmp.shape[0] for i in range(dd): pos = i / dd * ss if pos >= ss - 1: - tmp[i] = src[ss-1] + tmp[i] = src[ss - 1] else: p0 = int(math.floor(pos)) ii = pos - p0 - tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) + tmp[i] = src[p0] * (1 - ii) + src[p0 + 1] * (ii) mm[k] = tmp.reshape(mm[k].shape) sss = src.squeeze().float().cpu().numpy() - print(sss[:10], '...', sss[-10:]) + print(sss[:10], "...", sss[-10:]) mmm = mm[k].squeeze().float().cpu().numpy() - print(mmm[:10], '...', mmm[-10:]) + print(mmm[:10], "...", mmm[-10:]) # print(f"Save to {init_weight_name}...") # torch.save(mm, init_weight_name) diff --git a/benchmarks/rwkv/rwkv-v4neo/src/utils.py b/benchmarks/rwkv/rwkv-v4neo/src/utils.py index ea25990b4..87da098db 100644 --- a/benchmarks/rwkv/rwkv-v4neo/src/utils.py +++ b/benchmarks/rwkv/rwkv-v4neo/src/utils.py @@ -6,6 +6,7 @@ time_slot = {} time_ref = time.time_ns() + def record_time(name): if name not in time_slot: time_slot[name] = 1e20 @@ -13,20 +14,23 @@ def record_time(name): if tt < time_slot[name]: time_slot[name] = tt -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - if 'list' in str(type(WORD_NAME)): + +class TOKENIZER: + def __init__(self, WORD_NAME, UNKNOWN_CHAR="\ue083"): + if "list" in str(type(WORD_NAME)): self.charMode = False if WORD_NAME[0] == WORD_NAME[1]: from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) else: from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) self.vocab_size = len(self.tokenizer) else: self.charMode = True - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + with open(WORD_NAME + ".json", "r", encoding="utf-16") as result_file: self.word_table = json.load(result_file) self.vocab_size = len(self.word_table) @@ -37,23 +41,25 @@ def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] def refine_context(self, context): - context = context.strip().split('\n') + context = context.strip().split("\n") for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' + context[c] = context[c].strip().strip("\u3000").strip("\r") + context = list(filter(lambda c: c != "", context)) + context = "\n" + ("\n".join(context)).strip() + if context == "": + context = "\n" return context - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + def sample_logits( + self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None + ): # out[self.UNKNOWN_CHAR] = -float('Inf') lastChar = int(x[-1]) probs = F.softmax(out, dim=-1) if self.charMode: - if self.itos[lastChar] == '\n': + if self.itos[lastChar] == "\n": top_p = top_p_newline else: top_p = top_p_usual @@ -81,6 +87,7 @@ def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_ out = torch.multinomial(probs, num_samples=1)[0] return out + def MaybeIsPrime(number): if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): return True @@ -121,7 +128,9 @@ def MillerRabinPrimalityTest(number): if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): iterationNumber = 1 - while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): + while (iterationNumber <= timesTwoDividNumber - 1) and ( + randomNumberWithPower != number - 1 + ): randomNumberWithPower = pow(randomNumberWithPower, 2, number) iterationNumber = iterationNumber + 1 if randomNumberWithPower != (number - 1): diff --git a/benchmarks/rwkv/rwkv-v4neo/train.py b/benchmarks/rwkv/rwkv-v4neo/train.py index 6dd8ce166..875d9c4eb 100644 --- a/benchmarks/rwkv/rwkv-v4neo/train.py +++ b/benchmarks/rwkv/rwkv-v4neo/train.py @@ -52,53 +52,91 @@ parser = ArgumentParser() parser.add_argument("--load_model", default="", type=str) # full path, with .pth - parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb - parser.add_argument("--proj_dir", default=os.environ.get("MILABENCH_BASE", ".") + "/proj/rwkv/", type=str) + parser.add_argument( + "--wandb", default="", type=str + ) # wandb project name. if "" then don't use wandb + parser.add_argument( + "--proj_dir", + default=os.environ.get("MILABENCH_BASE", ".") + "/proj/rwkv/", + type=str, + ) parser.add_argument("--random_seed", default="-1", type=int) parser.add_argument("--data_file", default="", type=str) parser.add_argument("--data_type", default="utf-8", type=str) - parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + parser.add_argument( + "--vocab_size", default=0, type=int + ) # vocab_size = 0 means auto (for char-level LM and .txt data) parser.add_argument("--ctx_len", default=1024, type=int) - parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps - parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final - parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x - parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" - - parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument( + "--epoch_steps", default=1000, type=int + ) # a mini "epoch" has [epoch_steps] steps + parser.add_argument( + "--epoch_count", default=500, type=int + ) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument( + "--epoch_begin", default=0, type=int + ) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument( + "--epoch_save", default=5, type=int + ) # save the model every [epoch_save] "epochs" + + parser.add_argument( + "--micro_bsz", default=12, type=int + ) # micro batch size (batch size per GPU) parser.add_argument("--n_layer", default=6, type=int) parser.add_argument("--n_embd", default=512, type=int) parser.add_argument("--dim_att", default=0, type=int) parser.add_argument("--dim_ffn", default=0, type=int) - parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument( + "--pre_ffn", default=0, type=int + ) # replace first att layer by ffn (sometimes better) parser.add_argument("--head_qk", default=0, type=int) # my headQK trick parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim - parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + parser.add_argument( + "--tiny_att_layer", default=-999, type=int + ) # tiny attention @ which layer - parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument( + "--lr_init", default=6e-4, type=float + ) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 parser.add_argument("--lr_final", default=1e-5, type=float) - parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument( + "--warmup_steps", default=0, type=int + ) # try 50 if you load a model parser.add_argument("--beta1", default=0.9, type=float) - parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument( + "--beta2", default=0.99, type=float + ) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) - parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument( + "--grad_cp", default=0, type=int + ) # gradient checkpt: saves VRAM, but slower - parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version + parser.add_argument( + "--my_pile_version", default=1, type=int + ) # my special pile version parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode - parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument( + "--my_pile_shift", default=-1, type=int + ) # my special pile mode - text shift parser.add_argument("--my_pile_edecay", default=0, type=int) - parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) - parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + parser.add_argument( + "--layerwise_lr", default=1, type=int + ) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument( + "--ds_bucket_mb", default=200, type=int + ) # deepspeed bucket size in MB. 200 seems enough # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) parser.add_argument("--my_img_version", default=0, type=str) parser.add_argument("--my_img_size", default=0, type=int) parser.add_argument("--my_img_bit", default=0, type=int) - parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip", default="x", type=str) parser.add_argument("--my_img_clip_scale", default=1, type=float) parser.add_argument("--my_img_l1_scale", default=0, type=float) - parser.add_argument("--my_img_encoder", default='x', type=str) + parser.add_argument("--my_img_encoder", default="x", type=str) # parser.add_argument("--my_img_noise_scale", default=0, type=float) parser.add_argument("--my_sample_len", default=0, type=int) parser.add_argument("--my_ffn_shift", default=1, type=int) @@ -107,7 +145,7 @@ parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int) - parser.add_argument("--my_testing", default='', type=str) + parser.add_argument("--my_testing", default="", type=str) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -118,18 +156,26 @@ import numpy as np import torch from torch.utils.data import DataLoader + if "deepspeed" in args.strategy: import deepspeed import pytorch_lightning as pl from pytorch_lightning import seed_everything if args.random_seed >= 0: - print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + print( + f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" + * 3 + ) seed_everything(args.random_seed) np.set_printoptions(precision=4, suppress=True, linewidth=200) - warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") - warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + warnings.filterwarnings( + "ignore", ".*Consider increasing the value of the `num_workers` argument*" + ) + warnings.filterwarnings( + "ignore", ".*The progress bar already tracks a metric with the*" + ) # os.environ["WDS_SHOW_SEED"] = "1" args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") @@ -154,7 +200,9 @@ args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" args.proj_dir = f"{args.proj_dir}-{args.run_name}" else: - args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + args.run_name = ( + f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + ) if not os.path.exists(args.proj_dir): os.makedirs(args.proj_dir) @@ -242,18 +290,32 @@ ) rank_zero_info(str(vars(args)) + "\n") - assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + assert args.data_type in [ + "utf-8", + "utf-16le", + "numpy", + "binidx", + "dummy", + "wds_img", + "uint16", + ] if args.lr_final == 0 or args.lr_init == 0: - rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + rank_zero_info( + "\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n" + ) assert args.precision in ["fp32", "tf32", "fp16", "bf16"] os.environ["RWKV_FLOAT_MODE"] = args.precision if args.precision == "fp32": for i in range(10): - rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n" + ) if args.precision == "fp16": - rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + rank_zero_info( + "\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n" + ) os.environ["RWKV_JIT_ON"] = "1" if "deepspeed_stage_3" in args.strategy: @@ -283,11 +345,13 @@ train_data = MyDataset(args) args.vocab_size = train_data.vocab_size - if args.data_type == 'wds_img': + if args.data_type == "wds_img": from src.model_img import RWKV_IMG + model = RWKV_IMG(args) else: from src.model import RWKV + model = RWKV(args) # if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? @@ -335,10 +399,22 @@ print(f"{str(shape[0]).ljust(5)} {n}") if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = ( + args.ds_bucket_mb * 1000 * 1000 + ) # must set shuffle=False, persistent_workers=False (because worker is in another thread) - data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + data_loader = DataLoader( + train_data, + shuffle=False, + pin_memory=True, + batch_size=args.micro_bsz, + num_workers=1, + persistent_workers=False, + drop_last=True, + ) trainer.fit(model, data_loader) diff --git a/benchmarks/rwkv/rwkv-v4neo/verify.py b/benchmarks/rwkv/rwkv-v4neo/verify.py index 4f56e392f..695e651f2 100644 --- a/benchmarks/rwkv/rwkv-v4neo/verify.py +++ b/benchmarks/rwkv/rwkv-v4neo/verify.py @@ -7,6 +7,7 @@ import os, sys, types import numpy as np import torch + np.set_printoptions(precision=4, suppress=True, linewidth=200) try: os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] @@ -16,23 +17,24 @@ torch.backends.cudnn.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False -os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32 -os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA -RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] +os.environ["RWKV_FLOAT_MODE"] = "bf16" # bf16 or fp32 +os.environ["RWKV_RUN_DEVICE"] = "cuda" # currently model_train requires CUDA +RUN_DEVICE = os.environ["RWKV_RUN_DEVICE"] -TOKEN_MODE = 'pile' +TOKEN_MODE = "pile" -if TOKEN_MODE == 'pile': - WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] - MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' +if TOKEN_MODE == "pile": + WORD_NAME = ["20B_tokenizer.json", "20B_tokenizer.json"] + MODEL_NAME = "/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783" n_layer = 32 n_embd = 2560 ctx_len = 1024 UNKNOWN_CHAR = None from src.utils import TOKENIZER + tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) -if TOKEN_MODE == 'pile': +if TOKEN_MODE == "pile": tokenizer.vocab_size = 50277 ######################################################################################################## @@ -54,23 +56,23 @@ args.my_pos_emb = 0 model_train = RWKV(args).to(RUN_DEVICE) -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': +if os.environ["RWKV_FLOAT_MODE"] == "fp16": model_train = model_train.half() -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': +elif os.environ["RWKV_FLOAT_MODE"] == "bf16": model_train = model_train.bfloat16() -print('loading ' + MODEL_NAME) -m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu') +print("loading " + MODEL_NAME) +m2 = torch.load(MODEL_NAME + ".pth", map_location="cpu") model_train.load_state_dict(m2) -if os.environ['RWKV_FLOAT_MODE'] == 'fp16': +if os.environ["RWKV_FLOAT_MODE"] == "fp16": model_train = model_train.half() -elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': +elif os.environ["RWKV_FLOAT_MODE"] == "bf16": model_train = model_train.bfloat16() args.MODEL_NAME = MODEL_NAME args.RUN_DEVICE = RUN_DEVICE -args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE'] +args.FLOAT_MODE = os.environ["RWKV_FLOAT_MODE"] model_rnn = RWKV_RNN(args) ######################################################################################################## @@ -78,27 +80,33 @@ print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}") # context = '\nIn a' -context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' +context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." -if TOKEN_MODE == 'pile': +if TOKEN_MODE == "pile": ctx = tokenizer.tokenizer.encode(context) -print(f'input len {len(ctx)} data {ctx}') +print(f"input len {len(ctx)} data {ctx}") ######################################################################################################## with torch.no_grad(): - print('\nRWKV-train output') - out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy() - print(out, '\n') - - print('\nRWKV-RNN output') + print("\nRWKV-train output") + out = ( + model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0] + .detach() + .cpu() + .float() + .numpy() + ) + print(out, "\n") + + print("\nRWKV-RNN output") state = None out = None src_len = len(ctx) for i in range(src_len): - x = ctx[:i+1] + x = ctx[: i + 1] out, state = model_rnn.forward(x, state) if i < 3 or i >= src_len - 3: print(out.detach().cpu().numpy()) if i == 2: - print('...') + print("...") diff --git a/benchmarks/stargan/stargan/data_loader.py b/benchmarks/stargan/stargan/data_loader.py index d0c5eacb8..2f79594c6 100644 --- a/benchmarks/stargan/stargan/data_loader.py +++ b/benchmarks/stargan/stargan/data_loader.py @@ -23,14 +23,14 @@ def __init__(self, image_dir, attr_path, selected_attrs, transform, mode): self.idx2attr = {} self.preprocess() - if mode == 'train': + if mode == "train": self.num_images = len(self.train_dataset) else: self.num_images = len(self.test_dataset) def preprocess(self): """Preprocess the CelebA attribute file.""" - lines = [line.rstrip() for line in open(self.attr_path, 'r')] + lines = [line.rstrip() for line in open(self.attr_path, "r")] all_attr_names = lines[1].split() for i, attr_name in enumerate(all_attr_names): self.attr2idx[attr_name] = i @@ -47,18 +47,18 @@ def preprocess(self): label = [] for attr_name in self.selected_attrs: idx = self.attr2idx[attr_name] - label.append(values[idx] == '1') + label.append(values[idx] == "1") - if (i+1) < 2000: + if (i + 1) < 2000: self.test_dataset.append([filename, label]) else: self.train_dataset.append([filename, label]) - print('Finished preprocessing the CelebA dataset...') + print("Finished preprocessing the CelebA dataset...") def __getitem__(self, index): """Return one image and its corresponding attribute label.""" - dataset = self.train_dataset if self.mode == 'train' else self.test_dataset + dataset = self.train_dataset if self.mode == "train" else self.test_dataset filename, label = dataset[index] image = Image.open(os.path.join(self.image_dir, filename)) return self.transform(image), torch.FloatTensor(label) @@ -68,11 +68,20 @@ def __len__(self): return self.num_images -def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=128, - batch_size=16, dataset='CelebA', mode='train', num_workers=1): +def get_loader( + image_dir, + attr_path, + selected_attrs, + crop_size=178, + image_size=128, + batch_size=16, + dataset="CelebA", + mode="train", + num_workers=1, +): """Build and return a data loader.""" transform = [] - if mode == 'train': + if mode == "train": transform.append(T.RandomHorizontalFlip()) transform.append(T.CenterCrop(crop_size)) transform.append(T.Resize(image_size)) @@ -80,13 +89,15 @@ def get_loader(image_dir, attr_path, selected_attrs, crop_size=178, image_size=1 transform.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) transform = T.Compose(transform) - if dataset == 'CelebA': + if dataset == "CelebA": dataset = CelebA(image_dir, attr_path, selected_attrs, transform, mode) - elif dataset == 'RaFD': + elif dataset == "RaFD": dataset = ImageFolder(image_dir, transform) - data_loader = data.DataLoader(dataset=dataset, - batch_size=batch_size, - shuffle=(mode=='train'), - num_workers=num_workers) - return data_loader \ No newline at end of file + data_loader = data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=(mode == "train"), + num_workers=num_workers, + ) + return data_loader diff --git a/benchmarks/stargan/stargan/logger.py b/benchmarks/stargan/stargan/logger.py index f30431e8b..ffed8a260 100644 --- a/benchmarks/stargan/stargan/logger.py +++ b/benchmarks/stargan/stargan/logger.py @@ -11,4 +11,4 @@ def __init__(self, log_dir): def scalar_summary(self, tag, value, step): """Add scalar summary.""" summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) - self.writer.add_summary(summary, step) \ No newline at end of file + self.writer.add_summary(summary, step) diff --git a/benchmarks/stargan/stargan/main.py b/benchmarks/stargan/stargan/main.py index 754c7efa2..d7b411fdc 100644 --- a/benchmarks/stargan/stargan/main.py +++ b/benchmarks/stargan/stargan/main.py @@ -9,9 +9,9 @@ from torch.utils.data import DataLoader - def str2bool(v): - return v.lower() in ('true') + return v.lower() in ("true") + def main(config): # For fast training. @@ -28,15 +28,32 @@ def main(config): rafd_loader = None synth_loader = None - if config.dataset in ['CelebA', 'Both']: - celeba_loader = get_loader(config.celeba_image_dir, config.attr_path, config.selected_attrs, - config.celeba_crop_size, config.image_size, config.batch_size, - 'CelebA', config.mode, config.num_workers) - if config.dataset in ['RaFD', 'Both']: - rafd_loader = get_loader(config.rafd_image_dir, None, None, - config.rafd_crop_size, config.image_size, config.batch_size, - 'RaFD', config.mode, config.num_workers) + if config.dataset in ["CelebA", "Both"]: + celeba_loader = get_loader( + config.celeba_image_dir, + config.attr_path, + config.selected_attrs, + config.celeba_crop_size, + config.image_size, + config.batch_size, + "CelebA", + config.mode, + config.num_workers, + ) + if config.dataset in ["RaFD", "Both"]: + rafd_loader = get_loader( + config.rafd_image_dir, + None, + None, + config.rafd_crop_size, + config.image_size, + config.batch_size, + "RaFD", + config.mode, + config.num_workers, + ) if config.dataset == "synth": + def igen(): return torch.rand((3, config.image_size, config.image_size)) * 2 - 1 @@ -48,81 +65,158 @@ def ogen(): n=config.batch_size, repeat=10000, ) - synth_loader = DataLoader(synth_dataset, batch_size=config.batch_size, num_workers=config.num_workers) - + synth_loader = DataLoader( + synth_dataset, batch_size=config.batch_size, num_workers=config.num_workers + ) # Solver for training and testing StarGAN. solver = Solver(celeba_loader, rafd_loader, synth_loader, config) - if config.mode == 'train': - if config.dataset in ['CelebA', 'RaFD', 'synth']: + if config.mode == "train": + if config.dataset in ["CelebA", "RaFD", "synth"]: solver.train() - elif config.dataset in ['Both']: + elif config.dataset in ["Both"]: solver.train_multi() - elif config.mode == 'test': - if config.dataset in ['CelebA', 'RaFD', 'synth']: + elif config.mode == "test": + if config.dataset in ["CelebA", "RaFD", "synth"]: solver.test() - elif config.dataset in ['Both']: + elif config.dataset in ["Both"]: solver.test_multi() -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() # Model configuration. - parser.add_argument('--c_dim', type=int, default=5, help='dimension of domain labels (1st dataset)') - parser.add_argument('--c2_dim', type=int, default=8, help='dimension of domain labels (2nd dataset)') - parser.add_argument('--celeba_crop_size', type=int, default=178, help='crop size for the CelebA dataset') - parser.add_argument('--rafd_crop_size', type=int, default=256, help='crop size for the RaFD dataset') - parser.add_argument('--image_size', type=int, default=128, help='image resolution') - parser.add_argument('--g_conv_dim', type=int, default=64, help='number of conv filters in the first layer of G') - parser.add_argument('--d_conv_dim', type=int, default=64, help='number of conv filters in the first layer of D') - parser.add_argument('--g_repeat_num', type=int, default=6, help='number of residual blocks in G') - parser.add_argument('--d_repeat_num', type=int, default=6, help='number of strided conv layers in D') - parser.add_argument('--lambda_cls', type=float, default=1, help='weight for domain classification loss') - parser.add_argument('--lambda_rec', type=float, default=10, help='weight for reconstruction loss') - parser.add_argument('--lambda_gp', type=float, default=10, help='weight for gradient penalty') - + parser.add_argument( + "--c_dim", type=int, default=5, help="dimension of domain labels (1st dataset)" + ) + parser.add_argument( + "--c2_dim", type=int, default=8, help="dimension of domain labels (2nd dataset)" + ) + parser.add_argument( + "--celeba_crop_size", + type=int, + default=178, + help="crop size for the CelebA dataset", + ) + parser.add_argument( + "--rafd_crop_size", type=int, default=256, help="crop size for the RaFD dataset" + ) + parser.add_argument("--image_size", type=int, default=128, help="image resolution") + parser.add_argument( + "--g_conv_dim", + type=int, + default=64, + help="number of conv filters in the first layer of G", + ) + parser.add_argument( + "--d_conv_dim", + type=int, + default=64, + help="number of conv filters in the first layer of D", + ) + parser.add_argument( + "--g_repeat_num", type=int, default=6, help="number of residual blocks in G" + ) + parser.add_argument( + "--d_repeat_num", type=int, default=6, help="number of strided conv layers in D" + ) + parser.add_argument( + "--lambda_cls", + type=float, + default=1, + help="weight for domain classification loss", + ) + parser.add_argument( + "--lambda_rec", type=float, default=10, help="weight for reconstruction loss" + ) + parser.add_argument( + "--lambda_gp", type=float, default=10, help="weight for gradient penalty" + ) + # Training configuration. - parser.add_argument('--dataset', type=str, default='synth', choices=['CelebA', 'RaFD', 'Both', 'synth']) - parser.add_argument('--batch_size', type=int, default=16, help='mini-batch size') - parser.add_argument('--num_iters', type=int, default=200000, help='number of total iterations for training D') - parser.add_argument('--num_iters_decay', type=int, default=100000, help='number of iterations for decaying lr') - parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for G') - parser.add_argument('--d_lr', type=float, default=0.0001, help='learning rate for D') - parser.add_argument('--n_critic', type=int, default=5, help='number of D updates per each G update') - parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') - parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') - parser.add_argument('--resume_iters', type=int, default=None, help='resume training from this step') - parser.add_argument('--selected_attrs', '--list', nargs='+', help='selected attributes for the CelebA dataset', - default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) + parser.add_argument( + "--dataset", + type=str, + default="synth", + choices=["CelebA", "RaFD", "Both", "synth"], + ) + parser.add_argument("--batch_size", type=int, default=16, help="mini-batch size") + parser.add_argument( + "--num_iters", + type=int, + default=200000, + help="number of total iterations for training D", + ) + parser.add_argument( + "--num_iters_decay", + type=int, + default=100000, + help="number of iterations for decaying lr", + ) + parser.add_argument( + "--g_lr", type=float, default=0.0001, help="learning rate for G" + ) + parser.add_argument( + "--d_lr", type=float, default=0.0001, help="learning rate for D" + ) + parser.add_argument( + "--n_critic", type=int, default=5, help="number of D updates per each G update" + ) + parser.add_argument( + "--beta1", type=float, default=0.5, help="beta1 for Adam optimizer" + ) + parser.add_argument( + "--beta2", type=float, default=0.999, help="beta2 for Adam optimizer" + ) + parser.add_argument( + "--resume_iters", type=int, default=None, help="resume training from this step" + ) + parser.add_argument( + "--selected_attrs", + "--list", + nargs="+", + help="selected attributes for the CelebA dataset", + default=["Black_Hair", "Blond_Hair", "Brown_Hair", "Male", "Young"], + ) # Test configuration. - parser.add_argument('--test_iters', type=int, default=200000, help='test model from this step') + parser.add_argument( + "--test_iters", type=int, default=200000, help="test model from this step" + ) # Miscellaneous. - parser.add_argument('--num_workers', type=int, default=1) - parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) - parser.add_argument('--use_tensorboard', type=str2bool, default=False) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--mode", type=str, default="train", choices=["train", "test"]) + parser.add_argument("--use_tensorboard", type=str2bool, default=False) mbconfig = json.loads(os.environ["MILABENCH_CONFIG"]) datadir = mbconfig["dirs"]["extra"] # Directories. - parser.add_argument('--celeba_image_dir', type=str, default='data/celeba/images') - parser.add_argument('--attr_path', type=str, default='data/celeba/list_attr_celeba.txt') - parser.add_argument('--rafd_image_dir', type=str, default='data/RaFD/train') - parser.add_argument('--log_dir', type=str, default=os.path.join(datadir, 'logs')) - parser.add_argument('--model_save_dir', type=str, default=os.path.join(datadir, 'models')) - parser.add_argument('--sample_dir', type=str, default=os.path.join(datadir, 'samples')) - parser.add_argument('--result_dir', type=str, default=os.path.join(datadir, 'results')) + parser.add_argument("--celeba_image_dir", type=str, default="data/celeba/images") + parser.add_argument( + "--attr_path", type=str, default="data/celeba/list_attr_celeba.txt" + ) + parser.add_argument("--rafd_image_dir", type=str, default="data/RaFD/train") + parser.add_argument("--log_dir", type=str, default=os.path.join(datadir, "logs")) + parser.add_argument( + "--model_save_dir", type=str, default=os.path.join(datadir, "models") + ) + parser.add_argument( + "--sample_dir", type=str, default=os.path.join(datadir, "samples") + ) + parser.add_argument( + "--result_dir", type=str, default=os.path.join(datadir, "results") + ) # Step size. - parser.add_argument('--log_step', type=int, default=10) - parser.add_argument('--sample_step', type=int, default=1000) - parser.add_argument('--model_save_step', type=int, default=10000) - parser.add_argument('--lr_update_step', type=int, default=1000) + parser.add_argument("--log_step", type=int, default=10) + parser.add_argument("--sample_step", type=int, default=1000) + parser.add_argument("--model_save_step", type=int, default=10000) + parser.add_argument("--lr_update_step", type=int, default=1000) config = parser.parse_args() print(config) - main(config) \ No newline at end of file + main(config) diff --git a/benchmarks/stargan/stargan/model.py b/benchmarks/stargan/stargan/model.py index 3d0e62755..a9ecb43e3 100644 --- a/benchmarks/stargan/stargan/model.py +++ b/benchmarks/stargan/stargan/model.py @@ -6,6 +6,7 @@ class ResidualBlock(nn.Module): """Residual Block with instance normalization.""" + def __init__(self, dim_in, dim_out): super(ResidualBlock, self).__init__() self.main = nn.Sequential( @@ -13,7 +14,8 @@ def __init__(self, dim_in, dim_out): nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), nn.ReLU(inplace=True), nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True)) + nn.InstanceNorm2d(dim_out, affine=True, track_running_stats=True), + ) def forward(self, x): return x + self.main(x) @@ -21,19 +23,37 @@ def forward(self, x): class Generator(nn.Module): """Generator network.""" + def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): super(Generator, self).__init__() layers = [] - layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) - layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True)) + layers.append( + nn.Conv2d( + 3 + c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False + ) + ) + layers.append( + nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True) + ) layers.append(nn.ReLU(inplace=True)) # Down-sampling layers. curr_dim = conv_dim for i in range(2): - layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) - layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True)) + layers.append( + nn.Conv2d( + curr_dim, + curr_dim * 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ) + ) + layers.append( + nn.InstanceNorm2d(curr_dim * 2, affine=True, track_running_stats=True) + ) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim * 2 @@ -43,12 +63,25 @@ def __init__(self, conv_dim=64, c_dim=5, repeat_num=6): # Up-sampling layers. for i in range(2): - layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) - layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True)) + layers.append( + nn.ConvTranspose2d( + curr_dim, + curr_dim // 2, + kernel_size=4, + stride=2, + padding=1, + bias=False, + ) + ) + layers.append( + nn.InstanceNorm2d(curr_dim // 2, affine=True, track_running_stats=True) + ) layers.append(nn.ReLU(inplace=True)) curr_dim = curr_dim // 2 - layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) + layers.append( + nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False) + ) layers.append(nn.Tanh()) self.main = nn.Sequential(*layers) @@ -64,6 +97,7 @@ def forward(self, x, c): class Discriminator(nn.Module): """Discriminator network with PatchGAN.""" + def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): super(Discriminator, self).__init__() layers = [] @@ -72,15 +106,19 @@ def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6): curr_dim = conv_dim for i in range(1, repeat_num): - layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) + layers.append( + nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1) + ) layers.append(nn.LeakyReLU(0.01)) curr_dim = curr_dim * 2 kernel_size = int(image_size / np.power(2, repeat_num)) self.main = nn.Sequential(*layers) - self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False) + self.conv1 = nn.Conv2d( + curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False + ) self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=kernel_size, bias=False) - + def forward(self, x): h = self.main(x) out_src = self.conv1(h) diff --git a/benchmarks/stargan/stargan/solver.py b/benchmarks/stargan/stargan/solver.py index 00ee93cd9..d45bb6f9e 100644 --- a/benchmarks/stargan/stargan/solver.py +++ b/benchmarks/stargan/stargan/solver.py @@ -53,7 +53,7 @@ def __init__(self, celeba_loader, rafd_loader, synth_loader, config): # Miscellaneous. self.use_tensorboard = config.use_tensorboard - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Directories. self.log_dir = config.log_dir @@ -74,18 +74,31 @@ def __init__(self, celeba_loader, rafd_loader, synth_loader, config): def build_model(self): """Create a generator and a discriminator.""" - if self.dataset in ['CelebA', 'RaFD', 'synth']: + if self.dataset in ["CelebA", "RaFD", "synth"]: self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num) - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num) - elif self.dataset in ['Both']: - self.G = Generator(self.g_conv_dim, self.c_dim+self.c2_dim+2, self.g_repeat_num) # 2 for mask vector. - self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim+self.c2_dim, self.d_repeat_num) - - self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) - self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2]) - self.print_network(self.G, 'G') - self.print_network(self.D, 'D') - + self.D = Discriminator( + self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num + ) + elif self.dataset in ["Both"]: + self.G = Generator( + self.g_conv_dim, self.c_dim + self.c2_dim + 2, self.g_repeat_num + ) # 2 for mask vector. + self.D = Discriminator( + self.image_size, + self.d_conv_dim, + self.c_dim + self.c2_dim, + self.d_repeat_num, + ) + + self.g_optimizer = torch.optim.Adam( + self.G.parameters(), self.g_lr, [self.beta1, self.beta2] + ) + self.d_optimizer = torch.optim.Adam( + self.D.parameters(), self.d_lr, [self.beta1, self.beta2] + ) + self.print_network(self.G, "G") + self.print_network(self.D, "D") + self.G.to(self.device) self.D.to(self.device) @@ -100,23 +113,28 @@ def print_network(self, model, name): def restore_model(self, resume_iters): """Restore the trained generator and discriminator.""" - print('Loading the trained models from step {}...'.format(resume_iters)) - G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters)) - D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters)) - self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage)) - self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage)) + print("Loading the trained models from step {}...".format(resume_iters)) + G_path = os.path.join(self.model_save_dir, "{}-G.ckpt".format(resume_iters)) + D_path = os.path.join(self.model_save_dir, "{}-D.ckpt".format(resume_iters)) + self.G.load_state_dict( + torch.load(G_path, map_location=lambda storage, loc: storage) + ) + self.D.load_state_dict( + torch.load(D_path, map_location=lambda storage, loc: storage) + ) def build_tensorboard(self): """Build a tensorboard logger.""" from logger import Logger + self.logger = Logger(self.log_dir) def update_lr(self, g_lr, d_lr): """Decay learning rates of the generator and discriminator.""" for param_group in self.g_optimizer.param_groups: - param_group['lr'] = g_lr + param_group["lr"] = g_lr for param_group in self.d_optimizer.param_groups: - param_group['lr'] = d_lr + param_group["lr"] = d_lr def reset_grad(self): """Reset the gradient buffers.""" @@ -131,16 +149,18 @@ def denorm(self, x): def gradient_penalty(self, y, x): """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2.""" weight = torch.ones(y.size()).to(self.device) - dydx = torch.autograd.grad(outputs=y, - inputs=x, - grad_outputs=weight, - retain_graph=True, - create_graph=True, - only_inputs=True)[0] + dydx = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=weight, + retain_graph=True, + create_graph=True, + only_inputs=True, + )[0] dydx = dydx.view(dydx.size(0), -1) dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1)) - return torch.mean((dydx_l2norm-1)**2) + return torch.mean((dydx_l2norm - 1) ** 2) def label2onehot(self, labels, dim): """Convert label indices to one-hot vectors.""" @@ -149,54 +169,60 @@ def label2onehot(self, labels, dim): out[np.arange(batch_size), labels.long()] = 1 return out - def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): + def create_labels(self, c_org, c_dim=5, dataset="CelebA", selected_attrs=None): """Generate target domain labels for debugging and testing.""" # Get hair color indices. - if dataset == 'CelebA': + if dataset == "CelebA": hair_color_indices = [] for i, attr_name in enumerate(selected_attrs): - if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: + if attr_name in ["Black_Hair", "Blond_Hair", "Brown_Hair", "Gray_Hair"]: hair_color_indices.append(i) c_trg_list = [] for i in range(c_dim): - if dataset == 'CelebA': + if dataset == "CelebA": c_trg = c_org.clone() - if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. + if ( + i in hair_color_indices + ): # Set one hair color to 1 and the rest to 0. c_trg[:, i] = 1 for j in hair_color_indices: if j != i: c_trg[:, j] = 0 else: - c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. - elif dataset == 'RaFD' or dataset == "synth": - c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) + c_trg[:, i] = c_trg[:, i] == 0 # Reverse attribute value. + elif dataset == "RaFD" or dataset == "synth": + c_trg = self.label2onehot(torch.ones(c_org.size(0)) * i, c_dim) c_trg_list.append(c_trg.to(self.device)) return c_trg_list - def classification_loss(self, logit, target, dataset='CelebA'): + def classification_loss(self, logit, target, dataset="CelebA"): """Compute binary or softmax cross entropy loss.""" - if dataset == 'CelebA' or dataset == "synth": - return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0) - elif dataset == 'RaFD': + if dataset == "CelebA" or dataset == "synth": + return F.binary_cross_entropy_with_logits( + logit, target, size_average=False + ) / logit.size(0) + elif dataset == "RaFD": return F.cross_entropy(logit, target) def train(self): """Train StarGAN within a single dataset.""" # Set data loader. - if self.dataset == 'CelebA': + if self.dataset == "CelebA": data_loader = self.celeba_loader - elif self.dataset == 'RaFD': + elif self.dataset == "RaFD": data_loader = self.rafd_loader - elif self.dataset == 'synth': + elif self.dataset == "synth": data_loader = self.synth_loader # Fetch fixed inputs for debugging. data_iter = voir.iterate("train", data_loader, report_batch=True) x_fixed, c_org = next(data_iter) x_fixed = x_fixed.to(self.device) - c_fixed_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) + c_fixed_list = self.create_labels( + c_org, self.c_dim, self.dataset, self.selected_attrs + ) # Learning rate cache for decaying. g_lr = self.g_lr @@ -209,10 +235,9 @@ def train(self): self.restore_model(self.resume_iters) # Start training. - print('Start training...') + print("Start training...") start_time = time.time() for i in range(start_iters, self.num_iters): - # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # @@ -228,18 +253,22 @@ def train(self): rand_idx = torch.randperm(label_org.size(0)) label_trg = label_org[rand_idx] - if self.dataset == 'CelebA' or self.dataset == 'synth': + if self.dataset == "CelebA" or self.dataset == "synth": c_org = label_org.clone() c_trg = label_trg.clone() - elif self.dataset == 'RaFD': + elif self.dataset == "RaFD": c_org = self.label2onehot(label_org, self.c_dim) c_trg = self.label2onehot(label_trg, self.c_dim) - x_real = x_real.to(self.device) # Input images. - c_org = c_org.to(self.device) # Original domain labels. - c_trg = c_trg.to(self.device) # Target domain labels. - label_org = label_org.to(self.device) # Labels for computing classification loss. - label_trg = label_trg.to(self.device) # Labels for computing classification loss. + x_real = x_real.to(self.device) # Input images. + c_org = c_org.to(self.device) # Original domain labels. + c_trg = c_trg.to(self.device) # Target domain labels. + label_org = label_org.to( + self.device + ) # Labels for computing classification loss. + label_trg = label_trg.to( + self.device + ) # Labels for computing classification loss. # =================================================================================== # # 2. Train the discriminator # @@ -247,7 +276,7 @@ def train(self): # Compute loss with real images. out_src, out_cls = self.D(x_real) - d_loss_real = - torch.mean(out_src) + d_loss_real = -torch.mean(out_src) d_loss_cls = self.classification_loss(out_cls, label_org, self.dataset) # Compute loss with fake images. @@ -257,12 +286,19 @@ def train(self): # Compute loss for gradient penalty. alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) - x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) + x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_( + True + ) out_src, _ = self.D(x_hat) d_loss_gp = self.gradient_penalty(out_src, x_hat) # Backward and optimize. - d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp + d_loss = ( + d_loss_real + + d_loss_fake + + self.lambda_cls * d_loss_cls + + self.lambda_gp * d_loss_gp + ) give(task="train", loss=d_loss.item()) self.reset_grad() d_loss.backward() @@ -270,20 +306,20 @@ def train(self): # Logging. loss = {} - loss['D/loss_real'] = d_loss_real.item() - loss['D/loss_fake'] = d_loss_fake.item() - loss['D/loss_cls'] = d_loss_cls.item() - loss['D/loss_gp'] = d_loss_gp.item() - + loss["D/loss_real"] = d_loss_real.item() + loss["D/loss_fake"] = d_loss_fake.item() + loss["D/loss_cls"] = d_loss_cls.item() + loss["D/loss_gp"] = d_loss_gp.item() + # =================================================================================== # # 3. Train the generator # # =================================================================================== # - - if (i+1) % self.n_critic == 0: + + if (i + 1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, c_trg) out_src, out_cls = self.D(x_fake) - g_loss_fake = - torch.mean(out_src) + g_loss_fake = -torch.mean(out_src) g_loss_cls = self.classification_loss(out_cls, label_trg, self.dataset) # Target-to-original domain. @@ -291,61 +327,73 @@ def train(self): g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # Backward and optimize. - g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls + g_loss = ( + g_loss_fake + + self.lambda_rec * g_loss_rec + + self.lambda_cls * g_loss_cls + ) self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. - loss['G/loss_fake'] = g_loss_fake.item() - loss['G/loss_rec'] = g_loss_rec.item() - loss['G/loss_cls'] = g_loss_cls.item() + loss["G/loss_fake"] = g_loss_fake.item() + loss["G/loss_rec"] = g_loss_rec.item() + loss["G/loss_cls"] = g_loss_cls.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training information. - if (i+1) % self.log_step == 0: + if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] - log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters) + log = "Elapsed [{}], Iteration [{}/{}]".format( + et, i + 1, self.num_iters + ) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): - self.logger.scalar_summary(tag, value, i+1) + self.logger.scalar_summary(tag, value, i + 1) # Translate fixed images for debugging. - if (i+1) % self.sample_step == 0: + if (i + 1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] for c_fixed in c_fixed_list: x_fake_list.append(self.G(x_fixed, c_fixed)) x_concat = torch.cat(x_fake_list, dim=3) - sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(sample_path)) + sample_path = os.path.join( + self.sample_dir, "{}-images.jpg".format(i + 1) + ) + save_image( + self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0 + ) + print("Saved real and fake images into {}...".format(sample_path)) # Save model checkpoints. - if (i+1) % self.model_save_step == 0: - G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) - D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + if (i + 1) % self.model_save_step == 0: + G_path = os.path.join(self.model_save_dir, "{}-G.ckpt".format(i + 1)) + D_path = os.path.join(self.model_save_dir, "{}-D.ckpt".format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) - print('Saved model checkpoints into {}...'.format(self.model_save_dir)) + print("Saved model checkpoints into {}...".format(self.model_save_dir)) # Decay learning rates. - if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): - g_lr -= (self.g_lr / float(self.num_iters_decay)) - d_lr -= (self.d_lr / float(self.num_iters_decay)) + if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( + self.num_iters - self.num_iters_decay + ): + g_lr -= self.g_lr / float(self.num_iters_decay) + d_lr -= self.d_lr / float(self.num_iters_decay) self.update_lr(g_lr, d_lr) - print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + print("Decayed learning rates, g_lr: {}, d_lr: {}.".format(g_lr, d_lr)) def train_multi(self): - """Train StarGAN with multiple datasets.""" + """Train StarGAN with multiple datasets.""" # Data iterators. celeba_iter = iter(self.celeba_loader) rafd_iter = iter(self.rafd_loader) @@ -353,12 +401,22 @@ def train_multi(self): # Fetch fixed inputs for debugging. x_fixed, c_org = next(celeba_iter) x_fixed = x_fixed.to(self.device) - c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) - c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') - zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. - zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. - mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to(self.device) # Mask vector: [1, 0]. - mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to(self.device) # Mask vector: [0, 1]. + c_celeba_list = self.create_labels( + c_org, self.c_dim, "CelebA", self.selected_attrs + ) + c_rafd_list = self.create_labels(c_org, self.c2_dim, "RaFD") + zero_celeba = torch.zeros(x_fixed.size(0), self.c_dim).to( + self.device + ) # Zero vector for CelebA. + zero_rafd = torch.zeros(x_fixed.size(0), self.c2_dim).to( + self.device + ) # Zero vector for RaFD. + mask_celeba = self.label2onehot(torch.zeros(x_fixed.size(0)), 2).to( + self.device + ) # Mask vector: [1, 0]. + mask_rafd = self.label2onehot(torch.ones(x_fixed.size(0)), 2).to( + self.device + ) # Mask vector: [0, 1]. # Learning rate cache for decaying. g_lr = self.g_lr @@ -371,25 +429,24 @@ def train_multi(self): self.restore_model(self.resume_iters) # Start training. - print('Start training...') + print("Start training...") start_time = time.time() for i in range(start_iters, self.num_iters): - for dataset in ['CelebA', 'RaFD']: - + for dataset in ["CelebA", "RaFD"]: # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # - + # Fetch real images and labels. - data_iter = celeba_iter if dataset == 'CelebA' else rafd_iter - + data_iter = celeba_iter if dataset == "CelebA" else rafd_iter + try: x_real, label_org = next(data_iter) except: - if dataset == 'CelebA': + if dataset == "CelebA": celeba_iter = iter(self.celeba_loader) x_real, label_org = next(celeba_iter) - elif dataset == 'RaFD': + elif dataset == "RaFD": rafd_iter = iter(self.rafd_loader) x_real, label_org = next(rafd_iter) @@ -397,14 +454,14 @@ def train_multi(self): rand_idx = torch.randperm(label_org.size(0)) label_trg = label_org[rand_idx] - if dataset == 'CelebA': + if dataset == "CelebA": c_org = label_org.clone() c_trg = label_trg.clone() zero = torch.zeros(x_real.size(0), self.c2_dim) mask = self.label2onehot(torch.zeros(x_real.size(0)), 2) c_org = torch.cat([c_org, zero, mask], dim=1) c_trg = torch.cat([c_trg, zero, mask], dim=1) - elif dataset == 'RaFD': + elif dataset == "RaFD": c_org = self.label2onehot(label_org, self.c2_dim) c_trg = self.label2onehot(label_trg, self.c2_dim) zero = torch.zeros(x_real.size(0), self.c_dim) @@ -412,11 +469,15 @@ def train_multi(self): c_org = torch.cat([zero, c_org, mask], dim=1) c_trg = torch.cat([zero, c_trg, mask], dim=1) - x_real = x_real.to(self.device) # Input images. - c_org = c_org.to(self.device) # Original domain labels. - c_trg = c_trg.to(self.device) # Target domain labels. - label_org = label_org.to(self.device) # Labels for computing classification loss. - label_trg = label_trg.to(self.device) # Labels for computing classification loss. + x_real = x_real.to(self.device) # Input images. + c_org = c_org.to(self.device) # Original domain labels. + c_trg = c_trg.to(self.device) # Target domain labels. + label_org = label_org.to( + self.device + ) # Labels for computing classification loss. + label_trg = label_trg.to( + self.device + ) # Labels for computing classification loss. # =================================================================================== # # 2. Train the discriminator # @@ -424,8 +485,12 @@ def train_multi(self): # Compute loss with real images. out_src, out_cls = self.D(x_real) - out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] - d_loss_real = - torch.mean(out_src) + out_cls = ( + out_cls[:, : self.c_dim] + if dataset == "CelebA" + else out_cls[:, self.c_dim :] + ) + d_loss_real = -torch.mean(out_src) d_loss_cls = self.classification_loss(out_cls, label_org, dataset) # Compute loss with fake images. @@ -435,33 +500,44 @@ def train_multi(self): # Compute loss for gradient penalty. alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device) - x_hat = (alpha * x_real.data + (1 - alpha) * x_fake.data).requires_grad_(True) + x_hat = ( + alpha * x_real.data + (1 - alpha) * x_fake.data + ).requires_grad_(True) out_src, _ = self.D(x_hat) d_loss_gp = self.gradient_penalty(out_src, x_hat) # Backward and optimize. - d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls + self.lambda_gp * d_loss_gp + d_loss = ( + d_loss_real + + d_loss_fake + + self.lambda_cls * d_loss_cls + + self.lambda_gp * d_loss_gp + ) self.reset_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss = {} - loss['D/loss_real'] = d_loss_real.item() - loss['D/loss_fake'] = d_loss_fake.item() - loss['D/loss_cls'] = d_loss_cls.item() - loss['D/loss_gp'] = d_loss_gp.item() - + loss["D/loss_real"] = d_loss_real.item() + loss["D/loss_fake"] = d_loss_fake.item() + loss["D/loss_cls"] = d_loss_cls.item() + loss["D/loss_gp"] = d_loss_gp.item() + # =================================================================================== # # 3. Train the generator # # =================================================================================== # - if (i+1) % self.n_critic == 0: + if (i + 1) % self.n_critic == 0: # Original-to-target domain. x_fake = self.G(x_real, c_trg) out_src, out_cls = self.D(x_fake) - out_cls = out_cls[:, :self.c_dim] if dataset == 'CelebA' else out_cls[:, self.c_dim:] - g_loss_fake = - torch.mean(out_src) + out_cls = ( + out_cls[:, : self.c_dim] + if dataset == "CelebA" + else out_cls[:, self.c_dim :] + ) + g_loss_fake = -torch.mean(out_src) g_loss_cls = self.classification_loss(out_cls, label_trg, dataset) # Target-to-original domain. @@ -469,35 +545,41 @@ def train_multi(self): g_loss_rec = torch.mean(torch.abs(x_real - x_reconst)) # Backward and optimize. - g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls + g_loss = ( + g_loss_fake + + self.lambda_rec * g_loss_rec + + self.lambda_cls * g_loss_cls + ) self.reset_grad() g_loss.backward() self.g_optimizer.step() # Logging. - loss['G/loss_fake'] = g_loss_fake.item() - loss['G/loss_rec'] = g_loss_rec.item() - loss['G/loss_cls'] = g_loss_cls.item() + loss["G/loss_fake"] = g_loss_fake.item() + loss["G/loss_rec"] = g_loss_rec.item() + loss["G/loss_cls"] = g_loss_cls.item() # =================================================================================== # # 4. Miscellaneous # # =================================================================================== # # Print out training info. - if (i+1) % self.log_step == 0: + if (i + 1) % self.log_step == 0: et = time.time() - start_time et = str(datetime.timedelta(seconds=et))[:-7] - log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format(et, i+1, self.num_iters, dataset) + log = "Elapsed [{}], Iteration [{}/{}], Dataset [{}]".format( + et, i + 1, self.num_iters, dataset + ) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): - self.logger.scalar_summary(tag, value, i+1) + self.logger.scalar_summary(tag, value, i + 1) # Translate fixed images for debugging. - if (i+1) % self.sample_step == 0: + if (i + 1) % self.sample_step == 0: with torch.no_grad(): x_fake_list = [x_fixed] for c_fixed in c_celeba_list: @@ -507,42 +589,49 @@ def train_multi(self): c_trg = torch.cat([zero_celeba, c_fixed, mask_rafd], dim=1) x_fake_list.append(self.G(x_fixed, c_trg)) x_concat = torch.cat(x_fake_list, dim=3) - sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(sample_path)) + sample_path = os.path.join( + self.sample_dir, "{}-images.jpg".format(i + 1) + ) + save_image( + self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0 + ) + print("Saved real and fake images into {}...".format(sample_path)) # Save model checkpoints. - if (i+1) % self.model_save_step == 0: - G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1)) - D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1)) + if (i + 1) % self.model_save_step == 0: + G_path = os.path.join(self.model_save_dir, "{}-G.ckpt".format(i + 1)) + D_path = os.path.join(self.model_save_dir, "{}-D.ckpt".format(i + 1)) torch.save(self.G.state_dict(), G_path) torch.save(self.D.state_dict(), D_path) - print('Saved model checkpoints into {}...'.format(self.model_save_dir)) + print("Saved model checkpoints into {}...".format(self.model_save_dir)) # Decay learning rates. - if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay): - g_lr -= (self.g_lr / float(self.num_iters_decay)) - d_lr -= (self.d_lr / float(self.num_iters_decay)) + if (i + 1) % self.lr_update_step == 0 and (i + 1) > ( + self.num_iters - self.num_iters_decay + ): + g_lr -= self.g_lr / float(self.num_iters_decay) + d_lr -= self.d_lr / float(self.num_iters_decay) self.update_lr(g_lr, d_lr) - print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) + print("Decayed learning rates, g_lr: {}, d_lr: {}.".format(g_lr, d_lr)) def test(self): """Translate images using StarGAN trained on a single dataset.""" # Load the trained generator. self.restore_model(self.test_iters) - + # Set data loader. - if self.dataset == 'CelebA': + if self.dataset == "CelebA": data_loader = self.celeba_loader - elif self.dataset == 'RaFD': + elif self.dataset == "RaFD": data_loader = self.rafd_loader - + with torch.no_grad(): for i, (x_real, c_org) in enumerate(data_loader): - # Prepare input images and target domain labels. x_real = x_real.to(self.device) - c_trg_list = self.create_labels(c_org, self.c_dim, self.dataset, self.selected_attrs) + c_trg_list = self.create_labels( + c_org, self.c_dim, self.dataset, self.selected_attrs + ) # Translate images. x_fake_list = [x_real] @@ -551,26 +640,39 @@ def test(self): # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) - result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(result_path)) + result_path = os.path.join( + self.result_dir, "{}-images.jpg".format(i + 1) + ) + save_image( + self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0 + ) + print("Saved real and fake images into {}...".format(result_path)) def test_multi(self): """Translate images using StarGAN trained on multiple datasets.""" # Load the trained generator. self.restore_model(self.test_iters) - + with torch.no_grad(): for i, (x_real, c_org) in enumerate(self.celeba_loader): - # Prepare input images and target domain labels. x_real = x_real.to(self.device) - c_celeba_list = self.create_labels(c_org, self.c_dim, 'CelebA', self.selected_attrs) - c_rafd_list = self.create_labels(c_org, self.c2_dim, 'RaFD') - zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to(self.device) # Zero vector for CelebA. - zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to(self.device) # Zero vector for RaFD. - mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to(self.device) # Mask vector: [1, 0]. - mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to(self.device) # Mask vector: [0, 1]. + c_celeba_list = self.create_labels( + c_org, self.c_dim, "CelebA", self.selected_attrs + ) + c_rafd_list = self.create_labels(c_org, self.c2_dim, "RaFD") + zero_celeba = torch.zeros(x_real.size(0), self.c_dim).to( + self.device + ) # Zero vector for CelebA. + zero_rafd = torch.zeros(x_real.size(0), self.c2_dim).to( + self.device + ) # Zero vector for RaFD. + mask_celeba = self.label2onehot(torch.zeros(x_real.size(0)), 2).to( + self.device + ) # Mask vector: [1, 0]. + mask_rafd = self.label2onehot(torch.ones(x_real.size(0)), 2).to( + self.device + ) # Mask vector: [0, 1]. # Translate images. x_fake_list = [x_real] @@ -583,6 +685,10 @@ def test_multi(self): # Save the translated images. x_concat = torch.cat(x_fake_list, dim=3) - result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1)) - save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0) - print('Saved real and fake images into {}...'.format(result_path)) \ No newline at end of file + result_path = os.path.join( + self.result_dir, "{}-images.jpg".format(i + 1) + ) + save_image( + self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0 + ) + print("Saved real and fake images into {}...".format(result_path)) diff --git a/benchmarks/super-slomo/slomo/data/create_dataset.py b/benchmarks/super-slomo/slomo/data/create_dataset.py index 29e7eee17..dfde2169c 100644 --- a/benchmarks/super-slomo/slomo/data/create_dataset.py +++ b/benchmarks/super-slomo/slomo/data/create_dataset.py @@ -6,13 +6,33 @@ # For parsing commandline arguments parser = argparse.ArgumentParser() -parser.add_argument("--ffmpeg_dir", type=str, required=True, help='path to ffmpeg.exe') -parser.add_argument("--dataset", type=str, default="custom", help='specify if using "adobe240fps" or custom video dataset') -parser.add_argument("--videos_folder", type=str, required=True, help='path to the folder containing videos') -parser.add_argument("--dataset_folder", type=str, required=True, help='path to the output dataset folder') +parser.add_argument("--ffmpeg_dir", type=str, required=True, help="path to ffmpeg.exe") +parser.add_argument( + "--dataset", + type=str, + default="custom", + help='specify if using "adobe240fps" or custom video dataset', +) +parser.add_argument( + "--videos_folder", + type=str, + required=True, + help="path to the folder containing videos", +) +parser.add_argument( + "--dataset_folder", + type=str, + required=True, + help="path to the output dataset folder", +) parser.add_argument("--img_width", type=int, default=640, help="output image width") parser.add_argument("--img_height", type=int, default=360, help="output image height") -parser.add_argument("--train_test_split", type=tuple, default=(90, 10), help="train test split for custom dataset") +parser.add_argument( + "--train_test_split", + type=tuple, + default=(90, 10), + help="train test split for custom dataset", +) args = parser.parse_args() @@ -34,10 +54,17 @@ def extract_frames(videos, inDir, outDir): None """ - for video in videos: os.mkdir(os.path.join(outDir, os.path.splitext(video)[0])) - retn = os.system('{} -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg'.format(os.path.join(args.ffmpeg_dir, "ffmpeg"), os.path.join(inDir, video), args.img_width, args.img_height, os.path.join(outDir, os.path.splitext(video)[0]))) + retn = os.system( + "{} -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg".format( + os.path.join(args.ffmpeg_dir, "ffmpeg"), + os.path.join(inDir, video), + args.img_width, + args.img_height, + os.path.join(outDir, os.path.splitext(video)[0]), + ) + ) if retn: print("Error converting file:{}. Exiting.".format(video)) @@ -59,7 +86,6 @@ def create_clips(root, destination): None """ - folderCounter = -1 files = os.listdir(root) @@ -70,36 +96,40 @@ def create_clips(root, destination): for imageCounter, image in enumerate(images): # Bunch images in groups of 12 frames - if (imageCounter % 12 == 0): - if (imageCounter + 11 >= len(images)): + if imageCounter % 12 == 0: + if imageCounter + 11 >= len(images): break folderCounter += 1 os.mkdir("{}/{}".format(destination, folderCounter)) - move("{}/{}/{}".format(root, file, image), "{}/{}/{}".format(destination, folderCounter, image)) + move( + "{}/{}/{}".format(root, file, image), + "{}/{}/{}".format(destination, folderCounter, image), + ) rmtree(os.path.join(root, file)) + def main(): # Create dataset folder if it doesn't exist already. if not os.path.isdir(args.dataset_folder): os.mkdir(args.dataset_folder) - extractPath = os.path.join(args.dataset_folder, "extracted") - trainPath = os.path.join(args.dataset_folder, "train") - testPath = os.path.join(args.dataset_folder, "test") - validationPath = os.path.join(args.dataset_folder, "validation") + extractPath = os.path.join(args.dataset_folder, "extracted") + trainPath = os.path.join(args.dataset_folder, "train") + testPath = os.path.join(args.dataset_folder, "test") + validationPath = os.path.join(args.dataset_folder, "validation") os.mkdir(extractPath) os.mkdir(trainPath) os.mkdir(testPath) os.mkdir(validationPath) - if(args.dataset == "adobe240fps"): + if args.dataset == "adobe240fps": f = open("adobe240fps/test_list.txt", "r") - videos = f.read().split('\n') + videos = f.read().split("\n") extract_frames(videos, args.videos_folder, extractPath) create_clips(extractPath, testPath) f = open("adobe240fps/train_list.txt", "r") - videos = f.read().split('\n') + videos = f.read().split("\n") extract_frames(videos, args.videos_folder, extractPath) create_clips(extractPath, trainPath) @@ -109,17 +139,18 @@ def main(): for index in indices: move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index)) - else: # custom dataset - + else: # custom dataset # Extract video names videos = os.listdir(args.videos_folder) # Create random train-test split. - testIndices = random.sample(range(len(videos)), int((args.train_test_split[1] * len(videos)) / 100)) + testIndices = random.sample( + range(len(videos)), int((args.train_test_split[1] * len(videos)) / 100) + ) trainIndices = [x for x in range((len(videos))) if x not in testIndices] # Create list of video names - testVideoNames = [videos[index] for index in testIndices] + testVideoNames = [videos[index] for index in testIndices] trainVideoNames = [videos[index] for index in trainIndices] # Create train-test dataset @@ -130,10 +161,13 @@ def main(): # Select clips at random from test set for validation set. testClips = os.listdir(testPath) - indices = random.sample(range(len(testClips)), min(100, int(len(testClips) / 5))) + indices = random.sample( + range(len(testClips)), min(100, int(len(testClips) / 5)) + ) for index in indices: move("{}/{}".format(testPath, index), "{}/{}".format(validationPath, index)) rmtree(extractPath) + main() diff --git a/benchmarks/super-slomo/slomo/dataloader.py b/benchmarks/super-slomo/slomo/dataloader.py index a008c6f1d..2704e5cca 100644 --- a/benchmarks/super-slomo/slomo/dataloader.py +++ b/benchmarks/super-slomo/slomo/dataloader.py @@ -27,7 +27,6 @@ def _make_dataset(dir): 2D list described above. """ - framesPath = [] # Find and loop over all the clips in root `dir`. for index, folder in enumerate(os.listdir(dir)): @@ -42,6 +41,7 @@ def _make_dataset(dir): framesPath[index].append(os.path.join(clipsFolderPath, image)) return framesPath + def _make_video_dataset(dir): """ Creates a 1D list of all the frames. @@ -60,7 +60,6 @@ def _make_video_dataset(dir): 1D list described above. """ - framesPath = [] # Find and loop over all the frames in root `dir`. for image in sorted(os.listdir(dir)): @@ -68,6 +67,7 @@ def _make_video_dataset(dir): framesPath.append(os.path.join(dir, image)) return framesPath + def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): """ Opens image at `path` using pil and applies data augmentation. @@ -89,19 +89,22 @@ def _pil_loader(path, cropArea=None, resizeDim=None, frameFlip=0): 2D list described above. """ - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) # Resize image if specified. - resized_img = img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img + resized_img = ( + img.resize(resizeDim, Image.ANTIALIAS) if (resizeDim != None) else img + ) # Crop image if crop area specified. cropped_img = img.crop(cropArea) if (cropArea != None) else resized_img # Flip image horizontally if specified. - flipped_img = cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img - return flipped_img.convert('RGB') - - + flipped_img = ( + cropped_img.transpose(Image.FLIP_LEFT_RIGHT) if frameFlip else cropped_img + ) + return flipped_img.convert("RGB") + + class SuperSloMo(data.Dataset): """ A dataloader for loading N samples arranged in this way: @@ -144,8 +147,14 @@ class SuperSloMo(data.Dataset): Returns printable representation of the dataset object. """ - - def __init__(self, root, transform=None, dim=(640, 360), randomCropSize=(352, 352), train=True): + def __init__( + self, + root, + transform=None, + dim=(640, 360), + randomCropSize=(352, 352), + train=True, + ): """ Parameters ---------- @@ -161,27 +170,26 @@ def __init__(self, root, transform=None, dim=(640, 360), randomCropSize=(352, 35 Dimensions of random crop to be applied. Default: (352, 352) train : boolean, optional Specifies if the dataset is for training or testing/validation. - `True` returns samples with data augmentation like random + `True` returns samples with data augmentation like random flipping, random cropping, etc. while `False` returns the samples without randomization. Default: True """ - # Populate the list with image paths for all the # frame in `root`. framesPath = _make_dataset(root) # Raise error if no images found in root. if len(framesPath) == 0: - raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n")) - + raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n")) + self.randomCropSize = randomCropSize - self.cropX0 = dim[0] - randomCropSize[0] - self.cropY0 = dim[1] - randomCropSize[1] - self.root = root - self.transform = transform - self.train = train + self.cropX0 = dim[0] - randomCropSize[0] + self.cropY0 = dim[1] - randomCropSize[1] + self.root = root + self.transform = transform + self.train = train - self.framesPath = framesPath + self.framesPath = framesPath def __getitem__(self, index): """ @@ -199,28 +207,32 @@ def __getitem__(self, index): Returns ------- tuple - (sample, returnIndex) where sample is - [I0, intermediate_frame, I1] and returnIndex is - the position of `random_intermediate_frame`. + (sample, returnIndex) where sample is + [I0, intermediate_frame, I1] and returnIndex is + the position of `random_intermediate_frame`. e.g.- `returnIndex` of frame next to I0 would be 0 and frame before I1 would be 6. """ - sample = [] - - if (self.train): + + if self.train: ### Data Augmentation ### # To select random 9 frames from 12 frames in a clip firstFrame = random.randint(0, 3) # Apply random crop on the 9 input frames cropX = random.randint(0, self.cropX0) cropY = random.randint(0, self.cropY0) - cropArea = (cropX, cropY, cropX + self.randomCropSize[0], cropY + self.randomCropSize[1]) + cropArea = ( + cropX, + cropY, + cropX + self.randomCropSize[0], + cropY + self.randomCropSize[1], + ) # Random reverse frame - #frameRange = range(firstFrame, firstFrame + 9) if (random.randint(0, 1)) else range(firstFrame + 8, firstFrame - 1, -1) + # frameRange = range(firstFrame, firstFrame + 9) if (random.randint(0, 1)) else range(firstFrame + 8, firstFrame - 1, -1) IFrameIndex = random.randint(firstFrame + 1, firstFrame + 7) - if (random.randint(0, 1)): + if random.randint(0, 1): frameRange = [firstFrame, IFrameIndex, firstFrame + 8] returnIndex = IFrameIndex - firstFrame - 1 else: @@ -233,22 +245,25 @@ def __getitem__(self, index): # For validation/test sets. firstFrame = 0 cropArea = (0, 0, self.randomCropSize[0], self.randomCropSize[1]) - IFrameIndex = ((index) % 7 + 1) + IFrameIndex = (index) % 7 + 1 returnIndex = IFrameIndex - 1 frameRange = [0, IFrameIndex, 8] randomFrameFlip = 0 - + # Loop over for all frames corresponding to the `index`. for frameIndex in frameRange: # Open image using pil and augment the image. - image = _pil_loader(self.framesPath[index][frameIndex], cropArea=cropArea, frameFlip=randomFrameFlip) + image = _pil_loader( + self.framesPath[index][frameIndex], + cropArea=cropArea, + frameFlip=randomFrameFlip, + ) # Apply transformation if specified. if self.transform is not None: image = self.transform(image) sample.append(image) - - return sample, returnIndex + return sample, returnIndex def __len__(self): """ @@ -260,7 +275,6 @@ def __len__(self): number of samples. """ - return len(self.framesPath) def __repr__(self): @@ -273,14 +287,16 @@ def __repr__(self): info. """ - - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + fmt_str = "Dataset " + self.__class__.__name__ + "\n" + fmt_str += " Number of datapoints: {}\n".format(self.__len__()) + fmt_str += " Root Location: {}\n".format(self.root) + tmp = " Transforms (if any): " + fmt_str += "{0}{1}\n".format( + tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) + ) return fmt_str - + + class UCI101Test(data.Dataset): """ A dataloader for loading N samples arranged in this way: @@ -317,7 +333,6 @@ class UCI101Test(data.Dataset): Returns printable representation of the dataset object. """ - def __init__(self, root, transform=None): """ Parameters @@ -330,17 +345,16 @@ def __init__(self, root, transform=None): E.g, ``transforms.RandomCrop`` for images. """ - # Populate the list with image paths for all the # frame in `root`. framesPath = _make_dataset(root) # Raise error if no images found in root. if len(framesPath) == 0: - raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n")) + raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n")) - self.root = root - self.framesPath = framesPath - self.transform = transform + self.root = root + self.framesPath = framesPath + self.transform = transform def __getitem__(self, index): """ @@ -357,15 +371,14 @@ def __getitem__(self, index): Returns ------- tuple - (sample, returnIndex) where sample is - [I0, intermediate_frame, I1] and returnIndex is + (sample, returnIndex) where sample is + [I0, intermediate_frame, I1] and returnIndex is the position of `intermediate_frame`. The returnIndex is always 3 and is being returned to maintain compatibility with the `SuperSloMo` dataloader where 3 corresponds to the middle frame. """ - sample = [] # Loop over for all frames corresponding to the `index`. for framePath in self.framesPath[index]: @@ -377,7 +390,6 @@ def __getitem__(self, index): sample.append(image) return sample, 3 - def __len__(self): """ Returns the size of dataset. Invoked as len(datasetObj). @@ -388,7 +400,6 @@ def __len__(self): number of samples. """ - return len(self.framesPath) def __repr__(self): @@ -401,14 +412,16 @@ def __repr__(self): info. """ - - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + fmt_str = "Dataset " + self.__class__.__name__ + "\n" + fmt_str += " Number of datapoints: {}\n".format(self.__len__()) + fmt_str += " Root Location: {}\n".format(self.root) + tmp = " Transforms (if any): " + fmt_str += "{0}{1}\n".format( + tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) + ) return fmt_str + class Video(data.Dataset): """ A dataloader for loading all video frames in a folder: @@ -440,7 +453,6 @@ class Video(data.Dataset): Returns printable representation of the dataset object. """ - def __init__(self, root, transform=None): """ Parameters @@ -453,23 +465,22 @@ def __init__(self, root, transform=None): E.g, ``transforms.RandomCrop`` for images. """ - # Populate the list with image paths for all the # frame in `root`. framesPath = _make_video_dataset(root) # Get dimensions of frames - frame = _pil_loader(framesPath[0]) + frame = _pil_loader(framesPath[0]) self.origDim = frame.size - self.dim = int(self.origDim[0] / 32) * 32, int(self.origDim[1] / 32) * 32 + self.dim = int(self.origDim[0] / 32) * 32, int(self.origDim[1] / 32) * 32 # Raise error if no images found in root. if len(framesPath) == 0: - raise(RuntimeError("Found 0 files in: " + root + "\n")) + raise (RuntimeError("Found 0 files in: " + root + "\n")) - self.root = root - self.framesPath = framesPath - self.transform = transform + self.root = root + self.framesPath = framesPath + self.transform = transform def __getitem__(self, index): """ @@ -489,7 +500,6 @@ def __getitem__(self, index): `index` and I1 is the next frame. """ - sample = [] # Loop over for all frames corresponding to the `index`. for framePath in [self.framesPath[index], self.framesPath[index + 1]]: @@ -501,7 +511,6 @@ def __getitem__(self, index): sample.append(image) return sample - def __len__(self): """ Returns the size of dataset. Invoked as len(datasetObj). @@ -512,11 +521,10 @@ def __len__(self): number of samples. """ - # Using `-1` so that dataloader accesses only upto # frames [N-1, N] and not [N, N+1] which because frame # N+1 doesn't exist. - return len(self.framesPath) - 1 + return len(self.framesPath) - 1 def __repr__(self): """ @@ -528,10 +536,11 @@ def __repr__(self): info. """ - - fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' - fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) - fmt_str += ' Root Location: {}\n'.format(self.root) - tmp = ' Transforms (if any): ' - fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) - return fmt_str \ No newline at end of file + fmt_str = "Dataset " + self.__class__.__name__ + "\n" + fmt_str += " Number of datapoints: {}\n".format(self.__len__()) + fmt_str += " Root Location: {}\n".format(self.root) + tmp = " Transforms (if any): " + fmt_str += "{0}{1}\n".format( + tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) + ) + return fmt_str diff --git a/benchmarks/super-slomo/slomo/eval.py b/benchmarks/super-slomo/slomo/eval.py index 1c3cb9801..fd3273021 100644 --- a/benchmarks/super-slomo/slomo/eval.py +++ b/benchmarks/super-slomo/slomo/eval.py @@ -21,8 +21,12 @@ mean = [0.429, 0.431, 0.397] mea0 = [-m for m in mean] std = [1] * 3 - trans_forward = transforms.Compose([trans_forward, transforms.Normalize(mean=mean, std=std)]) - trans_backward = transforms.Compose([transforms.Normalize(mean=mea0, std=std), trans_backward]) + trans_forward = transforms.Compose( + [trans_forward, transforms.Normalize(mean=mean, std=std)] + ) + trans_backward = transforms.Compose( + [transforms.Normalize(mean=mea0, std=std), trans_backward] + ) flow = model.UNet(6, 4).to(device) interp = model.UNet(20, 5).to(device) @@ -36,9 +40,9 @@ def setup_back_warp(w, h): def load_models(checkpoint): - states = torch.load(checkpoint, map_location='cpu') - interp.load_state_dict(states['state_dictAT']) - flow.load_state_dict(states['state_dictFC']) + states = torch.load(checkpoint, map_location="cpu") + interp.load_state_dict(states["state_dictAT"]) + flow.load_state_dict(states["state_dictFC"]) def interpolate_batch(frames, factor): @@ -78,8 +82,9 @@ def interpolate_batch(frames, factor): co_eff = [1 - t, t] - ft_p = (co_eff[0] * vt0 * gi0ft0f + co_eff[1] * vt1 * gi1ft1f) / \ - (co_eff[0] * vt0 + co_eff[1] * vt1) + ft_p = (co_eff[0] * vt0 * gi0ft0f + co_eff[1] * vt1 * gi1ft1f) / ( + co_eff[0] * vt0 + co_eff[1] * vt1 + ) frame_buffer.append(ft_p) @@ -97,7 +102,7 @@ def load_batch(video_in, batch_size, batch, w, h): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.fromarray(frame) frame = frame.resize((w, h), Image.ANTIALIAS) - frame = frame.convert('RGB') + frame = frame.convert("RGB") frame = trans_forward(frame) batch.append(frame) @@ -108,14 +113,18 @@ def denorm_frame(frame, w0, h0): frame = frame.cpu() frame = trans_backward(frame) frame = frame.resize((w0, h0), Image.BILINEAR) - frame = frame.convert('RGB') + frame = frame.convert("RGB") return np.array(frame)[:, :, ::-1].copy() -def convert_video(source, dest, factor, batch_size=10, output_format='mp4v', output_fps=30): +def convert_video( + source, dest, factor, batch_size=10, output_format="mp4v", output_fps=30 +): vin = cv2.VideoCapture(source) count = vin.get(cv2.CAP_PROP_FRAME_COUNT) - w0, h0 = int(vin.get(cv2.CAP_PROP_FRAME_WIDTH)), int(vin.get(cv2.CAP_PROP_FRAME_HEIGHT)) + w0, h0 = int(vin.get(cv2.CAP_PROP_FRAME_WIDTH)), int( + vin.get(cv2.CAP_PROP_FRAME_HEIGHT) + ) codec = cv2.VideoWriter_fourcc(*output_format) vout = cv2.VideoWriter(dest, codec, float(output_fps), (w0, h0)) @@ -150,28 +159,34 @@ def convert_video(source, dest, factor, batch_size=10, output_format='mp4v', out vout.release() -@click.command('Evaluate Model by converting a low-FPS video to high-fps') -@click.argument('input') -@click.option('--checkpoint', help='Path to model checkpoint') -@click.option('--output', help='Path to output file to save') -@click.option('--batch', default=2, help='Number of frames to process in single forward pass') -@click.option('--scale', default=4, help='Scale Factor of FPS') -@click.option('--fps', default=30, help='FPS of output video') +@click.command("Evaluate Model by converting a low-FPS video to high-fps") +@click.argument("input") +@click.option("--checkpoint", help="Path to model checkpoint") +@click.option("--output", help="Path to output file to save") +@click.option( + "--batch", default=2, help="Number of frames to process in single forward pass" +) +@click.option("--scale", default=4, help="Scale Factor of FPS") +@click.option("--fps", default=30, help="FPS of output video") def main(input, checkpoint, output, batch, scale, fps): - avg = lambda x, n, x0: (x * n/(n+1) + x0 / (n+1), n+1) + avg = lambda x, n, x0: (x * n / (n + 1) + x0 / (n + 1), n + 1) load_models(checkpoint) t0 = time() n0 = 0 fpx = 0 - for dl, fd, fc in convert_video(input, output, int(scale), int(batch), output_fps=int(fps)): + for dl, fd, fc in convert_video( + input, output, int(scale), int(batch), output_fps=int(fps) + ): fpx, n0 = avg(fpx, n0, dl / (time() - t0)) - prg = int(100*fd/fc) + prg = int(100 * fd / fc) eta = (fc - fd) / fpx - print('\rDone: {:03d}% FPS: {:05.2f} ETA: {:.2f}s'.format(prg, fpx, eta) + ' '*5, end='') + print( + "\rDone: {:03d}% FPS: {:05.2f} ETA: {:.2f}s".format(prg, fpx, eta) + + " " * 5, + end="", + ) t0 = time() -if __name__ == '__main__': +if __name__ == "__main__": main() - - diff --git a/benchmarks/super-slomo/slomo/model.py b/benchmarks/super-slomo/slomo/model.py index bc3e3509a..ef706a2a2 100644 --- a/benchmarks/super-slomo/slomo/model.py +++ b/benchmarks/super-slomo/slomo/model.py @@ -10,9 +10,9 @@ class down(nn.Module): """ A class for creating neural network blocks containing layers: - + Average Pooling --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU - + This is used in the UNet Class to create a UNet like NN architecture. ... @@ -24,7 +24,6 @@ class down(nn.Module): block. """ - def __init__(self, inChannels, outChannels, filterSize): """ Parameters @@ -40,12 +39,23 @@ def __init__(self, inChannels, outChannels, filterSize): a N x N filter. """ - super(down, self).__init__() # Initialize convolutional layers. - self.conv1 = nn.Conv2d(inChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) - self.conv2 = nn.Conv2d(outChannels, outChannels, filterSize, stride=1, padding=int((filterSize - 1) / 2)) - + self.conv1 = nn.Conv2d( + inChannels, + outChannels, + filterSize, + stride=1, + padding=int((filterSize - 1) / 2), + ) + self.conv2 = nn.Conv2d( + outChannels, + outChannels, + filterSize, + stride=1, + padding=int((filterSize - 1) / 2), + ) + def forward(self, x): """ Returns output tensor after passing input `x` to the neural network @@ -62,21 +72,21 @@ def forward(self, x): output of the NN block. """ - # Average pooling with kernel size 2 (2 x 2). x = F.avg_pool2d(x, 2) # Convolution + Leaky ReLU - x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) # Convolution + Leaky ReLU - x = F.leaky_relu(self.conv2(x), negative_slope = 0.1) + x = F.leaky_relu(self.conv2(x), negative_slope=0.1) return x - + + class up(nn.Module): """ A class for creating neural network blocks containing layers: - + Bilinear interpolation --> Convlution + Leaky ReLU --> Convolution + Leaky ReLU - + This is used in the UNet Class to create a UNet like NN architecture. ... @@ -88,7 +98,6 @@ class up(nn.Module): block. """ - def __init__(self, inChannels, outChannels): """ Parameters @@ -101,13 +110,12 @@ def __init__(self, inChannels, outChannels): the second convolutional layer. """ - super(up, self).__init__() # Initialize convolutional layers. - self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1) + self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1) # (2 * outChannels) is used for accommodating skip connection. self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1) - + def forward(self, x, skpCn): """ Returns output tensor after passing input `x` to the neural network @@ -127,20 +135,19 @@ def forward(self, x, skpCn): """ # Bilinear interpolation with scaling 2. - x = F.interpolate(x, scale_factor=2, mode='bilinear') + x = F.interpolate(x, scale_factor=2, mode="bilinear") # Convolution + Leaky ReLU - x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) # Convolution + Leaky ReLU on (`x`, `skpCn`) - x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1) + x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1) return x - class UNet(nn.Module): """ A class for creating UNet like architecture as specified by the Super SloMo paper. - + ... Methods @@ -150,7 +157,6 @@ class UNet(nn.Module): block. """ - def __init__(self, inChannels, outChannels): """ Parameters @@ -161,7 +167,6 @@ def __init__(self, inChannels, outChannels): number of output channels for the UNet. """ - super(UNet, self).__init__() # Initialize neural network blocks. self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3) @@ -171,13 +176,13 @@ def __init__(self, inChannels, outChannels): self.down3 = down(128, 256, 3) self.down4 = down(256, 512, 3) self.down5 = down(512, 512, 3) - self.up1 = up(512, 512) - self.up2 = up(512, 256) - self.up3 = up(256, 128) - self.up4 = up(128, 64) - self.up5 = up(64, 32) + self.up1 = up(512, 512) + self.up2 = up(512, 256) + self.up3 = up(256, 128) + self.up4 = up(128, 64) + self.up5 = up(64, 32) self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1) - + def forward(self, x): """ Returns output tensor after passing input `x` to the neural network. @@ -193,20 +198,19 @@ def forward(self, x): output of the UNet. """ - - x = F.leaky_relu(self.conv1(x), negative_slope = 0.1) - s1 = F.leaky_relu(self.conv2(x), negative_slope = 0.1) + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) + s1 = F.leaky_relu(self.conv2(x), negative_slope=0.1) s2 = self.down1(s1) s3 = self.down2(s2) s4 = self.down3(s3) s5 = self.down4(s4) - x = self.down5(s5) - x = self.up1(x, s5) - x = self.up2(x, s4) - x = self.up3(x, s3) - x = self.up4(x, s2) - x = self.up5(x, s1) - x = F.leaky_relu(self.conv3(x), negative_slope = 0.1) + x = self.down5(s5) + x = self.up1(x, s5) + x = self.up2(x, s4) + x = self.up3(x, s3) + x = self.up4(x, s2) + x = self.up5(x, s1) + x = F.leaky_relu(self.conv3(x), negative_slope=0.1) return x @@ -216,7 +220,7 @@ class backWarp(nn.Module): This is used for backwarping to an image: - Given optical flow from frame I0 to I1 --> F_0_1 and frame I1, + Given optical flow from frame I0 to I1 --> F_0_1 and frame I1, it generates I0 <-- backwarp(F_0_1, I1). ... @@ -228,7 +232,6 @@ class backWarp(nn.Module): block. """ - def __init__(self, W, H, device): """ Parameters @@ -238,10 +241,9 @@ def __init__(self, W, H, device): H : int height of the image. device : device - computation device (cpu/cuda). + computation device (cpu/cuda). """ - super(backWarp, self).__init__() # create a grid gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) @@ -249,7 +251,7 @@ def __init__(self, W, H, device): self.H = H self.gridX = torch.tensor(gridX, requires_grad=False, device=device) self.gridY = torch.tensor(gridY, requires_grad=False, device=device) - + def forward(self, img, flow): """ Returns output tensor after passing input `img` and `flow` to the backwarping @@ -269,27 +271,27 @@ def forward(self, img, flow): frame I0. """ - # Extract horizontal and vertical flows. u = flow[:, 0, :, :] v = flow[:, 1, :, :] x = self.gridX.unsqueeze(0).expand_as(u).float() + u y = self.gridY.unsqueeze(0).expand_as(v).float() + v # range -1 to 1 - x = 2*(x/self.W - 0.5) - y = 2*(y/self.H - 0.5) + x = 2 * (x / self.W - 0.5) + y = 2 * (y / self.H - 0.5) # stacking X and Y - grid = torch.stack((x,y), dim=3) + grid = torch.stack((x, y), dim=3) # Sample pixels using bilinear interpolation. imgOut = torch.nn.functional.grid_sample(img, grid) return imgOut # Creating an array of `t` values for the 7 intermediate frames between -# reference frames I0 and I1. +# reference frames I0 and I1. t = np.linspace(0.125, 0.875, 7) -def getFlowCoeff (indices, device): + +def getFlowCoeff(indices, device): """ Gets flow coefficients used for calculating intermediate optical flows from optical flows between I0 and I1: F_0_1 and F_1_0. @@ -309,7 +311,7 @@ def getFlowCoeff (indices, device): indices corresponding to the intermediate frame positions of all samples in the batch. device : device - computation device (cpu/cuda). + computation device (cpu/cuda). Returns ------- @@ -317,17 +319,22 @@ def getFlowCoeff (indices, device): coefficients C00, C01, C10, C11. """ - # Convert indices tensor to numpy array ind = indices.detach().numpy() - C11 = C00 = - (1 - (t[ind])) * (t[ind]) + C11 = C00 = -(1 - (t[ind])) * (t[ind]) C01 = (t[ind]) * (t[ind]) C10 = (1 - (t[ind])) * (1 - (t[ind])) - return torch.Tensor(C00)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C01)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C10)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C11)[None, None, None, :].permute(3, 0, 1, 2).to(device) + return ( + torch.Tensor(C00)[None, None, None, :].permute(3, 0, 1, 2).to(device), + torch.Tensor(C01)[None, None, None, :].permute(3, 0, 1, 2).to(device), + torch.Tensor(C10)[None, None, None, :].permute(3, 0, 1, 2).to(device), + torch.Tensor(C11)[None, None, None, :].permute(3, 0, 1, 2).to(device), + ) + -def getWarpCoeff (indices, device): +def getWarpCoeff(indices, device): """ - Gets coefficients used for calculating final intermediate + Gets coefficients used for calculating final intermediate frame `It_gen` from backwarped images using flows F_t_0 and F_t_1. It_gen = (C0 x V_t_0 x g_I_0_F_t_0 + C1 x V_t_1 x g_I_1_F_t_1) / (C0 x V_t_0 + C1 x V_t_1) @@ -345,7 +352,7 @@ def getWarpCoeff (indices, device): indices corresponding to the intermediate frame positions of all samples in the batch. device : device - computation device (cpu/cuda). + computation device (cpu/cuda). Returns ------- @@ -353,9 +360,10 @@ def getWarpCoeff (indices, device): coefficients C0 and C1. """ - # Convert indices tensor to numpy array ind = indices.detach().numpy() C0 = 1 - t[ind] C1 = t[ind] - return torch.Tensor(C0)[None, None, None, :].permute(3, 0, 1, 2).to(device), torch.Tensor(C1)[None, None, None, :].permute(3, 0, 1, 2).to(device) \ No newline at end of file + return torch.Tensor(C0)[None, None, None, :].permute(3, 0, 1, 2).to( + device + ), torch.Tensor(C1)[None, None, None, :].permute(3, 0, 1, 2).to(device) diff --git a/benchmarks/super-slomo/slomo/synth.py b/benchmarks/super-slomo/slomo/synth.py index 1b69407a4..57835e360 100644 --- a/benchmarks/super-slomo/slomo/synth.py +++ b/benchmarks/super-slomo/slomo/synth.py @@ -1,4 +1,3 @@ - class SyntheticData: def __init__(self, generators, n, repeat): self.n = n diff --git a/benchmarks/super-slomo/slomo/train.py b/benchmarks/super-slomo/slomo/train.py index 0c680cae7..7fea1a045 100644 --- a/benchmarks/super-slomo/slomo/train.py +++ b/benchmarks/super-slomo/slomo/train.py @@ -1,5 +1,4 @@ - -#[Super SloMo] +# [Super SloMo] ##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation import argparse @@ -17,18 +16,59 @@ def main(): - # For parsing commandline arguments parser = argparse.ArgumentParser() - parser.add_argument("--dataset_root", type=str, required=False, help='path to dataset folder containing train-test-validation folders') - parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') - parser.add_argument("--train_continue", type=bool, default=False, help='If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.') - parser.add_argument("--epochs", type=int, default=200, help='number of epochs to train. Default: 200.') - parser.add_argument("--train_batch_size", type=int, default=6, help='batch size for training. Default: 6.') - parser.add_argument("--validation_batch_size", type=int, default=10, help='batch size for validation. Default: 10.') - parser.add_argument("--init_learning_rate", type=float, default=0.0001, help='set initial learning rate. Default: 0.0001.') - parser.add_argument("--milestones", type=list, default=[100, 150], help='Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]') - parser.add_argument("--progress_iter", type=int, default=100, help='frequency of reporting progress and validation. N: after every N iterations. Default: 100.') + parser.add_argument( + "--dataset_root", + type=str, + required=False, + help="path to dataset folder containing train-test-validation folders", + ) + parser.add_argument( + "--checkpoint", type=str, help="path of checkpoint for pretrained model" + ) + parser.add_argument( + "--train_continue", + type=bool, + default=False, + help="If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.", + ) + parser.add_argument( + "--epochs", + type=int, + default=200, + help="number of epochs to train. Default: 200.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=6, + help="batch size for training. Default: 6.", + ) + parser.add_argument( + "--validation_batch_size", + type=int, + default=10, + help="batch size for validation. Default: 10.", + ) + parser.add_argument( + "--init_learning_rate", + type=float, + default=0.0001, + help="set initial learning rate. Default: 0.0001.", + ) + parser.add_argument( + "--milestones", + type=list, + default=[100, 150], + help="Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]", + ) + parser.add_argument( + "--progress_iter", + type=int, + default=100, + help="frequency of reporting progress and validation. N: after every N iterations. Default: 100.", + ) parser.add_argument( "--no-tf32", dest="allow_tf32", @@ -42,29 +82,23 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - ###Initialize flow computation and arbitrary-time flow interpolation CNNs. - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") flowComp = model.UNet(6, 4) flowComp.to(device) ArbTimeFlowIntrp = model.UNet(20, 5) ArbTimeFlowIntrp.to(device) - ###Initialze backward warpers for train and validation datasets - - trainFlowBackWarp = model.backWarp(352, 352, device) - trainFlowBackWarp = trainFlowBackWarp.to(device) + trainFlowBackWarp = model.backWarp(352, 352, device) + trainFlowBackWarp = trainFlowBackWarp.to(device) validationFlowBackWarp = model.backWarp(640, 352, device) validationFlowBackWarp = validationFlowBackWarp.to(device) - ###Load Datasets - # # Channel wise mean calculated on adobe240-fps training dataset # mean = [0.429, 0.431, 0.397] # std = [1, 1, 1] @@ -86,27 +120,20 @@ def ogen(): return torch.randint(0, 7, ()) trainset = SyntheticData( - n=args.train_batch_size, - repeat=10000, - generators=[igen, ogen] + n=args.train_batch_size, repeat=10000, generators=[igen, ogen] ) trainloader = torch.utils.data.DataLoader( - trainset, - batch_size=args.train_batch_size, - num_workers=2 + trainset, batch_size=args.train_batch_size, num_workers=2 ) - ###Utils - + def get_lr(optimizer): for param_group in optimizer.param_groups: - return param_group['lr'] - + return param_group["lr"] ###Loss and Optimizer - L1_lossFn = nn.L1Loss() MSE_LossFn = nn.MSELoss() @@ -114,105 +141,126 @@ def get_lr(optimizer): optimizer = optim.Adam(params, lr=args.init_learning_rate) # scheduler to decrease learning rate by a factor of 10 at milestones. - scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1) + scheduler = optim.lr_scheduler.MultiStepLR( + optimizer, milestones=args.milestones, gamma=0.1 + ) ###Initializing VGG16 model for perceptual loss - vgg16 = torchvision.models.vgg16(pretrained=True) vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22]) vgg16_conv_4_3.to(device) for param in vgg16_conv_4_3.parameters(): - param.requires_grad = False - + param.requires_grad = False ### Initialization - if args.train_continue: dict1 = torch.load(args.checkpoint) - ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT']) - flowComp.load_state_dict(dict1['state_dictFC']) + ArbTimeFlowIntrp.load_state_dict(dict1["state_dictAT"]) + flowComp.load_state_dict(dict1["state_dictFC"]) else: - dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1} - + dict1 = {"loss": [], "valLoss": [], "valPSNR": [], "epoch": -1} ### Training - cLoss = dict1['loss'] - valLoss = dict1['valLoss'] - valPSNR = dict1['valPSNR'] + cLoss = dict1["loss"] + valLoss = dict1["valLoss"] + valPSNR = dict1["valPSNR"] ### Main training loop - for epoch in range(dict1['epoch'] + 1, args.epochs): + for epoch in range(dict1["epoch"] + 1, args.epochs): print("Epoch: ", epoch) - + # Append and reset cLoss.append([]) valLoss.append([]) valPSNR.append([]) iLoss = 0 - - # Increment scheduler count + + # Increment scheduler count scheduler.step() - + # for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0): - for trainIndex, (trainData, trainFrameIndex) in enumerate(voir.iterate("train", trainloader, report_batch=True, batch_size=lambda batch: batch[1].shape[0]), 0): + for trainIndex, (trainData, trainFrameIndex) in enumerate( + voir.iterate( + "train", + trainloader, + report_batch=True, + batch_size=lambda batch: batch[1].shape[0], + ), + 0, + ): ## Getting the input and the target from the training set frame0, frameT, frame1 = trainData - + I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) - + optimizer.zero_grad() - + # Calculate flow between reference frames I0 and I1 flowOut = flowComp(torch.cat((I0, I1), dim=1)) - + # Extracting flows between I0 and I1 - F_0_1 and F_1_0 - F_0_1 = flowOut[:,:2,:,:] - F_1_0 = flowOut[:,2:,:,:] - + F_0_1 = flowOut[:, :2, :, :] + F_1_0 = flowOut[:, 2:, :, :] + fCoeff = model.getFlowCoeff(trainFrameIndex, device) - + # Calculate intermediate flows F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 - + # Get intermediate frames from the intermediate flows g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1) - + # Calculate optical flow residuals and visibility maps - intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) - + intrpOut = ArbTimeFlowIntrp( + torch.cat( + (I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1 + ) + ) + # Extract optical flow residuals and visibility maps F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 - V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :]) - V_t_1 = 1 - V_t_0 - + V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :]) + V_t_1 = 1 - V_t_0 + # Get intermediate frames from the intermediate flows g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f) - + wCoeff = model.getWarpCoeff(trainFrameIndex, device) - - # Calculate final intermediate frame - Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) - + + # Calculate final intermediate frame + Ft_p = ( + wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f + ) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) + # Loss recnLoss = L1_lossFn(Ft_p, IFrame) - + prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) - - warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0) - - loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) - loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) + + warpLoss = ( + L1_lossFn(g_I0_F_t_0, IFrame) + + L1_lossFn(g_I1_F_t_1, IFrame) + + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0) + ) + + loss_smooth_1_0 = torch.mean( + torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:]) + ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) + loss_smooth_0_1 = torch.mean( + torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:]) + ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 - + # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4 # since the loss in paper is calculated for input pixels in range 0-255 # and the input to our network is in range 0-1 diff --git a/benchmarks/timm/benchfile.py b/benchmarks/timm/benchfile.py index 50f7e69dc..f8d0652e0 100644 --- a/benchmarks/timm/benchfile.py +++ b/benchmarks/timm/benchfile.py @@ -14,17 +14,21 @@ class TimmBenchmarkPack(Package): def make_env(self): return { **super().make_env(), - "OMP_NUM_THREADS": str(self.config.get("cpus_per_gpu", 8)) + "OMP_NUM_THREADS": str(self.config.get("cpus_per_gpu", 8)), } @property def argv(self): return [ *super().argv, - "--data-dir", self.dirs.data, - "--dataset", "FakeImageNet", - "--output", self.dirs.extra / self.logdir.name / self.tag, - "--checkpoint-hist", 1, + "--data-dir", + self.dirs.data, + "--dataset", + "FakeImageNet", + "--output", + self.dirs.extra / self.logdir.name / self.tag, + "--checkpoint-hist", + 1, ] async def install(self): @@ -32,7 +36,9 @@ async def install(self): timm = self.dirs.code / "pytorch-image-models" if not timm.exists(): - timm.clone_subtree("https://github.com/huggingface/pytorch-image-models", BRANCH) + timm.clone_subtree( + "https://github.com/huggingface/pytorch-image-models", BRANCH + ) def build_run_plan(self): # self.config is not the right config for this diff --git a/benchmarks/timm/voirfile.py b/benchmarks/timm/voirfile.py index 19ac71fa5..5f17d8408 100644 --- a/benchmarks/timm/voirfile.py +++ b/benchmarks/timm/voirfile.py @@ -33,12 +33,14 @@ def setup(args): ov.require(dash) ov.require( - log("value", "progress", "rate", "units", "loss", "gpudata", context="task"), + log( + "value", "progress", "rate", "units", "loss", "gpudata", context="task" + ), rate( interval=options.interval, skip=options.skip, sync=torch.cuda.synchronize if torch.cuda.is_available() else None, - batch_size_calc=lambda b: len(b) * args.world_size + batch_size_calc=lambda b: len(b) * args.world_size, ), early_stop(n=options.stop, key="rate", task="train", signal="stop"), gpu_monitor(poll_interval=options.gpu_poll), @@ -46,8 +48,7 @@ def setup(args): # Loss ( - loss_probe - .throttle(1)["loss"] + loss_probe.throttle(1)["loss"] .map(lambda loss: {"task": "train", "loss": float(loss)}) .give() ) diff --git a/benchmarks/torchvision/main.py b/benchmarks/torchvision/main.py index b0a7cf847..843f2246a 100644 --- a/benchmarks/torchvision/main.py +++ b/benchmarks/torchvision/main.py @@ -167,7 +167,7 @@ def main(): if not args.no_cuda: assert torch.cuda.is_available(), "Why is CUDA not available" - + use_cuda = not args.no_cuda torch.manual_seed(args.seed) diff --git a/milabench/cli.py b/milabench/cli.py index c036af5b6..5b6a7599e 100644 --- a/milabench/cli.py +++ b/milabench/cli.py @@ -690,7 +690,7 @@ def report(): title=None, sources=runs, errdata=reports and _error_report(reports), - stream=sys.stdout + stream=sys.stdout, ) def pip(): diff --git a/milabench/report.py b/milabench/report.py index fd05301d5..51a59d162 100644 --- a/milabench/report.py +++ b/milabench/report.py @@ -296,7 +296,7 @@ def _score(column): H.div["collapsible"](lines), ) ) - + out.finalize() diff --git a/milabench/sizer.py b/milabench/sizer.py index 6db8b299f..240296bf6 100644 --- a/milabench/sizer.py +++ b/milabench/sizer.py @@ -16,7 +16,10 @@ def is_autoscale_enabled(): - return os.getenv("MILABENCH_SIZER_AUTO", False) or os.getenv("MILABENCH_SIZER_MULTIPLE") is not None + return ( + os.getenv("MILABENCH_SIZER_AUTO", False) + or os.getenv("MILABENCH_SIZER_MULTIPLE") is not None + ) def getenv(name, type): @@ -24,9 +27,10 @@ def getenv(name, type): if value is not None: return type(value) - + return value + @dataclass class SizerOptions: size: int = getenv("MILABENCH_SIZER_BATCH_SIZE", int) @@ -62,7 +66,7 @@ def to_octet(value: str) -> float: if "io" in value: return float(value.replace("io", "")) - + if "o" in value: return float(value.replace("o", "")) @@ -97,17 +101,17 @@ def benchscaling(self, benchmark): def get_capacity(self, capacity): if self.options.capacity is not None: capacity = self.options.capacity - + if isinstance(capacity, str): capacity = to_octet(capacity) - + return capacity def auto_size(self, benchmark, capacity): capacity = self.get_capacity(capacity) - config = self.benchscaling(benchmark) - + config = self.benchscaling(benchmark) + data = list(sorted(config["model"].items(), key=lambda x: x[0])) mem = [to_octet(v[1]) for v in data] size = [float(v[0]) for v in data] @@ -131,7 +135,7 @@ def auto_size(self, benchmark, capacity): def size(self, benchmark, capacity): config = self.benchscaling(benchmark) - + if self.options.size is not None: return self.options.size @@ -145,11 +149,11 @@ def size(self, benchmark, capacity): def argv(self, benchmark, capacity, argv): """Find the batch size and override it with a new value""" - + config = self.benchscaling(benchmark) if config is None: return argv - + newsize = self.size(benchmark, capacity) if newsize is None: @@ -160,7 +164,7 @@ def argv(self, benchmark, capacity, argv): argname = config.get("arg") if argname is None: return argv - + for i, arg in enumerate(argv): if arg.endswith(argname): break @@ -184,32 +188,31 @@ def scale_argv(pack, argv): return sizer.argv(pack, capacity, argv) - class MemoryUsageExtractor(ValidationLayer): """Extract max memory usage per benchmark to populate the memory model""" - + def __init__(self): self.filepath = getenv("MILABENCH_SIZER_SAVE", str) - + self.memory = deepcopy(sizer_global.get().scaling_config) self.scaling = None self.benchname = None self.batch_size = 0 - self.max_usage = float('-inf') + self.max_usage = float("-inf") self.early_stopped = False - + def on_start(self, entry): if self.filepath is None: return - + argv = entry.data["command"] self.benchname = entry.pack.config["name"] self.batch_size = None - self.max_usage = float('-inf') - + self.max_usage = float("-inf") + config = self.memory.get(self.benchname, dict()) scalingarg = config.get("arg", None) - + if scalingarg is None: self.benchname = None return @@ -219,14 +222,14 @@ def on_start(self, entry): if arg.endswith(scalingarg): found = i break - + if found: - self.batch_size = int(argv[found + 1]) - + self.batch_size = int(argv[found + 1]) + def on_data(self, entry): if self.filepath is None: return - + if entry.data is None: return @@ -236,21 +239,23 @@ def on_data(self, entry): for device, data in gpudata.items(): usage, total = data.get("memory", [0, 1]) current_usage.append(usage) - + self.max_usage = max(*current_usage, self.max_usage) - + def on_stop(self, entry): self.early_stopped = True def on_end(self, entry): if self.filepath is None: return - - if (self.benchname is None or - self.batch_size is None or - self.max_usage == float('-inf')): + + if ( + self.benchname is None + or self.batch_size is None + or self.max_usage == float("-inf") + ): return - + # Only update is successful rc = entry.data["return_code"] if rc == 0 or self.early_stopped: @@ -258,16 +263,12 @@ def on_end(self, entry): model = config.setdefault("model", dict()) model[self.batch_size] = f"{self.max_usage} MiB" config["model"] = dict(sorted(model.items(), key=lambda x: x[0])) - + self.benchname = None self.batch_size = None - self.max_usage = float('-inf') - + self.max_usage = float("-inf") + def report(self, *args): if self.filepath is not None: - with open(self.filepath, 'w') as file: + with open(self.filepath, "w") as file: yaml.dump(self.memory, file) - - - - diff --git a/tests/test_scaler.py b/tests/test_scaler.py index deaae8139..283048c8b 100644 --- a/tests/test_scaler.py +++ b/tests/test_scaler.py @@ -4,9 +4,7 @@ def test_scaler_use_override(multipack, config): - sizer = Sizer( - SizerOptions(size=64, autoscale=False), config("scaling") - ) + sizer = Sizer(SizerOptions(size=64, autoscale=False), config("scaling")) for k, pack in multipack.packs.items(): assert sizer.size(pack, "48Go") == 64 @@ -38,9 +36,7 @@ def test_scaler_use_optimized(multipack, config): @pytest.mark.parametrize("capacity,expected", _values) def test_scaler_autoscaler_lerp(multipack, config, capacity, expected): - sizer = Sizer( - SizerOptions(size=None, autoscale=True), config("scaling") - ) + sizer = Sizer(SizerOptions(size=None, autoscale=True), config("scaling")) for k, pack in multipack.packs.items(): assert sizer.size(pack, capacity) == expected diff --git a/tests/test_summary.py b/tests/test_summary.py index 05865f93c..702cf82d8 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -9,11 +9,11 @@ def test_report(runs_folder, capsys, file_regression, config): except SystemExit as exc: assert not exc.code assert exc.code is None - + all = capsys.readouterr() stdout = all.out assert stdout != "" - + stdout = stdout.replace(str(folder), "XXX") file_regression.check(stdout) diff --git a/tests/test_validation.py b/tests/test_validation.py index 9ea69dc1a..eb765508c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -12,7 +12,7 @@ def replay_validation_scenario(folder, *validation, filename=None): path = folder / filename file = str(path) + ".txt" - + if os.path.isdir(path): files = [path / f for f in os.scandir(path)] gen = interleave(*files) @@ -30,11 +30,10 @@ def replay_validation_scenario(folder, *validation, filename=None): def replay_scenario(folder, name, filename=None): """Replay events from a data file or folder""" return replay_validation_scenario( - folder, - *validation_layers(name), - filename=filename or name + folder, *validation_layers(name), filename=filename or name ) - + + def test_error_layer(replayfolder): log = replay_scenario(replayfolder, "error") assert log.result() != 0 @@ -103,12 +102,17 @@ def test_planning_layer_per_gpu_bad(replayfolder, monkeypatch): log = replay_scenario(replayfolder, "planning", "planning_per_gpu_bad") assert log.result() != 0 - - + + def test_memory_tracking(replayfolder, config): import contextvars from milabench.sizer import ( - MemoryUsageExtractor, Sizer, SizerOptions, sizer_global, system_global) + MemoryUsageExtractor, + Sizer, + SizerOptions, + sizer_global, + system_global, + ) ctx = contextvars.copy_context() @@ -122,18 +126,9 @@ def update_ctx(): config("scaling"), ) sizer_global.set(sizer) - system = system_global.set({ - "gpu": { - "capacity": "41920 MiB" - } - }) - + system = system_global.set({"gpu": {"capacity": "41920 MiB"}}) + ctx.run(update_ctx) layer = MemoryUsageExtractor() - ctx.run(lambda: replay_validation_scenario( - replayfolder, - layer, - filename="usage" - )) - \ No newline at end of file + ctx.run(lambda: replay_validation_scenario(replayfolder, layer, filename="usage"))