From 33910f89442e157e035378af482dc1ba5c558d96 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 20 Jul 2022 11:10:35 +0800 Subject: [PATCH 01/62] add concurrent performance testing for websocket_server_main (non-stream) --- tools/websocket/performance-ws.py | 166 ++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100755 tools/websocket/performance-ws.py diff --git a/tools/websocket/performance-ws.py b/tools/websocket/performance-ws.py new file mode 100755 index 000000000..af77dea06 --- /dev/null +++ b/tools/websocket/performance-ws.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# coding:utf-8 + +# Copyright (c) 2022 SDCI Co. Ltd (author: veelion) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import time +import asyncio +import argparse +import websockets +import soundfile as sf +import statistics + + +WS_START = json.dumps({ + 'signal': 'start', + 'nbest': 1, + 'continuous_decoding': False, +}) +WS_END = json.dumps({ + 'signal': 'end' +}) + + +async def ws_rec(data, ws_uri): + begin = time.time() + conn = await websockets.connect(ws_uri, ping_timeout=200) + # step 1: send start + await conn.send(WS_START) + ret = await conn.recv() + # step 2: send audio data + await conn.send(data) + # step 3: send end + await conn.send(WS_END) + # step 4: receive result + texts = [] + while 1: + ret = await conn.recv() + ret = json.loads(ret) + if ret['type'] == 'final_result': + nbest = json.loads(ret['nbest']) + text = nbest[0]['sentence'] + texts.append(text) + elif ret['type'] == 'speech_end': + break + # step 5: close + try: + await conn.close() + except Exception as e: + # this except has no effect, just log as debug + # it seems the server does not send close info, maybe + print(e) + time_cost = time.time() - begin + return { + 'text': ''.join(texts), + 'time': time_cost, + } + + +def get_args(): + parser = argparse.ArgumentParser(description='') + parser.add_argument( + '-u', '--ws_uri', required=True, + help="websocket_server_main's uri, e.g. ws://127.0.0.1:10086") + parser.add_argument( + '-w', '--wav_scp', required=True, + help='path to wav_scp_file') + parser.add_argument( + '-t', '--trans', required=True, + help='path to trans_text_file of wavs') + parser.add_argument( + '-s', '--save_to', required=True, + help='path to save transcription') + parser.add_argument( + '-n', '--num_concurrence', type=int, required=True, + help='num of concurrence for query') + args = parser.parse_args() + return args + + +def print_result(info): + length = max([len(k) for k in info]) + for k, v in info.items(): + print(f'\t{k: >{length}} : {v}') + + +async def main(args): + wav_scp = [] + total_duration = 0 + with open(args.wav_scp) as f: + for line in f: + zz = line.strip().split() + assert len(zz) == 2 + data, sr = sf.read(zz[1], dtype='int16') + assert sr == 16000 + duration = (len(data)) / 16000 + total_duration += duration + wav_scp.append((zz[0], data.tobytes())) + print(f'{len(wav_scp) = }, {total_duration = }') + + tasks = [] + failed = 0 + texts = [] + request_times = [] + begin = time.time() + for i, (_uttid, data) in enumerate(wav_scp): + task = asyncio.create_task(ws_rec(data, args.ws_uri)) + tasks.append((_uttid, task)) + if len(tasks) < args.num_concurrence: + continue + print((f'{i=}, start {args.num_concurrence} ' + f'queries @ {time.strftime("%m-%d %H:%M:%S")}')) + for uttid, task in tasks: + result = await task + texts.append(f'{uttid}\t{result["text"]}\n') + request_times.append(result['time']) + tasks = [] + print(f'\tdone @ {time.strftime("%m-%d %H:%M:%S")}') + if tasks: + for uttid, task in tasks: + result = await task + texts.append(f'{uttid}\t{result["text"]}\n') + request_times.append(result['time']) + request_time = time.time() - begin + rtf = request_time / total_duration + print('For all concurrence:') + print_result({ + 'failed': failed, + 'total_duration': total_duration, + 'request_time': request_time, + 'RTF': rtf, + }) + print('For one request:') + print_result({ + 'mean': statistics.mean(request_times), + 'median': statistics.median(request_times), + 'max_time': max(request_times), + 'min_time': min(request_times), + }) + with open(args.save_to, 'w', encoding='utf8') as fsave: + fsave.write(''.join(texts)) + # caculate CER + cmd = (f'python ../compute-wer.py --char=1 --v=1 ' + f'{args.trans} {args.save_to} > ' + f'{args.save_to}-test-{args.num_concurrence}.cer.txt') + print(cmd) + os.system(cmd) + print('done') + + +if __name__ == '__main__': + args = get_args() + asyncio.run(main(args)) From 3f199f554705a4a47af004f26bf7eb13c0695f8b Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:30:20 +0800 Subject: [PATCH 02/62] add batch processing to decoder --- runtime/core/decoder/batch_asr_decoder.cc | 215 +++++++++++++++ runtime/core/decoder/batch_asr_decoder.h | 92 +++++++ runtime/core/decoder/batch_asr_model.cc | 22 ++ runtime/core/decoder/batch_asr_model.h | 59 +++++ runtime/core/decoder/batch_torch_asr_model.cc | 244 ++++++++++++++++++ runtime/core/decoder/batch_torch_asr_model.h | 64 +++++ 6 files changed, 696 insertions(+) create mode 100644 runtime/core/decoder/batch_asr_decoder.cc create mode 100644 runtime/core/decoder/batch_asr_decoder.h create mode 100644 runtime/core/decoder/batch_asr_model.cc create mode 100644 runtime/core/decoder/batch_asr_model.h create mode 100644 runtime/core/decoder/batch_torch_asr_model.cc create mode 100644 runtime/core/decoder/batch_torch_asr_model.h diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc new file mode 100644 index 000000000..508f1edcc --- /dev/null +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -0,0 +1,215 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_asr_decoder.h" + +#include + +#include +#include +#include + +#include "utils/timer.h" + +namespace wenet { + +BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, + std::shared_ptr resource, + const DecodeOptions& opts) + : feature_config_(config), + fbank_(config->num_bins, config->sample_rate, config->frame_length, config->frame_shift), + model_(resource->batch_model->Copy()), + post_processor_(resource->post_processor), + symbol_table_(resource->symbol_table), + fst_(resource->fst), + unit_table_(resource->unit_table), + opts_(opts) { + if (opts_.reverse_weight > 0) { + // Check if model has a right to left decoder + CHECK(model_->is_bidirectional_decoder()); + } + if (nullptr == fst_) { + searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts, + resource->context_graph)); + } else { + searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts, + resource->context_graph)); + } +} + +void BatchAsrDecoder::Reset() { + result_.clear(); + batch_result_.clear(); + global_frame_offset_ = 0; + searcher_->Reset(); +} + +void BatchAsrDecoder::Decode(const std::vector>& wavs) { + // 1. calc fbank feature of the batch of wavs + Timer timer; + batch_feature_t batch_feats; + std::vector batch_feats_lens; + VLOG(1) << "wavs : " << wavs.size(); + for (const auto& wav : wavs) { + VLOG(1) << "wav : " << wav.size(); + feature_t feats; + int num_frames = fbank_.Compute(wav, &feats); + VLOG(1) << "feat leng is " << num_frames; + batch_feats.push_back(std::move(feats)); + batch_feats_lens.push_back(num_frames); + } + int feat_time = timer.Elapsed(); + VLOG(1) << "feat_time : " << feat_time; + + // 1.1 feature padding + timer.Reset(); + int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); + VLOG(1) << "max length feature : " << max_len; + for (auto& feat : batch_feats) { + if (feat.size() == max_len) continue; + int pad_len = max_len - feat.size(); + for (size_t i = 0; i< pad_len; i++) { + std::vector one(feature_config_->num_bins, 0); + feat.push_back(std::move(one)); + } + } + VLOG(1) << "padding time : " << timer.Elapsed(); + timer.Reset(); + + // 2. encoder forward + batch_ctc_log_prob_t batch_ctc_log_probs; + model_->ForwardEncoder(batch_feats, batch_feats_lens, batch_ctc_log_probs); + VLOG(1) << "encoder forward time : " << timer.Elapsed(); + + // 3. ctc search one by one of the batch + // it seems, decoder forward only support 1 encoder_out with n-best ctc search result + int batch_size = wavs.size(); + batch_result_.clear(); + for (size_t i = 0; i < batch_size; i++) { + timer.Reset(); + const auto& ctc_log_probs = batch_ctc_log_probs[i]; + // 3.1. ctc search + searcher_->Search(ctc_log_probs); + int search_time = timer.Elapsed(); + VLOG(1) << "search takes " << search_time << " ms"; + + // 3.2. rescoring + timer.Reset(); + AttentionRescoring(i); + VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << "ms."; + + // 3.3. save to batch_result_ + batch_result_.push_back(std::move(result_)); + + // 3.4 reset + searcher_->Reset(); + } +} + +void BatchAsrDecoder::UpdateResult(bool finish) { + const auto& hypotheses = searcher_->Outputs(); + const auto& inputs = searcher_->Inputs(); + const auto& likelihood = searcher_->Likelihood(); + const auto& times = searcher_->Times(); + result_.clear(); + + CHECK_EQ(hypotheses.size(), likelihood.size()); + for (size_t i = 0; i < hypotheses.size(); i++) { + const std::vector& hypothesis = hypotheses[i]; + + DecodeResult path; + path.score = likelihood[i]; + int offset = global_frame_offset_ * feature_frame_shift_in_ms(); + for (size_t j = 0; j < hypothesis.size(); j++) { + std::string word = symbol_table_->Find(hypothesis[j]); + // A detailed explanation of this if-else branch can be found in + // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 + if (searcher_->Type() == kWfstBeamSearch) { + path.sentence += (' ' + word); + } else { + path.sentence += (word); + } + } + + // TimeStamp is only supported in final result + // TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to + // various FST operations when building the decoding graph. So here we use + // time stamp of the input(e2e model unit), which is more accurate, and it + // requires the symbol table of the e2e model used in training. + if (unit_table_ != nullptr && finish) { + const std::vector& input = inputs[i]; + const std::vector& time_stamp = times[i]; + CHECK_EQ(input.size(), time_stamp.size()); + for (size_t j = 0; j < input.size(); j++) { + std::string word = unit_table_->Find(input[j]); + int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0 + ? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ + : 0; + if (j > 0) { + start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() < + time_stamp_gap_ + ? (time_stamp[j - 1] + time_stamp[j]) / 2 * + frame_shift_in_ms() + : start; + } + int end = time_stamp[j] * frame_shift_in_ms(); + if (j < input.size() - 1) { + end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() < + time_stamp_gap_ + ? (time_stamp[j + 1] + time_stamp[j]) / 2 * + frame_shift_in_ms() + : end; + } + WordPiece word_piece(word, offset + start, offset + end); + path.word_pieces.emplace_back(word_piece); + } + } + + if (post_processor_ != nullptr) { + path.sentence = post_processor_->Process(path.sentence, finish); + } + result_.emplace_back(path); + } +} + +void BatchAsrDecoder::AttentionRescoring(int batch_index) { + searcher_->FinalizeSearch(); + UpdateResult(true); + // No need to do rescoring + if (0.0 == opts_.rescoring_weight) { + return; + } + // Inputs() returns N-best input ids, which is the basic unit for rescoring + // In CtcPrefixBeamSearch, inputs are the same to outputs + const auto& hypotheses = searcher_->Inputs(); + int num_hyps = hypotheses.size(); + if (num_hyps <= 0) { + return; + } + + std::vector rescoring_score; + model_->AttentionRescoring(hypotheses, batch_index, opts_.reverse_weight, + &rescoring_score); + + // Combine ctc score and rescoring score + for (size_t i = 0; i < num_hyps; ++i) { + result_[i].score = opts_.rescoring_weight * rescoring_score[i] + + opts_.ctc_weight * result_[i].score; + } + std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); +} + +} // namespace wenet diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h new file mode 100644 index 000000000..07f851e0e --- /dev/null +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -0,0 +1,92 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_ASR_DECODER_H_ +#define DECODER_BATCH_ASR_DECODER_H_ + +#include +#include +#include +#include + +#include "fst/fstlib.h" +#include "fst/symbol-table.h" + +#include "decoder/batch_asr_model.h" +#include "decoder/asr_decoder.h" +#include "decoder/context_graph.h" +#include "decoder/ctc_prefix_beam_search.h" +#include "decoder/ctc_wfst_beam_search.h" +#include "decoder/search_interface.h" +#include "frontend/feature_pipeline.h" +#include "post_processor/post_processor.h" +#include "utils/utils.h" +#include "frontend/fbank.h" + +namespace wenet { + +// Torch ASR batch decoder +class BatchAsrDecoder { + public: + BatchAsrDecoder(std::shared_ptr feature_config, + std::shared_ptr resource, + const DecodeOptions& opts); + void Decode(const std::vector>& wavs); + void Reset(); + + int frame_shift_in_ms() const { + return model_->subsampling_rate() * + feature_config_->frame_shift * 1000 / + feature_config_->sample_rate; + } + int feature_frame_shift_in_ms() const { + return feature_config_->frame_shift * 1000 / + feature_config_->sample_rate; + } + const std::vector& result() const { return result_; } + const std::vector>& batch_result() const { return batch_result_; } + + private: + Fbank fbank_; + void AttentionRescoring(int batch_index); + + void UpdateResult(bool finish = false); + + std::shared_ptr feature_config_; + std::shared_ptr model_; + std::shared_ptr post_processor_; + + std::shared_ptr> fst_ = nullptr; + // output symbol table + std::shared_ptr symbol_table_; + // e2e unit symbol table + std::shared_ptr unit_table_ = nullptr; + const DecodeOptions& opts_; + int global_frame_offset_ = 0; + const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence + + std::unique_ptr searcher_; + + std::vector result_; + std::vector> batch_result_; + + public: + WENET_DISALLOW_COPY_AND_ASSIGN(BatchAsrDecoder); +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ASR_DECODER_H_ diff --git a/runtime/core/decoder/batch_asr_model.cc b/runtime/core/decoder/batch_asr_model.cc new file mode 100644 index 000000000..977f17eff --- /dev/null +++ b/runtime/core/decoder/batch_asr_model.cc @@ -0,0 +1,22 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Author: binbin.zhang@horizon.ai (Binbin Zhang) + +#include "decoder/batch_asr_model.h" + +#include +#include + +namespace wenet { + +void BatchAsrModel::ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& batch_ctc_prob) { + batch_ctc_prob.clear(); + this->ForwardEncoderFunc( + batch_feats, + batch_feats_lens, + batch_ctc_prob); + } + +} // namespace wenet diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h new file mode 100644 index 000000000..b33be2285 --- /dev/null +++ b/runtime/core/decoder/batch_asr_model.h @@ -0,0 +1,59 @@ +// Copyright 2022 Horizon Robotics. All Rights Reserved. +// Author: binbin.zhang@horizon.ai (Binbin Zhang) + +#ifndef DECODER_BATCH_ASR_MODEL_H_ +#define DECODER_BATCH_ASR_MODEL_H_ + +#include +#include +#include +#include + +#include "utils/timer.h" +#include "utils/utils.h" + +namespace wenet { + +using feature_t = std::vector>; +using batch_feature_t = std::vector; +using batch_ctc_log_prob_t = std::vector; + +class BatchAsrModel { + + public: + virtual int right_context() const { return right_context_; } + virtual int subsampling_rate() const { return subsampling_rate_; } + virtual int sos() const { return sos_; } + virtual int eos() const { return eos_; } + virtual bool is_bidirectional_decoder() const { + return is_bidirectional_decoder_; + } + + virtual void ForwardEncoder( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& batch_ctc_prob); + + virtual void AttentionRescoring(const std::vector>& hyps, + int batch_index, + float reverse_weight, + std::vector* rescoring_score) = 0; + + virtual std::shared_ptr Copy() const = 0; + + protected: + virtual void ForwardEncoderFunc( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& batch_ctc_prob) = 0; + + int right_context_ = 1; + int subsampling_rate_ = 1; + int sos_ = 0; + int eos_ = 0; + bool is_bidirectional_decoder_ = false; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ASR_MODEL_H_ diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc new file mode 100644 index 000000000..98b93b1b7 --- /dev/null +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -0,0 +1,244 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_torch_asr_model.h" + +#include +#include +#include +#include + +#include "torch/script.h" +#include "torch/torch.h" + +namespace wenet { + +void BatchTorchAsrModel::InitEngineThreads(int num_threads) { + // For multi-thread performance + at::set_num_threads(num_threads); + // Note: Do not call the set_num_interop_threads function more than once. + // Please see https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ + // ParallelThreadPoolNative.cpp#L54-L56 + at::set_num_interop_threads(1); + VLOG(1) << "Num intra-op threads: " << at::get_num_threads(); + VLOG(1) << "Num inter-op threads: " << at::get_num_interop_threads(); +} + +void BatchTorchAsrModel::Read(const std::string& model_path) { + torch::DeviceType device = at::kCPU; +#ifdef USE_GPU + if (!torch::cuda::is_available()) { + VLOG(1) << "CUDA is not available! Please check your GPU settings"; + throw std::runtime_error("CUDA is not available!"); + } else { + VLOG(1) << "CUDA available! Running on GPU"; + device = at::kCUDA; + } +#endif + torch::jit::script::Module model = torch::jit::load(model_path, device); + model_ = std::make_shared(std::move(model)); + torch::NoGradGuard no_grad; + model_->eval(); + torch::jit::IValue o1 = model_->run_method("subsampling_rate"); + CHECK_EQ(o1.isInt(), true); + subsampling_rate_ = o1.toInt(); + torch::jit::IValue o2 = model_->run_method("right_context"); + CHECK_EQ(o2.isInt(), true); + torch::jit::IValue o3 = model_->run_method("sos_symbol"); + CHECK_EQ(o3.isInt(), true); + sos_ = o3.toInt(); + torch::jit::IValue o4 = model_->run_method("eos_symbol"); + CHECK_EQ(o4.isInt(), true); + eos_ = o4.toInt(); + torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder"); + CHECK_EQ(o5.isBool(), true); + is_bidirectional_decoder_ = o5.toBool(); + + VLOG(1) << "Torch Model Info:"; + VLOG(1) << "\tsubsampling_rate " << subsampling_rate_; + VLOG(1) << "\tsos " << sos_; + VLOG(1) << "\teos " << eos_; + VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_; +} + +BatchTorchAsrModel::BatchTorchAsrModel(const BatchTorchAsrModel& other) { + // 1. Init the model info + subsampling_rate_ = other.subsampling_rate_; + sos_ = other.sos_; + eos_ = other.eos_; + is_bidirectional_decoder_ = other.is_bidirectional_decoder_; + // 2. Model copy, just copy the model ptr since: + // PyTorch allows using multiple CPU threads during TorchScript model + // inference, please see https://pytorch.org/docs/stable/notes/cpu_ + // threading_torchscript_inference.html + model_ = other.model_; + +} + +std::shared_ptr BatchTorchAsrModel::Copy() const { + auto asr_model = std::make_shared(*this); + return asr_model; +} + +void BatchTorchAsrModel::ForwardEncoderFunc( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& out_prob) { + // 1. Prepare libtorch required data + int batch_size = batch_feats.size(); + int num_frames = batch_feats[0].size(); + const int feature_dim = batch_feats[0][0].size(); + torch::Tensor feats = + torch::zeros({batch_size, num_frames, feature_dim}, torch::kFloat); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + torch::Tensor row = + torch::from_blob(const_cast(batch_feats[i][j].data()), + {feature_dim}, torch::kFloat).clone(); + feats[i][j] = std::move(row); + } + } + torch::Tensor feats_lens = + torch::from_blob(const_cast(batch_feats_lens.data()), + {batch_size}, torch::kInt).clone(); + + // 2. Encoder batch forward +#ifdef USE_GPU + feats = feats.to(at::kCUDA); + feats_lens = feats_lens.to(at::kCUDA); +#endif + torch::NoGradGuard no_grad; + std::vector inputs = {feats, feats_lens}; + + // Refer interfaces in wenet/transformer/asr_model.py + auto outputs = + model_->get_method("forward_encoder_batch")(inputs).toTuple()->elements(); + CHECK_EQ(outputs.size(), 2); + encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) + + // The first dimension of returned value is for batchsize + torch::Tensor ctc_log_probs = + model_->run_method("ctc_activation", encoder_out_).toTensor(); +#ifdef USE_GPU + ctc_log_probs = ctc_log_probs.to(at::kCPU); +#endif + + // Copy to output + int num_outputs = ctc_log_probs.size(1); + int output_dim = ctc_log_probs.size(2); + out_prob.resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + out_prob[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + out_prob[i][j].resize(output_dim); + memcpy(out_prob[i][j].data(), ctc_log_probs[i][j].data_ptr(), + sizeof(float) * output_dim); + } + } +} + +float BatchTorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob, + const std::vector& hyp, + int eos) { + float score = 0.0f; + auto accessor = prob.accessor(); + for (size_t j = 0; j < hyp.size(); ++j) { + score += accessor[j][hyp[j]]; + } + score += accessor[hyp.size()][eos]; + return score; +} + +void BatchTorchAsrModel::AttentionRescoring( + const std::vector>& hyps, + int batch_index, + float reverse_weight, + std::vector* rescoring_score) { + CHECK(rescoring_score != nullptr); + int num_hyps = hyps.size(); + rescoring_score->resize(num_hyps, 0.0f); + + if (num_hyps == 0) { + return; + } + + torch::NoGradGuard no_grad; + // Step 1: Prepare input for libtorch + torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong); + int max_hyps_len = 0; + for (size_t i = 0; i < num_hyps; ++i) { + int length = hyps[i].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_length[i] = static_cast(length); + } + torch::Tensor hyps_tensor = + torch::zeros({num_hyps, max_hyps_len}, torch::kLong); + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + hyps_tensor[i][0] = sos_; + for (size_t j = 0; j < hyp.size(); ++j) { + hyps_tensor[i][j + 1] = hyp[j]; + } + } + + // Step 2: Forward attention decoder by hyps and corresponding encoder_outs_ + using namespace torch::indexing; + torch::Tensor encoder_out = encoder_out_.index({Slice(batch_index, batch_index + 1)}); +#ifdef USE_GPU + hyps_tensor = hyps_tensor.to(at::kCUDA); + hyps_length = hyps_length.to(at::kCUDA); + encoder_out = encoder_out.to(at::kCUDA); +#endif + auto outputs = model_ + ->run_method("forward_attention_decoder", hyps_tensor, + hyps_length, encoder_out, reverse_weight) + .toTuple() + ->elements(); +#ifdef USE_GPU + auto probs = outputs[0].toTensor().to(at::kCPU); + auto r_probs = outputs[1].toTensor().to(at::kCPU); +#else + auto probs = outputs[0].toTensor(); + auto r_probs = outputs[1].toTensor(); +#endif + CHECK_EQ(probs.size(0), num_hyps); + CHECK_EQ(probs.size(1), max_hyps_len); + + // Step 3: Compute rescoring score + for (size_t i = 0; i < num_hyps; ++i) { + const std::vector& hyp = hyps[i]; + float score = 0.0f; + // left-to-right decoder score + score = ComputeAttentionScore(probs[i], hyp, eos_); + // Optional: Used for right to left score + float r_score = 0.0f; + if (is_bidirectional_decoder_ && reverse_weight > 0) { + // right-to-left score + CHECK_EQ(r_probs.size(0), num_hyps); + CHECK_EQ(r_probs.size(1), max_hyps_len); + std::vector r_hyp(hyp.size()); + std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); + // right to left decoder score + r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_); + } + + // combined left-to-right and right-to-left score + (*rescoring_score)[i] = + score * (1 - reverse_weight) + r_score * reverse_weight; + } +} + +} // namespace wenet diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h new file mode 100644 index 000000000..182fb56a0 --- /dev/null +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -0,0 +1,64 @@ +// Copyright (c) 2022 SDCI Co. Ltd (author: veelion) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_TORCH_ASR_MODEL_H_ +#define DECODER_BATCH_TORCH_ASR_MODEL_H_ + +#include +#include +#include + +#include "torch/script.h" +#include "torch/torch.h" + +#include "decoder/batch_asr_model.h" +#include "utils/utils.h" + +namespace wenet { + +class BatchTorchAsrModel : public BatchAsrModel { +public: + // Note: Do not call the InitEngineThreads function more than once. + static void InitEngineThreads(int num_threads = 1); + + public: + using TorchModule = torch::jit::script::Module; + BatchTorchAsrModel() = default; + BatchTorchAsrModel(const BatchTorchAsrModel& other); + void Read(const std::string& model_path); + std::shared_ptr torch_model() const { return model_; } + void AttentionRescoring(const std::vector>& hyps, + int batch_index, + float reverse_weight, + std::vector* rescoring_score) override; + std::shared_ptr Copy() const override; + + protected: + void ForwardEncoderFunc( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& batch_ctc_log_prob) override; + + float ComputeAttentionScore(const torch::Tensor& batch_prob, + const std::vector& hyp, int eos); + + private: + std::shared_ptr model_ = nullptr; + torch::Tensor encoder_out_; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_TORCH_ASR_MODEL_H_ From 455c870bcd8e13c8626fb5d891c20c5a16ffc725 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:30:59 +0800 Subject: [PATCH 03/62] add api_batch_main --- runtime/core/bin/CMakeLists.txt | 3 ++ runtime/core/bin/api_batch_main.cc | 48 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 runtime/core/bin/api_batch_main.cc diff --git a/runtime/core/bin/CMakeLists.txt b/runtime/core/bin/CMakeLists.txt index 3eb2b3c47..9d96708e2 100644 --- a/runtime/core/bin/CMakeLists.txt +++ b/runtime/core/bin/CMakeLists.txt @@ -7,6 +7,9 @@ target_link_libraries(label_checker_main PUBLIC decoder) if(TORCH) add_executable(api_main api_main.cc) target_link_libraries(api_main PUBLIC wenet_api) + + add_executable(api_batch_main api_batch_main.cc) + target_link_libraries(api_batch_main PUBLIC wenet_api) endif() if(WEBSOCKET) diff --git a/runtime/core/bin/api_batch_main.cc b/runtime/core/bin/api_batch_main.cc new file mode 100644 index 000000000..9be76f58a --- /dev/null +++ b/runtime/core/bin/api_batch_main.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "api/batch_recognizer.h" +#include "api/wenet_api.h" +#include "frontend/wav.h" +#include "utils/flags.h" +#include "utils/timer.h" + +DEFINE_string(model_dir, "", "model dir path"); +DEFINE_string(wav_path, "", "single wave path"); +DEFINE_int32(batch_size, 1, "batch size of input"); +DEFINE_int32(num_threads, 1, "number threads of intraop"); +DEFINE_bool(enable_timestamp, false, "enable timestamps"); + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + wenet_set_log_level(2); + + BatchRecognizer br(FLAGS_model_dir, FLAGS_num_threads); + wenet::WavReader wav_reader(FLAGS_wav_path); + std::vector data; + data.insert(data.end(), wav_reader.data(), wav_reader.data() + wav_reader.num_samples()); + std::vector> wavs; + for (size_t i = 0; i < FLAGS_batch_size - 1; i++) { + wavs.push_back(data); + } + wavs.push_back(std::move(data)); + wenet::Timer timer; + std::string result = br.DecodeData(wavs); + int forward_time = timer.Elapsed(); + VLOG(1) << "Decode() takes " << forward_time << " ms"; + LOG(INFO) << result; + return 0; +} From a67bc9308e42f1ae2e7fb4237cdf8ed2d4e6a4ff Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:33:43 +0800 Subject: [PATCH 04/62] add BatchRecognizer to api --- runtime/core/api/batch_recognizer.h | 170 ++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 runtime/core/api/batch_recognizer.h diff --git a/runtime/core/api/batch_recognizer.h b/runtime/core/api/batch_recognizer.h new file mode 100644 index 000000000..bf5ed6c3d --- /dev/null +++ b/runtime/core/api/batch_recognizer.h @@ -0,0 +1,170 @@ +// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include + +#include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" +#include "decoder/batch_torch_asr_model.h" +#include "post_processor/post_processor.h" +#include "utils/file.h" +#include "utils/json.h" +#include "utils/string.h" + +class BatchRecognizer { + public: + explicit BatchRecognizer(const std::string& model_dir, int num_threads=1) { + // FeaturePipeline init + feature_config_ = std::make_shared(80, 16000); + // Resource init + resource_ = std::make_shared(); + wenet::BatchTorchAsrModel::InitEngineThreads(num_threads); + std::string model_path = wenet::JoinPath(model_dir, "final.zip"); + CHECK(wenet::FileExists(model_path)); + + auto model = std::make_shared(); + model->Read(model_path); + resource_->batch_model = model; + + // units.txt: E2E model unit + std::string unit_path = wenet::JoinPath(model_dir, "units.txt"); + CHECK(wenet::FileExists(unit_path)); + resource_->unit_table = std::shared_ptr( + fst::SymbolTable::ReadText(unit_path)); + + std::string fst_path = wenet::JoinPath(model_dir, "TLG.fst"); + if (wenet::FileExists(fst_path)) { // With LM + resource_->fst = std::shared_ptr>( + fst::Fst::Read(fst_path)); + + std::string symbol_path = wenet::JoinPath(model_dir, "words.txt"); + CHECK(wenet::FileExists(symbol_path)); + resource_->symbol_table = std::shared_ptr( + fst::SymbolTable::ReadText(symbol_path)); + } else { // Without LM, symbol_table is the same as unit_table + resource_->symbol_table = resource_->unit_table; + } + + // Context config init + context_config_ = std::make_shared(); + decode_options_ = std::make_shared(); + post_process_opts_ = std::make_shared(); + } + + void InitDecoder() { + CHECK(decoder_ == nullptr); + // Optional init context graph + if (context_.size() > 0) { + context_config_->context_score = context_score_; + auto context_graph = + std::make_shared(*context_config_); + context_graph->BuildContextGraph(context_, resource_->symbol_table); + resource_->context_graph = context_graph; + } + // PostProcessor + if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr) + post_process_opts_->language_type = wenet::kMandarinEnglish; + } else { + post_process_opts_->language_type = wenet::kIndoEuropean; + } + resource_->post_processor = + std::make_shared(*post_process_opts_); + // Init decoder + decoder_ = std::make_shared(feature_config_, resource_, + *decode_options_); + } + + std::string Decode(const std::vector& wavs) { + // Init decoder when it is called first time + if (decoder_ == nullptr) { + InitDecoder(); + } + std::vector> wavs_float; + for (auto& wav : wavs) { + const int16_t* pcm = reinterpret_cast(wav.data()); + int pcm_len = wav.size() / sizeof(int16_t); + std::vector wav_float(pcm_len); + for (size_t i = 0; i < pcm_len; i++) { + wav_float[i] = static_cast(*(pcm + i)); + } + wavs_float.push_back(std::move(wav_float)); + } + decoder_->Reset(); + decoder_->Decode(wavs_float); + return UpdateResult(); + } + + std::string DecodeData(const std::vector>& wavs) { + // Init decoder when it is called first time + if (decoder_ == nullptr) { + InitDecoder(); + } + decoder_->Reset(); + decoder_->Decode(wavs); + return UpdateResult(); + } + + std::string UpdateResult() { + const auto& batch_result = decoder_->batch_result(); + json::JSON obj; + obj["batch_size"] = batch_result.size(); + obj["batch_result"] = json::Array(); + for (const auto& result : batch_result) { + json::JSON batch_one; + batch_one["nbest"] = json::Array(); + for (int i = 0; i < nbest_ && i < result.size(); i++) { + json::JSON one; + one["sentence"] = result[i].sentence; + if (enable_timestamp_) { + one["word_pieces"] = json::Array(); + for (const auto& word_piece : result[i].word_pieces) { + json::JSON piece; + piece["word"] = word_piece.word; + piece["start"] = word_piece.start; + piece["end"] = word_piece.end; + one["word_pieces"].append(piece); + } + } + one["sentence"] = result[i].sentence; + batch_one["nbest"].append(one); + } + obj["batch_result"].append(batch_one); + } + return obj.dump(); + } + + void set_nbest(int n) { nbest_ = n; } + void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; } + void AddContext(const char* word) { context_.emplace_back(word); } + void set_context_score(float score) { context_score_ = score; } + void set_language(const char* lang) { language_ = lang; } + + private: + std::shared_ptr feature_config_ = nullptr; + std::shared_ptr resource_ = nullptr; + std::shared_ptr decode_options_ = nullptr; + std::shared_ptr decoder_ = nullptr; + std::shared_ptr context_config_ = nullptr; + std::shared_ptr post_process_opts_ = nullptr; + + int nbest_ = 1; + bool enable_timestamp_ = false; + std::vector context_; + float context_score_; + std::string language_ = "chs"; +}; + From 44501aedb6af9d9819c9bdcb906dc0dd46ddb122 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:34:54 +0800 Subject: [PATCH 05/62] add batch_model pointer to DecodeResource --- runtime/core/decoder/asr_decoder.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runtime/core/decoder/asr_decoder.h b/runtime/core/decoder/asr_decoder.h index df71f5b7b..31c3a99a7 100644 --- a/runtime/core/decoder/asr_decoder.h +++ b/runtime/core/decoder/asr_decoder.h @@ -26,6 +26,7 @@ #include "fst/symbol-table.h" #include "decoder/asr_model.h" +#include "decoder/batch_asr_model.h" #include "decoder/context_graph.h" #include "decoder/ctc_endpoint.h" #include "decoder/ctc_prefix_beam_search.h" @@ -90,6 +91,7 @@ enum DecodeState { // decoding threads struct DecodeResource { std::shared_ptr model = nullptr; + std::shared_ptr batch_model = nullptr; std::shared_ptr symbol_table = nullptr; std::shared_ptr> fst = nullptr; std::shared_ptr unit_table = nullptr; From 7aabd7bf48d080efe8ad822a5b1c26e78c1c66cd Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:35:46 +0800 Subject: [PATCH 06/62] add batch processing source to decoder_srcs --- runtime/core/decoder/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/runtime/core/decoder/CMakeLists.txt b/runtime/core/decoder/CMakeLists.txt index cfa439f42..f89be91c2 100644 --- a/runtime/core/decoder/CMakeLists.txt +++ b/runtime/core/decoder/CMakeLists.txt @@ -5,6 +5,9 @@ set(decoder_srcs ctc_prefix_beam_search.cc ctc_wfst_beam_search.cc ctc_endpoint.cc + batch_asr_decoder.cc + batch_asr_model.cc + batch_torch_asr_model.cc ) if(NOT TORCH AND NOT ONNX) From 7a2870193d46b02645f2bc81e1bac58284dc6fd3 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:36:27 +0800 Subject: [PATCH 07/62] jit export forward_encoder_batch() --- wenet/transformer/asr_model.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 5dd151e49..8bcbf965b 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -726,3 +726,23 @@ def forward_attention_decoder( # r_dccoder_out will be 0.0, if reverse_weight is 0.0 r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) return decoder_out, r_decoder_out + + @torch.jit.export + def forward_encoder_batch( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, encode a batch of speech + + Args: + speech: padded input tensor (B, T, D) + speech_lengths: input length (B) + Returns: + encoder output tensor xs, and subsampled masks + encoder_out: padded output tensor (B, T' ~= T/subsample_rate, D) + encoder_mask: torch.Tensor batch padding mask after subsample + (B, 1, T' ~= T/subsample_rate) + """ + encoder_out, encoder_mask = self.encoder(speech, speech_lengths) + return encoder_out, encoder_mask From 19114eefff447a0cb7a5612ba49af93df1209f02 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 5 Aug 2022 17:39:07 +0800 Subject: [PATCH 08/62] add batch processing to Python binding --- runtime/binding/python/cpp/binding.cc | 10 +++ runtime/binding/python/py/__init__.py | 1 + runtime/binding/python/py/batch_decoder.py | 78 ++++++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 runtime/binding/python/py/batch_decoder.py diff --git a/runtime/binding/python/cpp/binding.cc b/runtime/binding/python/cpp/binding.cc index cff4f545e..52eea9528 100644 --- a/runtime/binding/python/cpp/binding.cc +++ b/runtime/binding/python/cpp/binding.cc @@ -13,8 +13,10 @@ // limitations under the License. #include +#include #include "api/wenet_api.h" +#include "api/batch_recognizer.h" namespace py = pybind11; @@ -37,4 +39,12 @@ PYBIND11_MODULE(_wenet, m) { m.def("wenet_set_language", &wenet_set_language, "set language"); m.def("wenet_set_continuous_decoding", &wenet_set_continuous_decoding, "enable continuous decoding or not"); + py::class_(m, "BatchRecognizer") + .def(py::init()) + .def("set_enable_timestamp", &BatchRecognizer::set_enable_timestamp) + .def("AddContext", &BatchRecognizer::AddContext) + .def("set_context_score", &BatchRecognizer::set_context_score) + .def("set_language", &BatchRecognizer::set_language) + .def("DecodeData", &BatchRecognizer::DecodeData) + .def("Decode", &BatchRecognizer::Decode); } diff --git a/runtime/binding/python/py/__init__.py b/runtime/binding/python/py/__init__.py index 58cd2aef8..2886f52aa 100644 --- a/runtime/binding/python/py/__init__.py +++ b/runtime/binding/python/py/__init__.py @@ -1,2 +1,3 @@ from .decoder import Decoder # noqa +from .batch_decoder import BatchDecoder from _wenet import wenet_set_log_level as set_log_level # noqa diff --git a/runtime/binding/python/py/batch_decoder.py b/runtime/binding/python/py/batch_decoder.py new file mode 100644 index 000000000..c4d3f0761 --- /dev/null +++ b/runtime/binding/python/py/batch_decoder.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 Binbin Zhang(binbzha@qq.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import _wenet + +from .hub import Hub + + +class BatchDecoder: + + def __init__(self, + model_dir: Optional[str] = None, + lang: str = 'chs', + nbest: int = 1, + enable_timestamp: bool = False, + context: Optional[List[str]] = None, + context_score: float = 3.0): + """ Init WeNet decoder + Args: + lang: language type of the model + nbest: nbest number for the final result + enable_timestamp: whether to enable word level timestamp + for the final result + context: context words + context_score: bonus score when the context is matched + """ + if model_dir is None: + model_dir = Hub.get_model_by_lang(lang) + + self.d = _wenet.BatchRecognizer(model_dir) + + self.set_language(lang) + self.enable_timestamp(enable_timestamp) + if context is not None: + self.add_context(context) + self.set_context_score(context_score) + + def __del__(self): + del self.d + + def enable_timestamp(self, flag: bool): + tag = 1 if flag else 0 + self.d.set_enable_timestamp(tag) + + def add_context(self, contexts: List[str]): + for c in contexts: + assert isinstance(c, str) + self.d.AddContext(c) + + def set_context_score(self, score: float): + self.d.set_context_score(score) + + def set_language(self, lang: str): + assert lang in ['chs', 'en'] + self.d.set_language(lang) + + def decode(self, pcms: List[bytes]) -> str: + """ Decode the input data + + Args: + pcms: a list of wav pcm + """ + assert isinstance(pcms[0], bytes) + result = self.d.Decode(pcms) + return result From 0fd6c8976fb2281da4488a95c282f15774fb8360 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 17 Aug 2022 16:38:32 +0800 Subject: [PATCH 09/62] before change attention-scoring --- runtime/core/decoder/batch_torch_asr_model.cc | 142 +++++++++--------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index 98b93b1b7..6c09ffe56 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -16,6 +16,9 @@ #include "decoder/batch_torch_asr_model.h" +#ifdef USE_GPU +#include +#endif #include #include #include @@ -27,6 +30,7 @@ namespace wenet { void BatchTorchAsrModel::InitEngineThreads(int num_threads) { + VLOG(1) << "Num intra-op default threads: " << at::get_num_threads(); // For multi-thread performance at::set_num_threads(num_threads); // Note: Do not call the set_num_interop_threads function more than once. @@ -44,7 +48,7 @@ void BatchTorchAsrModel::Read(const std::string& model_path) { VLOG(1) << "CUDA is not available! Please check your GPU settings"; throw std::runtime_error("CUDA is not available!"); } else { - VLOG(1) << "CUDA available! Running on GPU"; + VLOG(1) << "CUDA is available! Running on GPU"; device = at::kCUDA; } #endif @@ -99,12 +103,12 @@ void BatchTorchAsrModel::ForwardEncoderFunc( batch_ctc_log_prob_t& out_prob) { // 1. Prepare libtorch required data int batch_size = batch_feats.size(); - int num_frames = batch_feats[0].size(); + num_frames_ = batch_feats[0].size(); const int feature_dim = batch_feats[0][0].size(); torch::Tensor feats = - torch::zeros({batch_size, num_frames, feature_dim}, torch::kFloat); + torch::zeros({batch_size, num_frames_, feature_dim}, torch::kFloat); for (size_t i = 0; i < batch_size; ++i) { - for (size_t j = 0; j < num_frames; ++j) { + for (size_t j = 0; j < num_frames_; ++j) { torch::Tensor row = torch::from_blob(const_cast(batch_feats[i][j].data()), {feature_dim}, torch::kFloat).clone(); @@ -125,15 +129,18 @@ void BatchTorchAsrModel::ForwardEncoderFunc( // Refer interfaces in wenet/transformer/asr_model.py auto outputs = - model_->get_method("forward_encoder_batch")(inputs).toTuple()->elements(); - CHECK_EQ(outputs.size(), 2); + model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); + CHECK_EQ(outputs.size(), 3); encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) + encoder_lens_ = outputs[1].toTensor(); // (B,) // The first dimension of returned value is for batchsize - torch::Tensor ctc_log_probs = - model_->run_method("ctc_activation", encoder_out_).toTensor(); + torch::Tensor ctc_log_probs = outputs[2].toTensor(); #ifdef USE_GPU + encoder_out_ = encoder_out_.to(at::kCPU); + encoder_lens_ = encoder_lens_.to(at::kCPU); ctc_log_probs = ctc_log_probs.to(at::kCPU); + c10::cuda::CUDACachingAllocator::emptyCache(); #endif // Copy to output @@ -163,81 +170,74 @@ float BatchTorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob, } void BatchTorchAsrModel::AttentionRescoring( - const std::vector>& hyps, - int batch_index, + const std::vector>>& batch_hyps, float reverse_weight, - std::vector* rescoring_score) { - CHECK(rescoring_score != nullptr); - int num_hyps = hyps.size(); - rescoring_score->resize(num_hyps, 0.0f); - - if (num_hyps == 0) { - return; - } + std::vector>* attention_scores) { + CHECK(attention_scores != nullptr); - torch::NoGradGuard no_grad; // Step 1: Prepare input for libtorch - torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong); + int batch_size = batch_hyps.size(); + int beam_size = batch_hyps[0].size(); // should be 10 + torch::Tensor hyps_lens_sos = torch::zeros({batch_size, beam_size}, torch::kLong); int max_hyps_len = 0; - for (size_t i = 0; i < num_hyps; ++i) { - int length = hyps[i].size() + 1; - max_hyps_len = std::max(length, max_hyps_len); - hyps_length[i] = static_cast(length); + for (size_t i = 0; i < batch_size; i++) { + for (size_t j = 0; j < beam_size; j++) { + int length = batch_hyps[i][j].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens_sos[i][j] = static_cast(length); + } + } + + // 1.2 add sos and eos to hyps, and padded by ignore_id (-1) + torch::Tensor hyps_pad_sos_eos = + torch::ones({batch_size, beam_size, max_hyps_len + 1}, torch::kLong) * eos_; + for (size_t i = 0; i < batch_size; i++) { + for (size_t j = 0; j < beam_size; j++) { + const std::vector& hyp = batch_hyps[i][j]; + hyps_pad_sos_eos[i][j][0] = sos_; + for (size_t k = 0; k < hyp.size(); k++) { + hyps_pad_sos_eos[i][j][k + 1] = hyp[k]; + } + hyps_pad_sos_eos[i][j][hyp.size() + 1] = eos_; + } } - torch::Tensor hyps_tensor = - torch::zeros({num_hyps, max_hyps_len}, torch::kLong); - for (size_t i = 0; i < num_hyps; ++i) { - const std::vector& hyp = hyps[i]; - hyps_tensor[i][0] = sos_; - for (size_t j = 0; j < hyp.size(); ++j) { - hyps_tensor[i][j + 1] = hyp[j]; + torch::Tensor r_hyps_pad_sos_eos = + torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong) * eos_; + for (size_t i = 0; i < batch_size; i++) { + for (size_t j = 0; j < beam_size; j++) { + const std::vector& hyp = batch_hyps[i][j]; + hyps_pad_sos_eos[i][j][0] = sos_; + int hyp_size = hyp.size(); + for (size_t k = 0; k < hyp_size; k++) { + hyps_pad_sos_eos[i][j][k + 1] = hyp[hyp_size - 1 - k]; // reverse copy + } + hyps_pad_sos_eos[i][j][hyp_size + 1] = eos_; } } - // Step 2: Forward attention decoder by hyps and corresponding encoder_outs_ - using namespace torch::indexing; - torch::Tensor encoder_out = encoder_out_.index({Slice(batch_index, batch_index + 1)}); + // Step 2: Forward attention decoder #ifdef USE_GPU - hyps_tensor = hyps_tensor.to(at::kCUDA); - hyps_length = hyps_length.to(at::kCUDA); - encoder_out = encoder_out.to(at::kCUDA); + hyps_pad_sos_eos = hyps_pad_sos_eos.to(at::kCUDA); + hyps_lens_sos = hyps_lens_sos.to(at::kCUDA); + r_hyps_pad_sos_eos = r_hyps_pad_sos_eos.to(at::kCUDA); + encoder_lens_ = encoder_lens_.to(at::kCUDA); + encoder_out_ = encoder_out_.to(at::kCUDA); #endif - auto outputs = model_ - ->run_method("forward_attention_decoder", hyps_tensor, - hyps_length, encoder_out, reverse_weight) - .toTuple() - ->elements(); + torch::NoGradGuard no_grad; + auto outputs = model_->run_method("batch_forward_attention_decoder", + encoder_out_, encoder_lens_, + hyps_pad_sos_eos, hyps_lens_sos, + r_hyps_pad_sos_eos, reverse_weight).toTensor(); #ifdef USE_GPU - auto probs = outputs[0].toTensor().to(at::kCPU); - auto r_probs = outputs[1].toTensor().to(at::kCPU); -#else - auto probs = outputs[0].toTensor(); - auto r_probs = outputs[1].toTensor(); + outputs = outputs.to(at::kCPU); + c10::cuda::CUDACachingAllocator::emptyCache(); #endif - CHECK_EQ(probs.size(0), num_hyps); - CHECK_EQ(probs.size(1), max_hyps_len); - - // Step 3: Compute rescoring score - for (size_t i = 0; i < num_hyps; ++i) { - const std::vector& hyp = hyps[i]; - float score = 0.0f; - // left-to-right decoder score - score = ComputeAttentionScore(probs[i], hyp, eos_); - // Optional: Used for right to left score - float r_score = 0.0f; - if (is_bidirectional_decoder_ && reverse_weight > 0) { - // right-to-left score - CHECK_EQ(r_probs.size(0), num_hyps); - CHECK_EQ(r_probs.size(1), max_hyps_len); - std::vector r_hyp(hyp.size()); - std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); - // right to left decoder score - r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_); - } - - // combined left-to-right and right-to-left score - (*rescoring_score)[i] = - score * (1 - reverse_weight) + r_score * reverse_weight; + CHECK_EQ(outputs.size(0), batch_size); + attention_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + (*attention_scores)[i].resize(beam_size); + memcpy((*attention_scores)[i].data(), outputs[i].data_ptr(), + sizeof(float) * beam_size); } } From 44392e62945caa546e0f436f2cc80ce4790dbf1b Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 17:52:44 +0800 Subject: [PATCH 10/62] add multi-threads for computing fbank, ctc searching --- runtime/core/decoder/batch_asr_decoder.cc | 225 ++++++++++++++-------- runtime/core/decoder/batch_asr_decoder.h | 23 ++- 2 files changed, 162 insertions(+), 86 deletions(-) diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 508f1edcc..1948b0f42 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -30,101 +31,168 @@ BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, std::shared_ptr resource, const DecodeOptions& opts) : feature_config_(config), + beam_size_(opts.ctc_prefix_search_opts.first_beam_size), fbank_(config->num_bins, config->sample_rate, config->frame_length, config->frame_shift), model_(resource->batch_model->Copy()), post_processor_(resource->post_processor), symbol_table_(resource->symbol_table), fst_(resource->fst), unit_table_(resource->unit_table), + resource_(resource), opts_(opts) { if (opts_.reverse_weight > 0) { // Check if model has a right to left decoder CHECK(model_->is_bidirectional_decoder()); } +} + +void BatchAsrDecoder::Reset() { + batch_result_.clear(); +} + +void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index) { + Timer ctc_timer; + SearchInterface* searcher; if (nullptr == fst_) { - searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts, - resource->context_graph)); + searcher = new CtcPrefixBeamSearch(opts_.ctc_prefix_search_opts, + resource_->context_graph); } else { - searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts, - resource->context_graph)); + searcher = new CtcWfstBeamSearch(*fst_, opts_.ctc_wfst_search_opts, + resource_->context_graph); } + // 3.1. ctc search + ctc_timer.Reset(); + searcher->Search(ctc_log_probs); + searcher->FinalizeSearch(); + VLOG(1) << "\tctc search i==" << index << " takes " << ctc_timer.Elapsed() << " ms"; + ctc_timer.Reset(); + std::vector result; + UpdateResult(searcher, result); + std::lock_guard lock(mutex_); + batch_pair_result_.emplace_back(std::make_pair(index, std::move(result))); + const auto& hypotheses = searcher->Inputs(); + if (hypotheses.size() < beam_size_) { + VLOG(2) << "=== searcher->Inputs() size < beam_size_, padding..."; + std::vector> hyps = hypotheses; + int to_pad = beam_size_ - hypotheses.size(); + for (size_t i = 0; i < to_pad; i++) { + std::vector pad = {0}; + hyps.push_back(std::move(pad)); + } + batch_hyps_.emplace_back(std::make_pair(index, std::move(hyps))); + } else { + batch_hyps_.emplace_back(std::make_pair(index, std::move(hypotheses))); + } + delete searcher; } -void BatchAsrDecoder::Reset() { - result_.clear(); - batch_result_.clear(); - global_frame_offset_ = 0; - searcher_->Reset(); +void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { + Timer timer; + feature_t feats; + int num_frames = fbank_.Compute(wav, &feats); + std::lock_guard lock(mutex_); + batch_feats_.push_back(std::make_pair(index, std::move(feats))); + batch_feats_lens_.push_back(std::make_pair(index, num_frames)); + VLOG(1) << "\tfeature comput i==" << index << ", takes " << timer.Elapsed() << " ms."; } void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 1. calc fbank feature of the batch of wavs Timer timer; + std::vector fbank_threads; + for (size_t i = 0; i < wavs.size(); i++) { + const std::vector& wav = wavs[i]; + std::thread thd(&BatchAsrDecoder::FbankWorker, this, wav, i); + fbank_threads.push_back(std::move(thd)); + } + for(auto& thd : fbank_threads) { + thd.join(); + } + std::sort(batch_feats_.begin(), batch_feats_.end()); + std::sort(batch_feats_lens_.begin(), batch_feats_lens_.end()); batch_feature_t batch_feats; std::vector batch_feats_lens; - VLOG(1) << "wavs : " << wavs.size(); - for (const auto& wav : wavs) { - VLOG(1) << "wav : " << wav.size(); - feature_t feats; - int num_frames = fbank_.Compute(wav, &feats); - VLOG(1) << "feat leng is " << num_frames; - batch_feats.push_back(std::move(feats)); - batch_feats_lens.push_back(num_frames); + for (auto& pair : batch_feats_) { + batch_feats.push_back(std::move(pair.second)); + } + for (auto& pair : batch_feats_lens_) { + batch_feats_lens.push_back(pair.second); + // VLOG(1) << "\t feats lens: " << pair.second; } - int feat_time = timer.Elapsed(); - VLOG(1) << "feat_time : " << feat_time; + VLOG(1) << "feature Compute takes " << timer.Elapsed() << " ms."; // 1.1 feature padding timer.Reset(); int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); - VLOG(1) << "max length feature : " << max_len; for (auto& feat : batch_feats) { if (feat.size() == max_len) continue; int pad_len = max_len - feat.size(); for (size_t i = 0; i< pad_len; i++) { - std::vector one(feature_config_->num_bins, 0); + std::vector one(feature_config_->num_bins, 0.0); feat.push_back(std::move(one)); } } - VLOG(1) << "padding time : " << timer.Elapsed(); - timer.Reset(); + VLOG(1) << "padding feautre takes " << timer.Elapsed() << " ms."; // 2. encoder forward + timer.Reset(); batch_ctc_log_prob_t batch_ctc_log_probs; model_->ForwardEncoder(batch_feats, batch_feats_lens, batch_ctc_log_probs); - VLOG(1) << "encoder forward time : " << timer.Elapsed(); + VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; // 3. ctc search one by one of the batch - // it seems, decoder forward only support 1 encoder_out with n-best ctc search result + // create batch of tct search result for attention decoding int batch_size = wavs.size(); - batch_result_.clear(); + timer.Reset(); + batch_pair_result_.clear(); + batch_hyps_.clear(); + std::vector search_threads; for (size_t i = 0; i < batch_size; i++) { - timer.Reset(); const auto& ctc_log_probs = batch_ctc_log_probs[i]; - // 3.1. ctc search - searcher_->Search(ctc_log_probs); - int search_time = timer.Elapsed(); - VLOG(1) << "search takes " << search_time << " ms"; - - // 3.2. rescoring - timer.Reset(); - AttentionRescoring(i); - VLOG(1) << "Rescoring cost latency: " << timer.Elapsed() << "ms."; - - // 3.3. save to batch_result_ - batch_result_.push_back(std::move(result_)); - - // 3.4 reset - searcher_->Reset(); + std::thread thd(&BatchAsrDecoder::SearchWorker, this, ctc_log_probs, i); + search_threads.push_back(std::move(thd)); + } + for(auto& thd : search_threads) { + thd.join(); + } + VLOG(1) << "ctc search batch(" << batch_size << ") takes " << timer.Elapsed() << " ms."; + + // 4. attention rescoring + timer.Reset(); + std::sort(batch_hyps_.begin(), batch_hyps_.end()); + std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), [](auto& a, auto& b) { + return a.first < b.first; + }); + std::vector>> batch_hyps; + for (auto& pair : batch_hyps_) { + batch_hyps.push_back(std::move(pair.second)); + } + + batch_result_.clear(); + for (auto& pair : batch_pair_result_) { + batch_result_.push_back(std::move(pair.second)); + } + timer.Reset(); + std::vector> attention_scores; + model_->AttentionRescoring(batch_hyps, opts_.reverse_weight, &attention_scores); + VLOG(1) << "attention rescoring takes " << timer.Elapsed() << " ms."; + for (size_t i = 0; i < batch_size; i++) { + std::vector& result = batch_result_[i]; + for (size_t j = 0; j < beam_size_; j++) { + result[j].score = opts_.rescoring_weight * attention_scores[i][j] + + opts_.ctc_weight * result[j].score; + } + std::sort(result.begin(), result.end(), DecodeResult::CompareFunc); } } -void BatchAsrDecoder::UpdateResult(bool finish) { - const auto& hypotheses = searcher_->Outputs(); - const auto& inputs = searcher_->Inputs(); - const auto& likelihood = searcher_->Likelihood(); - const auto& times = searcher_->Times(); - result_.clear(); +void BatchAsrDecoder::UpdateResult(SearchInterface* searcher, std::vector& result) { + bool finish = true; + const auto& hypotheses = searcher->Outputs(); + const auto& inputs = searcher->Inputs(); + const auto& likelihood = searcher->Likelihood(); + const auto& times = searcher->Times(); + result.clear(); CHECK_EQ(hypotheses.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { @@ -132,12 +200,11 @@ void BatchAsrDecoder::UpdateResult(bool finish) { DecodeResult path; path.score = likelihood[i]; - int offset = global_frame_offset_ * feature_frame_shift_in_ms(); for (size_t j = 0; j < hypothesis.size(); j++) { std::string word = symbol_table_->Find(hypothesis[j]); // A detailed explanation of this if-else branch can be found in // https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058 - if (searcher_->Type() == kWfstBeamSearch) { + if (searcher->Type() == kWfstBeamSearch) { path.sentence += (' ' + word); } else { path.sentence += (word); @@ -173,7 +240,7 @@ void BatchAsrDecoder::UpdateResult(bool finish) { frame_shift_in_ms() : end; } - WordPiece word_piece(word, offset + start, offset + end); + WordPiece word_piece(word, start, end); path.word_pieces.emplace_back(word_piece); } } @@ -181,35 +248,39 @@ void BatchAsrDecoder::UpdateResult(bool finish) { if (post_processor_ != nullptr) { path.sentence = post_processor_->Process(path.sentence, finish); } - result_.emplace_back(path); + result.emplace_back(path); } } -void BatchAsrDecoder::AttentionRescoring(int batch_index) { - searcher_->FinalizeSearch(); - UpdateResult(true); - // No need to do rescoring - if (0.0 == opts_.rescoring_weight) { - return; - } - // Inputs() returns N-best input ids, which is the basic unit for rescoring - // In CtcPrefixBeamSearch, inputs are the same to outputs - const auto& hypotheses = searcher_->Inputs(); - int num_hyps = hypotheses.size(); - if (num_hyps <= 0) { - return; - } - - std::vector rescoring_score; - model_->AttentionRescoring(hypotheses, batch_index, opts_.reverse_weight, - &rescoring_score); - - // Combine ctc score and rescoring score - for (size_t i = 0; i < num_hyps; ++i) { - result_[i].score = opts_.rescoring_weight * rescoring_score[i] + - opts_.ctc_weight * result_[i].score; +const std::string BatchAsrDecoder::get_batch_result(int nbest, bool enable_timestamp) { + json::JSON obj; + obj["status"] = "ok"; + obj["type"] = "final_result"; + obj["batch_size"] = batch_result_.size(); + obj["batch_result"] = json::Array(); + for (const auto& result : batch_result_) { + json::JSON batch_one; + batch_one["nbest"] = json::Array(); + for (int i = 0; i < nbest && i < result.size(); i++) { + json::JSON one; + one["sentence"] = result[i].sentence; + // one["score"] = result[i].score; + if (enable_timestamp) { + one["word_pieces"] = json::Array(); + for (const auto& word_piece : result[i].word_pieces) { + json::JSON piece; + piece["word"] = word_piece.word; + piece["start"] = word_piece.start; + piece["end"] = word_piece.end; + one["word_pieces"].append(piece); + } + } + one["sentence"] = result[i].sentence; + batch_one["nbest"].append(one); + } + obj["batch_result"].append(batch_one); + } + return obj.dump(); } - std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc); -} } // namespace wenet diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h index 07f851e0e..1f487da3f 100644 --- a/runtime/core/decoder/batch_asr_decoder.h +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -35,6 +35,7 @@ #include "post_processor/post_processor.h" #include "utils/utils.h" #include "frontend/fbank.h" +#include "utils/json.h" namespace wenet { @@ -56,14 +57,22 @@ class BatchAsrDecoder { return feature_config_->frame_shift * 1000 / feature_config_->sample_rate; } - const std::vector& result() const { return result_; } const std::vector>& batch_result() const { return batch_result_; } + const std::string get_batch_result(int nbest, bool enable_timestamp); private: Fbank fbank_; - void AttentionRescoring(int batch_index); + void FbankWorker(const std::vector& wav, int index); + std::vector> batch_feats_; // for FbankWorker + std::vector> batch_feats_lens_; // for FbankWorker + + void SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index); + std::mutex mutex_; + std::vector>>> batch_hyps_; // for SearchWorker + std::vector>> batch_pair_result_; // for SearchWorker + std::vector> batch_result_; - void UpdateResult(bool finish = false); + void UpdateResult(SearchInterface* searcher, std::vector& result); std::shared_ptr feature_config_; std::shared_ptr model_; @@ -74,15 +83,11 @@ class BatchAsrDecoder { std::shared_ptr symbol_table_; // e2e unit symbol table std::shared_ptr unit_table_ = nullptr; + std::shared_ptr resource_ = nullptr; const DecodeOptions& opts_; - int global_frame_offset_ = 0; + int beam_size_; const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence - std::unique_ptr searcher_; - - std::vector result_; - std::vector> batch_result_; - public: WENET_DISALLOW_COPY_AND_ASSIGN(BatchAsrDecoder); }; From e1e597e805dd83ede470ac2bb91135dd0351fce3 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 17:55:06 +0800 Subject: [PATCH 11/62] to call jit script which support batch_forward_attention_decoder() --- runtime/core/decoder/batch_asr_model.h | 10 +- runtime/core/decoder/batch_torch_asr_model.cc | 96 ++++++++----------- runtime/core/decoder/batch_torch_asr_model.h | 7 +- 3 files changed, 48 insertions(+), 65 deletions(-) diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index b33be2285..5469ff483 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -16,7 +16,8 @@ namespace wenet { using feature_t = std::vector>; using batch_feature_t = std::vector; -using batch_ctc_log_prob_t = std::vector; +using ctc_log_prob_t = std::vector>; +using batch_ctc_log_prob_t = std::vector; class BatchAsrModel { @@ -34,10 +35,9 @@ class BatchAsrModel { const std::vector& batch_feats_lens, batch_ctc_log_prob_t& batch_ctc_prob); - virtual void AttentionRescoring(const std::vector>& hyps, - int batch_index, - float reverse_weight, - std::vector* rescoring_score) = 0; + virtual void AttentionRescoring(const std::vector>>& batch_hyps, + float reverse_weight, + std::vector>* attention_scores) = 0; virtual std::shared_ptr Copy() const = 0; diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index 6c09ffe56..5c81fc1c4 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -42,17 +42,16 @@ void BatchTorchAsrModel::InitEngineThreads(int num_threads) { } void BatchTorchAsrModel::Read(const std::string& model_path) { - torch::DeviceType device = at::kCPU; #ifdef USE_GPU if (!torch::cuda::is_available()) { VLOG(1) << "CUDA is not available! Please check your GPU settings"; throw std::runtime_error("CUDA is not available!"); } else { VLOG(1) << "CUDA is available! Running on GPU"; - device = at::kCUDA; + device_ = at::kCUDA; } #endif - torch::jit::script::Module model = torch::jit::load(model_path, device); + torch::jit::script::Module model = torch::jit::load(model_path, device_); model_ = std::make_shared(std::move(model)); torch::NoGradGuard no_grad; model_->eval(); @@ -89,7 +88,7 @@ BatchTorchAsrModel::BatchTorchAsrModel(const BatchTorchAsrModel& other) { // inference, please see https://pytorch.org/docs/stable/notes/cpu_ // threading_torchscript_inference.html model_ = other.model_; - + device_ = other.device_; } std::shared_ptr BatchTorchAsrModel::Copy() const { @@ -103,12 +102,12 @@ void BatchTorchAsrModel::ForwardEncoderFunc( batch_ctc_log_prob_t& out_prob) { // 1. Prepare libtorch required data int batch_size = batch_feats.size(); - num_frames_ = batch_feats[0].size(); + int num_frames = batch_feats[0].size(); const int feature_dim = batch_feats[0][0].size(); torch::Tensor feats = - torch::zeros({batch_size, num_frames_, feature_dim}, torch::kFloat); + torch::zeros({batch_size, num_frames, feature_dim}, torch::kFloat); for (size_t i = 0; i < batch_size; ++i) { - for (size_t j = 0; j < num_frames_; ++j) { + for (size_t j = 0; j < num_frames; ++j) { torch::Tensor row = torch::from_blob(const_cast(batch_feats[i][j].data()), {feature_dim}, torch::kFloat).clone(); @@ -120,28 +119,21 @@ void BatchTorchAsrModel::ForwardEncoderFunc( {batch_size}, torch::kInt).clone(); // 2. Encoder batch forward -#ifdef USE_GPU - feats = feats.to(at::kCUDA); - feats_lens = feats_lens.to(at::kCUDA); -#endif + feats = feats.to(device_); + feats_lens = feats_lens.to(device_); torch::NoGradGuard no_grad; std::vector inputs = {feats, feats_lens}; - // Refer interfaces in wenet/transformer/asr_model.py auto outputs = model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); CHECK_EQ(outputs.size(), 3); encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) encoder_lens_ = outputs[1].toTensor(); // (B,) - - // The first dimension of returned value is for batchsize - torch::Tensor ctc_log_probs = outputs[2].toTensor(); -#ifdef USE_GPU - encoder_out_ = encoder_out_.to(at::kCPU); - encoder_lens_ = encoder_lens_.to(at::kCPU); - ctc_log_probs = ctc_log_probs.to(at::kCPU); - c10::cuda::CUDACachingAllocator::emptyCache(); -#endif + + auto ctc_log_probs = outputs[2].toTensor().to(at::kCPU); + // encoder_out_ = encoder_out_.to(at::kCPU); // to CPU to save GPU memory + // encoder_lens_ = encoder_lens_.to(at::kCPU); + // c10::cuda::CUDACachingAllocator::emptyCache(); // Copy to output int num_outputs = ctc_log_probs.size(1); @@ -174,10 +166,9 @@ void BatchTorchAsrModel::AttentionRescoring( float reverse_weight, std::vector>* attention_scores) { CHECK(attention_scores != nullptr); - // Step 1: Prepare input for libtorch int batch_size = batch_hyps.size(); - int beam_size = batch_hyps[0].size(); // should be 10 + int beam_size = batch_hyps[0].size(); torch::Tensor hyps_lens_sos = torch::zeros({batch_size, beam_size}, torch::kLong); int max_hyps_len = 0; for (size_t i = 0; i < batch_size; i++) { @@ -188,56 +179,47 @@ void BatchTorchAsrModel::AttentionRescoring( } } - // 1.2 add sos and eos to hyps, and padded by ignore_id (-1) - torch::Tensor hyps_pad_sos_eos = - torch::ones({batch_size, beam_size, max_hyps_len + 1}, torch::kLong) * eos_; + // 1.2 add sos to hyps + torch::Tensor hyps_pad_sos = + torch::zeros({batch_size, beam_size, max_hyps_len}, torch::kLong); for (size_t i = 0; i < batch_size; i++) { for (size_t j = 0; j < beam_size; j++) { const std::vector& hyp = batch_hyps[i][j]; - hyps_pad_sos_eos[i][j][0] = sos_; + hyps_pad_sos[i][j][0] = sos_; for (size_t k = 0; k < hyp.size(); k++) { - hyps_pad_sos_eos[i][j][k + 1] = hyp[k]; + hyps_pad_sos[i][j][k + 1] = hyp[k]; } - hyps_pad_sos_eos[i][j][hyp.size() + 1] = eos_; - } - } - torch::Tensor r_hyps_pad_sos_eos = - torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong) * eos_; - for (size_t i = 0; i < batch_size; i++) { - for (size_t j = 0; j < beam_size; j++) { - const std::vector& hyp = batch_hyps[i][j]; - hyps_pad_sos_eos[i][j][0] = sos_; - int hyp_size = hyp.size(); - for (size_t k = 0; k < hyp_size; k++) { - hyps_pad_sos_eos[i][j][k + 1] = hyp[hyp_size - 1 - k]; // reverse copy - } - hyps_pad_sos_eos[i][j][hyp_size + 1] = eos_; } } // Step 2: Forward attention decoder -#ifdef USE_GPU - hyps_pad_sos_eos = hyps_pad_sos_eos.to(at::kCUDA); - hyps_lens_sos = hyps_lens_sos.to(at::kCUDA); - r_hyps_pad_sos_eos = r_hyps_pad_sos_eos.to(at::kCUDA); - encoder_lens_ = encoder_lens_.to(at::kCUDA); - encoder_out_ = encoder_out_.to(at::kCUDA); -#endif + hyps_pad_sos = hyps_pad_sos.to(device_); + hyps_lens_sos = hyps_lens_sos.to(device_); + // encoder_lens_ = encoder_lens_.to(device_); + // encoder_out_ = encoder_out_.to(device_); torch::NoGradGuard no_grad; auto outputs = model_->run_method("batch_forward_attention_decoder", encoder_out_, encoder_lens_, - hyps_pad_sos_eos, hyps_lens_sos, - r_hyps_pad_sos_eos, reverse_weight).toTensor(); -#ifdef USE_GPU - outputs = outputs.to(at::kCPU); + hyps_pad_sos, hyps_lens_sos, + reverse_weight).toTuple()->elements(); + auto decoder_out = outputs[0].toTensor().to(at::kCPU); + auto r_decoder_out = outputs[1].toTensor().to(at::kCPU); c10::cuda::CUDACachingAllocator::emptyCache(); -#endif - CHECK_EQ(outputs.size(0), batch_size); attention_scores->resize(batch_size); for (size_t i = 0; i < batch_size; i++) { (*attention_scores)[i].resize(beam_size); - memcpy((*attention_scores)[i].data(), outputs[i].data_ptr(), - sizeof(float) * beam_size); + for (size_t j = 0; j < beam_size; ++j) { + const std::vector& hyp = batch_hyps[i][j]; + float score = 0.0f; + score = ComputeAttentionScore(decoder_out[i * beam_size + j], hyp, eos_); + float r_score = 0.0f; + if (is_bidirectional_decoder_ && reverse_weight > 0) { + std::vector r_hyp(hyp.size()); + std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); + r_score = ComputeAttentionScore(r_decoder_out[i * beam_size + j], r_hyp, eos_); + } + (*attention_scores)[i][j] = score * (1 - reverse_weight) + r_score * reverse_weight; + } } } diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index 182fb56a0..a55a44469 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -39,10 +39,9 @@ class BatchTorchAsrModel : public BatchAsrModel { BatchTorchAsrModel(const BatchTorchAsrModel& other); void Read(const std::string& model_path); std::shared_ptr torch_model() const { return model_; } - void AttentionRescoring(const std::vector>& hyps, - int batch_index, + void AttentionRescoring(const std::vector>>& batch_hyps, float reverse_weight, - std::vector* rescoring_score) override; + std::vector>* attention_scores) override; std::shared_ptr Copy() const override; protected: @@ -57,6 +56,8 @@ class BatchTorchAsrModel : public BatchAsrModel { private: std::shared_ptr model_ = nullptr; torch::Tensor encoder_out_; + torch::Tensor encoder_lens_; + torch::DeviceType device_; }; } // namespace wenet From c25afd89010d2a0fd8d01092fc2de42c60e2563e Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 17:56:07 +0800 Subject: [PATCH 12/62] add run_batch flag to support BatchTorchAsrModel --- runtime/core/decoder/params.h | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index dcabaeadc..37ad6be2f 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -23,11 +23,13 @@ #include #include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" #ifdef USE_ONNX #include "decoder/onnx_asr_model.h" #endif #ifdef USE_TORCH #include "decoder/torch_asr_model.h" +#include "decoder/batch_torch_asr_model.h" #endif #include "frontend/feature_pipeline.h" #include "post_processor/post_processor.h" @@ -88,6 +90,7 @@ DEFINE_int32(language_type, 0, "0x00 = kMandarinEnglish, " "0x01 = kIndoEuropean"); DEFINE_bool(lowercase, true, "lowercase final result if needed"); +DEFINE_bool(run_batch, false, "run websocket server for batch decoding"); namespace wenet { std::shared_ptr InitFeaturePipelineConfigFromFlags() { @@ -132,11 +135,19 @@ std::shared_ptr InitDecodeResourceFromFlags() { #endif } else { #ifdef USE_TORCH - LOG(INFO) << "Reading torch model " << FLAGS_model_path; - TorchAsrModel::InitEngineThreads(FLAGS_num_threads); - auto model = std::make_shared(); - model->Read(FLAGS_model_path); - resource->model = model; + if (FLAGS_run_batch) { + LOG(INFO) << "BatchTorchAsrModel Reading torch model " << FLAGS_model_path; + BatchTorchAsrModel::InitEngineThreads(FLAGS_num_threads); + auto model = std::make_shared(); + model->Read(FLAGS_model_path); + resource->batch_model = model; + } else { + LOG(INFO) << "Reading torch model " << FLAGS_model_path; + TorchAsrModel::InitEngineThreads(FLAGS_num_threads); + auto model = std::make_shared(); + model->Read(FLAGS_model_path); + resource->model = model; + } #else LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'."; #endif From 63c42f47fd3c86b8e4a9a7e5fb12adf67c1c5d35 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 17:57:36 +0800 Subject: [PATCH 13/62] replace UpdateResult with decoder's get_batch_result() --- runtime/core/api/batch_recognizer.h | 35 ++++------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/runtime/core/api/batch_recognizer.h b/runtime/core/api/batch_recognizer.h index bf5ed6c3d..86cca4785 100644 --- a/runtime/core/api/batch_recognizer.h +++ b/runtime/core/api/batch_recognizer.h @@ -105,9 +105,9 @@ class BatchRecognizer { } decoder_->Reset(); decoder_->Decode(wavs_float); - return UpdateResult(); + return decoder_->get_batch_result(nbest_, enable_timestamp_); } - + std::string DecodeData(const std::vector>& wavs) { // Init decoder when it is called first time if (decoder_ == nullptr) { @@ -115,37 +115,10 @@ class BatchRecognizer { } decoder_->Reset(); decoder_->Decode(wavs); - return UpdateResult(); + return decoder_->get_batch_result(nbest_, enable_timestamp_); } - std::string UpdateResult() { - const auto& batch_result = decoder_->batch_result(); - json::JSON obj; - obj["batch_size"] = batch_result.size(); - obj["batch_result"] = json::Array(); - for (const auto& result : batch_result) { - json::JSON batch_one; - batch_one["nbest"] = json::Array(); - for (int i = 0; i < nbest_ && i < result.size(); i++) { - json::JSON one; - one["sentence"] = result[i].sentence; - if (enable_timestamp_) { - one["word_pieces"] = json::Array(); - for (const auto& word_piece : result[i].word_pieces) { - json::JSON piece; - piece["word"] = word_piece.word; - piece["start"] = word_piece.start; - piece["end"] = word_piece.end; - one["word_pieces"].append(piece); - } - } - one["sentence"] = result[i].sentence; - batch_one["nbest"].append(one); - } - obj["batch_result"].append(batch_one); - } - return obj.dump(); - } + void set_nbest(int n) { nbest_ = n; } void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; } From 238fc8ef7dad957d34af91c0ef2430bab8287efe Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 17:58:52 +0800 Subject: [PATCH 14/62] add FLAGS_enable_timestamp --- runtime/core/bin/api_batch_main.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/runtime/core/bin/api_batch_main.cc b/runtime/core/bin/api_batch_main.cc index 9be76f58a..faa2b133e 100644 --- a/runtime/core/bin/api_batch_main.cc +++ b/runtime/core/bin/api_batch_main.cc @@ -31,6 +31,7 @@ int main(int argc, char* argv[]) { wenet_set_log_level(2); BatchRecognizer br(FLAGS_model_dir, FLAGS_num_threads); + if (FLAGS_enable_timestamp) br.set_enable_timestamp(true); wenet::WavReader wav_reader(FLAGS_wav_path); std::vector data; data.insert(data.end(), wav_reader.data(), wav_reader.data() + wav_reader.num_samples()); From 3576cb4ab6c096bacb8105cae91bdcd1cb909628 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 18:00:15 +0800 Subject: [PATCH 15/62] add FLAGS_run_batch for runing for batch decoding --- runtime/core/bin/websocket_server_main.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runtime/core/bin/websocket_server_main.cc b/runtime/core/bin/websocket_server_main.cc index 796d9d2e6..3bf7308b6 100644 --- a/runtime/core/bin/websocket_server_main.cc +++ b/runtime/core/bin/websocket_server_main.cc @@ -29,6 +29,7 @@ int main(int argc, char* argv[]) { wenet::WebSocketServer server(FLAGS_port, feature_config, decode_config, decode_resource); LOG(INFO) << "Listening at port " << FLAGS_port; - server.Start(); + LOG(INFO) << "run for batch decoding: " << FLAGS_run_batch; + server.Start(FLAGS_run_batch); return 0; } From c3a17b10473a959c09cf77b3a13ff973cd7d2948 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 18:01:14 +0800 Subject: [PATCH 16/62] fix: https://github.com/nbsdx/SimpleJSON/issues/4 --- runtime/core/utils/json.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/runtime/core/utils/json.h b/runtime/core/utils/json.h index bf8d94a3e..773bed319 100644 --- a/runtime/core/utils/json.h +++ b/runtime/core/utils/json.h @@ -488,7 +488,8 @@ class JSON { Class Type = Class::Null; }; -JSON Array() { return std::move(JSON::Make(JSON::Class::Array)); } +// fix: https://github.com/nbsdx/SimpleJSON/issues/4 (veelion) +inline JSON Array() { return std::move(JSON::Make(JSON::Class::Array)); } template JSON Array(T... args) { @@ -497,9 +498,9 @@ JSON Array(T... args) { return std::move(arr); } -JSON Object() { return std::move(JSON::Make(JSON::Class::Object)); } +inline JSON Object() { return std::move(JSON::Make(JSON::Class::Object)); } -std::ostream& operator<<(std::ostream& os, const JSON& json) { +inline std::ostream& operator<<(std::ostream& os, const JSON& json) { os << json.dump(); return os; } @@ -744,7 +745,7 @@ JSON parse_next(const string& str, size_t& offset) { // NOLINT } } // namespace -JSON JSON::Load(const string& str) { +inline JSON JSON::Load(const string& str) { size_t offset = 0; return std::move(parse_next(str, offset)); } From 7adcd744784a090d7f5354fb6155f63ac2b6e463 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 18 Aug 2022 18:02:35 +0800 Subject: [PATCH 17/62] support run_batch --- runtime/core/websocket/websocket_server.cc | 18 +++++++++++++----- runtime/core/websocket/websocket_server.h | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/runtime/core/websocket/websocket_server.cc b/runtime/core/websocket/websocket_server.cc index 52ab088f4..84a5e637a 100644 --- a/runtime/core/websocket/websocket_server.cc +++ b/runtime/core/websocket/websocket_server.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "websocket/websocket_server.h" +#include "websocket/batch_connection_handler.h" #include #include @@ -244,7 +245,7 @@ void ConnectionHandler::operator()() { } } -void WebSocketServer::Start() { +void WebSocketServer::Start(bool run_batch) { try { auto const address = asio::ip::make_address("0.0.0.0"); tcp::acceptor acceptor{ioc_, {address, static_cast(port_)}}; @@ -254,10 +255,17 @@ void WebSocketServer::Start() { // Block until we get a connection acceptor.accept(socket); // Launch the session, transferring ownership of the socket - ConnectionHandler handler(std::move(socket), feature_config_, - decode_config_, decode_resource_); - std::thread t(std::move(handler)); - t.detach(); + if (run_batch) { + BatchConnectionHandler handler(std::move(socket), feature_config_, + decode_config_, decode_resource_); + std::thread t(std::move(handler)); + t.detach(); + } else { + ConnectionHandler handler(std::move(socket), feature_config_, + decode_config_, decode_resource_); + std::thread t(std::move(handler)); + t.detach(); + } } } catch (const std::exception& e) { LOG(FATAL) << e.what(); diff --git a/runtime/core/websocket/websocket_server.h b/runtime/core/websocket/websocket_server.h index a12418342..e211faf9c 100644 --- a/runtime/core/websocket/websocket_server.h +++ b/runtime/core/websocket/websocket_server.h @@ -85,7 +85,7 @@ class WebSocketServer { decode_config_(std::move(decode_config)), decode_resource_(std::move(decode_resource)) {} - void Start(); + void Start(bool run_batch=false); private: int port_; From a3045fc5f4437aff43c737380204f45c58dbcf5f Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 19 Aug 2022 10:13:16 +0800 Subject: [PATCH 18/62] add batch_connection_handler.h --- .../core/websocket/batch_connection_handler.h | 215 ++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 runtime/core/websocket/batch_connection_handler.h diff --git a/runtime/core/websocket/batch_connection_handler.h b/runtime/core/websocket/batch_connection_handler.h new file mode 100644 index 000000000..a1ee21ebc --- /dev/null +++ b/runtime/core/websocket/batch_connection_handler.h @@ -0,0 +1,215 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ +#define WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ + +#include +#include +#include +#include +#include + +#include "boost/asio/connect.hpp" +#include "boost/asio/ip/tcp.hpp" +#include "boost/beast/core.hpp" +#include "boost/beast/websocket.hpp" +#include "boost/json/src.hpp" + +#include "decoder/asr_decoder.h" +#include "decoder/batch_asr_decoder.h" +#include "frontend/feature_pipeline.h" +#include "utils/log.h" + +namespace wenet { + +namespace beast = boost::beast; // from +namespace http = beast::http; // from +namespace websocket = beast::websocket; // from +namespace asio = boost::asio; // from +using tcp = boost::asio::ip::tcp; // from +namespace json = boost::json; + +class BatchConnectionHandler { + public: + BatchConnectionHandler( + tcp::socket&& socket, + std::shared_ptr feature_config, + std::shared_ptr decode_config, + std::shared_ptr decode_resource) + : ws_(std::move(socket)), + feature_config_(std::move(feature_config)), + decode_config_(std::move(decode_config)), + decode_resource_(std::move(decode_resource)) {} + + void operator()() { + try { + // Accept the websocket handshake + ws_.accept(); + for (;;) { + // This buffer will hold the incoming message + beast::flat_buffer buffer; + // Read a message + ws_.read(buffer); + if (ws_.got_text()) { + std::string message = beast::buffers_to_string(buffer.data()); + LOG(INFO) << message; + OnText(message); + if (got_end_tag_) { + break; + } + } else { + if (!got_start_tag_) { + OnError("Start signal is expected before binary data"); + } else { + OnSpeechData(buffer); + break; + } + } + } + ws_.close(websocket::close_code::normal); + LOG(INFO) << "ws_ is closed, bye :)"; + } catch (beast::system_error const& se) { + LOG(INFO) << se.code().message(); + // This indicates that the session was closed + if (se.code() == websocket::error::closed) { + OnSpeechEnd(); + } + } catch (std::exception const& e) { + LOG(ERROR) << e.what(); + OnError("Decoder got some exception!"); + } + } + + private: + void OnSpeechStart() { + LOG(INFO) << "Received speech start signal, start reading speech"; + got_start_tag_ = true; + json::value rv = {{"status", "ok"}, {"type", "server_ready"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + decoder_ = std::make_shared( + feature_config_, decode_resource_, + *decode_config_); + } + + void OnSpeechEnd() { + LOG(INFO) << "Received speech end signal"; + got_end_tag_ = true; + } + + void OnText(const std::string& message) { + json::value v = json::parse(message); + if (v.is_object()) { + json::object obj = v.get_object(); + if (obj.find("signal") != obj.end()) { + json::string signal = obj["signal"].as_string(); + if (signal == "start") { + if (obj.find("nbest") != obj.end()) { + if (obj["nbest"].is_int64()) { + nbest_ = obj["nbest"].as_int64(); + } else { + OnError("integer is expected for nbest option"); + } + } + if (obj.find("enable_timestamp") != obj.end()) { + if (obj["enable_timestamp"].is_bool()) { + enable_timestamp_ = obj["enable_timestamp"].as_bool(); + } else { + OnError( + "boolean true or false is expected for " + "enable_timestamp option"); + } + } + if (obj.find("batch_lens") != obj.end()) { + if (obj["batch_lens"].is_array()) { + batch_lens_.clear(); + auto& batch_lens = obj["batch_lens"].as_array(); + for (size_t i = 0; i < batch_lens.size(); i++) { + int len = batch_lens[i].as_int64(); + batch_lens_.push_back(len); + } + } else { + OnError("a list of batch_lens should be given"); + } + } + OnSpeechStart(); + } else if (signal == "end") { + OnSpeechEnd(); + } else { + OnError("Unexpected signal type"); + } + } else { + OnError("Wrong message header"); + } + } else { + OnError("Wrong protocol"); + } + } + + void OnFinish() { + // Send finish tag + json::value rv = {{"status", "ok"}, {"type", "speech_end"}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + } + + void OnSpeechData(const beast::flat_buffer& buffer) { + // Read binary PCM data + std::vector> wavs; + size_t total = std::accumulate(batch_lens_.begin(), batch_lens_.end(), 0); + VLOG(1) << "buffer size " << buffer.size() << ", batch_lens_ sum " << total; + CHECK(buffer.size() == total); + const auto* pcm_data = static_cast(buffer.data().data()); + int offset = 0; + for (int len : batch_lens_) { + len /= 2; // lenght of int16_t data + std::vector wav(len); + for (size_t i = 0; i < len; i++) { + wav[i] = static_cast(pcm_data[offset+i]); + } + wavs.push_back(std::move(wav)); + offset += len; + } + CHECK(decoder_ != nullptr); + decoder_->Decode(wavs); + std::string result = decoder_->get_batch_result(nbest_, enable_timestamp_); + ws_.text(true); + ws_.write(asio::buffer(result)); + } + + void OnError(const std::string& message) { + json::value rv = {{"status", "failed"}, {"message", message}}; + ws_.text(true); + ws_.write(asio::buffer(json::serialize(rv))); + // Close websocket + ws_.close(websocket::close_code::normal); + } + + int nbest_ = 1; + bool enable_timestamp_ = false; + std::vector batch_lens_; + websocket::stream ws_; + std::shared_ptr feature_config_; + std::shared_ptr decode_config_; + std::shared_ptr decode_resource_; + + bool got_start_tag_ = false; + bool got_end_tag_ = false; + std::shared_ptr decoder_ = nullptr; +}; + +} // namespace wenet + +#endif // WEBSOCKET_BATCH_CONNECTION_HANDLER_H_ From 7bc634a206eacabe2c6d5b913185dcd31e719b7a Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 19 Aug 2022 10:16:22 +0800 Subject: [PATCH 19/62] jit export batch_forward_attention_decoder() --- wenet/transformer/asr_model.py | 73 +++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 5 deletions(-) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 8bcbf965b..cc31114c1 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -728,15 +728,15 @@ def forward_attention_decoder( return decoder_out, r_decoder_out @torch.jit.export - def forward_encoder_batch( + def batch_forward_encoder( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Export interface for c++ call, encode a batch of speech Args: - speech: padded input tensor (B, T, D) + speech: padded input tensor (B, T, F) speech_lengths: input length (B) Returns: encoder output tensor xs, and subsampled masks @@ -744,5 +744,68 @@ def forward_encoder_batch( encoder_mask: torch.Tensor batch padding mask after subsample (B, 1, T' ~= T/subsample_rate) """ - encoder_out, encoder_mask = self.encoder(speech, speech_lengths) - return encoder_out, encoder_mask + encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + encoder_out_lens = encoder_out_lens.int() + ctc_log_probs = self.ctc.log_softmax(encoder_out) + return encoder_out, encoder_out_lens, ctc_log_probs + + @torch.jit.export + def batch_forward_attention_decoder( + self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + reverse_weight: float = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Export interface for c++ call, forward decoder with batch of + hypothesis from ctc prefix beam search and encoder output + Args: + encoder_out: B x T x F + encoder_lens: B + hyps_pad_sos: B x beam x T2 + hyps with sos and padded by 0 + hyps_lens_sos: B x beam, length for each hyp with sos + reverse_weight: used for verfing whether used right to left decoder, + > 0 will use. + + Returns: + scores: (B, beam) + """ + assert encoder_out.size(0) == hyps_pad_sos.size(0) + B, T, F = encoder_out.shape + bz = hyps_pad_sos.shape[1] + B2 = B * bz + T2 = hyps_pad_sos.shape[2] + # 1. prepare inputs for decoder + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = torch.ones(B2, 1, T, + dtype=torch.bool, + device=encoder_out.device) + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + hyps = hyps_pad_sos.view(B2, T2) + hyps_lens = hyps_lens_sos.view(B2,) + if reverse_weight > 0: + r_hyps_lens = hyps_lens - 1 + r_hyps = hyps[:, 1:] + max_len = torch.max(r_hyps_lens) + index_range = torch.arange(0, max_len, 1).to(encoder_out.device) + seq_len_expand = r_hyps_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + index = (seq_len_expand - 1) - index_range # (beam, max_len) + index = index * seq_mask + r_hyps = torch.gather(r_hyps, 1, index) + r_hyps = torch.where(seq_mask, r_hyps, self.eos) + r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) + else: + r_hyps = torch.empty(0, device=encoder_out.device) + + # 2. decoding + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, + reverse_weight) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + r_decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + return decoder_out, r_decoder_out From 35e8d1a32bfcb281441a85b2042f18091cba64d9 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 14:36:39 +0800 Subject: [PATCH 20/62] add to decoder_srcs with batch_torch_asr_model.cc, batch_onnx_asr_model.cc --- runtime/core/decoder/CMakeLists.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/runtime/core/decoder/CMakeLists.txt b/runtime/core/decoder/CMakeLists.txt index f89be91c2..2a5b89cf4 100644 --- a/runtime/core/decoder/CMakeLists.txt +++ b/runtime/core/decoder/CMakeLists.txt @@ -7,17 +7,16 @@ set(decoder_srcs ctc_endpoint.cc batch_asr_decoder.cc batch_asr_model.cc - batch_torch_asr_model.cc ) if(NOT TORCH AND NOT ONNX) message(FATAL_ERROR "Please build with TORCH or ONNX!!!") endif() if(TORCH) - list(APPEND decoder_srcs torch_asr_model.cc) + list(APPEND decoder_srcs torch_asr_model.cc batch_torch_asr_model.cc) endif() if(ONNX) - list(APPEND decoder_srcs onnx_asr_model.cc) + list(APPEND decoder_srcs onnx_asr_model.cc batch_onnx_asr_model.cc) endif() add_library(decoder STATIC ${decoder_srcs}) From d172dd33fadbb25a9ed174c4ebca7bf97f374cbf Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 14:37:25 +0800 Subject: [PATCH 21/62] remove log msg --- runtime/core/decoder/batch_asr_decoder.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 1948b0f42..6e77deafa 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -117,7 +117,6 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } for (auto& pair : batch_feats_lens_) { batch_feats_lens.push_back(pair.second); - // VLOG(1) << "\t feats lens: " << pair.second; } VLOG(1) << "feature Compute takes " << timer.Elapsed() << " ms."; From d710a54ce4c920a516a363be5003e7313d3d1520 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 14:38:02 +0800 Subject: [PATCH 22/62] add is_fp16_ --- runtime/core/decoder/batch_asr_model.h | 1 + 1 file changed, 1 insertion(+) diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index 5469ff483..413f35a30 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -52,6 +52,7 @@ class BatchAsrModel { int sos_ = 0; int eos_ = 0; bool is_bidirectional_decoder_ = false; + bool is_fp16_ = false; }; } // namespace wenet From 1329090eae3cfe08393a46b850b77f6c82305da9 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 14:39:12 +0800 Subject: [PATCH 23/62] add is_fp16 to Read() --- runtime/core/decoder/batch_torch_asr_model.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index a55a44469..e081690ef 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -37,8 +37,7 @@ class BatchTorchAsrModel : public BatchAsrModel { using TorchModule = torch::jit::script::Module; BatchTorchAsrModel() = default; BatchTorchAsrModel(const BatchTorchAsrModel& other); - void Read(const std::string& model_path); - std::shared_ptr torch_model() const { return model_; } + void Read(const std::string& model_path, bool is_fp16=false); void AttentionRescoring(const std::vector>>& batch_hyps, float reverse_weight, std::vector>* attention_scores) override; From 3aa76c7146c313ea9a20b07afcc7eddde9357ddd Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 15:04:35 +0800 Subject: [PATCH 24/62] add BatchOnnxAsrModel on GPU --- runtime/core/decoder/batch_onnx_asr_model.cc | 405 +++++++++++++++++++ runtime/core/decoder/batch_onnx_asr_model.h | 83 ++++ runtime/core/decoder/params.h | 22 +- 3 files changed, 504 insertions(+), 6 deletions(-) create mode 100644 runtime/core/decoder/batch_onnx_asr_model.cc create mode 100644 runtime/core/decoder/batch_onnx_asr_model.h diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc new file mode 100644 index 000000000..1ba49cdb5 --- /dev/null +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -0,0 +1,405 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 ZeXuan Li (lizexuan@huya.com) +// Xingchen Song(sxc19@mails.tsinghua.edu.cn) +// hamddct@gmail.com (Mddct) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "decoder/batch_onnx_asr_model.h" + +#include +#include +#include +#include + +#include "utils/string.h" +#include "utils/Yaml.hpp" + +namespace wenet { + +Ort::Env BatchOnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); +Ort::SessionOptions BatchOnnxAsrModel::session_options_ = Ort::SessionOptions(); + +void BatchOnnxAsrModel::InitEngineThreads(int num_threads) { + session_options_.SetIntraOpNumThreads(num_threads); + session_options_.SetInterOpNumThreads(num_threads); +} + +void BatchOnnxAsrModel::GetInputOutputInfo( + const std::shared_ptr& session, + std::vector* in_names, std::vector* out_names) { + Ort::AllocatorWithDefaultOptions allocator; + // Input info + int num_nodes = session->GetInputCount(); + in_names->resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + char* name = session->GetInputName(i, allocator); + Ort::TypeInfo type_info = session->GetInputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType type = tensor_info.GetElementType(); + std::vector node_dims = tensor_info.GetShape(); + std::stringstream shape; + for (auto j : node_dims) { + shape << j; + shape << " "; + } + LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type + << " dims=" << shape.str(); + (*in_names)[i] = name; + } + // Output info + num_nodes = session->GetOutputCount(); + out_names->resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + char* name = session->GetOutputName(i, allocator); + Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType type = tensor_info.GetElementType(); + std::vector node_dims = tensor_info.GetShape(); + std::stringstream shape; + for (auto j : node_dims) { + shape << j; + shape << " "; + } + LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type + << " dims=" << shape.str(); + (*out_names)[i] = name; + } +} + +void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { + is_fp16_ = is_fp16; + VLOG(1) << "is_fp16_ " << is_fp16_; + std::vector providers = Ort::GetAvailableProviders(); + VLOG(1) << "providers.size(): " << providers.size(); + bool cuda_is_available = false; + for (auto& prd : providers) { + VLOG(1) << "available provider: " << prd; + if (prd.find("CUDA") != std::string::npos) { + cuda_is_available = true; + } + } + if (!cuda_is_available) { + VLOG(1) << "CUDA is not available! Please check your GPU settings!"; + throw std::runtime_error("CUDA is not available!"); + } + std::string encoder_onnx_path = model_dir + "/encoder.onnx"; + std::string rescore_onnx_path = model_dir + "/decoder.onnx"; + if (is_fp16) { + encoder_onnx_path = model_dir + "/encoder_fp16.onnx"; + rescore_onnx_path = model_dir + "/decoder_fp16.onnx"; + } + + // 1. Load sessions + std::vector keys{ + "device_id", + "gpu_mem_limit", + "arena_extend_strategy", + "cudnn_conv_algo_search", + "do_copy_in_default_stream", + "cudnn_conv_use_max_workspace", + "cudnn_conv1d_pad_to_nc1d" + }; + std::vector values{ + "0", + "2147483648", + "kSameAsRequested", + "DEFAULT", + "1", + "1", + "1" + }; + std::cout << "prepare cuda options ...\n"; + const auto& api = Ort::GetApi(); + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + OrtStatus* error = api.CreateCUDAProviderOptions(&cuda_options); + if (error) { + api.ReleaseStatus(error); + throw std::runtime_error("CreateCUDAProviderOptions error"); + } + error = api.UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size()); + if (error) { + api.ReleaseStatus(error); + throw std::runtime_error("UpdateCUDAProviderOptions error"); + } + error = api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options); + if (error) { + api.ReleaseStatus(error); + throw std::runtime_error("SessionOptionsAppendExecutionProvider_CUDA_V2 error"); + } + api.ReleaseCUDAProviderOptions(cuda_options); + std::cout << "done cuda options ...\n"; + + try { + encoder_session_ = std::make_shared( + env_, encoder_onnx_path.c_str(), session_options_); + rescore_session_ = std::make_shared( + env_, rescore_onnx_path.c_str(), session_options_); + } catch (std::exception const& e) { + LOG(ERROR) << "error when load onnx model: " << e.what(); + exit(0); + } + std::cout << "read onnx model done \n"; + + // 2. Read config + std::string config_path = JoinPath(model_dir, "config.yaml"); + VLOG(1) << "Read " << config_path; + Yaml::Node root; + Yaml::Parse(root, config_path.c_str()); + sos_ = root["sos"].As(); + eos_ = root["eos"].As(); + is_bidirectional_decoder_ = root["is_bidirectional_decoder"].As(); + + LOG(INFO) << "Onnx Model Info:"; + LOG(INFO) << "\tsos " << sos_; + LOG(INFO) << "\teos " << eos_; + LOG(INFO) << "\tis bidirectional decoder " << is_bidirectional_decoder_; + + // 3. Read model nodes + LOG(INFO) << "Onnx Encoder:"; + GetInputOutputInfo(encoder_session_, &encoder_in_names_, &encoder_out_names_); + LOG(INFO) << "Onnx Rescore:"; + GetInputOutputInfo(rescore_session_, &rescore_in_names_, &rescore_out_names_); +} + +BatchOnnxAsrModel::BatchOnnxAsrModel(const BatchOnnxAsrModel& other) { + // metadatas + sos_ = other.sos_; + eos_ = other.eos_; + is_bidirectional_decoder_ = other.is_bidirectional_decoder_; + is_fp16_ = other.is_fp16_; + + // sessions + encoder_session_ = other.encoder_session_; + rescore_session_ = other.rescore_session_; + + // node names + encoder_in_names_ = other.encoder_in_names_; + encoder_out_names_ = other.encoder_out_names_; + rescore_in_names_ = other.rescore_in_names_; + rescore_out_names_ = other.rescore_out_names_; +} + +std::shared_ptr BatchOnnxAsrModel::Copy() const { + auto asr_model = std::make_shared(*this); + // Reset the inner states for new decoding + return asr_model; +} + +void BatchOnnxAsrModel::ForwardEncoderFunc( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& out_prob) { + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + // 1. Prepare onnx required data + int batch_size = batch_feats.size(); + int num_frames = batch_feats[0].size(); + int feature_dim = batch_feats[0][0].size(); + Ort::Value feats_ort{nullptr}; + + // speech + const int64_t feats_shape[3] = {batch_size, num_frames, feature_dim}; + if (is_fp16_) { + std::vector feats(batch_size * num_frames * feature_dim); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + for (size_t k = 0; k < feature_dim; ++k) { + int p = i * num_frames * feature_dim + j * feature_dim + k; + feats[p] = Ort::Float16_t(Eigen::half(batch_feats[i][j][k]).x); + } + } + } + feats_ort = std::move(Ort::Value::CreateTensor( + memory_info, feats.data(), feats.size(), feats_shape, 3)); + } else { + std::vector feats; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < num_frames; ++j) { + feats.insert(feats.end(), batch_feats[i][j].begin(), batch_feats[i][j].end()); + } + } + feats_ort = std::move(Ort::Value::CreateTensor( + memory_info, feats.data(), feats.size(), feats_shape, 3)); + } + + // speech_lens + const int64_t feats_lens_shape[1] = {batch_size}; + Ort::Value feats_lens_ort = Ort::Value::CreateTensor( + memory_info, const_cast(batch_feats_lens.data()), + batch_feats_lens.size(), feats_lens_shape, 1); + + // 2. Encoder forward + std::vector inputs; + for (auto name : encoder_in_names_) { + if (!strcmp(name, "speech")) { + inputs.push_back(std::move(feats_ort)); + } else if (!strcmp(name, "speech_lengths")) { + inputs.push_back(std::move(feats_lens_ort)); + } + } + + std::vector ort_outputs = encoder_session_->Run( + Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), + inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); + + float* ctc_log_probs = nullptr; + auto type_info = ort_outputs[2].GetTensorTypeAndShapeInfo(); + auto out_shape = type_info.GetShape(); + int num_outputs = out_shape[1]; + int output_dim = out_shape[2]; + if (is_fp16_) { + uint16_t* probs = ort_outputs[2].GetTensorMutableData(); + int length = out_shape[0] * out_shape[1] * out_shape[2]; + ctc_log_probs = new float[length]; + for (size_t i = 0; i < length; ++i) { + ctc_log_probs[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(probs[i])); + } + } else { + ctc_log_probs = ort_outputs[2].GetTensorMutableData(); + } + + out_prob.resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + out_prob[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + out_prob[i][j].resize(output_dim); + float* p = ctc_log_probs + (i * num_outputs + j) * output_dim; + memcpy(out_prob[i][j].data(), p, sizeof(float) * output_dim); + } + } + if (is_fp16_) { + delete [] ctc_log_probs; + } + // 3. cache encoder outs + encoder_outs_ = std::move(ort_outputs[0]); +} + +float BatchOnnxAsrModel::ComputeAttentionScore(const float* prob, + const std::vector& hyp, int eos, + int decode_out_len) { + float score = 0.0f; + for (size_t j = 0; j < hyp.size(); ++j) { + score += *(prob + j * decode_out_len + hyp[j]); + } + score += *(prob + hyp.size() * decode_out_len + eos); + return score; +} + +void BatchOnnxAsrModel::AttentionRescoring( + const std::vector>>& batch_hyps, + float reverse_weight, + std::vector>* attention_scores) { + Ort::MemoryInfo memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + // 1. prepare input for onnx + int batch_size = batch_hyps.size(); + int beam_size = batch_hyps[0].size(); + + // 1.1 generate hyps_lens_sos data for ort + std::vector hyps_lens_sos(batch_size * beam_size, 0); // (batch_size, beam_size) + int max_hyps_len = 0; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + int length = batch_hyps[i][j].size() + 1; + max_hyps_len = std::max(length, max_hyps_len); + hyps_lens_sos[i * beam_size + j] = length; + } + } + + // 1.2 generate hyps_pad_sos + std::vector hyps_pad_sos(batch_size * beam_size * max_hyps_len, 0); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + const std::vector& hyps = batch_hyps[i][j]; + hyps_pad_sos[i * beam_size * max_hyps_len] = sos_; + for (size_t k = 0; k < hyps.size(); ++k) { + hyps_pad_sos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[k]; + } + } + } + + // 2. forward attetion decoder + const int64_t hyps_lens_shape[] = {batch_size, beam_size}; + const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; + + Ort::Value hyps_lens_tensor = Ort::Value::CreateTensor( + memory_info, hyps_lens_sos.data(), hyps_lens_sos.size(), hyps_lens_shape, 2); + Ort::Value hyps_pad_tensor = Ort::Value::CreateTensor( + memory_info, hyps_pad_sos.data(), hyps_pad_sos.size(), hyps_pad_shape, 3); + + std::vector rescore_inputs; + rescore_inputs.emplace_back(std::move(encoder_outs_)); + rescore_inputs.emplace_back(std::move(hyps_pad_tensor)); + rescore_inputs.emplace_back(std::move(hyps_lens_tensor)); + + std::vector rescore_outputs = rescore_session_->Run( + Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(), + rescore_inputs.size(), rescore_out_names_.data(), + rescore_out_names_.size()); + + auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo(); + std::vector decoder_out_shape = type_info.GetShape(); //(B, beam, T2) + float* decoder_outs_data = nullptr; + float* r_decoder_outs_data = nullptr; + if (is_fp16_) { + int length = decoder_out_shape[0] * decoder_out_shape[1] * decoder_out_shape[2]; + decoder_outs_data = new float[length](); + auto outs = rescore_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i < length; ++i) { + decoder_outs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(outs[i])); + } + if (is_bidirectional_decoder_ && reverse_weight > 0) { + r_decoder_outs_data = new float[length](); + auto r_outs = rescore_outputs[1].GetTensorMutableData(); + for (size_t i = 0; i < length; ++i) { + r_decoder_outs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(r_outs[i])); + } + } + } else { + decoder_outs_data = rescore_outputs[0].GetTensorMutableData(); + if (is_bidirectional_decoder_ && reverse_weight > 0) { + r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData(); + } + } + + int decode_out_len = decoder_out_shape[2]; + attention_scores->clear(); + for (size_t i = 0; i < batch_size; ++i) { + std::vector Y(beam_size); + for (size_t j = 0; j < beam_size; ++j) { + const std::vector& hyp = batch_hyps[i][j]; + float score = 0.0f; + float* p = decoder_outs_data + (i * beam_size + j) * max_hyps_len * decode_out_len; + score = ComputeAttentionScore(p, hyp, eos_, decode_out_len); + float r_score = 0.0f; + if (is_bidirectional_decoder_ && reverse_weight > 0) { + std::vector r_hyp(hyp.size()); + std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); + p = r_decoder_outs_data + (i * beam_size +j) * max_hyps_len * decode_out_len; + r_score = ComputeAttentionScore(p, r_hyp, eos_, decode_out_len); + } + Y[j] = score * (1 - reverse_weight) + r_score * reverse_weight; + } + attention_scores->push_back(std::move(Y)); + } + if (is_fp16_) { + delete [] decoder_outs_data; + if (is_bidirectional_decoder_ && reverse_weight > 0) { + delete [] r_decoder_outs_data; + } + } +} + +} // namespace wenet diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h new file mode 100644 index 000000000..7844b8eea --- /dev/null +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -0,0 +1,83 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// 2022 ZeXuan Li (lizexuan@huya.com) +// Xingchen Song(sxc19@mails.tsinghua.edu.cn) +// hamddct@gmail.com (Mddct) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef DECODER_BATCH_ONNX_ASR_MODEL_H_ +#define DECODER_BATCH_ONNX_ASR_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +#include "decoder/batch_asr_model.h" +#include "utils/log.h" +#include "utils/utils.h" + +namespace wenet { + +class BatchOnnxAsrModel : public BatchAsrModel { + public: + // Note: Do not call the InitEngineThreads function more than once. + static void InitEngineThreads(int num_threads = 1); + + public: + BatchOnnxAsrModel() = default; + BatchOnnxAsrModel(const BatchOnnxAsrModel& other); + void Read(const std::string& model_dir, bool is_fp16=false); + void AttentionRescoring(const std::vector>>& batch_hyps, + float reverse_weight, + std::vector>* attention_scores) override; + std::shared_ptr Copy() const override; + + void GetInputOutputInfo(const std::shared_ptr& session, + std::vector* in_names, + std::vector* out_names); + + protected: + void ForwardEncoderFunc( + const batch_feature_t& batch_feats, + const std::vector& batch_feats_lens, + batch_ctc_log_prob_t& batch_ctc_log_prob) override; + + float ComputeAttentionScore(const float* prob, const std::vector& hyp, + int eos, int decode_out_len); + + private: + int encoder_output_size_ = 0; + bool is_fp16_ = false; + + // sessions + // NOTE(Mddct): The Env holds the logging state used by all other objects. + // One Env must be created before using any other Onnxruntime functionality. + static Ort::Env env_; // shared environment across threads. + static Ort::SessionOptions session_options_; + std::shared_ptr encoder_session_ = nullptr; + std::shared_ptr rescore_session_ = nullptr; + + // node names + std::vector encoder_in_names_, encoder_out_names_; + std::vector rescore_in_names_, rescore_out_names_; + + // cache encoder outs: [encoder_outs, encoder_outs_lens] + Ort::Value encoder_outs_{nullptr}; +}; + +} // namespace wenet + +#endif // DECODER_BATCH_ONNX_ASR_MODEL_H_ diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 37ad6be2f..9acdac9f9 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -26,6 +26,7 @@ #include "decoder/batch_asr_decoder.h" #ifdef USE_ONNX #include "decoder/onnx_asr_model.h" +#include "decoder/batch_onnx_asr_model.h" #endif #ifdef USE_TORCH #include "decoder/torch_asr_model.h" @@ -91,6 +92,7 @@ DEFINE_int32(language_type, 0, "0x01 = kIndoEuropean"); DEFINE_bool(lowercase, true, "lowercase final result if needed"); DEFINE_bool(run_batch, false, "run websocket server for batch decoding"); +DEFINE_bool(is_fp16, false, "the model is of fp16"); namespace wenet { std::shared_ptr InitFeaturePipelineConfigFromFlags() { @@ -125,11 +127,19 @@ std::shared_ptr InitDecodeResourceFromFlags() { if (!FLAGS_onnx_dir.empty()) { #ifdef USE_ONNX - LOG(INFO) << "Reading onnx model "; - OnnxAsrModel::InitEngineThreads(FLAGS_num_threads); - auto model = std::make_shared(); - model->Read(FLAGS_onnx_dir); - resource->model = model; + if (FLAGS_run_batch) { + LOG(INFO) << "BatchOnnxAsrModel Reading ONNX model dir: " << FLAGS_onnx_dir; + BatchOnnxAsrModel::InitEngineThreads(FLAGS_num_threads); + auto model = std::make_shared(); + model->Read(FLAGS_onnx_dir, FLAGS_is_fp16); + resource->batch_model = model; + } else { + LOG(INFO) << "Reading onnx model "; + OnnxAsrModel::InitEngineThreads(FLAGS_num_threads); + auto model = std::make_shared(); + model->Read(FLAGS_onnx_dir); + resource->model = model; + } #else LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'."; #endif @@ -139,7 +149,7 @@ std::shared_ptr InitDecodeResourceFromFlags() { LOG(INFO) << "BatchTorchAsrModel Reading torch model " << FLAGS_model_path; BatchTorchAsrModel::InitEngineThreads(FLAGS_num_threads); auto model = std::make_shared(); - model->Read(FLAGS_model_path); + model->Read(FLAGS_model_path, FLAGS_is_fp16); resource->batch_model = model; } else { LOG(INFO) << "Reading torch model " << FLAGS_model_path; From 816a13aa7b7bf855ad1d153646cdf9b044872524 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 25 Aug 2022 15:06:44 +0800 Subject: [PATCH 25/62] add Yaml reader --- runtime/core/utils/CMakeLists.txt | 3 +- runtime/core/utils/Yaml.cpp | 2773 +++++++++++++++++++++++++++++ runtime/core/utils/Yaml.hpp | 656 +++++++ 3 files changed, 3431 insertions(+), 1 deletion(-) create mode 100644 runtime/core/utils/Yaml.cpp create mode 100644 runtime/core/utils/Yaml.hpp diff --git a/runtime/core/utils/CMakeLists.txt b/runtime/core/utils/CMakeLists.txt index 686362688..1394fd50c 100644 --- a/runtime/core/utils/CMakeLists.txt +++ b/runtime/core/utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(utils STATIC string.cc utils.cc + Yaml.cpp ) if(NOT ANDROID) @@ -9,4 +10,4 @@ if(NOT ANDROID) else() target_link_libraries(utils PUBLIC fst dl) endif() -endif() \ No newline at end of file +endif() diff --git a/runtime/core/utils/Yaml.cpp b/runtime/core/utils/Yaml.cpp new file mode 100644 index 000000000..70adec6f3 --- /dev/null +++ b/runtime/core/utils/Yaml.cpp @@ -0,0 +1,2773 @@ +/* +* MIT License +* +* Copyright(c) 2018 Jimmie Bergmann +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files(the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions : +* +* The above copyright notice and this permission notice shall be included in all +* copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +*/ + +#include "Yaml.hpp" +#include +#include +#include +#include +#include +#include +#include + + +// Implementation access definitions. +#define NODE_IMP static_cast(m_pImp) +#define NODE_IMP_EXT(node) static_cast(node.m_pImp) +#define TYPE_IMP static_cast(m_pImp)->m_pImp + + +#define IT_IMP static_cast(m_pImp) + + +namespace Yaml +{ + class ReaderLine; + + // Exception message definitions. + static const std::string g_ErrorInvalidCharacter = "Invalid character found."; + static const std::string g_ErrorKeyMissing = "Missing key."; + static const std::string g_ErrorKeyIncorrect = "Incorrect key."; + static const std::string g_ErrorValueIncorrect = "Incorrect value."; + static const std::string g_ErrorTabInOffset = "Tab found in offset."; + static const std::string g_ErrorBlockSequenceNotAllowed = "Sequence entries are not allowed in this context."; + static const std::string g_ErrorUnexpectedDocumentEnd = "Unexpected document end."; + static const std::string g_ErrorDiffEntryNotAllowed = "Different entry is not allowed in this context."; + static const std::string g_ErrorIncorrectOffset = "Incorrect offset."; + static const std::string g_ErrorSequenceError = "Error in sequence node."; + static const std::string g_ErrorCannotOpenFile = "Cannot open file."; + static const std::string g_ErrorIndentation = "Space indentation is less than 2."; + static const std::string g_ErrorInvalidBlockScalar = "Invalid block scalar."; + static const std::string g_ErrorInvalidQuote = "Invalid quote."; + static const std::string g_EmptyString = ""; + static Yaml::Node g_NoneNode; + + // Global function definitions. Implemented at end of this source file. + static std::string ExceptionMessage(const std::string & message, ReaderLine & line); + static std::string ExceptionMessage(const std::string & message, ReaderLine & line, const size_t errorPos); + static std::string ExceptionMessage(const std::string & message, const size_t errorLine, const size_t errorPos); + static std::string ExceptionMessage(const std::string & message, const size_t errorLine, const std::string & data); + + static bool FindQuote(const std::string & input, size_t & start, size_t & end, size_t searchPos = 0); + static size_t FindNotCited(const std::string & input, char token, size_t & preQuoteCount); + static size_t FindNotCited(const std::string & input, char token); + static bool ValidateQuote(const std::string & input); + static void CopyNode(const Node & from, Node & to); + static bool ShouldBeCited(const std::string & key); + static void AddEscapeTokens(std::string & input, const std::string & tokens); + static void RemoveAllEscapeTokens(std::string & input); + + // Exception implementations + Exception::Exception(const std::string & message, const eType type) : + std::runtime_error(message), + m_Type(type) + { + } + + Exception::eType Exception::Type() const + { + return m_Type; + } + + const char * Exception::Message() const + { + return what(); + } + + InternalException::InternalException(const std::string & message) : + Exception(message, InternalError) + { + + } + + ParsingException::ParsingException(const std::string & message) : + Exception(message, ParsingError) + { + + } + + OperationException::OperationException(const std::string & message) : + Exception(message, OperationError) + { + + } + + + class TypeImp + { + + public: + + virtual ~TypeImp() + { + } + + virtual const std::string & GetData() const = 0; + virtual bool SetData(const std::string & data) = 0; + virtual size_t GetSize() const = 0; + virtual Node * GetNode(const size_t index) = 0; + virtual Node * GetNode(const std::string & key) = 0; + virtual Node * Insert(const size_t index) = 0; + virtual Node * PushFront() = 0; + virtual Node * PushBack() = 0; + virtual void Erase(const size_t index) = 0; + virtual void Erase(const std::string & key) = 0; + + }; + + class SequenceImp : public TypeImp + { + + public: + + ~SequenceImp() + { + for(auto it = m_Sequence.begin(); it != m_Sequence.end(); it++) + { + delete it->second; + } + } + + virtual const std::string & GetData() const + { + return g_EmptyString; + } + + virtual bool SetData(const std::string & data) + { + return false; + } + + virtual size_t GetSize() const + { + return m_Sequence.size(); + } + + virtual Node * GetNode(const size_t index) + { + auto it = m_Sequence.find(index); + if(it != m_Sequence.end()) + { + return it->second; + } + return nullptr; + } + + virtual Node * GetNode(const std::string & key) + { + return nullptr; + } + + virtual Node * Insert(const size_t index) + { + if(m_Sequence.size() == 0) + { + Node * pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; + } + + if(index >= m_Sequence.size()) + { + auto it = m_Sequence.end(); + --it; + Node * pNode = new Node; + m_Sequence.insert({it->first, pNode}); + return pNode; + } + + auto it = m_Sequence.cbegin(); + while(it != m_Sequence.cend()) + { + m_Sequence[it->first+1] = it->second; + + if(it->first == index) + { + break; + } + } + + Node * pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } + + virtual Node * PushFront() + { + for(auto it = m_Sequence.cbegin(); it != m_Sequence.cend(); it++) + { + m_Sequence[it->first+1] = it->second; + } + + Node * pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; + } + + virtual Node * PushBack() + { + size_t index = 0; + if(m_Sequence.size()) + { + auto it = m_Sequence.end(); + --it; + index = it->first + 1; + } + + Node * pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } + + virtual void Erase(const size_t index) + { + auto it = m_Sequence.find(index); + if(it == m_Sequence.end()) + { + return; + } + delete it->second; + m_Sequence.erase(index); + } + + virtual void Erase(const std::string & key) + { + } + + std::map m_Sequence; + + }; + + class MapImp : public TypeImp + { + + public: + + ~MapImp() + { + for(auto it = m_Map.begin(); it != m_Map.end(); it++) + { + delete it->second; + } + } + + virtual const std::string & GetData() const + { + return g_EmptyString; + } + + virtual bool SetData(const std::string & data) + { + return false; + } + + virtual size_t GetSize() const + { + return m_Map.size(); + } + + virtual Node * GetNode(const size_t index) + { + return nullptr; + } + + virtual Node * GetNode(const std::string & key) + { + auto it = m_Map.find(key); + if(it == m_Map.end()) + { + Node * pNode = new Node; + m_Map.insert({key, pNode}); + return pNode; + } + return it->second; + } + + virtual Node * Insert(const size_t index) + { + return nullptr; + } + + virtual Node * PushFront() + { + return nullptr; + } + + virtual Node * PushBack() + { + return nullptr; + } + + virtual void Erase(const size_t index) + { + } + + virtual void Erase(const std::string & key) + { + auto it = m_Map.find(key); + if(it == m_Map.end()) + { + return; + } + delete it->second; + m_Map.erase(key); + } + + std::map m_Map; + + }; + + class ScalarImp : public TypeImp + { + + public: + + ~ScalarImp() + { + } + + virtual const std::string & GetData() const + { + return m_Value; + } + + virtual bool SetData(const std::string & data) + { + m_Value = data; + return true; + } + + virtual size_t GetSize() const + { + return 0; + } + + virtual Node * GetNode(const size_t index) + { + return nullptr; + } + + virtual Node * GetNode(const std::string & key) + { + return nullptr; + } + + virtual Node * Insert(const size_t index) + { + return nullptr; + } + + virtual Node * PushFront() + { + return nullptr; + } + + virtual Node * PushBack() + { + return nullptr; + } + + virtual void Erase(const size_t index) + { + } + + virtual void Erase(const std::string & key) + { + } + + std::string m_Value; + + }; + + + // Node implementations. + class NodeImp + { + + public: + + NodeImp() : + m_Type(Node::None), + m_pImp(nullptr) + { + } + + ~NodeImp() + { + Clear(); + } + + void Clear() + { + if(m_pImp != nullptr) + { + delete m_pImp; + m_pImp = nullptr; + } + m_Type = Node::None; + } + + void InitSequence() + { + if(m_Type != Node::SequenceType || m_pImp == nullptr) + { + if(m_pImp) + { + delete m_pImp; + } + m_pImp = new SequenceImp; + m_Type = Node::SequenceType; + } + } + + void InitMap() + { + if(m_Type != Node::MapType || m_pImp == nullptr) + { + if(m_pImp) + { + delete m_pImp; + } + m_pImp = new MapImp; + m_Type = Node::MapType; + } + } + + void InitScalar() + { + if(m_Type != Node::ScalarType || m_pImp == nullptr) + { + if(m_pImp) + { + delete m_pImp; + } + m_pImp = new ScalarImp; + m_Type = Node::ScalarType; + } + + } + + Node::eType m_Type; ///< Type of node. + TypeImp * m_pImp; ///< Imp of type. + + }; + + // Iterator implementation class + class IteratorImp + { + + public: + + virtual ~IteratorImp() + { + } + + virtual Node::eType GetType() const = 0; + virtual void InitBegin(SequenceImp * pSequenceImp) = 0; + virtual void InitEnd(SequenceImp * pSequenceImp) = 0; + virtual void InitBegin(MapImp * pMapImp) = 0; + virtual void InitEnd(MapImp * pMapImp) = 0; + + }; + + class SequenceIteratorImp : public IteratorImp + { + + public: + + virtual Node::eType GetType() const + { + return Node::SequenceType; + } + + virtual void InitBegin(SequenceImp * pSequenceImp) + { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } + + virtual void InitEnd(SequenceImp * pSequenceImp) + { + m_Iterator = pSequenceImp->m_Sequence.end(); + } + + virtual void InitBegin(MapImp * pMapImp) + { + } + + virtual void InitEnd(MapImp * pMapImp) + { + } + + void Copy(const SequenceIteratorImp & it) + { + m_Iterator = it.m_Iterator; + } + + std::map::iterator m_Iterator; + + }; + + class MapIteratorImp : public IteratorImp + { + + public: + + virtual Node::eType GetType() const + { + return Node::MapType; + } + + virtual void InitBegin(SequenceImp * pSequenceImp) + { + } + + virtual void InitEnd(SequenceImp * pSequenceImp) + { + } + + virtual void InitBegin(MapImp * pMapImp) + { + m_Iterator = pMapImp->m_Map.begin(); + } + + virtual void InitEnd(MapImp * pMapImp) + { + m_Iterator = pMapImp->m_Map.end(); + } + + void Copy(const MapIteratorImp & it) + { + m_Iterator = it.m_Iterator; + } + + std::map::iterator m_Iterator; + + }; + + class SequenceConstIteratorImp : public IteratorImp + { + + public: + + virtual Node::eType GetType() const + { + return Node::SequenceType; + } + + virtual void InitBegin(SequenceImp * pSequenceImp) + { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } + + virtual void InitEnd(SequenceImp * pSequenceImp) + { + m_Iterator = pSequenceImp->m_Sequence.end(); + } + + virtual void InitBegin(MapImp * pMapImp) + { + } + + virtual void InitEnd(MapImp * pMapImp) + { + } + + void Copy(const SequenceConstIteratorImp & it) + { + m_Iterator = it.m_Iterator; + } + + std::map::const_iterator m_Iterator; + + }; + + class MapConstIteratorImp : public IteratorImp + { + + public: + + virtual Node::eType GetType() const + { + return Node::MapType; + } + + virtual void InitBegin(SequenceImp * pSequenceImp) + { + } + + virtual void InitEnd(SequenceImp * pSequenceImp) + { + } + + virtual void InitBegin(MapImp * pMapImp) + { + m_Iterator = pMapImp->m_Map.begin(); + } + + virtual void InitEnd(MapImp * pMapImp) + { + m_Iterator = pMapImp->m_Map.end(); + } + + void Copy(const MapConstIteratorImp & it) + { + m_Iterator = it.m_Iterator; + } + + std::map::const_iterator m_Iterator; + + }; + + + // Iterator class + Iterator::Iterator() : + m_Type(None), + m_pImp(nullptr) + { + } + + Iterator::~Iterator() + { + if(m_pImp) + { + switch(m_Type) + { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + + } + } + + Iterator::Iterator(const Iterator & it) : + m_Type(None), + m_pImp(nullptr) + { + *this = it; + } + + Iterator & Iterator::operator = (const Iterator & it) + { + if(m_pImp) + { + switch(m_Type) + { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp * pNewImp = nullptr; + + switch(it.m_Type) + { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceIteratorImp; + static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapIteratorImp; + static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; + } + + std::pair Iterator::operator *() + { + switch(m_Type) + { + case SequenceType: + return { g_EmptyString, *(static_cast(m_pImp)->m_Iterator->second)}; + break; + case MapType: + return {static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return { g_EmptyString, g_NoneNode}; + } + + Iterator & Iterator::operator ++ (int dummy) + { + switch(m_Type) + { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; + } + + Iterator & Iterator::operator -- (int dummy) + { + switch(m_Type) + { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; + } + + bool Iterator::operator == (const Iterator & it) + { + if(m_Type != it.m_Type) + { + return false; + } + + switch(m_Type) + { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; + } + + bool Iterator::operator != (const Iterator & it) + { + return !(*this == it); + } + + + // Const Iterator class + ConstIterator::ConstIterator() : + m_Type(None), + m_pImp(nullptr) + { + } + + ConstIterator::~ConstIterator() + { + if(m_pImp) + { + switch(m_Type) + { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + + } + } + + ConstIterator::ConstIterator(const ConstIterator & it) : + m_Type(None), + m_pImp(nullptr) + { + *this = it; + } + + ConstIterator & ConstIterator::operator = (const ConstIterator & it) + { + if(m_pImp) + { + switch(m_Type) + { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp * pNewImp = nullptr; + + switch(it.m_Type) + { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceConstIteratorImp; + static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapConstIteratorImp; + static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; + } + + std::pair ConstIterator::operator *() + { + switch(m_Type) + { + case SequenceType: + return { g_EmptyString, *(static_cast(m_pImp)->m_Iterator->second)}; + break; + case MapType: + return {static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return { g_EmptyString, g_NoneNode}; + } + + ConstIterator & ConstIterator::operator ++ (int dummy) + { + switch(m_Type) + { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; + } + + ConstIterator & ConstIterator::operator -- (int dummy) + { + switch(m_Type) + { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; + } + + bool ConstIterator::operator == (const ConstIterator & it) + { + if(m_Type != it.m_Type) + { + return false; + } + + switch(m_Type) + { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; + } + + bool ConstIterator::operator != (const ConstIterator & it) + { + return !(*this == it); + } + + + // Node class + Node::Node() : + m_pImp(new NodeImp) + { + } + + Node::Node(const Node & node) : + Node() + { + *this = node; + } + + Node::Node(const std::string & value) : + Node() + { + *this = value; + } + + Node::Node(const char * value) : + Node() + { + *this = value; + } + + Node::~Node() + { + delete static_cast(m_pImp); + } + + Node::eType Node::Type() const + { + return NODE_IMP->m_Type; + } + + bool Node::IsNone() const + { + return NODE_IMP->m_Type == Node::None; + } + + bool Node::IsSequence() const + { + return NODE_IMP->m_Type == Node::SequenceType; + } + + bool Node::IsMap() const + { + return NODE_IMP->m_Type == Node::MapType; + } + + bool Node::IsScalar() const + { + return NODE_IMP->m_Type == Node::ScalarType; + } + + void Node::Clear() + { + NODE_IMP->Clear(); + } + + size_t Node::Size() const + { + if(TYPE_IMP == nullptr) + { + return 0; + } + + return TYPE_IMP->GetSize(); + } + + Node & Node::Insert(const size_t index) + { + NODE_IMP->InitSequence(); + return *TYPE_IMP->Insert(index); + } + + Node & Node::PushFront() + { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushFront(); + } + Node & Node::PushBack() + { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushBack(); + } + + Node & Node::operator[](const size_t index) + { + NODE_IMP->InitSequence(); + Node * pNode = TYPE_IMP->GetNode(index); + if(pNode == nullptr) + { + g_NoneNode.Clear(); + return g_NoneNode; + } + return *pNode; + } + + Node & Node::operator[](const std::string & key) + { + NODE_IMP->InitMap(); + return *TYPE_IMP->GetNode(key); + } + + void Node::Erase(const size_t index) + { + if(TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::SequenceType) + { + return; + } + + return TYPE_IMP->Erase(index); + } + + void Node::Erase(const std::string & key) + { + if(TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::MapType) + { + return; + } + + return TYPE_IMP->Erase(key); + } + + Node & Node::operator = (const Node & node) + { + NODE_IMP->Clear(); + CopyNode(node, *this); + return *this; + } + + Node & Node::operator = (const std::string & value) + { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value); + return *this; + } + + Node & Node::operator = (const char * value) + { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value ? std::string(value) : ""); + return *this; + } + + Iterator Node::Begin() + { + Iterator it; + + if(TYPE_IMP != nullptr) + { + IteratorImp * pItImp = nullptr; + + switch(NODE_IMP->m_Type) + { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; + } + + ConstIterator Node::Begin() const + { + ConstIterator it; + + if(TYPE_IMP != nullptr) + { + IteratorImp * pItImp = nullptr; + + switch(NODE_IMP->m_Type) + { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; + } + + Iterator Node::End() + { + Iterator it; + + if(TYPE_IMP != nullptr) + { + IteratorImp * pItImp = nullptr; + + switch(NODE_IMP->m_Type) + { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; + } + + ConstIterator Node::End() const + { + ConstIterator it; + + if(TYPE_IMP != nullptr) + { + IteratorImp * pItImp = nullptr; + + switch(NODE_IMP->m_Type) + { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; + } + + it.m_pImp = pItImp; + } + + return it; + } + + const std::string & Node::AsString() const + { + if(TYPE_IMP == nullptr) + { + return g_EmptyString; + } + + return TYPE_IMP->GetData(); + } + + + + // Reader implementations + /** + * @breif Line information structure. + * + */ + class ReaderLine + { + + public: + + /** + * @breif Constructor. + * + */ + ReaderLine(const std::string & data = "", + const size_t no = 0, + const size_t offset = 0, + const Node::eType type = Node::None, + const unsigned char flags = 0) : + Data(data), + No(no), + Offset(offset), + Type(type), + Flags(flags), + NextLine(nullptr) + { + } + + enum eFlag + { + LiteralScalarFlag, ///< Literal scalar type, defined as "|". + FoldedScalarFlag, ///< Folded scalar type, defined as "<". + ScalarNewlineFlag ///< Scalar ends with a newline. + }; + + /** + * @breif Set flag. + * + */ + void SetFlag(const eFlag flag) + { + Flags |= FlagMask[static_cast(flag)]; + } + + /** + * @breif Set flags by mask value. + * + */ + void SetFlags(const unsigned char flags) + { + Flags |= flags; + } + + /** + * @breif Unset flag. + * + */ + void UnsetFlag(const eFlag flag) + { + Flags &= ~FlagMask[static_cast(flag)]; + } + + /** + * @breif Unset flags by mask value. + * + */ + void UnsetFlags(const unsigned char flags) + { + Flags &= ~flags; + } + + /** + * @breif Get flag value. + * + */ + bool GetFlag(const eFlag flag) const + { + return Flags & FlagMask[static_cast(flag)]; + } + + /** + * @breif Copy and replace scalar flags from another ReaderLine. + * + */ + void CopyScalarFlags(ReaderLine * from) + { + if (from == nullptr) + { + return; + } + + unsigned char newFlags = from->Flags & (FlagMask[0] | FlagMask[1] | FlagMask[2]); + Flags |= newFlags; + } + + static const unsigned char FlagMask[3]; + + std::string Data; ///< Data of line. + size_t No; ///< Line number. + size_t Offset; ///< Offset to first character in data. + Node::eType Type; ///< Type of line. + unsigned char Flags; ///< Flags of line. + ReaderLine * NextLine; ///< Pointer to next line. + + + + }; + + const unsigned char ReaderLine::FlagMask[3] = { 0x01, 0x02, 0x04 }; + + + /** + * @breif Implementation class of Yaml parsing. + * Parsing incoming stream and outputs a root node. + * + */ + class ParseImp + { + + public: + + /** + * @breif Default constructor. + * + */ + ParseImp() + { + } + + /** + * @breif Destructor. + * + */ + ~ParseImp() + { + ClearLines(); + } + + /** + * @breif Run full parsing procedure. + * + */ + void Parse(Node & root, std::iostream & stream) + { + try + { + root.Clear(); + ReadLines(stream); + PostProcessLines(); + //Print(); + ParseRoot(root); + } + catch(Exception e) + { + root.Clear(); + throw; + } + } + + private: + + /** + * @breif Copy constructor. + * + */ + ParseImp(const ParseImp & copy) + { + + } + + /** + * @breif Read all lines. + * Ignoring: + * - Empty lines. + * - Comments. + * - Document start/end. + * + */ + void ReadLines(std::iostream & stream) + { + std::string line = ""; + size_t lineNo = 0; + bool documentStartFound = false; + bool foundFirstNotEmpty = false; + std::streampos streamPos = 0; + + // Read all lines, as long as the stream is ok. + while (!stream.eof() && !stream.fail()) + { + // Read line + streamPos = stream.tellg(); + std::getline(stream, line); + lineNo++; + + // Remove comment + const size_t commentPos = FindNotCited(line, '#'); + if(commentPos != std::string::npos) + { + line.resize(commentPos); + } + + // Start of document. + if (documentStartFound == false && line == "---") + { + // Erase all lines before this line. + ClearLines(); + documentStartFound = true; + continue; + } + + // End of document. + if (line == "...") + { + break; + } + else if(line == "---") + { + stream.seekg(streamPos); + break; + } + + // Remove trailing return. + if (line.size()) + { + if (line[line.size() - 1] == '\r') + { + line.resize(line.size() - 1); + } + } + + // Validate characters. + for (size_t i = 0; i < line.size(); i++) + { + if (line[i] != '\t' && (line[i] < 32 || line[i] > 125)) + { + throw ParsingException(ExceptionMessage(g_ErrorInvalidCharacter, lineNo, i + 1)); + } + } + + // Validate tabs + const size_t firstTabPos = line.find_first_of('\t'); + size_t startOffset = line.find_first_not_of(" \t"); + + // Make sure no tabs are in the very front. + if (startOffset != std::string::npos) + { + if(firstTabPos < startOffset) + { + throw ParsingException(ExceptionMessage(g_ErrorTabInOffset, lineNo, firstTabPos)); + } + + // Remove front spaces. + line = line.substr(startOffset); + } + else + { + startOffset = 0; + line = ""; + } + + // Add line. + if(foundFirstNotEmpty == false) + { + if(line.size()) + { + foundFirstNotEmpty = true; + } + else + { + continue; + } + } + + ReaderLine * pLine = new ReaderLine(line, lineNo, startOffset); + m_Lines.push_back(pLine); + } + } + + /** + * @breif Run post-processing on all lines. + * Basically split lines into multiple lines if needed, to follow the parsing algorithm. + * + */ + void PostProcessLines() + { + for (auto it = m_Lines.begin(); it != m_Lines.end();) + { + // Sequence. + if (PostProcessSequenceLine(it) == true) + { + continue; + } + + // Mapping. + if (PostProcessMappingLine(it) == true) + { + continue; + } + + // Scalar. + PostProcessScalarLine(it); + } + + // Set next line of all lines. + if (m_Lines.size()) + { + if (m_Lines.back()->Type != Node::ScalarType) + { + throw ParsingException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *m_Lines.back())); + } + + if (m_Lines.size() > 1) + { + auto prevEnd = m_Lines.end(); + --prevEnd; + + for (auto it = m_Lines.begin(); it != prevEnd; it++) + { + auto nextIt = it; + ++nextIt; + + (*it)->NextLine = *nextIt; + } + } + } + } + + /** + * @breif Run post-processing and check for sequence. + * Split line into two lines if sequence token is not on it's own line. + * + * @return true if line is sequence, else false. + * + */ + bool PostProcessSequenceLine(std::list::iterator & it) + { + ReaderLine * pLine = *it; + + // Sequence split + if (IsSequenceStart(pLine->Data) == false) + { + return false; + } + + pLine->Type = Node::SequenceType; + + ClearTrailingEmptyLines(++it); + + const size_t valueStart = pLine->Data.find_first_not_of(" \t", 1); + if (valueStart == std::string::npos) + { + return true; + } + + // Create new line and insert + std::string newLine = pLine->Data.substr(valueStart); + it = m_Lines.insert(it, new ReaderLine(newLine, pLine->No, pLine->Offset + valueStart)); + pLine->Data = ""; + + return false; + } + + /** + * @breif Run post-processing and check for mapping. + * Split line into two lines if mapping value is not on it's own line. + * + * @return true if line is mapping, else move on to scalar parsing. + * + */ + bool PostProcessMappingLine(std::list::iterator & it) + { + ReaderLine * pLine = *it; + + // Find map key. + size_t preKeyQuotes = 0; + size_t tokenPos = FindNotCited(pLine->Data, ':', preKeyQuotes); + if (tokenPos == std::string::npos) + { + return false; + } + if(preKeyQuotes > 1) + { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + pLine->Type = Node::MapType; + + // Get key + std::string key = pLine->Data.substr(0, tokenPos); + const size_t keyEnd = key.find_last_not_of(" \t"); + if (keyEnd == std::string::npos) + { + throw ParsingException(ExceptionMessage(g_ErrorKeyMissing, *pLine)); + } + key.resize(keyEnd + 1); + + // Handle cited key. + if(preKeyQuotes == 1) + { + if(key.front() != '"' || key.back() != '"') + { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + key = key.substr(1, key.size() - 2); + } + RemoveAllEscapeTokens(key); + + // Get value + std::string value = ""; + size_t valueStart = std::string::npos; + if (tokenPos + 1 != pLine->Data.size()) + { + valueStart = pLine->Data.find_first_not_of(" \t", tokenPos + 1); + if (valueStart != std::string::npos) + { + value = pLine->Data.substr(valueStart); + } + } + + // Make sure the value is not a sequence start. + if (IsSequenceStart(value) == true) + { + throw ParsingException(ExceptionMessage(g_ErrorBlockSequenceNotAllowed, *pLine, valueStart)); + } + + pLine->Data = key; + + + // Remove all empty lines after map key. + ClearTrailingEmptyLines(++it); + + // Add new empty line? + size_t newLineOffset = valueStart; + if(newLineOffset == std::string::npos) + { + if(it != m_Lines.end() && (*it)->Offset > pLine->Offset) + { + return true; + } + + newLineOffset = tokenPos + 2; + } + else + { + newLineOffset += pLine->Offset; + } + + // Add new line with value. + unsigned char dummyBlockFlags = 0; + if(IsBlockScalar(value, pLine->No, dummyBlockFlags) == true) + { + newLineOffset = pLine->Offset; + } + ReaderLine * pNewLine = new ReaderLine(value, pLine->No, newLineOffset, Node::ScalarType); + it = m_Lines.insert(it, pNewLine); + + // Return false in order to handle next line(scalar value). + return false; + } + + /** + * @breif Run post-processing and check for scalar. + * Checking for multi-line scalars. + * + * @return true if scalar search should continue, else false. + * + */ + void PostProcessScalarLine(std::list::iterator & it) + { + ReaderLine * pLine = *it; + pLine->Type = Node::ScalarType; + + size_t parentOffset = pLine->Offset; + if(pLine != m_Lines.front()) + { + std::list::iterator lastIt = it; + --lastIt; + parentOffset = (*lastIt)->Offset; + } + + std::list::iterator lastNotEmpty = it++; + + // Find last empty lines + while(it != m_Lines.end()) + { + pLine = *it; + pLine->Type = Node::ScalarType; + if(pLine->Data.size()) + { + if(pLine->Offset <= parentOffset) + { + break; + } + else + { + lastNotEmpty = it; + } + } + ++it; + } + + ClearTrailingEmptyLines(++lastNotEmpty); + } + + /** + * @breif Process root node and start of document. + * + */ + void ParseRoot(Node & root) + { + // Get first line and start type. + auto it = m_Lines.begin(); + if(it == m_Lines.end()) + { + return; + } + Node::eType type = (*it)->Type; + ReaderLine * pLine = *it; + + // Handle next line. + switch(type) + { + case Node::SequenceType: + ParseSequence(root, it); + break; + case Node::MapType: + ParseMap(root, it); + break; + case Node::ScalarType: + ParseScalar(root, it); + break; + default: + break; + } + + if(it != m_Lines.end()) + { + throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + } + + /** + * @breif Process sequence node. + * + */ + void ParseSequence(Node & node, std::list::iterator & it) + { + ReaderLine * pNextLine = nullptr; + while(it != m_Lines.end()) + { + ReaderLine * pLine = *it; + Node & childNode = node.PushBack(); + + // Move to next line, error check. + ++it; + if(it == m_Lines.end()) + { + throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch(valueType) + { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if sequence and correct level, go on, else exit. + // If same level but but of type map = error. + if(it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) + { + break; + } + if(pNextLine->Offset > pLine->Offset) + { + throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if(pNextLine->Type != Node::SequenceType) + { + throw InternalException(ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + + } + } + + /** + * @breif Process map node. + * + */ + void ParseMap(Node & node, std::list::iterator & it) + { + ReaderLine * pNextLine = nullptr; + while(it != m_Lines.end()) + { + ReaderLine * pLine = *it; + Node & childNode = node[pLine->Data]; + + // Move to next line, error check. + ++it; + if(it == m_Lines.end()) + { + throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch(valueType) + { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if map and correct level, go on, else exit. + // if same level but but of type map = error. + if(it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) + { + break; + } + if(pNextLine->Offset > pLine->Offset) + { + throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if(pNextLine->Type != pLine->Type) + { + throw InternalException(ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + + } + } + + /** + * @breif Process scalar node. + * + */ + void ParseScalar(Node & node, std::list::iterator & it) + { + std::string data = ""; + ReaderLine * pFirstLine = *it; + ReaderLine * pLine = *it; + + // Check if current line is a block scalar. + unsigned char blockFlags = 0; + bool isBlockScalar = IsBlockScalar(pLine->Data, pLine->No, blockFlags); + const bool newLineFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]); + const bool foldedFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::FoldedScalarFlag)]); + const bool literalFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::LiteralScalarFlag)]); + size_t parentOffset = 0; + + // Find parent offset + if(it != m_Lines.begin()) + { + std::list::iterator parentIt = it; + --parentIt; + parentOffset = (*parentIt)->Offset; + } + + // Move to next iterator/line if current line is a block scalar. + if(isBlockScalar) + { + ++it; + if(it == m_Lines.end() || (pLine = *it)->Type != Node::ScalarType) + { + return; + } + } + + // Not a block scalar, cut end spaces/tabs + if(isBlockScalar == false) + { + while(1) + { + pLine = *it; + + if(parentOffset != 0 && pLine->Offset <= parentOffset) + { + throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if(endOffset == std::string::npos) + { + data += "\n"; + } + else + { + data += pLine->Data.substr(0, endOffset + 1); + } + + // Move to next line + ++it; + if(it == m_Lines.end() || (*it)->Type != Node::ScalarType) + { + break; + } + + data += " "; + } + + if(ValidateQuote(data) == false) + { + throw ParsingException(ExceptionMessage(g_ErrorInvalidQuote, *pFirstLine)); + } + } + // Block scalar + else + { + pLine = *it; + size_t blockOffset = pLine->Offset; + if(blockOffset <= parentOffset) + { + throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + bool addedSpace = false; + while(it != m_Lines.end() && (*it)->Type == Node::ScalarType) + { + pLine = *it; + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if(endOffset != std::string::npos && pLine->Offset < blockOffset) + { + throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + if(endOffset == std::string::npos) + { + if(addedSpace) + { + data[data.size() - 1] = '\n'; + addedSpace = false; + } + else + { + data += "\n"; + } + + ++it; + continue; + } + else + { + if(blockOffset != pLine->Offset && foldedFlag) + { + if(addedSpace) + { + data[data.size() - 1] = '\n'; + addedSpace = false; + } + else + { + data += "\n"; + } + } + data += std::string(pLine->Offset - blockOffset, ' '); + data += pLine->Data; + } + + // Move to next line + ++it; + if(it == m_Lines.end() || (*it)->Type != Node::ScalarType) + { + if(newLineFlag) + { + data += "\n"; + } + break; + } + + if(foldedFlag) + { + data += " "; + addedSpace = true; + } + else if(literalFlag && endOffset != std::string::npos) + { + data += "\n"; + } + } + } + + if(data.size() && (data[0] == '"' || data[0] == '\'')) + { + data = data.substr(1, data.size() - 2 ); + } + + node = data; + } + + /** + * @breif Debug printing. + * + */ + void Print() + { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) + { + + ReaderLine * pLine = *it; + + // Print type + if (pLine->Type == Node::SequenceType) + { + std::cout << "seq "; + } + else if (pLine->Type == Node::MapType) + { + std::cout << "map "; + } + else if (pLine->Type == Node::ScalarType) + { + std::cout << "sca "; + } + else + { + std::cout << " "; + } + + // Print flags + if (pLine->GetFlag(ReaderLine::FoldedScalarFlag)) + { + std::cout << "f"; + } + else + { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::LiteralScalarFlag)) + { + std::cout << "l"; + } + else + { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::ScalarNewlineFlag)) + { + std::cout << "n"; + } + else + { + std::cout << "-"; + } + if (pLine->NextLine == nullptr) + { + std::cout << "e"; + } + else + { + std::cout << "-"; + } + + + std::cout << "| "; + std::cout << pLine->No << " "; + std::cout << std::string(pLine->Offset, ' '); + + if (pLine->Type == Node::ScalarType) + { + std::string scalarValue = pLine->Data; + for (size_t i = 0; (i = scalarValue.find("\n", i)) != std::string::npos;) + { + scalarValue.replace(i, 1, "\\n"); + i += 2; + } + std::cout << scalarValue << std::endl; + } + else if (pLine->Type == Node::MapType) + { + std::cout << pLine->Data + ":" << std::endl; + } + else if (pLine->Type == Node::SequenceType) + { + std::cout << "-" << std::endl; + } + else + { + std::cout << "> UNKOWN TYPE <" << std::endl; + } + } + } + + /** + * @breif Clear all read lines. + * + */ + void ClearLines() + { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) + { + delete *it; + } + m_Lines.clear(); + } + + void ClearTrailingEmptyLines(std::list::iterator & it) + { + while(it != m_Lines.end()) + { + ReaderLine * pLine = *it; + if(pLine->Data.size() == 0) + { + delete *it; + it = m_Lines.erase(it); + } + else + { + return; + } + + } + } + + static bool IsSequenceStart(const std::string & data) + { + if (data.size() == 0 || data[0] != '-') + { + return false; + } + + if (data.size() >= 2 && data[1] != ' ') + { + return false; + } + + return true; + } + + static bool IsBlockScalar(const std::string & data, const size_t line, unsigned char & flags) + { + flags = 0; + if(data.size() == 0) + { + return false; + } + + if(data[0] == '|') + { + if(data.size() >= 2) + { + if(data[1] != '-' && data[1] != ' ' && data[1] != '\t') + { + throw ParsingException(ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } + else + { + flags |= ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast(ReaderLine::LiteralScalarFlag)]; + return true; + } + + if(data[0] == '>') + { + if(data.size() >= 2) + { + if(data[1] != '-' && data[1] != ' ' && data[1] != '\t') + { + throw ParsingException(ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } + else + { + flags |= ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast(ReaderLine::FoldedScalarFlag)]; + return true; + } + + return false; + } + + std::list m_Lines; ///< List of lines. + + }; + + // Parsing functions + void Parse(Node & root, const char * filename) + { + std::ifstream f(filename, std::ifstream::binary); + if (f.is_open() == false) + { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.seekg(0, f.end); + size_t fileSize = static_cast(f.tellg()); + f.seekg(0, f.beg); + + std::unique_ptr data(new char[fileSize]); + f.read(data.get(), fileSize); + f.close(); + + Parse(root, data.get(), fileSize); + } + + void Parse(Node & root, std::iostream & stream) + { + ParseImp * pImp = nullptr; + + try + { + pImp = new ParseImp; + pImp->Parse(root, stream); + delete pImp; + } + catch (const Exception e) + { + delete pImp; + throw; + } + } + + void Parse(Node & root, const std::string & string) + { + std::stringstream ss(string); + Parse(root, ss); + } + + void Parse(Node & root, const char * buffer, const size_t size) + { + std::stringstream ss(std::string(buffer, size)); + Parse(root, ss); + } + + + // Serialize configuration structure. + SerializeConfig::SerializeConfig(const size_t spaceIndentation, + const size_t scalarMaxLength, + const bool sequenceMapNewline, + const bool mapScalarNewline) : + SpaceIndentation(spaceIndentation), + ScalarMaxLength(scalarMaxLength), + SequenceMapNewline(sequenceMapNewline), + MapScalarNewline(mapScalarNewline) + { + } + + + // Serialization functions + void Serialize(const Node & root, const char * filename, const SerializeConfig & config) + { + std::stringstream stream; + Serialize(root, stream, config); + + std::ofstream f(filename); + if (f.is_open() == false) + { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.write(stream.str().c_str(), stream.str().size()); + f.close(); + } + + size_t LineFolding(const std::string & input, std::vector & folded, const size_t maxLength) + { + folded.clear(); + if(input.size() == 0) + { + return 0; + } + + size_t currentPos = 0; + size_t lastPos = 0; + size_t spacePos = std::string::npos; + while(currentPos < input.size()) + { + currentPos = lastPos + maxLength; + + if(currentPos < input.size()) + { + spacePos = input.find_first_of(' ', currentPos); + } + + if(spacePos == std::string::npos || currentPos >= input.size()) + { + const std::string endLine = input.substr(lastPos); + if(endLine.size()) + { + folded.push_back(endLine); + } + + return folded.size(); + } + + folded.push_back(input.substr(lastPos, spacePos - lastPos)); + + lastPos = spacePos + 1; + } + + return folded.size(); + } + + static void SerializeLoop(const Node & node, std::iostream & stream, bool useLevel, const size_t level, const SerializeConfig & config) + { + const size_t indention = config.SpaceIndentation; + + switch(node.Type()) + { + case Node::SequenceType: + { + for(auto it = node.Begin(); it != node.End(); it++) + { + const Node & value = (*it).second; + if(value.IsNone()) + { + continue; + } + stream << std::string(level, ' ') << "- "; + useLevel = false; + if(value.IsSequence() || (value.IsMap() && config.SequenceMapNewline == true)) + { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + 2, config); + } + + } + break; + case Node::MapType: + { + size_t count = 0; + for(auto it = node.Begin(); it != node.End(); it++) + { + const Node & value = (*it).second; + if(value.IsNone()) + { + continue; + } + + if(useLevel || count > 0) + { + stream << std::string(level, ' '); + } + + std::string key = (*it).first; + AddEscapeTokens(key, "\\\""); + if(ShouldBeCited(key)) + { + stream << "\"" << key << "\"" << ": "; + } + else + { + stream << key << ": "; + } + + + useLevel = false; + if(value.IsScalar() == false || (value.IsScalar() && config.MapScalarNewline)) + { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + indention, config); + + useLevel = true; + count++; + } + + } + break; + case Node::ScalarType: + { + const std::string value = node.As(); + + // Empty scalar + if(value.size() == 0) + { + stream << "\n"; + break; + } + + // Get lines of scalar. + std::string line = ""; + std::vector lines; + std::istringstream iss(value); + while (iss.eof() == false) + { + std::getline(iss, line); + lines.push_back(line); + } + + // Block scalar + const std::string & lastLine = lines.back(); + const bool endNewline = lastLine.size() == 0; + if(endNewline) + { + lines.pop_back(); + } + + // Literal + if(lines.size() > 1) + { + stream << "|"; + } + // Folded/plain + else + { + const std::string frontLine = lines.front(); + if(config.ScalarMaxLength == 0 || lines.front().size() <= config.ScalarMaxLength || + LineFolding(frontLine, lines, config.ScalarMaxLength) == 1) + { + if(useLevel) + { + stream << std::string(level, ' '); + } + + if(ShouldBeCited(value)) + { + stream << "\"" << value << "\"\n"; + break; + } + stream << value << "\n"; + break; + } + else + { + stream << ">"; + } + } + + if(endNewline == false) + { + stream << "-"; + } + stream << "\n"; + + + for(auto it = lines.begin(); it != lines.end(); it++) + { + stream << std::string(level, ' ') << (*it) << "\n"; + } + } + break; + + default: + break; + } + } + + void Serialize(const Node & root, std::iostream & stream, const SerializeConfig & config) + { + if(config.SpaceIndentation < 2) + { + throw OperationException(g_ErrorIndentation); + } + + SerializeLoop(root, stream, false, 0, config); + } + + void Serialize(const Node & root, std::string & string, const SerializeConfig & config) + { + std::stringstream stream; + Serialize(root, stream, config); + string = stream.str(); + } + + + + // Static function implementations + std::string ExceptionMessage(const std::string & message, ReaderLine & line) + { + return message + std::string(" Line ") + std::to_string(line.No) + std::string(": ") + line.Data; + } + + std::string ExceptionMessage(const std::string & message, ReaderLine & line, const size_t errorPos) + { + return message + std::string(" Line ") + std::to_string(line.No) + std::string(" column ") + std::to_string(errorPos + 1) + std::string(": ") + line.Data; + } + + std::string ExceptionMessage(const std::string & message, const size_t errorLine, const size_t errorPos) + { + return message + std::string(" Line ") + std::to_string(errorLine) + std::string(" column ") + std::to_string(errorPos); + } + + std::string ExceptionMessage(const std::string & message, const size_t errorLine, const std::string & data) + { + return message + std::string(" Line ") + std::to_string(errorLine) + std::string(": ") + data; + } + + bool FindQuote(const std::string & input, size_t & start, size_t & end, size_t searchPos) + { + start = end = std::string::npos; + size_t qPos = searchPos; + bool foundStart = false; + + while(qPos != std::string::npos) + { + // Find first quote. + qPos = input.find_first_of("\"'", qPos); + if(qPos == std::string::npos) + { + return false; + } + + const char token = input[qPos]; + if(token == '"' && (qPos == 0 || input[qPos-1] != '\\')) + { + // Found start quote. + if(foundStart == false) + { + start = qPos; + foundStart = true; + } + // Found end quote + else + { + end = qPos; + return true; + } + } + + // Check if it's possible for another loop. + if(qPos + 1 == input.size()) + { + return false; + } + qPos++; + } + + return false; + } + + size_t FindNotCited(const std::string & input, char token, size_t & preQuoteCount) + { + preQuoteCount = 0; + size_t tokenPos = input.find_first_of(token); + if(tokenPos == std::string::npos) + { + return std::string::npos; + } + + // Find all quotes + std::vector> quotes; + + size_t quoteStart = 0; + size_t quoteEnd = 0; + while(FindQuote(input, quoteStart, quoteEnd, quoteEnd)) + { + quotes.push_back({quoteStart, quoteEnd}); + + if(quoteEnd + 1 == input.size()) + { + break; + } + quoteEnd++; + } + + if(quotes.size() == 0) + { + return tokenPos; + } + + size_t currentQuoteIndex = 0; + std::pair currentQuote = {0, 0}; + + while(currentQuoteIndex < quotes.size()) + { + currentQuote = quotes[currentQuoteIndex]; + + if(tokenPos < currentQuote.first) + { + return tokenPos; + } + preQuoteCount++; + if(tokenPos <= currentQuote.second) + { + // Find next token + if(tokenPos + 1 == input.size()) + { + return std::string::npos; + } + tokenPos = input.find_first_of(token, tokenPos + 1); + if(tokenPos == std::string::npos) + { + return std::string::npos; + } + } + + currentQuoteIndex++; + } + + return tokenPos; + } + + size_t FindNotCited(const std::string & input, char token) + { + size_t dummy = 0; + return FindNotCited(input, token, dummy); + } + + bool ValidateQuote(const std::string & input) + { + if(input.size() == 0) + { + return true; + } + + char token = 0; + size_t searchPos = 0; + if(input[0] == '\"' || input[0] == '\'') + { + if(input.size() == 1) + { + return false; + } + token = input[0]; + searchPos = 1; + } + + while(searchPos != std::string::npos && searchPos < input.size() - 1) + { + searchPos = input.find_first_of("\"'", searchPos + 1); + if(searchPos == std::string::npos) + { + break; + } + + const char foundToken = input[searchPos]; + + if(input[searchPos] == '\"' || input[searchPos] == '\'') + { + if(token == 0 && input[searchPos-1] != '\\') + { + return false; + } + //if(foundToken == token) + //{ + + /*if(foundToken == token && searchPos == input.size() - 1 && input[searchPos-1] != '\\') + { + return true; + if(searchPos == input.size() - 1) + { + return true; + } + return false; + } + else */ + if(foundToken == token && input[searchPos-1] != '\\') + { + if(searchPos == input.size() - 1) + { + return true; + } + return false; + } + //} + } + } + + return token == 0; + } + + void CopyNode(const Node & from, Node & to) + { + const Node::eType type = from.Type(); + + switch(type) + { + case Node::SequenceType: + for(auto it = from.Begin(); it != from.End(); it++) + { + const Node & currentNode = (*it).second; + Node & newNode = to.PushBack(); + CopyNode(currentNode, newNode); + } + break; + case Node::MapType: + for(auto it = from.Begin(); it != from.End(); it++) + { + const Node & currentNode = (*it).second; + Node & newNode = to[(*it).first]; + CopyNode(currentNode, newNode); + } + break; + case Node::ScalarType: + to = from.As(); + break; + case Node::None: + break; + } + } + + bool ShouldBeCited(const std::string & key) + { + return key.find_first_of("\":{}[],&*#?|-<>=!%@") != std::string::npos; + } + + void AddEscapeTokens(std::string & input, const std::string & tokens) + { + for(auto it = tokens.begin(); it != tokens.end(); it++) + { + const char token = *it; + const std::string replace = std::string("\\") + std::string(1, token); + size_t found = input.find_first_of(token); + while(found != std::string::npos) + { + input.replace(found, 1, replace); + found = input.find_first_of(token, found + 2); + } + } + } + + void RemoveAllEscapeTokens(std::string & input) + { + size_t found = input.find_first_of("\\"); + while(found != std::string::npos) + { + if(found + 1 == input.size()) + { + return; + } + + std::string replace(1, input[found + 1]); + input.replace(found, 2, replace); + found = input.find_first_of("\\", found + 1); + } + } + + +} diff --git a/runtime/core/utils/Yaml.hpp b/runtime/core/utils/Yaml.hpp new file mode 100644 index 000000000..586657fb2 --- /dev/null +++ b/runtime/core/utils/Yaml.hpp @@ -0,0 +1,656 @@ +/* +* MIT License +* +* Copyright(c) 2018 Jimmie Bergmann +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files(the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions : +* +* The above copyright notice and this permission notice shall be included in all +* copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +* +*/ + +/* +YAML documentation: +http://yaml.org/spec/1.0/index.html +https://www.codeproject.com/Articles/28720/YAML-Parser-in-C +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +/** +* @breif Namespace wrapping mini-yaml classes. +* +*/ +namespace Yaml +{ + + /** + * @breif Forward declarations. + * + */ + class Node; + + + /** + * @breif Helper classes and functions + * + */ + namespace impl + { + + /** + * @breif Helper functionality, converting string to any data type. + * Strings are left untouched. + * + */ + template + struct StringConverter + { + static T Get(const std::string & data) + { + T type; + std::stringstream ss(data); + ss >> type; + return type; + } + + static T Get(const std::string & data, const T & defaultValue) + { + T type; + std::stringstream ss(data); + ss >> type; + + if(ss.fail()) + { + return defaultValue; + } + + return type; + } + }; + template<> + struct StringConverter + { + static std::string Get(const std::string & data) + { + return data; + } + + static std::string Get(const std::string & data, const std::string & defaultValue) + { + if(data.size() == 0) + { + return defaultValue; + } + return data; + } + }; + + template<> + struct StringConverter + { + static bool Get(const std::string & data) + { + std::string tmpData = data; + std::transform(tmpData.begin(), tmpData.end(), tmpData.begin(), ::tolower); + if(tmpData == "true" || tmpData == "yes" || tmpData == "1") + { + return true; + } + + return false; + } + + static bool Get(const std::string & data, const bool & defaultValue) + { + if(data.size() == 0) + { + return defaultValue; + } + + return Get(data); + } + }; + + } + + + /** + * @breif Exception class. + * + */ + class Exception : public std::runtime_error + { + + public: + + /** + * @breif Enumeration of exception types. + * + */ + enum eType + { + InternalError, ///< Internal error. + ParsingError, ///< Invalid parsing data. + OperationError ///< User operation error. + }; + + /** + * @breif Constructor. + * + * @param message Exception message. + * @param type Type of exception. + * + */ + Exception(const std::string & message, const eType type); + + /** + * @breif Get type of exception. + * + */ + eType Type() const; + + /** + * @breif Get message of exception. + * + */ + const char * Message() const; + + private: + + eType m_Type; ///< Type of exception. + + }; + + + /** + * @breif Internal exception class. + * + * @see Exception + * + */ + class InternalException : public Exception + { + + public: + + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + InternalException(const std::string & message); + + }; + + + /** + * @breif Parsing exception class. + * + * @see Exception + * + */ + class ParsingException : public Exception + { + + public: + + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + ParsingException(const std::string & message); + + }; + + + /** + * @breif Operation exception class. + * + * @see Exception + * + */ + class OperationException : public Exception + { + + public: + + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + OperationException(const std::string & message); + + }; + + + /** + * @breif Iterator class. + * + */ + class Iterator + { + + public: + + friend class Node; + + /** + * @breif Default constructor. + * + */ + Iterator(); + + /** + * @breif Copy constructor. + * + */ + Iterator(const Iterator & it); + + /** + * @breif Assignment operator. + * + */ + Iterator & operator = (const Iterator & it); + + /** + * @breif Destructor. + * + */ + ~Iterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + Iterator & operator ++ (int); + + /** + * @breif Post-decrement operator. + * + */ + Iterator & operator -- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const Iterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const Iterator & it); + + private: + + enum eType + { + None, + SequenceType, + MapType + }; + + eType m_Type; ///< Type of iterator. + void * m_pImp; ///< Implementation of iterator class. + + }; + + + /** + * @breif Constant iterator class. + * + */ + class ConstIterator + { + + public: + + friend class Node; + + /** + * @breif Default constructor. + * + */ + ConstIterator(); + + /** + * @breif Copy constructor. + * + */ + ConstIterator(const ConstIterator & it); + + /** + * @breif Assignment operator. + * + */ + ConstIterator & operator = (const ConstIterator & it); + + /** + * @breif Destructor. + * + */ + ~ConstIterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + ConstIterator & operator ++ (int); + + /** + * @breif Post-decrement operator. + * + */ + ConstIterator & operator -- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const ConstIterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const ConstIterator & it); + + private: + + enum eType + { + None, + SequenceType, + MapType + }; + + eType m_Type; ///< Type of iterator. + void * m_pImp; ///< Implementation of constant iterator class. + + }; + + + /** + * @breif Node class. + * + */ + class Node + { + + public: + + friend class Iterator; + + /** + * @breif Enumeration of node types. + * + */ + enum eType + { + None, + SequenceType, + MapType, + ScalarType + }; + + /** + * @breif Default constructor. + * + */ + Node(); + + /** + * @breif Copy constructor. + * + */ + Node(const Node & node); + + /** + * @breif Assignment constructors. + * Converts node to scalar type if needed. + * + */ + Node(const std::string & value); + Node(const char * value); + + /** + * @breif Destructor. + * + */ + ~Node(); + + /** + * @breif Functions for checking type of node. + * + */ + eType Type() const; + bool IsNone() const; + bool IsSequence() const; + bool IsMap() const; + bool IsScalar() const; + + /** + * @breif Completely clear node. + * + */ + void Clear(); + + /** + * @breif Get node as given template type. + * + */ + template + T As() const + { + return impl::StringConverter::Get(AsString()); + } + + /** + * @breif Get node as given template type. + * + */ + template + T As(const T & defaultValue) const + { + return impl::StringConverter::Get(AsString(), defaultValue); + } + + /** + * @breif Get size of node. + * Nodes of type None or Scalar will return 0. + * + */ + size_t Size() const; + + // Sequence operators + + /** + * @breif Insert sequence item at given index. + * Converts node to sequence type if needed. + * Adding new item to end of sequence if index is larger than sequence size. + * + */ + Node & Insert(const size_t index); + + /** + * @breif Add new sequence index to back. + * Converts node to sequence type if needed. + * + */ + Node & PushFront(); + + /** + * @breif Add new sequence index to front. + * Converts node to sequence type if needed. + * + */ + Node & PushBack(); + + /** + * @breif Get sequence/map item. + * Converts node to sequence/map type if needed. + * + * @param index Sequence index. Returns None type Node if index is unknown. + * @param key Map key. Creates a new node if key is unknown. + * + */ + Node & operator [] (const size_t index); + Node & operator [] (const std::string & key); + + /** + * @breif Erase item. + * No action if node is not a sequence or map. + * + */ + void Erase(const size_t index); + void Erase(const std::string & key); + + /** + * @breif Assignment operators. + * + */ + Node & operator = (const Node & node); + Node & operator = (const std::string & value); + Node & operator = (const char * value); + + /** + * @breif Get start iterator. + * + */ + Iterator Begin(); + ConstIterator Begin() const; + + /** + * @breif Get end iterator. + * + */ + Iterator End(); + ConstIterator End() const; + + + private: + + /** + * @breif Get as string. If type is scalar, else empty. + * + */ + const std::string & AsString() const; + + void * m_pImp; ///< Implementation of node class. + + }; + + + /** + * @breif Parsing functions. + * Population given root node with deserialized data. + * + * @param root Root node to populate. + * @param filename Path of input file. + * @param stream Input stream. + * @param string String of input data. + * @param buffer Char array of input data. + * @param size Buffer size. + * + * @throw InternalException An internal error occurred. + * @throw ParsingException Invalid input YAML data. + * @throw OperationException If filename or buffer pointer is invalid. + * + */ + void Parse(Node & root, const char * filename); + void Parse(Node & root, std::iostream & stream); + void Parse(Node & root, const std::string & string); + void Parse(Node & root, const char * buffer, const size_t size); + + + /** + * @breif Serialization configuration structure, + * describing output behavior. + * + */ + struct SerializeConfig + { + + /** + * @breif Constructor. + * + * @param spaceIndentation Number of spaces per indentation. + * @param scalarMaxLength Maximum length of scalars. Serialized as folder scalars if exceeded. + * Ignored if equal to 0. + * @param sequenceMapNewline Put maps on a new line if parent node is a sequence. + * @param mapScalarNewline Put scalars on a new line if parent node is a map. + * + */ + SerializeConfig(const size_t spaceIndentation = 2, + const size_t scalarMaxLength = 64, + const bool sequenceMapNewline = false, + const bool mapScalarNewline = false); + + size_t SpaceIndentation; ///< Number of spaces per indentation. + size_t ScalarMaxLength; ///< Maximum length of scalars. Serialized as folder scalars if exceeded. + bool SequenceMapNewline; ///< Put maps on a new line if parent node is a sequence. + bool MapScalarNewline; ///< Put scalars on a new line if parent node is a map. + }; + + + /** + * @breif Serialization functions. + * + * @param root Root node to serialize. + * @param filename Path of output file. + * @param stream Output stream. + * @param string String of output data. + * @param config Serialization configurations. + * + * @throw InternalException An internal error occurred. + * @throw OperationException If filename or buffer pointer is invalid. + * If config is invalid. + * + */ + void Serialize(const Node & root, const char * filename, const SerializeConfig & config = {2, 64, false, false}); + void Serialize(const Node & root, std::iostream & stream, const SerializeConfig & config = {2, 64, false, false}); + void Serialize(const Node & root, std::string & string, const SerializeConfig & config = {2, 64, false, false}); + +} From 85e5ac52f6ed31855838c71133e517afd61c35fe Mon Sep 17 00:00:00 2001 From: veelion Date: Mon, 29 Aug 2022 11:26:45 +0800 Subject: [PATCH 26/62] export onnx gpu model for c++ runtime --- wenet/bin/export_onnx_gpu_runtime.py | 360 +++++++++++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 wenet/bin/export_onnx_gpu_runtime.py diff --git a/wenet/bin/export_onnx_gpu_runtime.py b/wenet/bin/export_onnx_gpu_runtime.py new file mode 100644 index 000000000..9db2d7fb6 --- /dev/null +++ b/wenet/bin/export_onnx_gpu_runtime.py @@ -0,0 +1,360 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import argparse +import os +import sys + +import torch +import yaml +import logging + +from wenet.utils.checkpoint import load_checkpoint +from wenet.transformer.ctc import CTC +from wenet.transformer.decoder import TransformerDecoder +from wenet.transformer.encoder import BaseEncoder +from wenet.utils.init_model import init_model + +try: + import onnxruntime +except ImportError: + print('Please install onnxruntime-gpu!') + sys.exit(1) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +class Encoder(torch.nn.Module): + def __init__(self, + encoder: BaseEncoder, + ctc: CTC, + beam_size: int = 10): + super().__init__() + self.encoder = encoder + self.ctc = ctc + self.beam_size = beam_size + + def forward(self, speech: torch.Tensor, + speech_lengths: torch.Tensor,): + """Encoder + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + Returns: + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size + """ + encoder_out, encoder_mask = self.encoder(speech, + speech_lengths, + -1, -1) + encoder_out_lens = encoder_mask.squeeze(1).sum(1) + ctc_log_probs = self.ctc.log_softmax(encoder_out) + encoder_out_lens = encoder_out_lens.int() + return encoder_out, encoder_out_lens, ctc_log_probs + + +class Decoder(torch.nn.Module): + def __init__(self, + decoder: TransformerDecoder, + ctc_weight: float = 0.5, + reverse_weight: float = 0.0, + beam_size: int = 10, + eos: int = 5537): + super().__init__() + self.decoder = decoder + self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight + self.beam_size = beam_size + self.eos = eos + + def forward(self, + encoder_out: torch.Tensor, + hyps_pad_sos: torch.Tensor, + hyps_lens_sos: torch.Tensor): + """ Export interface for c++ call, forward decoder with batch of + hypothesis from ctc prefix beam search and encoder output + Args: + encoder_out: B x T x F + hyps_pad_sos: B x beam x T2 + hyps with sos and padded by 0 + hyps_lens_sos: B x beam, length for each hyp with sos + + Returns: + decoder_out: B x beam x T2 x V + r_decoder_out: B x beam x T2 x V + """ + B, T, F = encoder_out.shape + bz = hyps_pad_sos.shape[1] + B2 = B * bz + T2 = hyps_pad_sos.shape[2] + # 1. prepare inputs for decoder + encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) + encoder_mask = torch.ones(B2, 1, T, + dtype=torch.bool, + device=encoder_out.device) + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + hyps = hyps_pad_sos.view(B2, T2) + hyps_lens = hyps_lens_sos.view(B2,) + if self.reverse_weight > 0: + r_hyps_lens = hyps_lens - 1 + r_hyps = hyps[:, 1:] + max_len = torch.max(r_hyps_lens) + index_range = torch.arange(0, max_len, 1).to(encoder_out.device) + seq_len_expand = r_hyps_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + index = (seq_len_expand - 1) - index_range # (beam, max_len) + index = index * seq_mask + r_hyps = torch.gather(r_hyps, 1, index) + r_hyps = torch.where(seq_mask, r_hyps, self.eos) + r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) + else: + r_hyps = torch.empty(0, device=encoder_out.device) + + # 2. decoding + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, + self.reverse_weight) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + r_decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + # V = decoder_out.shape[-1] + # decoder_out = decoder_out.view(B, bz, T2, V) + # print("decoder_out.shape", decoder_out.shape) + # r_decoder_out = r_decoder_out.view(B, bz, T2, V) + return decoder_out, r_decoder_out # B2 X T2 X V + + +def to_numpy(tensors): + out = [] + if type(tensors) == torch.tensor: + tensors = [tensors] + for tensor in tensors: + if tensor.requires_grad: + tensor = tensor.detach().cpu().numpy() + else: + tensor = tensor.cpu().numpy() + out.append(tensor) + return out + + +def test(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): + for a, b in zip(xlist, blist): + try: + torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) + except AssertionError as error: + if tolerate_small_mismatch: + print(error) + else: + raise + + +def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): + bz = 32 + seq_len = 100 + beam_size = args.beam_size + feature_size = configs["input_dim"] + + speech = torch.randn(bz, seq_len, feature_size, dtype=torch.float32) + speech_lens = torch.randint( + low=10, high=seq_len, size=(bz,), dtype=torch.int32) + encoder = Encoder(model.encoder, model.ctc, beam_size) + encoder.eval() + + torch.onnx.export(encoder, + (speech, speech_lens), + encoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=['speech', 'speech_lengths'], + output_names=['encoder_out', 'encoder_out_lens', + 'ctc_log_probs'], + dynamic_axes={ + 'speech': {0: 'B', 1: 'T'}, + 'speech_lengths': {0: 'B'}, + 'encoder_out': {0: 'B', 1: 'T_OUT'}, + 'encoder_out_lens': {0: 'B'}, + 'ctc_log_probs': {0: 'B', 1: 'T_OUT'}, + }, + verbose=False + ) + + with torch.no_grad(): + o0, o1, o2 = encoder(speech, speech_lens) + + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(encoder_onnx_path, + providers=providers) + ort_inputs = {'speech': to_numpy(speech), + 'speech_lengths': to_numpy(speech_lens)} + ort_outs = ort_session.run(None, ort_inputs) + + # check encoder output + test(to_numpy([o0, o1, o2]), ort_outs) + logger.info("export offline onnx encoder succeed!") + is_bidirectional_decoder = 1 if configs['decoder'] == 'bitransformer' else 0 + onnx_config = {"beam_size": args.beam_size, + "reverse_weight": args.reverse_weight, + "ctc_weight": args.ctc_weight, + "sos": configs["output_dim"] - 1, + "eos": configs["output_dim"] - 1, + "is_bidirectional_decoder": is_bidirectional_decoder, + "fp16": args.fp16} + return onnx_config + + +def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path): + bz, seq_len = 32, 100 + beam_size = args.beam_size + decoder = Decoder(model.decoder, + model.ctc_weight, + model.reverse_weight, + beam_size, + configs["output_dim"] - 1) + decoder.eval() + + hyps_pad_sos_eos = torch.randint( + low=3, high=1000, size=(bz, beam_size, seq_len)) + hyps_lens_sos = torch.randint( + low=3, high=seq_len, size=(bz, beam_size), dtype=torch.int32) + + output_size = configs["encoder_conf"]["output_size"] + encoder_out = torch.randn(bz, seq_len, output_size, dtype=torch.float32) + + input_names = ['encoder_out', 'hyps_pad_sos', 'hyps_lens_sos', ] + + torch.onnx.export(decoder, + (encoder_out, + hyps_pad_sos_eos, hyps_lens_sos), + decoder_onnx_path, + export_params=True, + opset_version=13, + do_constant_folding=True, + input_names=input_names, + output_names=['decoder_out', 'r_decoder_out'], + dynamic_axes={'encoder_out': {0: 'B', 1: 'T'}, + 'hyps_pad_sos': {0: 'B', 2: 'T2'}, + 'hyps_lens_sos': {0: 'B'}, + 'decoder_out': {0: 'B'}, + 'r_decoder_out': {0: 'B'}, + }, + verbose=False + ) + with torch.no_grad(): + o0 = decoder(encoder_out, + hyps_pad_sos_eos, + hyps_lens_sos,) + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(decoder_onnx_path, + providers=providers) + + input_tensors = [encoder_out, hyps_pad_sos_eos, + hyps_lens_sos] + ort_inputs = {} + input_tensors = to_numpy(input_tensors) + for idx, name in enumerate(input_names): + ort_inputs[name] = input_tensors[idx] + + # if model.reverse weight == 0, + # the r_hyps_pad will be removed + # from the onnx decoder since it doen't play any role + # if model.reverse_weight == 0: + # del ort_inputs['r_hyps_pad_sos_eos'] + ort_outs = ort_session.run(None, ort_inputs) + + # check decoder output + test(to_numpy(list(o0)), ort_outs, rtol=1e-03, atol=1e-05) + logger.info("export to onnx decoder succeed!") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='export x86_gpu model') + parser.add_argument('--config', required=True, help='config file') + parser.add_argument('--checkpoint', required=True, help='checkpoint model') + parser.add_argument('--cmvn_file', required=False, default='', type=str, + help='global_cmvn file, default path is in config file') + parser.add_argument('--reverse_weight', default=-1.0, type=float, + required=False, + help='reverse weight for bitransformer,' + + 'default value is in config file') + parser.add_argument('--ctc_weight', default=-1.0, type=float, + required=False, + help='ctc weight, default value is in config file') + parser.add_argument('--beam_size', default=10, type=int, required=False, + help="beam size would be ctc output size") + parser.add_argument('--output_onnx_dir', + default="onnx_model", + help='output onnx encoder and decoder directory') + parser.add_argument('--fp16', + action='store_true', + help='whether to export fp16 model, default false') + args = parser.parse_args() + + torch.manual_seed(0) + torch.set_printoptions(precision=10) + + with open(args.config, 'r') as fin: + configs = yaml.load(fin, Loader=yaml.FullLoader) + if args.cmvn_file and os.path.exists(args.cmvn_file): + configs['cmvn_file'] = args.cmvn_file + if args.reverse_weight != -1.0 and 'reverse_weight' in configs['model_conf']: + configs['model_conf']['reverse_weight'] = args.reverse_weight + print("Update reverse weight to", args.reverse_weight) + if args.ctc_weight != -1: + print("Update ctc weight to ", args.ctc_weight) + configs['model_conf']['ctc_weight'] = args.ctc_weight + configs["encoder_conf"]["use_dynamic_chunk"] = False + + model = init_model(configs) + load_checkpoint(model, args.checkpoint) + model.eval() + + if not os.path.exists(args.output_onnx_dir): + os.mkdir(args.output_onnx_dir) + encoder_onnx_path = os.path.join(args.output_onnx_dir, 'encoder.onnx') + export_enc_func = export_offline_encoder + + onnx_config = export_enc_func(model, configs, args, logger, encoder_onnx_path) + + decoder_onnx_path = os.path.join(args.output_onnx_dir, 'decoder.onnx') + export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path) + + if args.fp16: + try: + import onnxmltools + from onnxmltools.utils.float16_converter import convert_float_to_float16 + except ImportError: + import traceback + traceback.print_exc() + print('Please install onnxmltools!') + sys.exit(1) + encoder_onnx_model = onnxmltools.utils.load_model(encoder_onnx_path) + encoder_onnx_model = convert_float_to_float16(encoder_onnx_model) + encoder_onnx_path = os.path.join(args.output_onnx_dir, 'encoder_fp16.onnx') + onnxmltools.utils.save_model(encoder_onnx_model, encoder_onnx_path) + decoder_onnx_model = onnxmltools.utils.load_model(decoder_onnx_path) + decoder_onnx_model = convert_float_to_float16(decoder_onnx_model) + decoder_onnx_path = os.path.join(args.output_onnx_dir, 'decoder_fp16.onnx') + onnxmltools.utils.save_model(decoder_onnx_model, decoder_onnx_path) + # dump configurations + + config_dir = os.path.join(args.output_onnx_dir, "config.yaml") + with open(config_dir, "w") as out: + yaml.dump(onnx_config, out) From e9ffa532fc95b89a11cc854bfddf584c250fa4b3 Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 30 Aug 2022 09:18:21 +0800 Subject: [PATCH 27/62] improve memory managing if is_fp16 --- runtime/core/decoder/batch_onnx_asr_model.cc | 63 +++++++++----------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 1ba49cdb5..fb93d1d00 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -102,9 +102,10 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { } // 1. Load sessions + // config for CUDA std::vector keys{ "device_id", - "gpu_mem_limit", + // "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", @@ -113,33 +114,22 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { }; std::vector values{ "0", - "2147483648", + // "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", "1" }; - std::cout << "prepare cuda options ...\n"; + // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 + const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options = nullptr; - OrtStatus* error = api.CreateCUDAProviderOptions(&cuda_options); - if (error) { - api.ReleaseStatus(error); - throw std::runtime_error("CreateCUDAProviderOptions error"); - } - error = api.UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size()); - if (error) { - api.ReleaseStatus(error); - throw std::runtime_error("UpdateCUDAProviderOptions error"); - } - error = api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options); - if (error) { - api.ReleaseStatus(error); - throw std::runtime_error("SessionOptionsAppendExecutionProvider_CUDA_V2 error"); - } + Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); + Ort::ThrowOnError(api.UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size())); + Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options)); + api.ReleaseCUDAProviderOptions(cuda_options); - std::cout << "done cuda options ...\n"; try { encoder_session_ = std::make_shared( @@ -250,8 +240,10 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( } } + Ort::RunOptions run_option; + run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); std::vector ort_outputs = encoder_session_->Run( - Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), + run_option, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); float* ctc_log_probs = nullptr; @@ -259,13 +251,15 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( auto out_shape = type_info.GetShape(); int num_outputs = out_shape[1]; int output_dim = out_shape[2]; + std::vector ctc_log_probs_data; // for holding ctc_log_probs converted from fp16 if (is_fp16_) { uint16_t* probs = ort_outputs[2].GetTensorMutableData(); int length = out_shape[0] * out_shape[1] * out_shape[2]; - ctc_log_probs = new float[length]; + ctc_log_probs_data.resize(length); for (size_t i = 0; i < length; ++i) { - ctc_log_probs[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(probs[i])); + ctc_log_probs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(probs[i])); } + ctc_log_probs = ctc_log_probs_data.data(); } else { ctc_log_probs = ort_outputs[2].GetTensorMutableData(); } @@ -279,9 +273,6 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( memcpy(out_prob[i][j].data(), p, sizeof(float) * output_dim); } } - if (is_fp16_) { - delete [] ctc_log_probs; - } // 3. cache encoder outs encoder_outs_ = std::move(ort_outputs[0]); } @@ -344,8 +335,10 @@ void BatchOnnxAsrModel::AttentionRescoring( rescore_inputs.emplace_back(std::move(hyps_pad_tensor)); rescore_inputs.emplace_back(std::move(hyps_lens_tensor)); + Ort::RunOptions run_option; + run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); std::vector rescore_outputs = rescore_session_->Run( - Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(), + run_option, rescore_in_names_.data(), rescore_inputs.data(), rescore_inputs.size(), rescore_out_names_.data(), rescore_out_names_.size()); @@ -353,19 +346,23 @@ void BatchOnnxAsrModel::AttentionRescoring( std::vector decoder_out_shape = type_info.GetShape(); //(B, beam, T2) float* decoder_outs_data = nullptr; float* r_decoder_outs_data = nullptr; + std::vector decoder_outs_fp16; // for holding decoder outs data converted from fp16; + std::vector r_decoder_outs_fp16; // for holding decoder outs data converted from fp16; if (is_fp16_) { int length = decoder_out_shape[0] * decoder_out_shape[1] * decoder_out_shape[2]; - decoder_outs_data = new float[length](); + decoder_outs_fp16.resize(length); auto outs = rescore_outputs[0].GetTensorMutableData(); for (size_t i = 0; i < length; ++i) { - decoder_outs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(outs[i])); + decoder_outs_fp16[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(outs[i])); } + decoder_outs_data = decoder_outs_fp16.data(); if (is_bidirectional_decoder_ && reverse_weight > 0) { - r_decoder_outs_data = new float[length](); + r_decoder_outs_fp16.reserve(length); auto r_outs = rescore_outputs[1].GetTensorMutableData(); for (size_t i = 0; i < length; ++i) { - r_decoder_outs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(r_outs[i])); + r_decoder_outs_fp16[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(r_outs[i])); } + r_decoder_outs_data = r_decoder_outs_fp16.data(); } } else { decoder_outs_data = rescore_outputs[0].GetTensorMutableData(); @@ -394,12 +391,6 @@ void BatchOnnxAsrModel::AttentionRescoring( } attention_scores->push_back(std::move(Y)); } - if (is_fp16_) { - delete [] decoder_outs_data; - if (is_bidirectional_decoder_ && reverse_weight > 0) { - delete [] r_decoder_outs_data; - } - } } } // namespace wenet From 5cb3c07d7cb7b0df716776947c9c1b7153635fbd Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 30 Aug 2022 15:30:45 +0800 Subject: [PATCH 28/62] replace Eigen::half with --- runtime/core/decoder/batch_onnx_asr_model.cc | 49 ++++++++++++-------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index fb93d1d00..99e73ae16 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -21,10 +21,11 @@ #include #include #include -#include +#include #include "utils/string.h" #include "utils/Yaml.hpp" +#include "utils/timer.h" namespace wenet { @@ -122,7 +123,7 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { "1" }; // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 - + const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options = nullptr; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); @@ -150,7 +151,7 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { sos_ = root["sos"].As(); eos_ = root["eos"].As(); is_bidirectional_decoder_ = root["is_bidirectional_decoder"].As(); - + LOG(INFO) << "Onnx Model Info:"; LOG(INFO) << "\tsos " << sos_; LOG(INFO) << "\teos " << eos_; @@ -201,18 +202,20 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( // speech const int64_t feats_shape[3] = {batch_size, num_frames, feature_dim}; + Timer timer; if (is_fp16_) { std::vector feats(batch_size * num_frames * feature_dim); for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < num_frames; ++j) { for (size_t k = 0; k < feature_dim; ++k) { - int p = i * num_frames * feature_dim + j * feature_dim + k; - feats[p] = Ort::Float16_t(Eigen::half(batch_feats[i][j][k]).x); + int p = i * num_frames * feature_dim + j * feature_dim + k; + feats[p] = Ort::Float16_t(_cvtss_sh(batch_feats[i][j][k], 0)); } } } feats_ort = std::move(Ort::Value::CreateTensor( memory_info, feats.data(), feats.size(), feats_shape, 3)); + VLOG(1) << "feats to fp16 takes " << timer.Elapsed() << " ms."; } else { std::vector feats; for (size_t i = 0; i < batch_size; ++i) { @@ -223,7 +226,7 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( feats_ort = std::move(Ort::Value::CreateTensor( memory_info, feats.data(), feats.size(), feats_shape, 3)); } - + // speech_lens const int64_t feats_lens_shape[1] = {batch_size}; Ort::Value feats_lens_ort = Ort::Value::CreateTensor( @@ -237,29 +240,33 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( inputs.push_back(std::move(feats_ort)); } else if (!strcmp(name, "speech_lengths")) { inputs.push_back(std::move(feats_lens_ort)); - } + } } Ort::RunOptions run_option; run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); + timer.Reset(); std::vector ort_outputs = encoder_session_->Run( run_option, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); + VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; float* ctc_log_probs = nullptr; auto type_info = ort_outputs[2].GetTensorTypeAndShapeInfo(); auto out_shape = type_info.GetShape(); - int num_outputs = out_shape[1]; + int num_outputs = out_shape[1]; int output_dim = out_shape[2]; std::vector ctc_log_probs_data; // for holding ctc_log_probs converted from fp16 if (is_fp16_) { - uint16_t* probs = ort_outputs[2].GetTensorMutableData(); + timer.Reset(); + Ort::Float16_t* probs = ort_outputs[2].GetTensorMutableData(); int length = out_shape[0] * out_shape[1] * out_shape[2]; ctc_log_probs_data.resize(length); for (size_t i = 0; i < length; ++i) { - ctc_log_probs_data[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(probs[i])); + ctc_log_probs_data[i] = _cvtsh_ss(probs[i].value); } ctc_log_probs = ctc_log_probs_data.data(); + VLOG(1) << "ctc_log_probs from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; } else { ctc_log_probs = ort_outputs[2].GetTensorMutableData(); } @@ -297,7 +304,7 @@ void BatchOnnxAsrModel::AttentionRescoring( // 1. prepare input for onnx int batch_size = batch_hyps.size(); int beam_size = batch_hyps[0].size(); - + // 1.1 generate hyps_lens_sos data for ort std::vector hyps_lens_sos(batch_size * beam_size, 0); // (batch_size, beam_size) int max_hyps_len = 0; @@ -324,12 +331,12 @@ void BatchOnnxAsrModel::AttentionRescoring( // 2. forward attetion decoder const int64_t hyps_lens_shape[] = {batch_size, beam_size}; const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; - + Ort::Value hyps_lens_tensor = Ort::Value::CreateTensor( memory_info, hyps_lens_sos.data(), hyps_lens_sos.size(), hyps_lens_shape, 2); Ort::Value hyps_pad_tensor = Ort::Value::CreateTensor( memory_info, hyps_pad_sos.data(), hyps_pad_sos.size(), hyps_pad_shape, 3); - + std::vector rescore_inputs; rescore_inputs.emplace_back(std::move(encoder_outs_)); rescore_inputs.emplace_back(std::move(hyps_pad_tensor)); @@ -337,10 +344,12 @@ void BatchOnnxAsrModel::AttentionRescoring( Ort::RunOptions run_option; run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); + Timer timer; std::vector rescore_outputs = rescore_session_->Run( run_option, rescore_in_names_.data(), rescore_inputs.data(), rescore_inputs.size(), rescore_out_names_.data(), rescore_out_names_.size()); + VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo(); std::vector decoder_out_shape = type_info.GetShape(); //(B, beam, T2) @@ -349,28 +358,30 @@ void BatchOnnxAsrModel::AttentionRescoring( std::vector decoder_outs_fp16; // for holding decoder outs data converted from fp16; std::vector r_decoder_outs_fp16; // for holding decoder outs data converted from fp16; if (is_fp16_) { + Timer timer; int length = decoder_out_shape[0] * decoder_out_shape[1] * decoder_out_shape[2]; decoder_outs_fp16.resize(length); - auto outs = rescore_outputs[0].GetTensorMutableData(); + auto outs = rescore_outputs[0].GetTensorMutableData(); for (size_t i = 0; i < length; ++i) { - decoder_outs_fp16[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(outs[i])); + decoder_outs_fp16[i] = _cvtsh_ss(outs[i].value); } decoder_outs_data = decoder_outs_fp16.data(); if (is_bidirectional_decoder_ && reverse_weight > 0) { r_decoder_outs_fp16.reserve(length); - auto r_outs = rescore_outputs[1].GetTensorMutableData(); + auto r_outs = rescore_outputs[1].GetTensorMutableData(); for (size_t i = 0; i < length; ++i) { - r_decoder_outs_fp16[i] = Eigen::half_impl::half_to_float(Eigen::half_impl::raw_uint16_to_half(r_outs[i])); + r_decoder_outs_fp16[i] = _cvtsh_ss(r_outs[i].value); } r_decoder_outs_data = r_decoder_outs_fp16.data(); } + VLOG(1) << "decoder_out from fp16 to float takes " << timer.Elapsed() << " ms. data length " << length; } else { decoder_outs_data = rescore_outputs[0].GetTensorMutableData(); if (is_bidirectional_decoder_ && reverse_weight > 0) { r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData(); } } - + int decode_out_len = decoder_out_shape[2]; attention_scores->clear(); for (size_t i = 0; i < batch_size; ++i) { @@ -384,7 +395,7 @@ void BatchOnnxAsrModel::AttentionRescoring( if (is_bidirectional_decoder_ && reverse_weight > 0) { std::vector r_hyp(hyp.size()); std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); - p = r_decoder_outs_data + (i * beam_size +j) * max_hyps_len * decode_out_len; + p = r_decoder_outs_data + (i * beam_size +j) * max_hyps_len * decode_out_len; r_score = ComputeAttentionScore(p, r_hyp, eos_, decode_out_len); } Y[j] = score * (1 - reverse_weight) + r_score * reverse_weight; From 48602b72597ffb56fa912882452789bf6ff69d73 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 31 Aug 2022 13:06:12 +0800 Subject: [PATCH 29/62] let model caculate attention score --- runtime/core/decoder/batch_asr_decoder.cc | 165 +++++++++++------- runtime/core/decoder/batch_asr_decoder.h | 1 + runtime/core/decoder/batch_asr_model.h | 4 +- runtime/core/decoder/batch_onnx_asr_model.cc | 122 ++++++------- runtime/core/decoder/batch_onnx_asr_model.h | 5 +- runtime/core/decoder/batch_torch_asr_model.cc | 61 +++---- runtime/core/decoder/batch_torch_asr_model.h | 9 +- runtime/core/decoder/params.h | 2 +- 8 files changed, 210 insertions(+), 159 deletions(-) diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 6e77deafa..8fad0867f 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -44,6 +44,13 @@ BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, // Check if model has a right to left decoder CHECK(model_->is_bidirectional_decoder()); } + if (nullptr == fst_) { + searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts, + resource->context_graph)); + } else { + searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts, + resource->context_graph)); + } } void BatchAsrDecoder::Reset() { @@ -52,13 +59,13 @@ void BatchAsrDecoder::Reset() { void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index) { Timer ctc_timer; - SearchInterface* searcher; + std::unique_ptr searcher; if (nullptr == fst_) { - searcher = new CtcPrefixBeamSearch(opts_.ctc_prefix_search_opts, - resource_->context_graph); + searcher.reset(new CtcPrefixBeamSearch(opts_.ctc_prefix_search_opts, + resource_->context_graph)); } else { - searcher = new CtcWfstBeamSearch(*fst_, opts_.ctc_wfst_search_opts, - resource_->context_graph); + searcher.reset(new CtcWfstBeamSearch(*fst_, opts_.ctc_wfst_search_opts, + resource_->context_graph)); } // 3.1. ctc search ctc_timer.Reset(); @@ -67,7 +74,7 @@ void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int inde VLOG(1) << "\tctc search i==" << index << " takes " << ctc_timer.Elapsed() << " ms"; ctc_timer.Reset(); std::vector result; - UpdateResult(searcher, result); + UpdateResult(searcher.get(), result); std::lock_guard lock(mutex_); batch_pair_result_.emplace_back(std::make_pair(index, std::move(result))); const auto& hypotheses = searcher->Inputs(); @@ -83,7 +90,6 @@ void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int inde } else { batch_hyps_.emplace_back(std::make_pair(index, std::move(hypotheses))); } - delete searcher; } void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { @@ -99,39 +105,49 @@ void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 1. calc fbank feature of the batch of wavs Timer timer; - std::vector fbank_threads; - for (size_t i = 0; i < wavs.size(); i++) { - const std::vector& wav = wavs[i]; - std::thread thd(&BatchAsrDecoder::FbankWorker, this, wav, i); - fbank_threads.push_back(std::move(thd)); - } - for(auto& thd : fbank_threads) { - thd.join(); - } - std::sort(batch_feats_.begin(), batch_feats_.end()); - std::sort(batch_feats_lens_.begin(), batch_feats_lens_.end()); batch_feature_t batch_feats; std::vector batch_feats_lens; - for (auto& pair : batch_feats_) { - batch_feats.push_back(std::move(pair.second)); - } - for (auto& pair : batch_feats_lens_) { - batch_feats_lens.push_back(pair.second); + if (wavs.size() > 1) { + std::vector fbank_threads; + for (size_t i = 0; i < wavs.size(); i++) { + const std::vector& wav = wavs[i]; + std::thread thd(&BatchAsrDecoder::FbankWorker, this, wav, i); + fbank_threads.push_back(std::move(thd)); + } + for(auto& thd : fbank_threads) { + thd.join(); + } + std::sort(batch_feats_.begin(), batch_feats_.end()); + std::sort(batch_feats_lens_.begin(), batch_feats_lens_.end()); + for (auto& pair : batch_feats_) { + batch_feats.push_back(std::move(pair.second)); + } + for (auto& pair : batch_feats_lens_) { + batch_feats_lens.push_back(pair.second); + } + } else { + // only one wave + feature_t feats; + int num_frames = fbank_.Compute(wavs[0], &feats); + batch_feats.push_back(feats); + batch_feats_lens.push_back(num_frames); } VLOG(1) << "feature Compute takes " << timer.Elapsed() << " ms."; // 1.1 feature padding timer.Reset(); - int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); - for (auto& feat : batch_feats) { - if (feat.size() == max_len) continue; - int pad_len = max_len - feat.size(); - for (size_t i = 0; i< pad_len; i++) { - std::vector one(feature_config_->num_bins, 0.0); - feat.push_back(std::move(one)); + if (wavs.size() > 1) { + int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); + for (auto& feat : batch_feats) { + if (feat.size() == max_len) continue; + int pad_len = max_len - feat.size(); + for (size_t i = 0; i< pad_len; i++) { + std::vector one(feature_config_->num_bins, 0.0); + feat.push_back(std::move(one)); + } } + VLOG(1) << "padding feautre takes " << timer.Elapsed() << " ms."; } - VLOG(1) << "padding feautre takes " << timer.Elapsed() << " ms."; // 2. encoder forward timer.Reset(); @@ -141,45 +157,74 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 3. ctc search one by one of the batch // create batch of tct search result for attention decoding - int batch_size = wavs.size(); timer.Reset(); - batch_pair_result_.clear(); - batch_hyps_.clear(); - std::vector search_threads; - for (size_t i = 0; i < batch_size; i++) { - const auto& ctc_log_probs = batch_ctc_log_probs[i]; - std::thread thd(&BatchAsrDecoder::SearchWorker, this, ctc_log_probs, i); - search_threads.push_back(std::move(thd)); - } - for(auto& thd : search_threads) { - thd.join(); - } - VLOG(1) << "ctc search batch(" << batch_size << ") takes " << timer.Elapsed() << " ms."; - - // 4. attention rescoring - timer.Reset(); - std::sort(batch_hyps_.begin(), batch_hyps_.end()); - std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), [](auto& a, auto& b) { - return a.first < b.first; - }); + int batch_size = wavs.size(); std::vector>> batch_hyps; - for (auto& pair : batch_hyps_) { - batch_hyps.push_back(std::move(pair.second)); + if (batch_size > 1) { + batch_pair_result_.clear(); + batch_hyps_.clear(); + std::vector search_threads; + for (size_t i = 0; i < batch_size; i++) { + const auto& ctc_log_probs = batch_ctc_log_probs[i]; + std::thread thd(&BatchAsrDecoder::SearchWorker, this, ctc_log_probs, i); + search_threads.push_back(std::move(thd)); + } + for(auto& thd : search_threads) { + thd.join(); + } + std::sort(batch_hyps_.begin(), batch_hyps_.end()); + std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), [](auto& a, auto& b) { + return a.first < b.first; }); + for (auto& pair : batch_hyps_) { + batch_hyps.push_back(std::move(pair.second)); + } + batch_result_.clear(); + for (auto& pair : batch_pair_result_) { + batch_result_.push_back(std::move(pair.second)); + } + } else { + // one wav + searcher_->Search(batch_ctc_log_probs[0]); + searcher_->FinalizeSearch(); + std::vector result; + UpdateResult(searcher_.get(), result); + std::lock_guard lock(mutex_); + batch_result_.push_back(std::move(result)); + const auto& hypotheses = searcher_->Inputs(); + if (hypotheses.size() < beam_size_) { + VLOG(2) << "=== searcher->Inputs() size < beam_size_, padding..."; + std::vector> hyps = hypotheses; + int to_pad = beam_size_ - hypotheses.size(); + for (size_t i = 0; i < to_pad; i++) { + std::vector pad = {0}; + hyps.push_back(std::move(pad)); + } + batch_hyps.push_back(std::move(hyps)); + } else { + batch_hyps.push_back(std::move(hypotheses)); + } } - - batch_result_.clear(); - for (auto& pair : batch_pair_result_) { - batch_result_.push_back(std::move(pair.second)); + VLOG(1) << "ctc search batch(" << batch_size << ") takes " << timer.Elapsed() << " ms."; + VLOG(1) << "1"; + std::vector> ctc_scores(batch_size); + for (int i = 0; i < batch_result_.size(); ++i) { + ctc_scores[i].resize(beam_size_); + for (int j = 0; j < beam_size_; ++j) { + ctc_scores[i][j] = batch_result_[i][j].score; + //std::cout << ctc_scores[i][j] << ", " << batch_result_[i][j].sentence << "\n"; + } + //std::cout << "==============\n"; } + // 4. attention rescoring + VLOG(1) << "2"; timer.Reset(); std::vector> attention_scores; - model_->AttentionRescoring(batch_hyps, opts_.reverse_weight, &attention_scores); + model_->AttentionRescoring(batch_hyps, ctc_scores, attention_scores); VLOG(1) << "attention rescoring takes " << timer.Elapsed() << " ms."; for (size_t i = 0; i < batch_size; i++) { std::vector& result = batch_result_[i]; for (size_t j = 0; j < beam_size_; j++) { - result[j].score = opts_.rescoring_weight * attention_scores[i][j] + - opts_.ctc_weight * result[j].score; + result[j].score = attention_scores[i][j]; } std::sort(result.begin(), result.end(), DecodeResult::CompareFunc); } diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h index 1f487da3f..19759d140 100644 --- a/runtime/core/decoder/batch_asr_decoder.h +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -87,6 +87,7 @@ class BatchAsrDecoder { const DecodeOptions& opts_; int beam_size_; const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence + std::unique_ptr searcher_; public: WENET_DISALLOW_COPY_AND_ASSIGN(BatchAsrDecoder); diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index 413f35a30..c5b86aeef 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -36,8 +36,8 @@ class BatchAsrModel { batch_ctc_log_prob_t& batch_ctc_prob); virtual void AttentionRescoring(const std::vector>>& batch_hyps, - float reverse_weight, - std::vector>* attention_scores) = 0; + const std::vector>& ctc_scores, + std::vector>& attention_scores) = 0; virtual std::shared_ptr Copy() const = 0; diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 99e73ae16..d8bd22bd4 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -243,11 +243,11 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( } } - Ort::RunOptions run_option; - run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); + // Ort::RunOptions run_option; + // run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); timer.Reset(); std::vector ort_outputs = encoder_session_->Run( - run_option, encoder_in_names_.data(), inputs.data(), + Ort::RunOptions(nullptr), encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; @@ -282,6 +282,7 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( } // 3. cache encoder outs encoder_outs_ = std::move(ort_outputs[0]); + encoder_outs_lens_ = std::move(ort_outputs[1]); } float BatchOnnxAsrModel::ComputeAttentionScore(const float* prob, @@ -297,8 +298,8 @@ float BatchOnnxAsrModel::ComputeAttentionScore(const float* prob, void BatchOnnxAsrModel::AttentionRescoring( const std::vector>>& batch_hyps, - float reverse_weight, - std::vector>* attention_scores) { + const std::vector>& ctc_scores, + std::vector>& attention_scores) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); // 1. prepare input for onnx @@ -316,18 +317,46 @@ void BatchOnnxAsrModel::AttentionRescoring( } } - // 1.2 generate hyps_pad_sos - std::vector hyps_pad_sos(batch_size * beam_size * max_hyps_len, 0); + // 1.2 generate hyps_pad_sos_eos, r_hyps_pad_sos_eos + std::vector hyps_pad_sos_eos(batch_size * beam_size * (max_hyps_len + 1), 0); + std::vector r_hyps_pad_sos_eos(batch_size * beam_size * (max_hyps_len + 1), 0); for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < beam_size; ++j) { const std::vector& hyps = batch_hyps[i][j]; - hyps_pad_sos[i * beam_size * max_hyps_len] = sos_; - for (size_t k = 0; k < hyps.size(); ++k) { - hyps_pad_sos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[k]; + hyps_pad_sos_eos[i * beam_size * max_hyps_len] = sos_; + size_t hyps_len = hyps.size(); + for (size_t k = 0; k < hyps_len; ++k) { + hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[k]; + r_hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[hyps_len - 1 - k]; } + hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + hyps.size() + 1] = eos_; + r_hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + hyps.size() + 1] = eos_; } } + // 1.3 ctc_scores_data + Ort::Value ctc_scores_tensor{nullptr}; + const int64_t ctc_shape[] = {batch_size, beam_size}; + if (is_fp16_) { + std::vector data(batch_size * beam_size); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < beam_size; ++j) { + int p = i * beam_size + j; + data[p] = Ort::Float16_t(_cvtss_sh(ctc_scores[i][j], 0)); + } + } + ctc_scores_tensor = std::move(Ort::Value::CreateTensor( + memory_info, data.data(), data.size(), ctc_shape, 2)); + } else { + std::vector data(batch_size * beam_size); + for (size_t i = 0; i < batch_size; ++i) { + memcpy(data.data() + i * beam_size, ctc_scores[i].data(), sizeof(float) * beam_size); + } + ctc_scores_tensor = std::move(Ort::Value::CreateTensor( + memory_info, data.data(), data.size(), ctc_shape, 2)); + } + + // 2. forward attetion decoder const int64_t hyps_lens_shape[] = {batch_size, beam_size}; const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; @@ -335,72 +364,45 @@ void BatchOnnxAsrModel::AttentionRescoring( Ort::Value hyps_lens_tensor = Ort::Value::CreateTensor( memory_info, hyps_lens_sos.data(), hyps_lens_sos.size(), hyps_lens_shape, 2); Ort::Value hyps_pad_tensor = Ort::Value::CreateTensor( - memory_info, hyps_pad_sos.data(), hyps_pad_sos.size(), hyps_pad_shape, 3); + memory_info, hyps_pad_sos_eos.data(), hyps_pad_sos_eos.size(), hyps_pad_shape, 3); + Ort::Value r_hyps_pad_tensor = Ort::Value::CreateTensor( + memory_info, r_hyps_pad_sos_eos.data(), r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); std::vector rescore_inputs; - rescore_inputs.emplace_back(std::move(encoder_outs_)); - rescore_inputs.emplace_back(std::move(hyps_pad_tensor)); - rescore_inputs.emplace_back(std::move(hyps_lens_tensor)); + rescore_inputs.push_back(std::move(encoder_outs_)); + rescore_inputs.push_back(std::move(encoder_outs_lens_)); + rescore_inputs.push_back(std::move(hyps_pad_tensor)); + rescore_inputs.push_back(std::move(hyps_lens_tensor)); + rescore_inputs.push_back(std::move(r_hyps_pad_tensor)); + rescore_inputs.push_back(std::move(ctc_scores_tensor)); - Ort::RunOptions run_option; - run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); Timer timer; std::vector rescore_outputs = rescore_session_->Run( - run_option, rescore_in_names_.data(), rescore_inputs.data(), + Ort::RunOptions(nullptr), rescore_in_names_.data(), rescore_inputs.data(), rescore_inputs.size(), rescore_out_names_.data(), rescore_out_names_.size()); VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; - auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo(); - std::vector decoder_out_shape = type_info.GetShape(); //(B, beam, T2) - float* decoder_outs_data = nullptr; - float* r_decoder_outs_data = nullptr; - std::vector decoder_outs_fp16; // for holding decoder outs data converted from fp16; - std::vector r_decoder_outs_fp16; // for holding decoder outs data converted from fp16; + auto type_info = rescore_outputs[1].GetTensorTypeAndShapeInfo(); + std::vector scores_shape = type_info.GetShape(); //(B, beam, T2) + attention_scores.resize(scores_shape[0]); if (is_fp16_) { Timer timer; - int length = decoder_out_shape[0] * decoder_out_shape[1] * decoder_out_shape[2]; - decoder_outs_fp16.resize(length); - auto outs = rescore_outputs[0].GetTensorMutableData(); - for (size_t i = 0; i < length; ++i) { - decoder_outs_fp16[i] = _cvtsh_ss(outs[i].value); - } - decoder_outs_data = decoder_outs_fp16.data(); - if (is_bidirectional_decoder_ && reverse_weight > 0) { - r_decoder_outs_fp16.reserve(length); - auto r_outs = rescore_outputs[1].GetTensorMutableData(); - for (size_t i = 0; i < length; ++i) { - r_decoder_outs_fp16[i] = _cvtsh_ss(r_outs[i].value); + int length = scores_shape[0] * scores_shape[1]; + auto outs = rescore_outputs[1].GetTensorMutableData(); + for (size_t i = 0; i < scores_shape[0]; ++i) { + attention_scores[i].resize(scores_shape[1]); + for (size_t j = 0; j < scores_shape[1]; ++j) { + attention_scores[i][j] = _cvtsh_ss(outs[i * scores_shape[1] + j].value); } - r_decoder_outs_data = r_decoder_outs_fp16.data(); } VLOG(1) << "decoder_out from fp16 to float takes " << timer.Elapsed() << " ms. data length " << length; } else { - decoder_outs_data = rescore_outputs[0].GetTensorMutableData(); - if (is_bidirectional_decoder_ && reverse_weight > 0) { - r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData(); - } - } - - int decode_out_len = decoder_out_shape[2]; - attention_scores->clear(); - for (size_t i = 0; i < batch_size; ++i) { - std::vector Y(beam_size); - for (size_t j = 0; j < beam_size; ++j) { - const std::vector& hyp = batch_hyps[i][j]; - float score = 0.0f; - float* p = decoder_outs_data + (i * beam_size + j) * max_hyps_len * decode_out_len; - score = ComputeAttentionScore(p, hyp, eos_, decode_out_len); - float r_score = 0.0f; - if (is_bidirectional_decoder_ && reverse_weight > 0) { - std::vector r_hyp(hyp.size()); - std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); - p = r_decoder_outs_data + (i * beam_size +j) * max_hyps_len * decode_out_len; - r_score = ComputeAttentionScore(p, r_hyp, eos_, decode_out_len); - } - Y[j] = score * (1 - reverse_weight) + r_score * reverse_weight; + auto outs = rescore_outputs[0].GetTensorMutableData(); + for (size_t i = 0; i < scores_shape[0]; ++i) { + attention_scores[i].resize(scores_shape[1]); + memcpy(attention_scores[i].data(), outs + i * scores_shape[1], sizeof(float) * scores_shape[1]); } - attention_scores->push_back(std::move(Y)); } } diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index 7844b8eea..4359eca68 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -41,8 +41,8 @@ class BatchOnnxAsrModel : public BatchAsrModel { BatchOnnxAsrModel(const BatchOnnxAsrModel& other); void Read(const std::string& model_dir, bool is_fp16=false); void AttentionRescoring(const std::vector>>& batch_hyps, - float reverse_weight, - std::vector>* attention_scores) override; + const std::vector>& ctc_scores, + std::vector>& attention_scores) override; std::shared_ptr Copy() const override; void GetInputOutputInfo(const std::shared_ptr& session, @@ -76,6 +76,7 @@ class BatchOnnxAsrModel : public BatchAsrModel { // cache encoder outs: [encoder_outs, encoder_outs_lens] Ort::Value encoder_outs_{nullptr}; + Ort::Value encoder_outs_lens_{nullptr}; }; } // namespace wenet diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index 5c81fc1c4..9cc2b2c1a 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -163,9 +163,8 @@ float BatchTorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob, void BatchTorchAsrModel::AttentionRescoring( const std::vector>>& batch_hyps, - float reverse_weight, - std::vector>* attention_scores) { - CHECK(attention_scores != nullptr); + const std::vector>& ctc_scores, + std::vector>& attention_scores) { // Step 1: Prepare input for libtorch int batch_size = batch_hyps.size(); int beam_size = batch_hyps[0].size(); @@ -179,47 +178,49 @@ void BatchTorchAsrModel::AttentionRescoring( } } - // 1.2 add sos to hyps - torch::Tensor hyps_pad_sos = - torch::zeros({batch_size, beam_size, max_hyps_len}, torch::kLong); + // 1.2 add sos, eos to hyps, r_hyps + torch::Tensor hyps_pad_sos_eos = torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong); + torch::Tensor r_hyps_pad_sos_eos = torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong); for (size_t i = 0; i < batch_size; i++) { for (size_t j = 0; j < beam_size; j++) { const std::vector& hyp = batch_hyps[i][j]; - hyps_pad_sos[i][j][0] = sos_; - for (size_t k = 0; k < hyp.size(); k++) { - hyps_pad_sos[i][j][k + 1] = hyp[k]; + hyps_pad_sos_eos[i][j][0] = sos_; + r_hyps_pad_sos_eos[i][j][0] = sos_; + size_t hyps_len = hyp.size(); + for (size_t k = 0; k < hyps_len; k++) { + hyps_pad_sos_eos[i][j][k + 1] = hyp[k]; + r_hyps_pad_sos_eos[i][j][k + 1] = hyp[hyps_len - 1 - k]; } } } + // 1.3 ctc_scores_data + torch::Tensor ctc_scores_tensor = torch::zeros({batch_size, beam_size}, torch::kFloat); + for (size_t i = 0; i < batch_size; ++i) { + auto row = torch::from_blob(const_cast(ctc_scores[i].data()), + {beam_size}, torch::kFloat).clone(); + ctc_scores_tensor[i] = std::move(row); + } + // Step 2: Forward attention decoder - hyps_pad_sos = hyps_pad_sos.to(device_); + hyps_pad_sos_eos = hyps_pad_sos_eos.to(device_); hyps_lens_sos = hyps_lens_sos.to(device_); + r_hyps_pad_sos_eos = r_hyps_pad_sos_eos.to(device_); + ctc_scores_tensor = ctc_scores_tensor.to(device_); // encoder_lens_ = encoder_lens_.to(device_); // encoder_out_ = encoder_out_.to(device_); torch::NoGradGuard no_grad; - auto outputs = model_->run_method("batch_forward_attention_decoder", - encoder_out_, encoder_lens_, - hyps_pad_sos, hyps_lens_sos, - reverse_weight).toTuple()->elements(); - auto decoder_out = outputs[0].toTensor().to(at::kCPU); - auto r_decoder_out = outputs[1].toTensor().to(at::kCPU); + auto outputs = model_->run_method( + "batch_forward_attention_decoder", + encoder_out_, encoder_lens_, + hyps_pad_sos_eos, hyps_lens_sos, + r_hyps_pad_sos_eos, ctc_scores_tensor).toTuple()->elements(); + auto rescores = outputs[1].toTensor().to(at::kCPU); c10::cuda::CUDACachingAllocator::emptyCache(); - attention_scores->resize(batch_size); + attention_scores.resize(batch_size); for (size_t i = 0; i < batch_size; i++) { - (*attention_scores)[i].resize(beam_size); - for (size_t j = 0; j < beam_size; ++j) { - const std::vector& hyp = batch_hyps[i][j]; - float score = 0.0f; - score = ComputeAttentionScore(decoder_out[i * beam_size + j], hyp, eos_); - float r_score = 0.0f; - if (is_bidirectional_decoder_ && reverse_weight > 0) { - std::vector r_hyp(hyp.size()); - std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin()); - r_score = ComputeAttentionScore(r_decoder_out[i * beam_size + j], r_hyp, eos_); - } - (*attention_scores)[i][j] = score * (1 - reverse_weight) + r_score * reverse_weight; - } + attention_scores[i].resize(beam_size); + memcpy(attention_scores[i].data(), rescores[i].data_ptr(), sizeof(float) * beam_size); } } diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index e081690ef..b7abfe443 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -37,10 +37,11 @@ class BatchTorchAsrModel : public BatchAsrModel { using TorchModule = torch::jit::script::Module; BatchTorchAsrModel() = default; BatchTorchAsrModel(const BatchTorchAsrModel& other); - void Read(const std::string& model_path, bool is_fp16=false); - void AttentionRescoring(const std::vector>>& batch_hyps, - float reverse_weight, - std::vector>* attention_scores) override; + void Read(const std::string& model_path); + void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>& attention_scores) override; std::shared_ptr Copy() const override; protected: diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 9acdac9f9..7de47dad8 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -149,7 +149,7 @@ std::shared_ptr InitDecodeResourceFromFlags() { LOG(INFO) << "BatchTorchAsrModel Reading torch model " << FLAGS_model_path; BatchTorchAsrModel::InitEngineThreads(FLAGS_num_threads); auto model = std::make_shared(); - model->Read(FLAGS_model_path, FLAGS_is_fp16); + model->Read(FLAGS_model_path); resource->batch_model = model; } else { LOG(INFO) << "Reading torch model " << FLAGS_model_path; From a71815d456953b838b802965595bf54d4f31cbba Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 31 Aug 2022 13:08:19 +0800 Subject: [PATCH 30/62] let decoder return score --- wenet/bin/export_onnx_gpu.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 14f107d5e..fa6a69c6e 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -222,6 +222,7 @@ def forward(self, r_decoder_out: B x beam x T2 x V best_index: B """ + print('self.reverse_weight ', self.reverse_weight, 'self.ctc_weight ', self.ctc_weight) B, T, F = encoder_out.shape bz = self.beam_size B2 = B * bz @@ -262,7 +263,7 @@ def forward(self, score = torch.sum(score, axis=1) # B2 score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score best_index = torch.argmax(score, dim=1) - return best_index + return best_index, score def to_numpy(tensors): @@ -500,7 +501,7 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path): ort_outs = ort_session.run(None, ort_inputs) # check decoder output - test(to_numpy([o0]), ort_outs, rtol=1e-03, atol=1e-05) + test(to_numpy(o0), ort_outs, rtol=1e-03, atol=1e-05) logger.info("export to onnx decoder succeed!") From 2bfc65cb88178bbd339d54aa3dc592213f89a907 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 31 Aug 2022 13:09:17 +0800 Subject: [PATCH 31/62] let attention decoder return score --- wenet/transformer/asr_model.py | 93 ++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index cc31114c1..b38e24d7c 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -752,60 +752,65 @@ def batch_forward_encoder( @torch.jit.export def batch_forward_attention_decoder( - self, - encoder_out: torch.Tensor, - encoder_lens: torch.Tensor, - hyps_pad_sos: torch.Tensor, - hyps_lens_sos: torch.Tensor, - reverse_weight: float = 0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Export interface for c++ call, forward decoder with batch of - hypothesis from ctc prefix beam search and encoder output + self, + encoder_out: torch.Tensor, + encoder_lens: torch.Tensor, + hyps_pad_sos_eos: torch.Tensor, + hyps_lens_sos: torch.Tensor, + r_hyps_pad_sos_eos: torch.Tensor, + ctc_score: torch.Tensor): + """Decoder Args: encoder_out: B x T x F encoder_lens: B - hyps_pad_sos: B x beam x T2 - hyps with sos and padded by 0 + hyps_pad_sos_eos: B x beam x (T2+1), + hyps with sos & eos and padded by ignore id hyps_lens_sos: B x beam, length for each hyp with sos - reverse_weight: used for verfing whether used right to left decoder, - > 0 will use. - + r_hyps_pad_sos_eos: B x beam x (T2+1), + reversed hyps with sos & eos and padded by ignore id + ctc_score: B x beam, ctc score for each hyp Returns: - scores: (B, beam) + best_index: B + score: B x beam """ - assert encoder_out.size(0) == hyps_pad_sos.size(0) B, T, F = encoder_out.shape - bz = hyps_pad_sos.shape[1] + bz = hyps_pad_sos_eos.shape[1] # beam_size B2 = B * bz - T2 = hyps_pad_sos.shape[2] - # 1. prepare inputs for decoder encoder_out = encoder_out.repeat(1, bz, 1).view(B2, T, F) - encoder_mask = torch.ones(B2, 1, T, - dtype=torch.bool, - device=encoder_out.device) - # input for right to left decoder - # this hyps_lens has count token, we need minus it. - hyps = hyps_pad_sos.view(B2, T2) + encoder_mask = ~make_pad_mask(encoder_lens, T).unsqueeze(1) + encoder_mask = encoder_mask.repeat(1, bz, 1).view(B2, 1, T) + T2 = hyps_pad_sos_eos.shape[2] - 1 + hyps_pad = hyps_pad_sos_eos.view(B2, T2 + 1) hyps_lens = hyps_lens_sos.view(B2,) - if reverse_weight > 0: - r_hyps_lens = hyps_lens - 1 - r_hyps = hyps[:, 1:] - max_len = torch.max(r_hyps_lens) - index_range = torch.arange(0, max_len, 1).to(encoder_out.device) - seq_len_expand = r_hyps_lens.unsqueeze(1) - seq_mask = seq_len_expand > index_range # (beam, max_len) - index = (seq_len_expand - 1) - index_range # (beam, max_len) - index = index * seq_mask - r_hyps = torch.gather(r_hyps, 1, index) - r_hyps = torch.where(seq_mask, r_hyps, self.eos) - r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) - else: - r_hyps = torch.empty(0, device=encoder_out.device) + hyps_pad_sos = hyps_pad[:, :-1].contiguous() + hyps_pad_eos = hyps_pad[:, 1:].contiguous() + + r_hyps_pad = r_hyps_pad_sos_eos.view(B2, T2 + 1) + r_hyps_pad_sos = r_hyps_pad[:, :-1].contiguous() + r_hyps_pad_eos = r_hyps_pad[:, 1:].contiguous() - # 2. decoding decoder_out, r_decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, - reverse_weight) + encoder_out, encoder_mask, hyps_pad_sos, hyps_lens, r_hyps_pad_sos, + self.reverse_weight) decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) - r_decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) - return decoder_out, r_decoder_out + V = decoder_out.shape[-1] + decoder_out = decoder_out.view(B2, T2, V) + mask = ~make_pad_mask(hyps_lens, T2) # B2 x T2 + # mask index, remove ignore id + index = torch.unsqueeze(hyps_pad_eos * mask, 2) + score = decoder_out.gather(2, index).squeeze(2) # B2 X T2 + # mask padded part + score = score * mask + decoder_out = decoder_out.view(B, bz, T2, V) + if self.reverse_weight > 0: + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out.view(B2, T2, V) + index = torch.unsqueeze(r_hyps_pad_eos * mask, 2) + r_score = r_decoder_out.gather(2, index).squeeze(2) + r_score = r_score * mask + score = score * (1 - self.reverse_weight) + self.reverse_weight * r_score + r_decoder_out = r_decoder_out.view(B, bz, T2, V) + score = torch.sum(score, dim=1) # B2 + score = torch.reshape(score, (B, bz)) + self.ctc_weight * ctc_score + best_index = torch.argmax(score, dim=1) + return best_index, score From 0544b95648968ab369d5bea8c1a421fcaf49ebed Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:13:53 +0800 Subject: [PATCH 32/62] remove lock if only one wav --- runtime/core/decoder/batch_asr_decoder.cc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 8fad0867f..4506e3734 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -55,6 +55,7 @@ BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, void BatchAsrDecoder::Reset() { batch_result_.clear(); + searcher_->Reset(); } void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index) { @@ -71,11 +72,10 @@ void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int inde ctc_timer.Reset(); searcher->Search(ctc_log_probs); searcher->FinalizeSearch(); - VLOG(1) << "\tctc search i==" << index << " takes " << ctc_timer.Elapsed() << " ms"; - ctc_timer.Reset(); std::vector result; UpdateResult(searcher.get(), result); - std::lock_guard lock(mutex_); + VLOG(1) << "\tctc search i==" << index << " takes " << ctc_timer.Elapsed() << " ms"; + std::lock_guard lock(mutex_); batch_pair_result_.emplace_back(std::make_pair(index, std::move(result))); const auto& hypotheses = searcher->Inputs(); if (hypotheses.size() < beam_size_) { @@ -135,8 +135,8 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { VLOG(1) << "feature Compute takes " << timer.Elapsed() << " ms."; // 1.1 feature padding - timer.Reset(); if (wavs.size() > 1) { + timer.Reset(); int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); for (auto& feat : batch_feats) { if (feat.size() == max_len) continue; @@ -184,11 +184,11 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } } else { // one wav + VLOG(1) << "=== ctc search for one wav! " << batch_ctc_log_probs[0].size(); searcher_->Search(batch_ctc_log_probs[0]); searcher_->FinalizeSearch(); std::vector result; UpdateResult(searcher_.get(), result); - std::lock_guard lock(mutex_); batch_result_.push_back(std::move(result)); const auto& hypotheses = searcher_->Inputs(); if (hypotheses.size() < beam_size_) { @@ -205,18 +205,14 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } } VLOG(1) << "ctc search batch(" << batch_size << ") takes " << timer.Elapsed() << " ms."; - VLOG(1) << "1"; std::vector> ctc_scores(batch_size); for (int i = 0; i < batch_result_.size(); ++i) { ctc_scores[i].resize(beam_size_); for (int j = 0; j < beam_size_; ++j) { ctc_scores[i][j] = batch_result_[i][j].score; - //std::cout << ctc_scores[i][j] << ", " << batch_result_[i][j].sentence << "\n"; } - //std::cout << "==============\n"; } // 4. attention rescoring - VLOG(1) << "2"; timer.Reset(); std::vector> attention_scores; model_->AttentionRescoring(batch_hyps, ctc_scores, attention_scores); From ed673b047ada5a7f130f1a52df024996fe9149ae Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:14:26 +0800 Subject: [PATCH 33/62] add gpu_id flag for BatchOnnxAsrModel --- runtime/core/decoder/params.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 7de47dad8..330b25e51 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -93,6 +93,7 @@ DEFINE_int32(language_type, 0, DEFINE_bool(lowercase, true, "lowercase final result if needed"); DEFINE_bool(run_batch, false, "run websocket server for batch decoding"); DEFINE_bool(is_fp16, false, "the model is of fp16"); +DEFINE_int32(gpu_id, 0, "which GPU to use"); namespace wenet { std::shared_ptr InitFeaturePipelineConfigFromFlags() { @@ -131,7 +132,7 @@ std::shared_ptr InitDecodeResourceFromFlags() { LOG(INFO) << "BatchOnnxAsrModel Reading ONNX model dir: " << FLAGS_onnx_dir; BatchOnnxAsrModel::InitEngineThreads(FLAGS_num_threads); auto model = std::make_shared(); - model->Read(FLAGS_onnx_dir, FLAGS_is_fp16); + model->Read(FLAGS_onnx_dir, FLAGS_is_fp16, FLAGS_gpu_id); resource->batch_model = model; } else { LOG(INFO) << "Reading onnx model "; From d9d676c3d15f51505d86980c10a0d973682ff2ab Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:14:56 +0800 Subject: [PATCH 34/62] add gpu_id flag for BatchOnnxAsrModel --- runtime/core/decoder/batch_onnx_asr_model.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index 4359eca68..64be657f4 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -39,7 +39,7 @@ class BatchOnnxAsrModel : public BatchAsrModel { public: BatchOnnxAsrModel() = default; BatchOnnxAsrModel(const BatchOnnxAsrModel& other); - void Read(const std::string& model_dir, bool is_fp16=false); + void Read(const std::string& model_dir, bool is_fp16=false, int gpu_id=0); void AttentionRescoring(const std::vector>>& batch_hyps, const std::vector>& ctc_scores, std::vector>& attention_scores) override; @@ -55,9 +55,6 @@ class BatchOnnxAsrModel : public BatchAsrModel { const std::vector& batch_feats_lens, batch_ctc_log_prob_t& batch_ctc_log_prob) override; - float ComputeAttentionScore(const float* prob, const std::vector& hyp, - int eos, int decode_out_len); - private: int encoder_output_size_ = 0; bool is_fp16_ = false; From 6afce87b57891b1edc5a4b23a21aac44e5f7cbf2 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:16:10 +0800 Subject: [PATCH 35/62] fix memory issue of CreateTensor() --- runtime/core/decoder/batch_onnx_asr_model.cc | 70 ++++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index d8bd22bd4..2b2b3b649 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -79,7 +79,7 @@ void BatchOnnxAsrModel::GetInputOutputInfo( } } -void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { +void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu_id) { is_fp16_ = is_fp16; VLOG(1) << "is_fp16_ " << is_fp16_; std::vector providers = Ort::GetAvailableProviders(); @@ -104,6 +104,7 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { // 1. Load sessions // config for CUDA + std::string device_id = std::to_string(gpu_id); std::vector keys{ "device_id", // "gpu_mem_limit", @@ -111,16 +112,16 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16) { "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", - "cudnn_conv1d_pad_to_nc1d" + // "cudnn_conv1d_pad_to_nc1d" // supported from 1.12.0 }; std::vector values{ - "0", + device_id.data(), // "2147483648", "kSameAsRequested", "DEFAULT", "1", "1", - "1" + //"1" }; // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 @@ -198,33 +199,42 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( int batch_size = batch_feats.size(); int num_frames = batch_feats[0].size(); int feature_dim = batch_feats[0][0].size(); + + // generate data for CreateTensor Ort::Value feats_ort{nullptr}; + // https://github.com/microsoft/onnxruntime/issues/9629#issuecomment-963828881 + // Ort::Value::CreateTensor does NOT copy the data + std::vector feats_fp16; // for holding feats of fp16 + std::vector feats_fp32; // for holding feats of float // speech const int64_t feats_shape[3] = {batch_size, num_frames, feature_dim}; Timer timer; if (is_fp16_) { - std::vector feats(batch_size * num_frames * feature_dim); + feats_fp16.resize(batch_size * num_frames * feature_dim); for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < num_frames; ++j) { for (size_t k = 0; k < feature_dim; ++k) { int p = i * num_frames * feature_dim + j * feature_dim + k; - feats[p] = Ort::Float16_t(_cvtss_sh(batch_feats[i][j][k], 0)); + feats_fp16[p] = Ort::Float16_t(_cvtss_sh(batch_feats[i][j][k], 0)); } } } - feats_ort = std::move(Ort::Value::CreateTensor( - memory_info, feats.data(), feats.size(), feats_shape, 3)); + auto tensor = Ort::Value::CreateTensor( + memory_info, + feats_fp16.data(), + feats_fp16.size(), + feats_shape, 3); + feats_ort = std::move(tensor); VLOG(1) << "feats to fp16 takes " << timer.Elapsed() << " ms."; } else { - std::vector feats; for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < num_frames; ++j) { - feats.insert(feats.end(), batch_feats[i][j].begin(), batch_feats[i][j].end()); + feats_fp32.insert(feats_fp32.end(), batch_feats[i][j].begin(), batch_feats[i][j].end()); } } feats_ort = std::move(Ort::Value::CreateTensor( - memory_info, feats.data(), feats.size(), feats_shape, 3)); + memory_info, feats_fp32.data(), feats_fp32.size(), feats_shape, 3)); } // speech_lens @@ -243,11 +253,9 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( } } - // Ort::RunOptions run_option; - // run_option.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); timer.Reset(); std::vector ort_outputs = encoder_session_->Run( - Ort::RunOptions(nullptr), encoder_in_names_.data(), inputs.data(), + Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; @@ -259,11 +267,11 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( std::vector ctc_log_probs_data; // for holding ctc_log_probs converted from fp16 if (is_fp16_) { timer.Reset(); - Ort::Float16_t* probs = ort_outputs[2].GetTensorMutableData(); + auto probs = ort_outputs[2].GetTensorMutableData(); int length = out_shape[0] * out_shape[1] * out_shape[2]; ctc_log_probs_data.resize(length); for (size_t i = 0; i < length; ++i) { - ctc_log_probs_data[i] = _cvtsh_ss(probs[i].value); + ctc_log_probs_data[i] = _cvtsh_ss(probs[i]); } ctc_log_probs = ctc_log_probs_data.data(); VLOG(1) << "ctc_log_probs from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; @@ -285,17 +293,6 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( encoder_outs_lens_ = std::move(ort_outputs[1]); } -float BatchOnnxAsrModel::ComputeAttentionScore(const float* prob, - const std::vector& hyp, int eos, - int decode_out_len) { - float score = 0.0f; - for (size_t j = 0; j < hyp.size(); ++j) { - score += *(prob + j * decode_out_len + hyp[j]); - } - score += *(prob + hyp.size() * decode_out_len + eos); - return score; -} - void BatchOnnxAsrModel::AttentionRescoring( const std::vector>>& batch_hyps, const std::vector>& ctc_scores, @@ -336,27 +333,28 @@ void BatchOnnxAsrModel::AttentionRescoring( // 1.3 ctc_scores_data Ort::Value ctc_scores_tensor{nullptr}; + std::vector ctc_fp16; + std::vector ctc_fp32; const int64_t ctc_shape[] = {batch_size, beam_size}; if (is_fp16_) { - std::vector data(batch_size * beam_size); + ctc_fp16.resize(batch_size * beam_size); for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < beam_size; ++j) { int p = i * beam_size + j; - data[p] = Ort::Float16_t(_cvtss_sh(ctc_scores[i][j], 0)); + ctc_fp16[p] = Ort::Float16_t(_cvtss_sh(ctc_scores[i][j], 0)); } } ctc_scores_tensor = std::move(Ort::Value::CreateTensor( - memory_info, data.data(), data.size(), ctc_shape, 2)); + memory_info, ctc_fp16.data(), ctc_fp16.size(), ctc_shape, 2)); } else { - std::vector data(batch_size * beam_size); + ctc_fp32.resize(batch_size * beam_size); for (size_t i = 0; i < batch_size; ++i) { - memcpy(data.data() + i * beam_size, ctc_scores[i].data(), sizeof(float) * beam_size); + memcpy(ctc_fp32.data() + i * beam_size, ctc_scores[i].data(), sizeof(float) * beam_size); } ctc_scores_tensor = std::move(Ort::Value::CreateTensor( - memory_info, data.data(), data.size(), ctc_shape, 2)); + memory_info, ctc_fp32.data(), ctc_fp32.size(), ctc_shape, 2)); } - // 2. forward attetion decoder const int64_t hyps_lens_shape[] = {batch_size, beam_size}; const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; @@ -383,8 +381,8 @@ void BatchOnnxAsrModel::AttentionRescoring( rescore_out_names_.size()); VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; - auto type_info = rescore_outputs[1].GetTensorTypeAndShapeInfo(); - std::vector scores_shape = type_info.GetShape(); //(B, beam, T2) + //(B, beam, T2) + auto scores_shape = rescore_outputs[1].GetTensorTypeAndShapeInfo().GetShape(); attention_scores.resize(scores_shape[0]); if (is_fp16_) { Timer timer; From 8ea5a3e70d279ad328c5e0b140d2f9b04ece57d7 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:18:13 +0800 Subject: [PATCH 36/62] add decoder_main_batch --- runtime/core/bin/CMakeLists.txt | 3 + runtime/core/bin/decoder_main_batch.cc | 82 ++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 runtime/core/bin/decoder_main_batch.cc diff --git a/runtime/core/bin/CMakeLists.txt b/runtime/core/bin/CMakeLists.txt index 9d96708e2..daafe05a8 100644 --- a/runtime/core/bin/CMakeLists.txt +++ b/runtime/core/bin/CMakeLists.txt @@ -1,6 +1,9 @@ add_executable(decoder_main decoder_main.cc) target_link_libraries(decoder_main PUBLIC decoder) +add_executable(decoder_main_batch decoder_main_batch.cc) +target_link_libraries(decoder_main_batch PUBLIC decoder) + add_executable(label_checker_main label_checker_main.cc) target_link_libraries(label_checker_main PUBLIC decoder) diff --git a/runtime/core/bin/decoder_main_batch.cc b/runtime/core/bin/decoder_main_batch.cc new file mode 100644 index 000000000..f9a2efb17 --- /dev/null +++ b/runtime/core/bin/decoder_main_batch.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "decoder/params.h" +#include "frontend/wav.h" +#include "utils/flags.h" +#include "utils/string.h" +#include "utils/timer.h" +#include "utils/utils.h" + +DEFINE_string(wav_path, "", "single wave path"); +DEFINE_int32(thread_num, 1, "num of decode thread"); +DEFINE_int32(batch_size, 1, "batch size of input"); + +std::shared_ptr g_decode_config; +std::shared_ptr g_feature_config; +std::shared_ptr g_decode_resource; + +int g_total_waves_dur = 0; +int g_total_decode_time = 0; + +// using namespace wenet; + +void decode(const std::string& wav) { + wenet::WavReader wav_reader(wav); + std::vector wav_data; + int num_samples = wav_reader.num_samples(); + wav_data.insert(wav_data.end(), wav_reader.data(), wav_reader.data() + num_samples); + std::vector> batch_wav_data; + int wav_dur = static_cast(static_cast(num_samples) / wav_reader.sample_rate() * 1000); + for (int i = 0; i < FLAGS_batch_size; ++i) { + batch_wav_data.push_back(wav_data); + g_total_waves_dur += wav_dur; + } + + std::unique_ptr decoder = + std::make_unique(g_feature_config, g_decode_resource, *g_decode_config); + wenet::Timer timer; + decoder->Decode(batch_wav_data); + int decode_time = timer.Elapsed(); + std::string result = decoder->get_batch_result(1, false); + std::cout << result << std::endl; + + LOG(INFO) << "batch_size : " << FLAGS_batch_size << std::endl; + LOG(INFO) << "Total: decoded " << g_total_waves_dur << "ms audio taken " + << decode_time << "ms."; + LOG(INFO) << "RTF: " << std::setprecision(4) + << static_cast(decode_time) / g_total_waves_dur; +} + + +int main(int argc, char* argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, false); + google::InitGoogleLogging(argv[0]); + + g_decode_config = wenet::InitDecodeOptionsFromFlags(); + g_feature_config = wenet::InitFeaturePipelineConfigFromFlags(); + g_decode_resource = wenet::InitDecodeResourceFromFlags(); + + if (FLAGS_wav_path.empty()) { + LOG(FATAL) << "Please provide the wave path."; + } + LOG(INFO) << "decoding " << FLAGS_wav_path; + decode(FLAGS_wav_path); + + return 0; +} From b3d82aac1ee4cd45aa70579c125ae728e15e085c Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 14:19:53 +0800 Subject: [PATCH 37/62] re-link --- runtime/onnxruntime/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/onnxruntime/CMakeLists.txt b/runtime/onnxruntime/CMakeLists.txt index 38174afb1..492b84242 120000 --- a/runtime/onnxruntime/CMakeLists.txt +++ b/runtime/onnxruntime/CMakeLists.txt @@ -1 +1 @@ -../LibTorch/CMakeLists.txt \ No newline at end of file +../libtorch/CMakeLists.txt \ No newline at end of file From b6d5cdf923369eae03498c78ddb964a4f431d38e Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 16:45:58 +0800 Subject: [PATCH 38/62] config cudnn_conv1d_pad_to_nc1d --- runtime/core/decoder/batch_onnx_asr_model.cc | 24 ++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 2b2b3b649..fec9123c3 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -23,6 +23,7 @@ #include #include +#include "glog/logging.h" #include "utils/string.h" #include "utils/Yaml.hpp" #include "utils/timer.h" @@ -112,7 +113,7 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", - // "cudnn_conv1d_pad_to_nc1d" // supported from 1.12.0 + "cudnn_conv1d_pad_to_nc1d" // supported from 1.12.0 }; std::vector values{ device_id.data(), @@ -121,7 +122,7 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu "DEFAULT", "1", "1", - //"1" + "1" }; // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 @@ -367,6 +368,25 @@ void BatchOnnxAsrModel::AttentionRescoring( memory_info, r_hyps_pad_sos_eos.data(), r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); std::vector rescore_inputs; + /* + for (auto name : rescore_in_names_) { + if (!strcmp(name, "encoder_out")) { + rescore_inputs.push_back(std::move(encoder_outs_)); + } else if (!strcmp(name, "encoder_out_lens")) { + rescore_inputs.push_back(std::move(encoder_outs_lens_)); + } else if (!strcmp(name, "hyps_pad_sos_eos")) { + rescore_inputs.push_back(std::move(hyps_pad_tensor)); + } else if (!strcmp(name, "hyps_lens_sos")) { + rescore_inputs.push_back(std::move(hyps_lens_tensor)); + } else if (!strcmp(name, "r_hyps_pad_sos_eos")) { + rescore_inputs.push_back(std::move(r_hyps_pad_tensor)); + } else if (!strcmp(name, "ctc_score")) { + rescore_inputs.push_back(std::move(ctc_scores_tensor)); + } else { + VLOG(1) << "invalid input name " << name; + } + } + */ rescore_inputs.push_back(std::move(encoder_outs_)); rescore_inputs.push_back(std::move(encoder_outs_lens_)); rescore_inputs.push_back(std::move(hyps_pad_tensor)); From 38e983d35d2ff12fd1eb3699047ef3c740e80bea Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 1 Sep 2022 17:53:40 +0800 Subject: [PATCH 39/62] use encoder's out : beam_log_probs/index (topk) to ctc_search, which save a lot of memory convertint time when is_fp16, compared with ctc_log_probs --- runtime/core/decoder/CMakeLists.txt | 1 - runtime/core/decoder/batch_asr_decoder.cc | 20 ++-- runtime/core/decoder/batch_asr_decoder.h | 5 +- runtime/core/decoder/batch_asr_model.cc | 7 +- runtime/core/decoder/batch_asr_model.h | 8 +- runtime/core/decoder/batch_onnx_asr_model.cc | 56 ++++++--- runtime/core/decoder/batch_onnx_asr_model.h | 7 +- .../core/decoder/ctc_prefix_beam_search.cc | 106 ++++++++++++++++++ runtime/core/decoder/ctc_prefix_beam_search.h | 2 + runtime/core/decoder/ctc_wfst_beam_search.h | 2 + runtime/core/decoder/search_interface.h | 2 + 11 files changed, 176 insertions(+), 40 deletions(-) diff --git a/runtime/core/decoder/CMakeLists.txt b/runtime/core/decoder/CMakeLists.txt index 2a5b89cf4..bd11940cb 100644 --- a/runtime/core/decoder/CMakeLists.txt +++ b/runtime/core/decoder/CMakeLists.txt @@ -6,7 +6,6 @@ set(decoder_srcs ctc_wfst_beam_search.cc ctc_endpoint.cc batch_asr_decoder.cc - batch_asr_model.cc ) if(NOT TORCH AND NOT ONNX) diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 4506e3734..7a96728bf 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -58,7 +58,10 @@ void BatchAsrDecoder::Reset() { searcher_->Reset(); } -void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index) { +void BatchAsrDecoder::SearchWorker( + const std::vector>& topk_scores, + const std::vector>& topk_indexs, + int index) { Timer ctc_timer; std::unique_ptr searcher; if (nullptr == fst_) { @@ -70,7 +73,7 @@ void BatchAsrDecoder::SearchWorker(const ctc_log_prob_t& ctc_log_probs, int inde } // 3.1. ctc search ctc_timer.Reset(); - searcher->Search(ctc_log_probs); + searcher->Search(topk_scores, topk_indexs); searcher->FinalizeSearch(); std::vector result; UpdateResult(searcher.get(), result); @@ -151,8 +154,9 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 2. encoder forward timer.Reset(); - batch_ctc_log_prob_t batch_ctc_log_probs; - model_->ForwardEncoder(batch_feats, batch_feats_lens, batch_ctc_log_probs); + std::vector>> batch_topk_scores; + std::vector>> batch_topk_indexs; + model_->ForwardEncoder(batch_feats, batch_feats_lens, batch_topk_scores, batch_topk_indexs); VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; // 3. ctc search one by one of the batch @@ -165,8 +169,9 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { batch_hyps_.clear(); std::vector search_threads; for (size_t i = 0; i < batch_size; i++) { - const auto& ctc_log_probs = batch_ctc_log_probs[i]; - std::thread thd(&BatchAsrDecoder::SearchWorker, this, ctc_log_probs, i); + const auto& topk_scores = batch_topk_scores[i]; + const auto& topk_indexs = batch_topk_indexs[i]; + std::thread thd(&BatchAsrDecoder::SearchWorker, this, topk_scores, topk_indexs, i); search_threads.push_back(std::move(thd)); } for(auto& thd : search_threads) { @@ -184,8 +189,7 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } } else { // one wav - VLOG(1) << "=== ctc search for one wav! " << batch_ctc_log_probs[0].size(); - searcher_->Search(batch_ctc_log_probs[0]); + searcher_->Search(batch_topk_scores[0], batch_topk_indexs[0]); searcher_->FinalizeSearch(); std::vector result; UpdateResult(searcher_.get(), result); diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h index 19759d140..d144ff98a 100644 --- a/runtime/core/decoder/batch_asr_decoder.h +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -66,7 +66,10 @@ class BatchAsrDecoder { std::vector> batch_feats_; // for FbankWorker std::vector> batch_feats_lens_; // for FbankWorker - void SearchWorker(const ctc_log_prob_t& ctc_log_probs, int index); + void SearchWorker( + const std::vector>& topk_scores, + const std::vector>& topk_indexs, + int index); std::mutex mutex_; std::vector>>> batch_hyps_; // for SearchWorker std::vector>> batch_pair_result_; // for SearchWorker diff --git a/runtime/core/decoder/batch_asr_model.cc b/runtime/core/decoder/batch_asr_model.cc index 977f17eff..263113b59 100644 --- a/runtime/core/decoder/batch_asr_model.cc +++ b/runtime/core/decoder/batch_asr_model.cc @@ -11,12 +11,13 @@ namespace wenet { void BatchAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& batch_ctc_prob) { - batch_ctc_prob.clear(); + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) { this->ForwardEncoderFunc( batch_feats, batch_feats_lens, - batch_ctc_prob); + batch_topk_scores, + batch_topk_indexs); } } // namespace wenet diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index c5b86aeef..aefbff8a6 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -33,7 +33,8 @@ class BatchAsrModel { virtual void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& batch_ctc_prob); + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) = 0; virtual void AttentionRescoring(const std::vector>>& batch_hyps, const std::vector>& ctc_scores, @@ -42,11 +43,6 @@ class BatchAsrModel { virtual std::shared_ptr Copy() const = 0; protected: - virtual void ForwardEncoderFunc( - const batch_feature_t& batch_feats, - const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& batch_ctc_prob) = 0; - int right_context_ = 1; int subsampling_rate_ = 1; int sos_ = 0; diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index fec9123c3..2c03f5da7 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -190,10 +190,11 @@ std::shared_ptr BatchOnnxAsrModel::Copy() const { return asr_model; } -void BatchOnnxAsrModel::ForwardEncoderFunc( +void BatchOnnxAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& out_prob) { + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); // 1. Prepare onnx required data @@ -260,33 +261,54 @@ void BatchOnnxAsrModel::ForwardEncoderFunc( inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; - float* ctc_log_probs = nullptr; - auto type_info = ort_outputs[2].GetTensorTypeAndShapeInfo(); - auto out_shape = type_info.GetShape(); + // get topk_scores + auto out_shape = ort_outputs[3].GetTensorTypeAndShapeInfo().GetShape(); int num_outputs = out_shape[1]; int output_dim = out_shape[2]; - std::vector ctc_log_probs_data; // for holding ctc_log_probs converted from fp16 + float* topk_scores_ptr = nullptr; + std::vector topk_scores_data; // for holding topk_scores converted from fp16 if (is_fp16_) { timer.Reset(); - auto probs = ort_outputs[2].GetTensorMutableData(); + auto probs = ort_outputs[3].GetTensorMutableData(); int length = out_shape[0] * out_shape[1] * out_shape[2]; - ctc_log_probs_data.resize(length); + topk_scores_data.resize(length); for (size_t i = 0; i < length; ++i) { - ctc_log_probs_data[i] = _cvtsh_ss(probs[i]); + topk_scores_data[i] = _cvtsh_ss(probs[i]); } - ctc_log_probs = ctc_log_probs_data.data(); - VLOG(1) << "ctc_log_probs from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; + topk_scores_ptr = topk_scores_data.data(); + VLOG(1) << "topk_scores from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; } else { - ctc_log_probs = ort_outputs[2].GetTensorMutableData(); + topk_scores_ptr = ort_outputs[3].GetTensorMutableData(); } - out_prob.resize(batch_size); + batch_topk_scores.resize(batch_size); for (size_t i = 0; i < batch_size; ++i) { - out_prob[i].resize(num_outputs); + batch_topk_scores[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; j++) { - out_prob[i][j].resize(output_dim); - float* p = ctc_log_probs + (i * num_outputs + j) * output_dim; - memcpy(out_prob[i][j].data(), p, sizeof(float) * output_dim); + batch_topk_scores[i][j].resize(output_dim); + float* p = topk_scores_ptr + (i * num_outputs + j) * output_dim; + memcpy(batch_topk_scores[i][j].data(), p, sizeof(float) * output_dim); + } + } + // get batch_topk_indexs + std::vector topk_indexs_data; // for holding topk_indexs converted from fp16 + timer.Reset(); + auto probs = ort_outputs[4].GetTensorMutableData(); + int length = out_shape[0] * out_shape[1] * out_shape[2]; + topk_indexs_data.resize(length); + for (size_t i = 0; i < length; ++i) { + topk_indexs_data[i] = probs[i]; + } + int32_t* topk_indexs_ptr = topk_indexs_data.data(); + VLOG(1) << "topk_indexs from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; + + batch_topk_indexs.resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + batch_topk_indexs[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + batch_topk_indexs[i][j].resize(output_dim); + int32_t* p = topk_indexs_ptr + (i * num_outputs + j) * output_dim; + memcpy(batch_topk_indexs[i][j].data(), p, sizeof(int32_t) * output_dim); } } // 3. cache encoder outs diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index 64be657f4..578d48a33 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -48,12 +48,11 @@ class BatchOnnxAsrModel : public BatchAsrModel { void GetInputOutputInfo(const std::shared_ptr& session, std::vector* in_names, std::vector* out_names); - - protected: - void ForwardEncoderFunc( + void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& batch_ctc_log_prob) override; + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) override; private: int encoder_output_size_ = 0; diff --git a/runtime/core/decoder/ctc_prefix_beam_search.cc b/runtime/core/decoder/ctc_prefix_beam_search.cc index 154c8864b..8b9c4cacb 100644 --- a/runtime/core/decoder/ctc_prefix_beam_search.cc +++ b/runtime/core/decoder/ctc_prefix_beam_search.cc @@ -209,6 +209,112 @@ void CtcPrefixBeamSearch::Search(const std::vector>& logp) { } } +void CtcPrefixBeamSearch::Search( + const std::vector>& topk_scores, + const std::vector>& topk_indexs) { + if (topk_scores.size() == 0) return; + int first_beam_size = + std::min(static_cast(topk_scores[0].size()), opts_.first_beam_size); + for (int t = 0; t < topk_scores.size(); ++t, ++abs_time_step_) { + std::unordered_map, PrefixScore, PrefixHash> next_hyps; + // 1. First beam prune, only select topk candidates + auto& topk_score = topk_scores[t]; + auto& topk_index = topk_indexs[t]; + + // 2. Token passing + for (int i = 0; i < topk_index.size(); ++i) { + int id = topk_index[i]; + auto prob = topk_score[i]; + for (const auto& it : cur_hyps_) { + const std::vector& prefix = it.first; + const PrefixScore& prefix_score = it.second; + // If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert + // PrefixScore(-inf, -inf) by default, since the default constructor + // of PrefixScore will set fields s(blank ending score) and + // ns(none blank ending score) to -inf, respectively. + if (id == opts_.blank) { + // Case 0: *a + ε => *a + PrefixScore& next_score = next_hyps[prefix]; + next_score.s = LogAdd(next_score.s, prefix_score.score() + prob); + next_score.v_s = prefix_score.viterbi_score() + prob; + next_score.times_s = prefix_score.times(); + // Prefix not changed, copy the context from prefix. + if (context_graph_ && !next_score.has_context) { + next_score.CopyContext(prefix_score); + next_score.has_context = true; + } + } else if (!prefix.empty() && id == prefix.back()) { + // Case 1: *a + a => *a + PrefixScore& next_score1 = next_hyps[prefix]; + next_score1.ns = LogAdd(next_score1.ns, prefix_score.ns + prob); + if (next_score1.v_ns < prefix_score.v_ns + prob) { + next_score1.v_ns = prefix_score.v_ns + prob; + if (next_score1.cur_token_prob < prob) { + next_score1.cur_token_prob = prob; + next_score1.times_ns = prefix_score.times_ns; + CHECK_GT(next_score1.times_ns.size(), 0); + next_score1.times_ns.back() = abs_time_step_; + } + } + if (context_graph_ && !next_score1.has_context) { + next_score1.CopyContext(prefix_score); + next_score1.has_context = true; + } + + // Case 2: *aε + a => *aa + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score2 = next_hyps[new_prefix]; + next_score2.ns = LogAdd(next_score2.ns, prefix_score.s + prob); + if (next_score2.v_ns < prefix_score.v_s + prob) { + next_score2.v_ns = prefix_score.v_s + prob; + next_score2.cur_token_prob = prob; + next_score2.times_ns = prefix_score.times_s; + next_score2.times_ns.emplace_back(abs_time_step_); + } + if (context_graph_ && !next_score2.has_context) { + // Prefix changed, calculate the context score. + next_score2.UpdateContext(context_graph_, prefix_score, id, + prefix.size()); + next_score2.has_context = true; + } + } else { + // Case 3: *a + b => *ab, *aε + b => *ab + std::vector new_prefix(prefix); + new_prefix.emplace_back(id); + PrefixScore& next_score = next_hyps[new_prefix]; + next_score.ns = LogAdd(next_score.ns, prefix_score.score() + prob); + if (next_score.v_ns < prefix_score.viterbi_score() + prob) { + next_score.v_ns = prefix_score.viterbi_score() + prob; + next_score.cur_token_prob = prob; + next_score.times_ns = prefix_score.times(); + next_score.times_ns.emplace_back(abs_time_step_); + } + if (context_graph_ && !next_score.has_context) { + // Calculate the context score. + next_score.UpdateContext(context_graph_, prefix_score, id, + prefix.size()); + next_score.has_context = true; + } + } + } + } + + // 3. Second beam prune, only keep top n best paths + std::vector, PrefixScore>> arr(next_hyps.begin(), + next_hyps.end()); + int second_beam_size = + std::min(static_cast(arr.size()), opts_.second_beam_size); + std::nth_element(arr.begin(), arr.begin() + second_beam_size, arr.end(), + PrefixScoreCompare); + arr.resize(second_beam_size); + std::sort(arr.begin(), arr.end(), PrefixScoreCompare); + + // 4. Update cur_hyps_ and get new result + UpdateHypotheses(arr); + } +} + void CtcPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); } void CtcPrefixBeamSearch::UpdateFinalContext() { diff --git a/runtime/core/decoder/ctc_prefix_beam_search.h b/runtime/core/decoder/ctc_prefix_beam_search.h index f44ec23c3..743752f97 100644 --- a/runtime/core/decoder/ctc_prefix_beam_search.h +++ b/runtime/core/decoder/ctc_prefix_beam_search.h @@ -99,6 +99,8 @@ class CtcPrefixBeamSearch : public SearchInterface { const std::shared_ptr& context_graph = nullptr); void Search(const std::vector>& logp) override; + void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) override; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kPrefixBeamSearch; } diff --git a/runtime/core/decoder/ctc_wfst_beam_search.h b/runtime/core/decoder/ctc_wfst_beam_search.h index 56967743d..0215ed1eb 100644 --- a/runtime/core/decoder/ctc_wfst_beam_search.h +++ b/runtime/core/decoder/ctc_wfst_beam_search.h @@ -63,6 +63,8 @@ class CtcWfstBeamSearch : public SearchInterface { const fst::Fst& fst, const CtcWfstBeamSearchOptions& opts, const std::shared_ptr& context_graph); void Search(const std::vector>& logp) override; + void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) override {}; void Reset() override; void FinalizeSearch() override; SearchType Type() const override { return SearchType::kWfstBeamSearch; } diff --git a/runtime/core/decoder/search_interface.h b/runtime/core/decoder/search_interface.h index 25bad2670..72722fb63 100644 --- a/runtime/core/decoder/search_interface.h +++ b/runtime/core/decoder/search_interface.h @@ -29,6 +29,8 @@ class SearchInterface { public: virtual ~SearchInterface() {} virtual void Search(const std::vector>& logp) = 0; + virtual void Search(const std::vector>& topk_scores, + const std::vector>& topk_indexs) = 0; virtual void Reset() = 0; virtual void FinalizeSearch() = 0; From fcacb20918953bd4f4ba102b25427ac36fe93ef3 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 2 Sep 2022 09:06:19 +0800 Subject: [PATCH 40/62] make ForwardEncoder() output topk --- runtime/core/decoder/batch_torch_asr_model.cc | 48 +++++++++---------- runtime/core/decoder/batch_torch_asr_model.h | 9 ++-- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index 9cc2b2c1a..f2e11fb28 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -96,10 +96,11 @@ std::shared_ptr BatchTorchAsrModel::Copy() const { return asr_model; } -void BatchTorchAsrModel::ForwardEncoderFunc( +void BatchTorchAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& out_prob) { + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) { // 1. Prepare libtorch required data int batch_size = batch_feats.size(); int num_frames = batch_feats[0].size(); @@ -126,39 +127,34 @@ void BatchTorchAsrModel::ForwardEncoderFunc( auto outputs = model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); - CHECK_EQ(outputs.size(), 3); + CHECK_EQ(outputs.size(), 5); encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) encoder_lens_ = outputs[1].toTensor(); // (B,) - auto ctc_log_probs = outputs[2].toTensor().to(at::kCPU); - // encoder_out_ = encoder_out_.to(at::kCPU); // to CPU to save GPU memory - // encoder_lens_ = encoder_lens_.to(at::kCPU); - // c10::cuda::CUDACachingAllocator::emptyCache(); - - // Copy to output - int num_outputs = ctc_log_probs.size(1); - int output_dim = ctc_log_probs.size(2); - out_prob.resize(batch_size); + // Copy topk_scores + auto topk_scores = outputs[3].toTensor().to(at::kCPU); + int num_outputs = topk_scores.size(1); + int output_dim = topk_scores.size(2); + batch_topk_scores.resize(batch_size); for (size_t i = 0; i < batch_size; i++) { - out_prob[i].resize(num_outputs); + batch_topk_scores[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; j++) { - out_prob[i][j].resize(output_dim); - memcpy(out_prob[i][j].data(), ctc_log_probs[i][j].data_ptr(), + batch_topk_scores[i][j].resize(output_dim); + memcpy(batch_topk_scores[i][j].data(), topk_scores[i][j].data_ptr(), sizeof(float) * output_dim); } } -} - -float BatchTorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob, - const std::vector& hyp, - int eos) { - float score = 0.0f; - auto accessor = prob.accessor(); - for (size_t j = 0; j < hyp.size(); ++j) { - score += accessor[j][hyp[j]]; + // copy topk_indexes + auto topk_indexes = outputs[4].toTensor().to(at::kCPU); + batch_topk_indexs.resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + batch_topk_indexs[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + batch_topk_indexs[i][j].resize(output_dim); + memcpy(batch_topk_indexs[i][j].data(), topk_indexes[i][j].data_ptr(), + sizeof(int) * output_dim); + } } - score += accessor[hyp.size()][eos]; - return score; } void BatchTorchAsrModel::AttentionRescoring( diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index b7abfe443..a49598f45 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -44,14 +44,11 @@ class BatchTorchAsrModel : public BatchAsrModel { std::vector>& attention_scores) override; std::shared_ptr Copy() const override; - protected: - void ForwardEncoderFunc( + void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - batch_ctc_log_prob_t& batch_ctc_log_prob) override; - - float ComputeAttentionScore(const torch::Tensor& batch_prob, - const std::vector& hyp, int eos); + std::vector>>& batch_topk_scores, + std::vector>>& batch_topk_indexs) override; private: std::shared_ptr model_ = nullptr; From 5b28502564fbf5313a2b66aa5d76fe01891a67f0 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 2 Sep 2022 09:10:42 +0800 Subject: [PATCH 41/62] make batch_forward_encoder() to return topk ctc_log_probs --- wenet/transformer/asr_model.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index b38e24d7c..6b0e49404 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -732,23 +732,30 @@ def batch_forward_encoder( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + beam_size: int = 10, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Export interface for c++ call, encode a batch of speech Args: speech: padded input tensor (B, T, F) speech_lengths: input length (B) Returns: - encoder output tensor xs, and subsampled masks - encoder_out: padded output tensor (B, T' ~= T/subsample_rate, D) - encoder_mask: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) + encoder_out: B x T x F + encoder_out_lens: B + ctc_log_probs: B x T x V + beam_log_probs: B x T x beam_size + beam_log_probs_idx: B x T x beam_size """ - encoder_out, encoder_mask = self.encoder(speech, speech_lengths, -1, -1) + encoder_out, encoder_mask = self.encoder( + speech, + speech_lengths, -1, -1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) encoder_out_lens = encoder_out_lens.int() ctc_log_probs = self.ctc.log_softmax(encoder_out) - return encoder_out, encoder_out_lens, ctc_log_probs + beam_log_probs, beam_log_probs_idx = torch.topk( + ctc_log_probs, beam_size, dim=2) + return encoder_out, encoder_out_lens, ctc_log_probs, \ + beam_log_probs, beam_log_probs_idx @torch.jit.export def batch_forward_attention_decoder( From 2ba32c3e9b522e66837877a1c30a3ca18ad61b63 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 21 Oct 2022 17:16:01 +0800 Subject: [PATCH 42/62] only emptyCache() if USE_GPU --- runtime/core/decoder/batch_torch_asr_model.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index f2e11fb28..a538dd6eb 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -130,7 +130,7 @@ void BatchTorchAsrModel::ForwardEncoder( CHECK_EQ(outputs.size(), 5); encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) encoder_lens_ = outputs[1].toTensor(); // (B,) - + // Copy topk_scores auto topk_scores = outputs[3].toTensor().to(at::kCPU); int num_outputs = topk_scores.size(1); @@ -212,7 +212,9 @@ void BatchTorchAsrModel::AttentionRescoring( hyps_pad_sos_eos, hyps_lens_sos, r_hyps_pad_sos_eos, ctc_scores_tensor).toTuple()->elements(); auto rescores = outputs[1].toTensor().to(at::kCPU); +#ifdef USE_GPU c10::cuda::CUDACachingAllocator::emptyCache(); +#endif attention_scores.resize(batch_size); for (size_t i = 0; i < batch_size; i++) { attention_scores[i].resize(beam_size); From e2259dbb8b0752a767687db6f15766f564fef263 Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 25 Oct 2022 08:43:40 +0800 Subject: [PATCH 43/62] supprot GPU --- runtime/core/cmake/onnx.cmake | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/runtime/core/cmake/onnx.cmake b/runtime/core/cmake/onnx.cmake index 70756dd00..2a90c56f6 100644 --- a/runtime/core/cmake/onnx.cmake +++ b/runtime/core/cmake/onnx.cmake @@ -1,11 +1,27 @@ if(ONNX) set(ONNX_VERSION "1.9.0") + if(GPU) + add_definitions(-DUSE_GPU) + set(ONNX_VERSION "1.12.1") + endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip") set(URL_HASH "SHA256=484b08c55867963bd8f74cc39d7c9b6199260f1184839cc40f37e50304597364") elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") - set(URL_HASH "SHA256=f386ab80e9d6d41f14ed9e61bff4acc6bf375770691bc3ba883ba0ba3cabca7f") + if(GPU) + set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-gpu-${ONNX_VERSION}.tgz") + if(${ONNX_VERSION} STREQUAL "1.12.1") + set(URL_HASH "SHA256=41fcb4b0bb162c2788240d5f21d18714238817a78fb68e5733c5caef326a7306") + elseif(${ONNX_VERSION} STREQUAL "1.12.0") + set(URL_HASH "SHA256=bc2e615314df0a871c560b7af6d4ce5896f351d23cad476562d2715208c9c7f7") + elseif(${ONNX_VERSION} STREQUAL "1.11.1") + set(URL_HASH "SHA256=b96e3e266f66f6e1293841e0a5b5ec3b0a602512d68e5cc73c014546092c87c8") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") + else() + set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz") + set(URL_HASH "SHA256=f386ab80e9d6d41f14ed9e61bff4acc6bf375770691bc3ba883ba0ba3cabca7f") + endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x64-${ONNX_VERSION}.tgz") set(URL_HASH "SHA256=71517c8571186eddd31e78134ac441571494fc2f524153165f4a2fec22940d66") From 457d9b06db2f729c2e177d687216a39e4cf57962 Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 25 Oct 2022 08:44:13 +0800 Subject: [PATCH 44/62] add more pytorch version --- runtime/core/cmake/libtorch.cmake | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/runtime/core/cmake/libtorch.cmake b/runtime/core/cmake/libtorch.cmake index 40a64ff84..bd4a9248f 100644 --- a/runtime/core/cmake/libtorch.cmake +++ b/runtime/core/cmake/libtorch.cmake @@ -1,6 +1,6 @@ if(TORCH) if(NOT ANDROID) - set(PYTORCH_VERSION "1.10.0") + set(PYTORCH_VERSION "1.12.0") if(GPU) add_definitions(-DUSE_GPU) set(CUDA_NAME "cu113") @@ -20,10 +20,18 @@ if(TORCH) if(CXX11_ABI) if(NOT GPU) set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") + if(${PYTORCH_VERSION} STREQUAL "1.12.0") + set(URL_HASH "SHA256=0f0f36219862a4ed0ad0522c4de97e9e189194b44eb09036d2b94bea456260c6") + else() + set(URL_HASH "SHA256=6d7be1073d1bd76f6563572b2aa5548ad51d5bc241d6895e3181b7dc25554426") + endif() else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/${CUDA_NAME}/libtorch-cxx11-abi-shared-with-deps-${PYTORCH_VERSION}%2B${CUDA_NAME}.zip") - set(URL_HASH "SHA256=190e963e739d5f7c2dcf94b3994de8fcd335706a4ebb333812ea7d8c841beb06") + if(${PYTORCH_VERSION} STREQUAL "1.12.0" AND ${CUDA_NAME} STREQUAL "cu113") + set(URL_HASH "SHA256=80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865") + else() + set(URL_HASH "SHA256=190e963e739d5f7c2dcf94b3994de8fcd335706a4ebb333812ea7d8c841beb06") + endif() endif() else() if(NOT GPU) @@ -31,7 +39,11 @@ if(TORCH) set(URL_HASH "SHA256=16961222938b205a6a767b0b0b9f5e3b1f8740aa1f3475580e33cfd5952b1a44") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/${CUDA_NAME}/libtorch-shared-with-deps-${PYTORCH_VERSION}%2B${CUDA_NAME}.zip") - set(URL_HASH "SHA256=0996a6a4ea8bbc1137b4fb0476eeca25b5efd8ed38955218dec1b73929090053") + if(${PYTORCH_VERSION} STREQUAL "1.12.0" AND ${CUDA_NAME} STREQUAL "cu113") + set(URL_HASH "SHA256=8e35371403f7052d9e9b43bcff383980dbde4df028986dc1dab539953481d55f") + else() + set(URL_HASH "SHA256=0996a6a4ea8bbc1137b4fb0476eeca25b5efd8ed38955218dec1b73929090053") + endif() endif() endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") @@ -57,7 +69,7 @@ if(TORCH) file(COPY ${TORCH_DLLS} DESTINATION ${CMAKE_BINARY_DIR}) endif() else() - # Change version in runtime/android/app/build.gradle. + # Change version in runtime/device/android/wenet/app/build.gradle. file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers") file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}") find_library(PYTORCH_LIBRARY pytorch_jni From 14ffae6087d5ed94c12c2eff98efc69e07f20b63 Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 25 Oct 2022 08:49:19 +0800 Subject: [PATCH 45/62] save eos, sos to onnx_config for onnxruntime of C++ --- wenet/bin/export_onnx_gpu.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index fa6a69c6e..e0c90e908 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -333,10 +333,14 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): # check encoder output test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) + is_bidirectional_decoder = 1 if configs['decoder'] == 'bitransformer' else 0 logger.info("export offline onnx encoder succeed!") onnx_config = {"beam_size": args.beam_size, - "reverse_weight": args.reverse_weight, - "ctc_weight": args.ctc_weight, + "reverse_weight": configs['model_conf']['reverse_weight'], + "ctc_weight": configs['model_conf']['ctc_weight'], + "sos": configs["output_dim"] - 1, + "eos": configs["output_dim"] - 1, + "is_bidirectional_decoder": is_bidirectional_decoder, "fp16": args.fp16} return onnx_config From 1e9faaf4ff483fb467ef49a0803b2cb1d51ece1f Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 25 Oct 2022 10:07:44 +0800 Subject: [PATCH 46/62] transformer decoder has no 'reverse_weight' in confi --- wenet/bin/export_onnx_gpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index e0c90e908..77f278336 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -335,8 +335,9 @@ def export_offline_encoder(model, configs, args, logger, encoder_onnx_path): test(to_numpy([o0, o1, o2, o3, o4]), ort_outs) is_bidirectional_decoder = 1 if configs['decoder'] == 'bitransformer' else 0 logger.info("export offline onnx encoder succeed!") + reverse_weight = configs['model_conf'].get('reverse_weight', 0) onnx_config = {"beam_size": args.beam_size, - "reverse_weight": configs['model_conf']['reverse_weight'], + "reverse_weight": reverse_weight, "ctc_weight": configs['model_conf']['ctc_weight'], "sos": configs["output_dim"] - 1, "eos": configs["output_dim"] - 1, From a73b792100a9ea9899029e41e21a1b7ddc01e064 Mon Sep 17 00:00:00 2001 From: veelion Date: Tue, 25 Oct 2022 10:13:05 +0800 Subject: [PATCH 47/62] fix rescore_inputs --- runtime/core/decoder/batch_onnx_asr_model.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 2c03f5da7..3ec64a2a3 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -390,7 +390,6 @@ void BatchOnnxAsrModel::AttentionRescoring( memory_info, r_hyps_pad_sos_eos.data(), r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); std::vector rescore_inputs; - /* for (auto name : rescore_in_names_) { if (!strcmp(name, "encoder_out")) { rescore_inputs.push_back(std::move(encoder_outs_)); @@ -408,13 +407,6 @@ void BatchOnnxAsrModel::AttentionRescoring( VLOG(1) << "invalid input name " << name; } } - */ - rescore_inputs.push_back(std::move(encoder_outs_)); - rescore_inputs.push_back(std::move(encoder_outs_lens_)); - rescore_inputs.push_back(std::move(hyps_pad_tensor)); - rescore_inputs.push_back(std::move(hyps_lens_tensor)); - rescore_inputs.push_back(std::move(r_hyps_pad_tensor)); - rescore_inputs.push_back(std::move(ctc_scores_tensor)); Timer timer; std::vector rescore_outputs = rescore_session_->Run( From 65eb60809e43e0bf7148fd6f21fe91d55e11ed26 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 26 Oct 2022 10:16:41 +0800 Subject: [PATCH 48/62] release GPU memory --- runtime/core/decoder/batch_onnx_asr_model.cc | 50 ++++++++++++++++++-- runtime/core/decoder/batch_onnx_asr_model.h | 1 + 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 3ec64a2a3..9f09567be 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -30,8 +30,9 @@ namespace wenet { -Ort::Env BatchOnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, ""); +Ort::Env BatchOnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_VERBOSE, ""); Ort::SessionOptions BatchOnnxAsrModel::session_options_ = Ort::SessionOptions(); +Ort::RunOptions BatchOnnxAsrModel::run_option_ = Ort::RunOptions(); void BatchOnnxAsrModel::InitEngineThreads(int num_threads) { session_options_.SetIntraOpNumThreads(num_threads); @@ -103,6 +104,47 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu rescore_onnx_path = model_dir + "/decoder_fp16.onnx"; } + // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_, 0)); + // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 + // 1. Not allocate weights memory through the arena + session_options_.AddConfigEntry("kOrtSessionOptionsUseDeviceAllocatorForInitializers", "1"); + // 2. Configure the arena to have high enough initial chunk to support most Run() calls. See "initial_chunk_size_bytes" + const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes"}; + const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256}; + + OrtArenaCfg* arena_cfg = nullptr; + const auto& api = Ort::GetApi(); + auto zz = api.CreateArenaCfgV2(keys, values, 5, &arena_cfg); + //auto zz = api.CreateArenaCfg(0, 0, 1024, 0, &arena_cfg); + VLOG(1) << "CreateArenaCfgV2: " << zz << ", arena_cfg: " << arena_cfg; + std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + + OrtCUDAProviderOptions cuda_options{}; + + cuda_options.device_id = 0; + cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; + cuda_options.gpu_mem_limit = std::numeric_limits::max(); + cuda_options.arena_extend_strategy = 0; + cuda_options.do_copy_in_default_stream = true; + cuda_options.has_user_compute_stream = 0; + cuda_options.user_compute_stream = nullptr; + cuda_options.default_memory_arena_cfg = arena_cfg; + + session_options_.AppendExecutionProvider_CUDA(cuda_options); + run_option_.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); + + /* share memory between sessions + //Ort::MemoryInfo* info_cuda = new Ort::MemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + OrtMemoryInfo* info_cuda = nullptr; + auto error = api.CreateMemoryInfo("Cpu", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault, &info_cuda); + VLOG(1) << "CreateMemoryInfo() error " << error; + OrtEnv* env_ptr = (OrtEnv*)(env_); + auto err = api.CreateAndRegisterAllocator(env_ptr, info_cuda, arena_cfg); + VLOG(1) << "CreateAndRegisterAllocator() error " << err; + */ + + + /* // 1. Load sessions // config for CUDA std::string device_id = std::to_string(gpu_id); @@ -133,10 +175,12 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options)); api.ReleaseCUDAProviderOptions(cuda_options); + */ try { encoder_session_ = std::make_shared( env_, encoder_onnx_path.c_str(), session_options_); + VLOG(1) << "========================= encoder_session_ done"; rescore_session_ = std::make_shared( env_, rescore_onnx_path.c_str(), session_options_); } catch (std::exception const& e) { @@ -257,7 +301,7 @@ void BatchOnnxAsrModel::ForwardEncoder( timer.Reset(); std::vector ort_outputs = encoder_session_->Run( - Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), + run_option_, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; @@ -410,7 +454,7 @@ void BatchOnnxAsrModel::AttentionRescoring( Timer timer; std::vector rescore_outputs = rescore_session_->Run( - Ort::RunOptions(nullptr), rescore_in_names_.data(), rescore_inputs.data(), + run_option_, rescore_in_names_.data(), rescore_inputs.data(), rescore_inputs.size(), rescore_out_names_.data(), rescore_out_names_.size()); VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index 578d48a33..99daa2303 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -63,6 +63,7 @@ class BatchOnnxAsrModel : public BatchAsrModel { // One Env must be created before using any other Onnxruntime functionality. static Ort::Env env_; // shared environment across threads. static Ort::SessionOptions session_options_; + static Ort::RunOptions run_option_; std::shared_ptr encoder_session_ = nullptr; std::shared_ptr rescore_session_ = nullptr; From 7d1700a0f6a8adaa106e79756e95675b2921b8ea Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 13:38:30 +0800 Subject: [PATCH 49/62] add onnx_version 1.13.1 --- runtime/core/cmake/onnx.cmake | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/runtime/core/cmake/onnx.cmake b/runtime/core/cmake/onnx.cmake index 2a90c56f6..c3b5028a1 100644 --- a/runtime/core/cmake/onnx.cmake +++ b/runtime/core/cmake/onnx.cmake @@ -2,7 +2,7 @@ if(ONNX) set(ONNX_VERSION "1.9.0") if(GPU) add_definitions(-DUSE_GPU) - set(ONNX_VERSION "1.12.1") + set(ONNX_VERSION "1.13.1") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip") @@ -16,6 +16,8 @@ if(ONNX) set(URL_HASH "SHA256=bc2e615314df0a871c560b7af6d4ce5896f351d23cad476562d2715208c9c7f7") elseif(${ONNX_VERSION} STREQUAL "1.11.1") set(URL_HASH "SHA256=b96e3e266f66f6e1293841e0a5b5ec3b0a602512d68e5cc73c014546092c87c8") + elseif(${ONNX_VERSION} STREQUAL "1.13.1") + set(URL_HASH "SHA256=7725c232c78b9b49037fa7409f3ae255ba81d9a7e1af910c2443b1174171d8b1") endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c") else() From 14d6cd58a9d0804cb65c3c74162bdbe79111cf5b Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 13:44:00 +0800 Subject: [PATCH 50/62] replace GetInputName() with GetInputNameAllocated(), becaue GetInputName() deprecated in 1.13.1 --- runtime/core/decoder/batch_onnx_asr_model.cc | 70 +++++++++----------- runtime/core/decoder/batch_onnx_asr_model.h | 1 + runtime/core/decoder/onnx_asr_model.cc | 38 ++++++----- runtime/core/decoder/onnx_asr_model.h | 1 + 4 files changed, 54 insertions(+), 56 deletions(-) diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 9f09567be..576a79647 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -27,12 +27,15 @@ #include "utils/string.h" #include "utils/Yaml.hpp" #include "utils/timer.h" +#include "onnxruntime_run_options_config_keys.h" +#include "onnxruntime_session_options_config_keys.h" namespace wenet { Ort::Env BatchOnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_VERBOSE, ""); Ort::SessionOptions BatchOnnxAsrModel::session_options_ = Ort::SessionOptions(); Ort::RunOptions BatchOnnxAsrModel::run_option_ = Ort::RunOptions(); +std::vector BatchOnnxAsrModel::node_names_; void BatchOnnxAsrModel::InitEngineThreads(int num_threads) { session_options_.SetIntraOpNumThreads(num_threads); @@ -47,7 +50,7 @@ void BatchOnnxAsrModel::GetInputOutputInfo( int num_nodes = session->GetInputCount(); in_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetInputName(i, allocator); + auto name = session->GetInputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -57,15 +60,16 @@ void BatchOnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type << " dims=" << shape.str(); - (*in_names)[i] = name; + node_names_.push_back(std::move(name)); + (*in_names)[i] = node_names_.back().get(); } // Output info num_nodes = session->GetOutputCount(); out_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetOutputName(i, allocator); + auto name = session->GetOutputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -75,9 +79,10 @@ void BatchOnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type << " dims=" << shape.str(); - (*out_names)[i] = name; + node_names_.push_back(std::move(name)); + (*out_names)[i] = node_names_.back().get(); } } @@ -106,8 +111,8 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_, 0)); // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 - // 1. Not allocate weights memory through the arena - session_options_.AddConfigEntry("kOrtSessionOptionsUseDeviceAllocatorForInitializers", "1"); + // 1. Not allocate weights memory through the arena + session_options_.AddConfigEntry(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1"); // 2. Configure the arena to have high enough initial chunk to support most Run() calls. See "initial_chunk_size_bytes" const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes"}; const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256}; @@ -115,72 +120,57 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu OrtArenaCfg* arena_cfg = nullptr; const auto& api = Ort::GetApi(); auto zz = api.CreateArenaCfgV2(keys, values, 5, &arena_cfg); - //auto zz = api.CreateArenaCfg(0, 0, 1024, 0, &arena_cfg); - VLOG(1) << "CreateArenaCfgV2: " << zz << ", arena_cfg: " << arena_cfg; std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); OrtCUDAProviderOptions cuda_options{}; - + cuda_options.device_id = 0; cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; - cuda_options.gpu_mem_limit = std::numeric_limits::max(); - cuda_options.arena_extend_strategy = 0; + //cuda_options.gpu_mem_limit = 16 * 1024 * 1024 * 1024ul; + cuda_options.arena_extend_strategy = 1; cuda_options.do_copy_in_default_stream = true; cuda_options.has_user_compute_stream = 0; cuda_options.user_compute_stream = nullptr; - cuda_options.default_memory_arena_cfg = arena_cfg; - + // TODO: arena_cfg didn't work, it blocked when session.Run() + // Just comment this out until find a work way. + // cuda_options.default_memory_arena_cfg = arena_cfg; session_options_.AppendExecutionProvider_CUDA(cuda_options); - run_option_.AddConfigEntry("kOrtRunOptionsConfigEnableMemoryArenaShrinkage", "gpu:0"); - - /* share memory between sessions - //Ort::MemoryInfo* info_cuda = new Ort::MemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); - OrtMemoryInfo* info_cuda = nullptr; - auto error = api.CreateMemoryInfo("Cpu", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault, &info_cuda); - VLOG(1) << "CreateMemoryInfo() error " << error; - OrtEnv* env_ptr = (OrtEnv*)(env_); - auto err = api.CreateAndRegisterAllocator(env_ptr, info_cuda, arena_cfg); - VLOG(1) << "CreateAndRegisterAllocator() error " << err; - */ - - /* + /* TODO: In the future use OrtCUDAProviderOptionsV2 until it support ArenaCfg // 1. Load sessions // config for CUDA std::string device_id = std::to_string(gpu_id); - std::vector keys{ + std::vector keys2{ "device_id", - // "gpu_mem_limit", + "gpu_mem_limit", "arena_extend_strategy", "cudnn_conv_algo_search", "do_copy_in_default_stream", "cudnn_conv_use_max_workspace", "cudnn_conv1d_pad_to_nc1d" // supported from 1.12.0 }; - std::vector values{ + std::vector values2{ device_id.data(), - // "2147483648", + //"2147483648", + "8589934592", "kSameAsRequested", "DEFAULT", "1", "1", "1" }; - // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options = nullptr; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - Ort::ThrowOnError(api.UpdateCUDAProviderOptions(cuda_options, keys.data(), values.data(), keys.size())); + Ort::ThrowOnError(api.UpdateCUDAProviderOptions(cuda_options, keys2.data(), values2.data(), keys2.size())); Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options)); - api.ReleaseCUDAProviderOptions(cuda_options); */ try { encoder_session_ = std::make_shared( env_, encoder_onnx_path.c_str(), session_options_); - VLOG(1) << "========================= encoder_session_ done"; rescore_session_ = std::make_shared( env_, rescore_onnx_path.c_str(), session_options_); } catch (std::exception const& e) { @@ -300,8 +290,10 @@ void BatchOnnxAsrModel::ForwardEncoder( } timer.Reset(); + //Ort::RunOptions ro; + //ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); std::vector ort_outputs = encoder_session_->Run( - run_option_, encoder_in_names_.data(), inputs.data(), + Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); VLOG(1) << "\tencoder ->Run() takes " << timer.Elapsed() << " ms."; @@ -453,8 +445,10 @@ void BatchOnnxAsrModel::AttentionRescoring( } Timer timer; + Ort::RunOptions ro; + ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); std::vector rescore_outputs = rescore_session_->Run( - run_option_, rescore_in_names_.data(), rescore_inputs.data(), + ro, rescore_in_names_.data(), rescore_inputs.data(), rescore_inputs.size(), rescore_out_names_.data(), rescore_out_names_.size()); VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index 99daa2303..ff2f4aa38 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -68,6 +68,7 @@ class BatchOnnxAsrModel : public BatchAsrModel { std::shared_ptr rescore_session_ = nullptr; // node names + static std::vector node_names_; std::vector encoder_in_names_, encoder_out_names_; std::vector rescore_in_names_, rescore_out_names_; diff --git a/runtime/core/decoder/onnx_asr_model.cc b/runtime/core/decoder/onnx_asr_model.cc index 3097e2020..f10b28a2b 100644 --- a/runtime/core/decoder/onnx_asr_model.cc +++ b/runtime/core/decoder/onnx_asr_model.cc @@ -42,7 +42,7 @@ void OnnxAsrModel::GetInputOutputInfo( int num_nodes = session->GetInputCount(); in_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetInputName(i, allocator); + auto name = session->GetInputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetInputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -52,15 +52,16 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type << " dims=" << shape.str(); - (*in_names)[i] = name; + node_names_.push_back(std::move(name)); + (*in_names)[i] = node_names_.back().get(); } // Output info num_nodes = session->GetOutputCount(); out_names->resize(num_nodes); for (int i = 0; i < num_nodes; ++i) { - char* name = session->GetOutputName(i, allocator); + auto name = session->GetOutputNameAllocated(i, allocator); Ort::TypeInfo type_info = session->GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); ONNXTensorElementDataType type = tensor_info.GetElementType(); @@ -70,9 +71,10 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type << " dims=" << shape.str(); - (*out_names)[i] = name; + node_names_.push_back(std::move(name)); + (*out_names)[i] = node_names_.back().get(); } } @@ -108,24 +110,24 @@ void OnnxAsrModel::Read(const std::string& model_dir) { Ort::AllocatorWithDefaultOptions allocator; encoder_output_size_ = - atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("output_size", allocator).get()); num_blocks_ = - atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator)); - head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("num_blocks", allocator).get()); + head_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("head", allocator).get()); cnn_module_kernel_ = atoi( - model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator)); + model_metadata.LookupCustomMetadataMapAllocated("cnn_module_kernel", allocator).get()); subsampling_rate_ = atoi( - model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator)); + model_metadata.LookupCustomMetadataMapAllocated("subsampling_rate", allocator).get()); right_context_ = - atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator)); - sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator)); - eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator)); - is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap( - "is_bidirectional_decoder", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("right_context", allocator).get()); + sos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("sos_symbol", allocator).get()); + eos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("eos_symbol", allocator).get()); + is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMapAllocated( + "is_bidirectional_decoder", allocator).get()); chunk_size_ = - atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("chunk_size", allocator).get()); num_left_chunks_ = - atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator)); + atoi(model_metadata.LookupCustomMetadataMapAllocated("left_chunks", allocator).get()); LOG(INFO) << "Onnx Model Info:"; LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; diff --git a/runtime/core/decoder/onnx_asr_model.h b/runtime/core/decoder/onnx_asr_model.h index 906bd0d68..f904df5ba 100644 --- a/runtime/core/decoder/onnx_asr_model.h +++ b/runtime/core/decoder/onnx_asr_model.h @@ -72,6 +72,7 @@ class OnnxAsrModel : public AsrModel { std::shared_ptr ctc_session_ = nullptr; // node names + std::vector node_names_; std::vector encoder_in_names_, encoder_out_names_; std::vector ctc_in_names_, ctc_out_names_; std::vector rescore_in_names_, rescore_out_names_; From 4ae1c65f0730dc43fc557df19ea8729d253a1266 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 16:39:14 +0800 Subject: [PATCH 51/62] add description of 'run_batch' mode --- runtime/libtorch/README.md | 30 ++++++++++++++++++++++++++++++ runtime/libtorch/README_CN.md | 29 +++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/runtime/libtorch/README.md b/runtime/libtorch/README.md index bccfcc4fe..a928000d8 100644 --- a/runtime/libtorch/README.md +++ b/runtime/libtorch/README.md @@ -131,6 +131,36 @@ Here is a demo for command line based websocket server/client interaction. ![Runtime server demo](../../../docs/images/runtime_server.gif) +#### run_batch (offline) mode on GPU + +When start Websocket server with the option `--run_batch`, it will work on `run_batch` mode which accept a batch of wav data (batch_size >= 1). The encoding and decoding use the advantage of GPU batch processing to improve speed. + +This mode support both of libtorch and onnxruntime, but libtorch performs better due to some GPU memory issue of onnxruntime. + +Test result: + +* hardware-1: + Platinum 8358P CPU @ 2.60GHz 15 cores + 80G memory, A5000 * 1 + 24G memory + +* hardware-2: + Platinum 8369B CPU @ 2.90GHz 32 cores + 120GB memory, A100-SXM4-80GB * 1 + 80GB memory + +* data: + 3000 wavs with different durations in range [0.6, 15] seconds. + +| hardware | websocket_server | concurrency | batch_size | RTF | CER | +| --- | --- | --- | --- | --- | --- | +| hardware-1 | libtorch(CPU) | 30 | - | 0.01666 | 8.90 | +| hardware-1 | libtorch(GPU) | 10 | - | 0.00831 | 8.90 | +| hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | +| hardware-2 | libtorch(CPU) | 48 | - | 0.00753 | 8.90 | +| hardware-2 | libtorch(GPU) | 48 | - | 0.00234 | 8.90 | +| hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | + +With same CPU, GPU is 2~3 times faster than CPU, run_batch is 2.x times faster than non run_batch mode. + + + ### gRPC Why grpc? You may find your answer in https://grpc.io/. diff --git a/runtime/libtorch/README_CN.md b/runtime/libtorch/README_CN.md index ee74968bd..526ed177d 100644 --- a/runtime/libtorch/README_CN.md +++ b/runtime/libtorch/README_CN.md @@ -110,6 +110,35 @@ model_dir=./20210602_unified_transformer_server 上述服务启动后,会监听 10086 端口。若想使用其他端口,请修改 `--port` 对应的参数. +#### run_batch (非流式) 模式运行在GPU上 + +启动 Websocket server 时添加 `--run_batch`,既可以开启 `run_batch` 模式,它的输入是一批wav数据(batch_size >= 1),模型在编码和解码阶段都可以利用GPU的批处理能力,从而提高推理速度。 + +该模式在 libtorch 和 onnxruntime 库上都已经实现,但是libtoch的表现更好(更大的并发性能),因为 onnxruntime 目前没有办法清除显存缓存而导致并发较大时显存不足。 + +测试结果: + +* hardware-1: + Platinum 8358P CPU @ 2.60GHz 15 cores + 80G memory, A5000 * 1 + 24G memory + +* hardware-2: + Platinum 8369B CPU @ 2.90GHz 32 cores + 120GB memory, A100-SXM4-80GB * 1 + 80GB memory + +* data: + 3000 wavs with different durations in range [0.6, 15] seconds. + +| hardware | websocket_server | concurrency | batch_size | RTF | CER | +| --- | --- | --- | --- | --- | --- | +| hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | +| hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | +| hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | +| hardware-2 | libtorch(CPU) | 48 | 1 | 0.00753 | 8.90 | +| hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | +| hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | + +可以看出,同样的CPU下,GPU(batch_size == 1) 是 CPU 速度的 2-3 倍, 而 run_batch 速度又是 GPU(batch_size==1) 的 2.x 倍。 + + ### websocket 识别客户端 客户端按 websocket 协议去请求服务,可以用不同语言来实现客户端。我们提供了两种客户端,一种是基于 C++ 的命令行工具。一种是基于网页形式的可视化客户端。 From 068e4a77657e2a0d8f05b54e208ac605a3c27025 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 16:52:35 +0800 Subject: [PATCH 52/62] fix batch_size --- runtime/libtorch/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/runtime/libtorch/README.md b/runtime/libtorch/README.md index a928000d8..507c18df2 100644 --- a/runtime/libtorch/README.md +++ b/runtime/libtorch/README.md @@ -150,11 +150,11 @@ Test result: | hardware | websocket_server | concurrency | batch_size | RTF | CER | | --- | --- | --- | --- | --- | --- | -| hardware-1 | libtorch(CPU) | 30 | - | 0.01666 | 8.90 | -| hardware-1 | libtorch(GPU) | 10 | - | 0.00831 | 8.90 | +| hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | +| hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | | hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | -| hardware-2 | libtorch(CPU) | 48 | - | 0.00753 | 8.90 | -| hardware-2 | libtorch(GPU) | 48 | - | 0.00234 | 8.90 | +| hardware-2 | libtorch(CPU) | 48 | 1 | 0.00753 | 8.90 | +| hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | | hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | With same CPU, GPU is 2~3 times faster than CPU, run_batch is 2.x times faster than non run_batch mode. From aa1ac47ae8ef72790cbc764428195b7675b1118e Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 16:55:17 +0800 Subject: [PATCH 53/62] notes for a little bigger CER --- runtime/libtorch/README.md | 2 +- runtime/libtorch/README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/libtorch/README.md b/runtime/libtorch/README.md index 507c18df2..ee01e25f4 100644 --- a/runtime/libtorch/README.md +++ b/runtime/libtorch/README.md @@ -157,7 +157,7 @@ Test result: | hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | | hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | -With same CPU, GPU is 2~3 times faster than CPU, run_batch is 2.x times faster than non run_batch mode. +With same CPU, GPU is 2~3 times faster than CPU, run_batch is 2.x times faster than non run_batch mode, but the CER has a little bigger. diff --git a/runtime/libtorch/README_CN.md b/runtime/libtorch/README_CN.md index 526ed177d..91323c980 100644 --- a/runtime/libtorch/README_CN.md +++ b/runtime/libtorch/README_CN.md @@ -136,7 +136,7 @@ model_dir=./20210602_unified_transformer_server | hardware-2 | libtorch(GPU) | 48 | 1 | 0.00234 | 8.90 | | hardware-2 | libtorch(GPU+batch) | 48 | 8 | 0.00110 | 9.61 | -可以看出,同样的CPU下,GPU(batch_size == 1) 是 CPU 速度的 2-3 倍, 而 run_batch 速度又是 GPU(batch_size==1) 的 2.x 倍。 +可以看出,同样的CPU下,GPU(batch_size == 1) 是 CPU 速度的 2-3 倍, 而 run_batch 速度又是 GPU(batch_size==1) 的 2.x 倍,但是CER有所提高。 ### websocket 识别客户端 From ba93bd9ab752b3debe7bc7fd31ac698b4103d071 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 17:14:27 +0800 Subject: [PATCH 54/62] remove trailing whitespace --- runtime/libtorch/README.md | 2 +- runtime/libtorch/README_CN.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/runtime/libtorch/README.md b/runtime/libtorch/README.md index ee01e25f4..0e97cd92e 100644 --- a/runtime/libtorch/README.md +++ b/runtime/libtorch/README.md @@ -149,7 +149,7 @@ Test result: 3000 wavs with different durations in range [0.6, 15] seconds. | hardware | websocket_server | concurrency | batch_size | RTF | CER | -| --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | | hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | | hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | | hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | diff --git a/runtime/libtorch/README_CN.md b/runtime/libtorch/README_CN.md index 91323c980..097c361d5 100644 --- a/runtime/libtorch/README_CN.md +++ b/runtime/libtorch/README_CN.md @@ -128,7 +128,7 @@ model_dir=./20210602_unified_transformer_server 3000 wavs with different durations in range [0.6, 15] seconds. | hardware | websocket_server | concurrency | batch_size | RTF | CER | -| --- | --- | --- | --- | --- | --- | +| --- | --- | --- | --- | --- | --- | | hardware-1 | libtorch(CPU) | 30 | 1 | 0.01666 | 8.90 | | hardware-1 | libtorch(GPU) | 10 | 1 | 0.00831 | 8.90 | | hardware-1 | libtorch(GPU+batch) | 20 | 8 | 0.00339 | 9.61 | From bbc7c1580e11088b131491d00b490a2a52defc57 Mon Sep 17 00:00:00 2001 From: veelion Date: Thu, 3 Nov 2022 17:17:34 +0800 Subject: [PATCH 55/62] fix flake8 error --- runtime/binding/python/py/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/binding/python/py/__init__.py b/runtime/binding/python/py/__init__.py index 2886f52aa..dcb72e016 100644 --- a/runtime/binding/python/py/__init__.py +++ b/runtime/binding/python/py/__init__.py @@ -1,3 +1,3 @@ from .decoder import Decoder # noqa -from .batch_decoder import BatchDecoder +from .batch_decoder import BatchDecoder # noqa from _wenet import wenet_set_log_level as set_log_level # noqa From fb9e43688590d901382b50224a2f07e5501d2712 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 4 Nov 2022 11:10:26 +0800 Subject: [PATCH 56/62] fix cpplint error --- runtime/core/api/batch_recognizer.h | 11 +- runtime/core/bin/api_batch_main.cc | 4 +- runtime/core/bin/decoder_main_batch.cc | 10 +- runtime/core/decoder/batch_asr_decoder.cc | 96 +- runtime/core/decoder/batch_asr_decoder.h | 15 +- runtime/core/decoder/batch_asr_model.cc | 24 - runtime/core/decoder/batch_asr_model.h | 12 +- runtime/core/decoder/batch_onnx_asr_model.cc | 141 +- runtime/core/decoder/batch_onnx_asr_model.h | 15 +- runtime/core/decoder/batch_torch_asr_model.cc | 17 +- runtime/core/decoder/batch_torch_asr_model.h | 4 +- runtime/core/decoder/onnx_asr_model.cc | 52 +- runtime/core/decoder/params.h | 6 +- runtime/core/utils/Yaml.cpp | 4582 ++++++++--------- runtime/core/utils/Yaml.hpp | 1165 ++--- .../core/websocket/batch_connection_handler.h | 1 + runtime/core/websocket/websocket_server.cc | 2 +- runtime/core/websocket/websocket_server.h | 2 +- 18 files changed, 2798 insertions(+), 3361 deletions(-) delete mode 100644 runtime/core/decoder/batch_asr_model.cc diff --git a/runtime/core/api/batch_recognizer.h b/runtime/core/api/batch_recognizer.h index 86cca4785..02418a19d 100644 --- a/runtime/core/api/batch_recognizer.h +++ b/runtime/core/api/batch_recognizer.h @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef API_BATCH_RECOGNIZER_H_ +#define API_BATCH_RECOGNIZER_H_ #include #include #include +#include #include "decoder/asr_decoder.h" #include "decoder/batch_asr_decoder.h" @@ -27,7 +30,7 @@ class BatchRecognizer { public: - explicit BatchRecognizer(const std::string& model_dir, int num_threads=1) { + explicit BatchRecognizer(const std::string& model_dir, int num_threads = 1) { // FeaturePipeline init feature_config_ = std::make_shared(80, 16000); // Resource init @@ -84,8 +87,9 @@ class BatchRecognizer { resource_->post_processor = std::make_shared(*post_process_opts_); // Init decoder - decoder_ = std::make_shared(feature_config_, resource_, - *decode_options_); + decoder_ = std::make_shared( + feature_config_, resource_, + *decode_options_); } std::string Decode(const std::vector& wavs) { @@ -141,3 +145,4 @@ class BatchRecognizer { std::string language_ = "chs"; }; +#endif // API_BATCH_RECOGNIZER_H_ diff --git a/runtime/core/bin/api_batch_main.cc b/runtime/core/bin/api_batch_main.cc index faa2b133e..e80321c21 100644 --- a/runtime/core/bin/api_batch_main.cc +++ b/runtime/core/bin/api_batch_main.cc @@ -34,7 +34,9 @@ int main(int argc, char* argv[]) { if (FLAGS_enable_timestamp) br.set_enable_timestamp(true); wenet::WavReader wav_reader(FLAGS_wav_path); std::vector data; - data.insert(data.end(), wav_reader.data(), wav_reader.data() + wav_reader.num_samples()); + data.insert( + data.end(), wav_reader.data(), + wav_reader.data() + wav_reader.num_samples()); std::vector> wavs; for (size_t i = 0; i < FLAGS_batch_size - 1; i++) { wavs.push_back(data); diff --git a/runtime/core/bin/decoder_main_batch.cc b/runtime/core/bin/decoder_main_batch.cc index f9a2efb17..3420baa26 100644 --- a/runtime/core/bin/decoder_main_batch.cc +++ b/runtime/core/bin/decoder_main_batch.cc @@ -40,16 +40,18 @@ void decode(const std::string& wav) { wenet::WavReader wav_reader(wav); std::vector wav_data; int num_samples = wav_reader.num_samples(); - wav_data.insert(wav_data.end(), wav_reader.data(), wav_reader.data() + num_samples); + wav_data.insert( + wav_data.end(), wav_reader.data(), wav_reader.data() + num_samples); std::vector> batch_wav_data; - int wav_dur = static_cast(static_cast(num_samples) / wav_reader.sample_rate() * 1000); + int wav_dur = static_cast( + static_cast(num_samples) / wav_reader.sample_rate() * 1000); for (int i = 0; i < FLAGS_batch_size; ++i) { batch_wav_data.push_back(wav_data); g_total_waves_dur += wav_dur; } - std::unique_ptr decoder = - std::make_unique(g_feature_config, g_decode_resource, *g_decode_config); + auto decoder = std::make_unique( + g_feature_config, g_decode_resource, *g_decode_config); wenet::Timer timer; decoder->Decode(batch_wav_data); int decode_time = timer.Elapsed(); diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index 369c6fd01..d2a6137a7 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -33,7 +33,8 @@ BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, const DecodeOptions& opts) : feature_config_(config), beam_size_(opts.ctc_prefix_search_opts.first_beam_size), - fbank_(config->num_bins, config->sample_rate, config->frame_length, config->frame_shift), + fbank_(config->num_bins, config->sample_rate, + config->frame_length, config->frame_shift), model_(resource->batch_model->Copy()), post_processor_(resource->post_processor), symbol_table_(resource->symbol_table), @@ -77,8 +78,9 @@ void BatchAsrDecoder::SearchWorker( searcher->Search(topk_scores, topk_indexs); searcher->FinalizeSearch(); std::vector result; - UpdateResult(searcher.get(), result); - VLOG(1) << "\tctc search i==" << index << " takes " << ctc_timer.Elapsed() << " ms"; + UpdateResult(searcher.get(), &result); + VLOG(1) << "\tctc search i==" << index + << " takes " << ctc_timer.Elapsed() << " ms"; std::lock_guard lock(mutex_); batch_pair_result_.emplace_back(std::make_pair(index, std::move(result))); const auto& hypotheses = searcher->Inputs(); @@ -103,7 +105,8 @@ void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { std::lock_guard lock(mutex_); batch_feats_.push_back(std::make_pair(index, std::move(feats))); batch_feats_lens_.push_back(std::make_pair(index, num_frames)); - VLOG(1) << "\tfeature comput i==" << index << ", takes " << timer.Elapsed() << " ms."; + VLOG(1) << "\tfeature comput i==" << index + << ", takes " << timer.Elapsed() << " ms."; } void BatchAsrDecoder::Decode(const std::vector>& wavs) { @@ -118,7 +121,7 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { std::thread thd(&BatchAsrDecoder::FbankWorker, this, wav, i); fbank_threads.push_back(std::move(thd)); } - for(auto& thd : fbank_threads) { + for (auto& thd : fbank_threads) { thd.join(); } std::sort(batch_feats_.begin(), batch_feats_.end()); @@ -141,7 +144,8 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 1.1 feature padding if (wavs.size() > 1) { timer.Reset(); - int max_len = *std::max_element(batch_feats_lens.begin(), batch_feats_lens.end()); + int max_len = *std::max_element( + batch_feats_lens.begin(), batch_feats_lens.end()); for (auto& feat : batch_feats) { if (feat.size() == max_len) continue; int pad_len = max_len - feat.size(); @@ -157,7 +161,8 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { timer.Reset(); std::vector>> batch_topk_scores; std::vector>> batch_topk_indexs; - model_->ForwardEncoder(batch_feats, batch_feats_lens, batch_topk_scores, batch_topk_indexs); + model_->ForwardEncoder( + batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; // 3. ctc search one by one of the batch @@ -172,15 +177,17 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { for (size_t i = 0; i < batch_size; i++) { const auto& topk_scores = batch_topk_scores[i]; const auto& topk_indexs = batch_topk_indexs[i]; - std::thread thd(&BatchAsrDecoder::SearchWorker, this, topk_scores, topk_indexs, i); + std::thread thd( + &BatchAsrDecoder::SearchWorker, this, topk_scores, topk_indexs, i); search_threads.push_back(std::move(thd)); } - for(auto& thd : search_threads) { + for (auto& thd : search_threads) { thd.join(); } std::sort(batch_hyps_.begin(), batch_hyps_.end()); - std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), [](auto& a, auto& b) { - return a.first < b.first; }); + std::sort(batch_pair_result_.begin(), batch_pair_result_.end(), + [](auto& a, auto& b) { + return a.first < b.first; }); for (auto& pair : batch_hyps_) { batch_hyps.push_back(std::move(pair.second)); } @@ -193,7 +200,7 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { searcher_->Search(batch_topk_scores[0], batch_topk_indexs[0]); searcher_->FinalizeSearch(); std::vector result; - UpdateResult(searcher_.get(), result); + UpdateResult(searcher_.get(), &result); batch_result_.push_back(std::move(result)); const auto& hypotheses = searcher_->Inputs(); if (hypotheses.size() < beam_size_) { @@ -209,7 +216,8 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { batch_hyps.push_back(std::move(hypotheses)); } } - VLOG(1) << "ctc search batch(" << batch_size << ") takes " << timer.Elapsed() << " ms."; + VLOG(1) << "ctc search batch(" << batch_size << ") takes " + << timer.Elapsed() << " ms."; std::vector> ctc_scores(batch_size); for (int i = 0; i < batch_result_.size(); ++i) { ctc_scores[i].resize(beam_size_); @@ -220,7 +228,7 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { // 4. attention rescoring timer.Reset(); std::vector> attention_scores; - model_->AttentionRescoring(batch_hyps, ctc_scores, attention_scores); + model_->AttentionRescoring(batch_hyps, ctc_scores, &attention_scores); VLOG(1) << "attention rescoring takes " << timer.Elapsed() << " ms."; for (size_t i = 0; i < batch_size; i++) { std::vector& result = batch_result_[i]; @@ -231,13 +239,14 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } } -void BatchAsrDecoder::UpdateResult(SearchInterface* searcher, std::vector& result) { +void BatchAsrDecoder::UpdateResult(SearchInterface* searcher, + std::vector* result) { bool finish = true; const auto& hypotheses = searcher->Outputs(); const auto& inputs = searcher->Inputs(); const auto& likelihood = searcher->Likelihood(); const auto& times = searcher->Times(); - result.clear(); + result->clear(); CHECK_EQ(hypotheses.size(), likelihood.size()); for (size_t i = 0; i < hypotheses.size(); i++) { @@ -293,39 +302,40 @@ void BatchAsrDecoder::UpdateResult(SearchInterface* searcher, std::vectorProcess(path.sentence, finish); } - result.emplace_back(path); + result->emplace_back(path); } } -const std::string BatchAsrDecoder::get_batch_result(int nbest, bool enable_timestamp) { - json::JSON obj; - obj["status"] = "ok"; - obj["type"] = "final_result"; - obj["batch_size"] = batch_result_.size(); - obj["batch_result"] = json::Array(); - for (const auto& result : batch_result_) { - json::JSON batch_one; - batch_one["nbest"] = json::Array(); - for (int i = 0; i < nbest && i < result.size(); i++) { - json::JSON one; - one["sentence"] = result[i].sentence; - // one["score"] = result[i].score; - if (enable_timestamp) { - one["word_pieces"] = json::Array(); - for (const auto& word_piece : result[i].word_pieces) { - json::JSON piece; - piece["word"] = word_piece.word; - piece["start"] = word_piece.start; - piece["end"] = word_piece.end; - one["word_pieces"].append(piece); - } +const std::string BatchAsrDecoder::get_batch_result(int nbest, + bool enable_timestamp) { + json::JSON obj; + obj["status"] = "ok"; + obj["type"] = "final_result"; + obj["batch_size"] = batch_result_.size(); + obj["batch_result"] = json::Array(); + for (const auto& result : batch_result_) { + json::JSON batch_one; + batch_one["nbest"] = json::Array(); + for (int i = 0; i < nbest && i < result.size(); i++) { + json::JSON one; + one["sentence"] = result[i].sentence; + // one["score"] = result[i].score; + if (enable_timestamp) { + one["word_pieces"] = json::Array(); + for (const auto& word_piece : result[i].word_pieces) { + json::JSON piece; + piece["word"] = word_piece.word; + piece["start"] = word_piece.start; + piece["end"] = word_piece.end; + one["word_pieces"].append(piece); } - one["sentence"] = result[i].sentence; - batch_one["nbest"].append(one); } - obj["batch_result"].append(batch_one); + one["sentence"] = result[i].sentence; + batch_one["nbest"].append(one); } - return obj.dump(); + obj["batch_result"].append(batch_one); + } + return obj.dump(); } } // namespace wenet diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h index e34c684f9..992b40061 100644 --- a/runtime/core/decoder/batch_asr_decoder.h +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -58,25 +58,28 @@ class BatchAsrDecoder { return feature_config_->frame_shift * 1000 / feature_config_->sample_rate; } - const std::vector>& batch_result() const { return batch_result_; } + const std::vector>& batch_result() const { + return batch_result_; } const std::string get_batch_result(int nbest, bool enable_timestamp); private: Fbank fbank_; void FbankWorker(const std::vector& wav, int index); - std::vector> batch_feats_; // for FbankWorker - std::vector> batch_feats_lens_; // for FbankWorker + std::vector> batch_feats_; // for FbankWorker + std::vector> batch_feats_lens_; // for FbankWorker void SearchWorker( const std::vector>& topk_scores, const std::vector>& topk_indexs, int index); std::mutex mutex_; - std::vector>>> batch_hyps_; // for SearchWorker - std::vector>> batch_pair_result_; // for SearchWorker + // for SearchWorker + std::vector>>> batch_hyps_; + std::vector>> batch_pair_result_; std::vector> batch_result_; - void UpdateResult(SearchInterface* searcher, std::vector& result); + void UpdateResult(SearchInterface* searcher, + std::vector* result); std::shared_ptr feature_config_; std::shared_ptr model_; diff --git a/runtime/core/decoder/batch_asr_model.cc b/runtime/core/decoder/batch_asr_model.cc deleted file mode 100644 index 7ba35db79..000000000 --- a/runtime/core/decoder/batch_asr_model.cc +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2022 Horizon Robotics. All Rights Reserved. -// Author: binbin.zhang@horizon.ai (Binbin Zhang) -// Copyright (c) 2022 SoundDataConverge Co.LTD (Weiliang Chong) - -#include "decoder/batch_asr_model.h" - -#include -#include - -namespace wenet { - -void BatchAsrModel::ForwardEncoder( - const batch_feature_t& batch_feats, - const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) { - this->ForwardEncoderFunc( - batch_feats, - batch_feats_lens, - batch_topk_scores, - batch_topk_indexs); - } - -} // namespace wenet diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index 148b0e3cc..53f686f3a 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -21,7 +21,6 @@ using ctc_log_prob_t = std::vector>; using batch_ctc_log_prob_t = std::vector; class BatchAsrModel { - public: virtual int right_context() const { return right_context_; } virtual int subsampling_rate() const { return subsampling_rate_; } @@ -34,12 +33,13 @@ class BatchAsrModel { virtual void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) = 0; + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) = 0; - virtual void AttentionRescoring(const std::vector>>& batch_hyps, - const std::vector>& ctc_scores, - std::vector>& attention_scores) = 0; + virtual void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) = 0; virtual std::shared_ptr Copy() const = 0; diff --git a/runtime/core/decoder/batch_onnx_asr_model.cc b/runtime/core/decoder/batch_onnx_asr_model.cc index 49ebd0106..89cd89a9d 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.cc +++ b/runtime/core/decoder/batch_onnx_asr_model.cc @@ -19,17 +19,15 @@ #include "decoder/batch_onnx_asr_model.h" +#include #include #include #include -#include #include "glog/logging.h" #include "utils/string.h" #include "utils/Yaml.hpp" #include "utils/timer.h" -#include "onnxruntime_run_options_config_keys.h" -#include "onnxruntime_session_options_config_keys.h" namespace wenet { @@ -61,8 +59,8 @@ void BatchOnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type - << " dims=" << shape.str(); + LOG(INFO) << "\tInput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); node_names_.push_back(std::move(name)); (*in_names)[i] = node_names_.back().get(); } @@ -80,14 +78,15 @@ void BatchOnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type - << " dims=" << shape.str(); + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); node_names_.push_back(std::move(name)); (*out_names)[i] = node_names_.back().get(); } } -void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu_id) { +void BatchOnnxAsrModel::Read(const std::string& model_dir, + bool is_fp16, int gpu_id) { is_fp16_ = is_fp16; VLOG(1) << "is_fp16_ " << is_fp16_; std::vector providers = Ort::GetAvailableProviders(); @@ -110,34 +109,40 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu rescore_onnx_path = model_dir + "/decoder_fp16.onnx"; } - // Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_, 0)); - // release GPU memory: https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 + // release GPU memory: + // https://github.com/microsoft/onnxruntime/issues/9509#issuecomment-951546580 // 1. Not allocate weights memory through the arena - session_options_.AddConfigEntry(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1"); - // 2. Configure the arena to have high enough initial chunk to support most Run() calls. See "initial_chunk_size_bytes" - const char* keys[] = {"max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes"}; - const size_t values[] = {0 /*let ort pick default max memory*/, 0, 1024, 0, 256}; + session_options_.AddConfigEntry( + kOrtSessionOptionsUseDeviceAllocatorForInitializers, "1"); + // 2. Configure the arena to have high enough initial chunk + // to support most Run() calls. See "initial_chunk_size_bytes" + const char* keys[] = { + "max_mem", "arena_extend_strategy", "initial_chunk_size_bytes", + "max_dead_bytes_per_chunk", "initial_growth_chunk_size_bytes"}; + const size_t values[] = {0, 0, 1024, 0, 256}; OrtArenaCfg* arena_cfg = nullptr; const auto& api = Ort::GetApi(); auto zz = api.CreateArenaCfgV2(keys, values, 5, &arena_cfg); - std::unique_ptr rel_arena_cfg(arena_cfg, api.ReleaseArenaCfg); + std::unique_ptr rel_arena_cfg( + arena_cfg, api.ReleaseArenaCfg); OrtCUDAProviderOptions cuda_options{}; cuda_options.device_id = 0; - cuda_options.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; - //cuda_options.gpu_mem_limit = 16 * 1024 * 1024 * 1024ul; + cuda_options.cudnn_conv_algo_search = + OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; + // cuda_options.gpu_mem_limit = 16 * 1024 * 1024 * 1024ul; cuda_options.arena_extend_strategy = 1; cuda_options.do_copy_in_default_stream = true; cuda_options.has_user_compute_stream = 0; cuda_options.user_compute_stream = nullptr; - // TODO: arena_cfg didn't work, it blocked when session.Run() + // TODO(veelion): arena_cfg didn't work, it blocked when session.Run() // Just comment this out until find a work way. // cuda_options.default_memory_arena_cfg = arena_cfg; session_options_.AppendExecutionProvider_CUDA(cuda_options); - /* TODO: In the future use OrtCUDAProviderOptionsV2 until it support ArenaCfg + /* TODO(veelion): use OrtCUDAProviderOptionsV2 until it support ArenaCfg // 1. Load sessions // config for CUDA std::string device_id = std::to_string(gpu_id); @@ -164,8 +169,10 @@ void BatchOnnxAsrModel::Read(const std::string& model_dir, bool is_fp16, int gpu const auto& api = Ort::GetApi(); OrtCUDAProviderOptionsV2* cuda_options = nullptr; Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_options)); - Ort::ThrowOnError(api.UpdateCUDAProviderOptions(cuda_options, keys2.data(), values2.data(), keys2.size())); - Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2(session_options_, cuda_options)); + Ort::ThrowOnError(api.UpdateCUDAProviderOptions( + cuda_options, keys2.data(), values2.data(), keys2.size())); + Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_CUDA_V2( + session_options_, cuda_options)); api.ReleaseCUDAProviderOptions(cuda_options); */ @@ -228,8 +235,8 @@ std::shared_ptr BatchOnnxAsrModel::Copy() const { void BatchOnnxAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) { + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); // 1. Prepare onnx required data @@ -241,7 +248,7 @@ void BatchOnnxAsrModel::ForwardEncoder( Ort::Value feats_ort{nullptr}; // https://github.com/microsoft/onnxruntime/issues/9629#issuecomment-963828881 // Ort::Value::CreateTensor does NOT copy the data - std::vector feats_fp16; // for holding feats of fp16 + std::vector feats_fp16; // for holding feats of fp16 std::vector feats_fp32; // for holding feats of float // speech @@ -267,7 +274,8 @@ void BatchOnnxAsrModel::ForwardEncoder( } else { for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < num_frames; ++j) { - feats_fp32.insert(feats_fp32.end(), batch_feats[i][j].begin(), batch_feats[i][j].end()); + feats_fp32.insert(feats_fp32.end(), batch_feats[i][j].begin(), + batch_feats[i][j].end()); } } feats_ort = std::move(Ort::Value::CreateTensor( @@ -291,8 +299,8 @@ void BatchOnnxAsrModel::ForwardEncoder( } timer.Reset(); - //Ort::RunOptions ro; - //ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); + // Ort::RunOptions ro; + // ro.AddConfigEntry(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "gpu:0"); std::vector ort_outputs = encoder_session_->Run( Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(), inputs.size(), encoder_out_names_.data(), encoder_out_names_.size()); @@ -303,7 +311,7 @@ void BatchOnnxAsrModel::ForwardEncoder( int num_outputs = out_shape[1]; int output_dim = out_shape[2]; float* topk_scores_ptr = nullptr; - std::vector topk_scores_data; // for holding topk_scores converted from fp16 + std::vector topk_scores_data; // for holding topk_scores in fp16 if (is_fp16_) { timer.Reset(); auto probs = ort_outputs[3].GetTensorMutableData(); @@ -313,22 +321,23 @@ void BatchOnnxAsrModel::ForwardEncoder( topk_scores_data[i] = _cvtsh_ss(probs[i]); } topk_scores_ptr = topk_scores_data.data(); - VLOG(1) << "topk_scores from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; + VLOG(1) << "topk_scores from GPU-fp16 to float takes " << timer.Elapsed() + << " ms. data lenght " << length; } else { topk_scores_ptr = ort_outputs[3].GetTensorMutableData(); } - batch_topk_scores.resize(batch_size); + batch_topk_scores->resize(batch_size); for (size_t i = 0; i < batch_size; ++i) { - batch_topk_scores[i].resize(num_outputs); + (*batch_topk_scores)[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; j++) { - batch_topk_scores[i][j].resize(output_dim); + (*batch_topk_scores)[i][j].resize(output_dim); float* p = topk_scores_ptr + (i * num_outputs + j) * output_dim; - memcpy(batch_topk_scores[i][j].data(), p, sizeof(float) * output_dim); + memcpy((*batch_topk_scores)[i][j].data(), p, sizeof(float) * output_dim); } } // get batch_topk_indexs - std::vector topk_indexs_data; // for holding topk_indexs converted from fp16 + std::vector topk_indexs_data; // for holding topk_indexs from fp16 timer.Reset(); auto probs = ort_outputs[4].GetTensorMutableData(); int length = out_shape[0] * out_shape[1] * out_shape[2]; @@ -337,15 +346,17 @@ void BatchOnnxAsrModel::ForwardEncoder( topk_indexs_data[i] = probs[i]; } int32_t* topk_indexs_ptr = topk_indexs_data.data(); - VLOG(1) << "topk_indexs from GPU-fp16 to float takes " << timer.Elapsed() << " ms. data lenght " << length; + VLOG(1) << "topk_indexs from GPU-fp16 to float takes " + << timer.Elapsed() << " ms. data lenght " << length; - batch_topk_indexs.resize(batch_size); + batch_topk_indexs->resize(batch_size); for (size_t i = 0; i < batch_size; ++i) { - batch_topk_indexs[i].resize(num_outputs); + (*batch_topk_indexs)[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; j++) { - batch_topk_indexs[i][j].resize(output_dim); + (*batch_topk_indexs)[i][j].resize(output_dim); int32_t* p = topk_indexs_ptr + (i * num_outputs + j) * output_dim; - memcpy(batch_topk_indexs[i][j].data(), p, sizeof(int32_t) * output_dim); + memcpy((*batch_topk_indexs)[i][j].data(), p, + sizeof(int32_t) * output_dim); } } // 3. cache encoder outs @@ -356,15 +367,15 @@ void BatchOnnxAsrModel::ForwardEncoder( void BatchOnnxAsrModel::AttentionRescoring( const std::vector>>& batch_hyps, const std::vector>& ctc_scores, - std::vector>& attention_scores) { + std::vector>* attention_scores) { Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); // 1. prepare input for onnx int batch_size = batch_hyps.size(); int beam_size = batch_hyps[0].size(); - // 1.1 generate hyps_lens_sos data for ort - std::vector hyps_lens_sos(batch_size * beam_size, 0); // (batch_size, beam_size) + // 1.1 generate hyps_lens_sos data for ort (batch_size, beam_size) + std::vector hyps_lens_sos(batch_size * beam_size, 0); int max_hyps_len = 0; for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < beam_size; ++j) { @@ -375,19 +386,24 @@ void BatchOnnxAsrModel::AttentionRescoring( } // 1.2 generate hyps_pad_sos_eos, r_hyps_pad_sos_eos - std::vector hyps_pad_sos_eos(batch_size * beam_size * (max_hyps_len + 1), 0); - std::vector r_hyps_pad_sos_eos(batch_size * beam_size * (max_hyps_len + 1), 0); + std::vector hyps_pad_sos_eos( + batch_size * beam_size * (max_hyps_len + 1), 0); + std::vector r_hyps_pad_sos_eos( + batch_size * beam_size * (max_hyps_len + 1), 0); for (size_t i = 0; i < batch_size; ++i) { for (size_t j = 0; j < beam_size; ++j) { const std::vector& hyps = batch_hyps[i][j]; hyps_pad_sos_eos[i * beam_size * max_hyps_len] = sos_; size_t hyps_len = hyps.size(); for (size_t k = 0; k < hyps_len; ++k) { - hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[k]; - r_hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + k + 1] = hyps[hyps_len - 1 - k]; + size_t p = i * beam_size * max_hyps_len + j * max_hyps_len + k + 1; + hyps_pad_sos_eos[p] = hyps[k]; + r_hyps_pad_sos_eos[p] = hyps[hyps_len - 1 - k]; } - hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + hyps.size() + 1] = eos_; - r_hyps_pad_sos_eos[i * beam_size * max_hyps_len + j * max_hyps_len + hyps.size() + 1] = eos_; + size_t p = i * beam_size * max_hyps_len + + j * max_hyps_len + hyps.size() + 1; + hyps_pad_sos_eos[p] = eos_; + r_hyps_pad_sos_eos[p] = eos_; } } @@ -409,7 +425,8 @@ void BatchOnnxAsrModel::AttentionRescoring( } else { ctc_fp32.resize(batch_size * beam_size); for (size_t i = 0; i < batch_size; ++i) { - memcpy(ctc_fp32.data() + i * beam_size, ctc_scores[i].data(), sizeof(float) * beam_size); + memcpy(ctc_fp32.data() + i * beam_size, + ctc_scores[i].data(), sizeof(float) * beam_size); } ctc_scores_tensor = std::move(Ort::Value::CreateTensor( memory_info, ctc_fp32.data(), ctc_fp32.size(), ctc_shape, 2)); @@ -420,11 +437,14 @@ void BatchOnnxAsrModel::AttentionRescoring( const int64_t hyps_pad_shape[] = {batch_size, beam_size, max_hyps_len}; Ort::Value hyps_lens_tensor = Ort::Value::CreateTensor( - memory_info, hyps_lens_sos.data(), hyps_lens_sos.size(), hyps_lens_shape, 2); + memory_info, hyps_lens_sos.data(), + hyps_lens_sos.size(), hyps_lens_shape, 2); Ort::Value hyps_pad_tensor = Ort::Value::CreateTensor( - memory_info, hyps_pad_sos_eos.data(), hyps_pad_sos_eos.size(), hyps_pad_shape, 3); + memory_info, hyps_pad_sos_eos.data(), + hyps_pad_sos_eos.size(), hyps_pad_shape, 3); Ort::Value r_hyps_pad_tensor = Ort::Value::CreateTensor( - memory_info, r_hyps_pad_sos_eos.data(), r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); + memory_info, r_hyps_pad_sos_eos.data(), + r_hyps_pad_sos_eos.size(), hyps_pad_shape, 3); std::vector rescore_inputs; for (auto name : rescore_in_names_) { @@ -454,25 +474,28 @@ void BatchOnnxAsrModel::AttentionRescoring( rescore_out_names_.size()); VLOG(1) << "decoder->Run() takes " << timer.Elapsed() << " ms."; - //(B, beam, T2) + // (B, beam, T2) auto scores_shape = rescore_outputs[1].GetTensorTypeAndShapeInfo().GetShape(); - attention_scores.resize(scores_shape[0]); + attention_scores->resize(scores_shape[0]); if (is_fp16_) { Timer timer; int length = scores_shape[0] * scores_shape[1]; auto outs = rescore_outputs[1].GetTensorMutableData(); for (size_t i = 0; i < scores_shape[0]; ++i) { - attention_scores[i].resize(scores_shape[1]); + (*attention_scores)[i].resize(scores_shape[1]); for (size_t j = 0; j < scores_shape[1]; ++j) { - attention_scores[i][j] = _cvtsh_ss(outs[i * scores_shape[1] + j].value); + (*attention_scores)[i][j] = _cvtsh_ss( + outs[i * scores_shape[1] + j].value); } } - VLOG(1) << "decoder_out from fp16 to float takes " << timer.Elapsed() << " ms. data length " << length; + VLOG(1) << "decoder_out from fp16 to float takes " + << timer.Elapsed() << " ms. data length " << length; } else { auto outs = rescore_outputs[0].GetTensorMutableData(); for (size_t i = 0; i < scores_shape[0]; ++i) { - attention_scores[i].resize(scores_shape[1]); - memcpy(attention_scores[i].data(), outs + i * scores_shape[1], sizeof(float) * scores_shape[1]); + (*attention_scores)[i].resize(scores_shape[1]); + memcpy((*attention_scores)[i].data(), outs + i * scores_shape[1], + sizeof(float) * scores_shape[1]); } } } diff --git a/runtime/core/decoder/batch_onnx_asr_model.h b/runtime/core/decoder/batch_onnx_asr_model.h index ffa510028..67e61d0f3 100644 --- a/runtime/core/decoder/batch_onnx_asr_model.h +++ b/runtime/core/decoder/batch_onnx_asr_model.h @@ -29,6 +29,8 @@ #include "decoder/batch_asr_model.h" #include "utils/log.h" #include "utils/utils.h" +#include "onnxruntime_run_options_config_keys.h" // NOLINT +#include "onnxruntime_session_options_config_keys.h" // NOLINT namespace wenet { @@ -40,10 +42,11 @@ class BatchOnnxAsrModel : public BatchAsrModel { public: BatchOnnxAsrModel() = default; BatchOnnxAsrModel(const BatchOnnxAsrModel& other); - void Read(const std::string& model_dir, bool is_fp16=false, int gpu_id=0); - void AttentionRescoring(const std::vector>>& batch_hyps, - const std::vector>& ctc_scores, - std::vector>& attention_scores) override; + void Read(const std::string& model_dir, bool is_fp16 = false, int gpu_id = 0); + void AttentionRescoring( + const std::vector>>& batch_hyps, + const std::vector>& ctc_scores, + std::vector>* attention_scores) override; std::shared_ptr Copy() const override; void GetInputOutputInfo(const std::shared_ptr& session, @@ -52,8 +55,8 @@ class BatchOnnxAsrModel : public BatchAsrModel { void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) override; + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT private: int encoder_output_size_ = 0; diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index e829a84bb..e6f13e928 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -130,7 +130,7 @@ void BatchTorchAsrModel::ForwardEncoder( model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); CHECK_EQ(outputs.size(), 5); encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) - encoder_lens_ = outputs[1].toTensor(); // (B,) + encoder_lens_ = outputs[1].toTensor(); // (B,) // Copy topk_scores auto topk_scores = outputs[3].toTensor().to(at::kCPU); @@ -165,7 +165,8 @@ void BatchTorchAsrModel::AttentionRescoring( // Step 1: Prepare input for libtorch int batch_size = batch_hyps.size(); int beam_size = batch_hyps[0].size(); - torch::Tensor hyps_lens_sos = torch::zeros({batch_size, beam_size}, torch::kLong); + torch::Tensor hyps_lens_sos = torch::zeros( + {batch_size, beam_size}, torch::kLong); int max_hyps_len = 0; for (size_t i = 0; i < batch_size; i++) { for (size_t j = 0; j < beam_size; j++) { @@ -176,8 +177,10 @@ void BatchTorchAsrModel::AttentionRescoring( } // 1.2 add sos, eos to hyps, r_hyps - torch::Tensor hyps_pad_sos_eos = torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong); - torch::Tensor r_hyps_pad_sos_eos = torch::zeros({batch_size, beam_size, max_hyps_len + 1}, torch::kLong); + torch::Tensor hyps_pad_sos_eos = torch::zeros( + {batch_size, beam_size, max_hyps_len + 1}, torch::kLong); + torch::Tensor r_hyps_pad_sos_eos = torch::zeros( + {batch_size, beam_size, max_hyps_len + 1}, torch::kLong); for (size_t i = 0; i < batch_size; i++) { for (size_t j = 0; j < beam_size; j++) { const std::vector& hyp = batch_hyps[i][j]; @@ -192,7 +195,8 @@ void BatchTorchAsrModel::AttentionRescoring( } // 1.3 ctc_scores_data - torch::Tensor ctc_scores_tensor = torch::zeros({batch_size, beam_size}, torch::kFloat); + torch::Tensor ctc_scores_tensor = torch::zeros( + {batch_size, beam_size}, torch::kFloat); for (size_t i = 0; i < batch_size; ++i) { auto row = torch::from_blob(const_cast(ctc_scores[i].data()), {beam_size}, torch::kFloat).clone(); @@ -219,7 +223,8 @@ void BatchTorchAsrModel::AttentionRescoring( attention_scores.resize(batch_size); for (size_t i = 0; i < batch_size; i++) { attention_scores[i].resize(beam_size); - memcpy(attention_scores[i].data(), rescores[i].data_ptr(), sizeof(float) * beam_size); + memcpy(attention_scores[i].data(), rescores[i].data_ptr(), + sizeof(float) * beam_size); } } diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index 8358903d6..80daef9b8 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -31,7 +31,7 @@ namespace wenet { class BatchTorchAsrModel : public BatchAsrModel { -public: + public: // Note: Do not call the InitEngineThreads function more than once. static void InitEngineThreads(int num_threads = 1); @@ -50,7 +50,7 @@ class BatchTorchAsrModel : public BatchAsrModel { const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) override; + std::vector>>& batch_topk_indexs) override; // NOLINT private: std::shared_ptr model_ = nullptr; diff --git a/runtime/core/decoder/onnx_asr_model.cc b/runtime/core/decoder/onnx_asr_model.cc index f10b28a2b..7ce07e330 100644 --- a/runtime/core/decoder/onnx_asr_model.cc +++ b/runtime/core/decoder/onnx_asr_model.cc @@ -71,8 +71,8 @@ void OnnxAsrModel::GetInputOutputInfo( shape << j; shape << " "; } - LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type - << " dims=" << shape.str(); + LOG(INFO) << "\tOutput " << i << " : name=" << name.get() + << " type=" << type << " dims=" << shape.str(); node_names_.push_back(std::move(name)); (*out_names)[i] = node_names_.back().get(); } @@ -109,25 +109,39 @@ void OnnxAsrModel::Read(const std::string& model_dir) { auto model_metadata = encoder_session_->GetModelMetadata(); Ort::AllocatorWithDefaultOptions allocator; - encoder_output_size_ = - atoi(model_metadata.LookupCustomMetadataMapAllocated("output_size", allocator).get()); - num_blocks_ = - atoi(model_metadata.LookupCustomMetadataMapAllocated("num_blocks", allocator).get()); - head_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("head", allocator).get()); + encoder_output_size_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "output_size", allocator).get()); + num_blocks_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "num_blocks", allocator).get()); + head_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "head", allocator).get()); cnn_module_kernel_ = atoi( - model_metadata.LookupCustomMetadataMapAllocated("cnn_module_kernel", allocator).get()); + model_metadata.LookupCustomMetadataMapAllocated( + "cnn_module_kernel", allocator).get()); subsampling_rate_ = atoi( - model_metadata.LookupCustomMetadataMapAllocated("subsampling_rate", allocator).get()); - right_context_ = - atoi(model_metadata.LookupCustomMetadataMapAllocated("right_context", allocator).get()); - sos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("sos_symbol", allocator).get()); - eos_ = atoi(model_metadata.LookupCustomMetadataMapAllocated("eos_symbol", allocator).get()); - is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMapAllocated( - "is_bidirectional_decoder", allocator).get()); - chunk_size_ = - atoi(model_metadata.LookupCustomMetadataMapAllocated("chunk_size", allocator).get()); - num_left_chunks_ = - atoi(model_metadata.LookupCustomMetadataMapAllocated("left_chunks", allocator).get()); + model_metadata.LookupCustomMetadataMapAllocated( + "subsampling_rate", allocator).get()); + right_context_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "right_context", allocator).get()); + sos_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "sos_symbol", allocator).get()); + eos_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "eos_symbol", allocator).get()); + is_bidirectional_decoder_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "is_bidirectional_decoder", allocator).get()); + chunk_size_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "chunk_size", allocator).get()); + num_left_chunks_ = atoi( + model_metadata.LookupCustomMetadataMapAllocated( + "left_chunks", allocator).get()); LOG(INFO) << "Onnx Model Info:"; LOG(INFO) << "\tencoder_output_size " << encoder_output_size_; diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 9347191ee..6094e7b17 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -136,7 +136,8 @@ std::shared_ptr InitDecodeResourceFromFlags() { if (!FLAGS_onnx_dir.empty()) { #ifdef USE_ONNX if (FLAGS_run_batch) { - LOG(INFO) << "BatchOnnxAsrModel Reading ONNX model dir: " << FLAGS_onnx_dir; + LOG(INFO) << "BatchOnnxAsrModel Reading ONNX model dir: " + << FLAGS_onnx_dir; BatchOnnxAsrModel::InitEngineThreads(FLAGS_num_threads); auto model = std::make_shared(); model->Read(FLAGS_onnx_dir, FLAGS_is_fp16, FLAGS_gpu_id); @@ -154,7 +155,8 @@ std::shared_ptr InitDecodeResourceFromFlags() { } else if (!FLAGS_model_path.empty()) { #ifdef USE_TORCH if (FLAGS_run_batch) { - LOG(INFO) << "BatchTorchAsrModel Reading torch model " << FLAGS_model_path; + LOG(INFO) << "BatchTorchAsrModel Reading torch model " + << FLAGS_model_path; BatchTorchAsrModel::InitEngineThreads(FLAGS_num_threads); auto model = std::make_shared(); model->Read(FLAGS_model_path); diff --git a/runtime/core/utils/Yaml.cpp b/runtime/core/utils/Yaml.cpp index 70adec6f3..4e5494183 100644 --- a/runtime/core/utils/Yaml.cpp +++ b/runtime/core/utils/Yaml.cpp @@ -1,2773 +1,2215 @@ +// Copyright (c) From https://github.com/jimmiebergmann/mini-yaml +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) /* -* MIT License -* -* Copyright(c) 2018 Jimmie Bergmann -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files(the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions : -* -* The above copyright notice and this permission notice shall be included in all -* copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -* SOFTWARE. -* -*/ - -#include "Yaml.hpp" -#include -#include -#include -#include -#include -#include -#include - - -// Implementation access definitions. -#define NODE_IMP static_cast(m_pImp) -#define NODE_IMP_EXT(node) static_cast(node.m_pImp) -#define TYPE_IMP static_cast(m_pImp)->m_pImp - - -#define IT_IMP static_cast(m_pImp) - - -namespace Yaml -{ - class ReaderLine; - - // Exception message definitions. - static const std::string g_ErrorInvalidCharacter = "Invalid character found."; - static const std::string g_ErrorKeyMissing = "Missing key."; - static const std::string g_ErrorKeyIncorrect = "Incorrect key."; - static const std::string g_ErrorValueIncorrect = "Incorrect value."; - static const std::string g_ErrorTabInOffset = "Tab found in offset."; - static const std::string g_ErrorBlockSequenceNotAllowed = "Sequence entries are not allowed in this context."; - static const std::string g_ErrorUnexpectedDocumentEnd = "Unexpected document end."; - static const std::string g_ErrorDiffEntryNotAllowed = "Different entry is not allowed in this context."; - static const std::string g_ErrorIncorrectOffset = "Incorrect offset."; - static const std::string g_ErrorSequenceError = "Error in sequence node."; - static const std::string g_ErrorCannotOpenFile = "Cannot open file."; - static const std::string g_ErrorIndentation = "Space indentation is less than 2."; - static const std::string g_ErrorInvalidBlockScalar = "Invalid block scalar."; - static const std::string g_ErrorInvalidQuote = "Invalid quote."; - static const std::string g_EmptyString = ""; - static Yaml::Node g_NoneNode; - - // Global function definitions. Implemented at end of this source file. - static std::string ExceptionMessage(const std::string & message, ReaderLine & line); - static std::string ExceptionMessage(const std::string & message, ReaderLine & line, const size_t errorPos); - static std::string ExceptionMessage(const std::string & message, const size_t errorLine, const size_t errorPos); - static std::string ExceptionMessage(const std::string & message, const size_t errorLine, const std::string & data); - - static bool FindQuote(const std::string & input, size_t & start, size_t & end, size_t searchPos = 0); - static size_t FindNotCited(const std::string & input, char token, size_t & preQuoteCount); - static size_t FindNotCited(const std::string & input, char token); - static bool ValidateQuote(const std::string & input); - static void CopyNode(const Node & from, Node & to); - static bool ShouldBeCited(const std::string & key); - static void AddEscapeTokens(std::string & input, const std::string & tokens); - static void RemoveAllEscapeTokens(std::string & input); - - // Exception implementations - Exception::Exception(const std::string & message, const eType type) : - std::runtime_error(message), - m_Type(type) - { - } - - Exception::eType Exception::Type() const - { - return m_Type; - } - - const char * Exception::Message() const - { - return what(); - } - - InternalException::InternalException(const std::string & message) : - Exception(message, InternalError) - { - - } - - ParsingException::ParsingException(const std::string & message) : - Exception(message, ParsingError) - { - - } - - OperationException::OperationException(const std::string & message) : - Exception(message, OperationError) - { - - } - - - class TypeImp - { - - public: - - virtual ~TypeImp() - { - } - - virtual const std::string & GetData() const = 0; - virtual bool SetData(const std::string & data) = 0; - virtual size_t GetSize() const = 0; - virtual Node * GetNode(const size_t index) = 0; - virtual Node * GetNode(const std::string & key) = 0; - virtual Node * Insert(const size_t index) = 0; - virtual Node * PushFront() = 0; - virtual Node * PushBack() = 0; - virtual void Erase(const size_t index) = 0; - virtual void Erase(const std::string & key) = 0; - - }; - - class SequenceImp : public TypeImp - { - - public: - - ~SequenceImp() - { - for(auto it = m_Sequence.begin(); it != m_Sequence.end(); it++) - { - delete it->second; - } - } - - virtual const std::string & GetData() const - { - return g_EmptyString; - } - - virtual bool SetData(const std::string & data) - { - return false; - } - - virtual size_t GetSize() const - { - return m_Sequence.size(); - } - - virtual Node * GetNode(const size_t index) - { - auto it = m_Sequence.find(index); - if(it != m_Sequence.end()) - { - return it->second; - } - return nullptr; - } - - virtual Node * GetNode(const std::string & key) - { - return nullptr; - } - - virtual Node * Insert(const size_t index) - { - if(m_Sequence.size() == 0) - { - Node * pNode = new Node; - m_Sequence.insert({0, pNode}); - return pNode; - } - - if(index >= m_Sequence.size()) - { - auto it = m_Sequence.end(); - --it; - Node * pNode = new Node; - m_Sequence.insert({it->first, pNode}); - return pNode; - } - - auto it = m_Sequence.cbegin(); - while(it != m_Sequence.cend()) - { - m_Sequence[it->first+1] = it->second; - - if(it->first == index) - { - break; - } - } - - Node * pNode = new Node; - m_Sequence.insert({index, pNode}); - return pNode; - } - - virtual Node * PushFront() - { - for(auto it = m_Sequence.cbegin(); it != m_Sequence.cend(); it++) - { - m_Sequence[it->first+1] = it->second; - } - - Node * pNode = new Node; - m_Sequence.insert({0, pNode}); - return pNode; - } - - virtual Node * PushBack() - { - size_t index = 0; - if(m_Sequence.size()) - { - auto it = m_Sequence.end(); - --it; - index = it->first + 1; - } - - Node * pNode = new Node; - m_Sequence.insert({index, pNode}); - return pNode; - } - - virtual void Erase(const size_t index) - { - auto it = m_Sequence.find(index); - if(it == m_Sequence.end()) - { - return; - } - delete it->second; - m_Sequence.erase(index); - } - - virtual void Erase(const std::string & key) - { - } - - std::map m_Sequence; - - }; - - class MapImp : public TypeImp - { - - public: - - ~MapImp() - { - for(auto it = m_Map.begin(); it != m_Map.end(); it++) - { - delete it->second; - } - } - - virtual const std::string & GetData() const - { - return g_EmptyString; - } - - virtual bool SetData(const std::string & data) - { - return false; - } - - virtual size_t GetSize() const - { - return m_Map.size(); - } - - virtual Node * GetNode(const size_t index) - { - return nullptr; - } - - virtual Node * GetNode(const std::string & key) - { - auto it = m_Map.find(key); - if(it == m_Map.end()) - { - Node * pNode = new Node; - m_Map.insert({key, pNode}); - return pNode; - } - return it->second; - } - - virtual Node * Insert(const size_t index) - { - return nullptr; - } - - virtual Node * PushFront() - { - return nullptr; - } - - virtual Node * PushBack() - { - return nullptr; - } - - virtual void Erase(const size_t index) - { - } - - virtual void Erase(const std::string & key) - { - auto it = m_Map.find(key); - if(it == m_Map.end()) - { - return; - } - delete it->second; - m_Map.erase(key); - } - - std::map m_Map; - - }; - - class ScalarImp : public TypeImp - { - - public: - - ~ScalarImp() - { - } - - virtual const std::string & GetData() const - { - return m_Value; - } - - virtual bool SetData(const std::string & data) - { - m_Value = data; - return true; - } - - virtual size_t GetSize() const - { - return 0; - } - - virtual Node * GetNode(const size_t index) - { - return nullptr; - } - - virtual Node * GetNode(const std::string & key) - { - return nullptr; - } - - virtual Node * Insert(const size_t index) - { - return nullptr; - } - - virtual Node * PushFront() - { - return nullptr; - } - - virtual Node * PushBack() - { - return nullptr; - } - - virtual void Erase(const size_t index) - { - } - - virtual void Erase(const std::string & key) - { - } - - std::string m_Value; - - }; - - - // Node implementations. - class NodeImp - { - - public: - - NodeImp() : - m_Type(Node::None), - m_pImp(nullptr) - { - } - - ~NodeImp() - { - Clear(); - } - - void Clear() - { - if(m_pImp != nullptr) - { - delete m_pImp; - m_pImp = nullptr; - } - m_Type = Node::None; - } - - void InitSequence() - { - if(m_Type != Node::SequenceType || m_pImp == nullptr) - { - if(m_pImp) - { - delete m_pImp; - } - m_pImp = new SequenceImp; - m_Type = Node::SequenceType; - } - } - - void InitMap() - { - if(m_Type != Node::MapType || m_pImp == nullptr) - { - if(m_pImp) - { - delete m_pImp; - } - m_pImp = new MapImp; - m_Type = Node::MapType; - } - } - - void InitScalar() - { - if(m_Type != Node::ScalarType || m_pImp == nullptr) - { - if(m_pImp) - { - delete m_pImp; - } - m_pImp = new ScalarImp; - m_Type = Node::ScalarType; - } - - } - - Node::eType m_Type; ///< Type of node. - TypeImp * m_pImp; ///< Imp of type. - - }; - - // Iterator implementation class - class IteratorImp - { - - public: - - virtual ~IteratorImp() - { - } - - virtual Node::eType GetType() const = 0; - virtual void InitBegin(SequenceImp * pSequenceImp) = 0; - virtual void InitEnd(SequenceImp * pSequenceImp) = 0; - virtual void InitBegin(MapImp * pMapImp) = 0; - virtual void InitEnd(MapImp * pMapImp) = 0; - - }; - - class SequenceIteratorImp : public IteratorImp - { - - public: - - virtual Node::eType GetType() const - { - return Node::SequenceType; - } - - virtual void InitBegin(SequenceImp * pSequenceImp) - { - m_Iterator = pSequenceImp->m_Sequence.begin(); - } - - virtual void InitEnd(SequenceImp * pSequenceImp) - { - m_Iterator = pSequenceImp->m_Sequence.end(); - } - - virtual void InitBegin(MapImp * pMapImp) - { - } - - virtual void InitEnd(MapImp * pMapImp) - { - } - - void Copy(const SequenceIteratorImp & it) - { - m_Iterator = it.m_Iterator; - } - - std::map::iterator m_Iterator; - - }; - - class MapIteratorImp : public IteratorImp - { - - public: - - virtual Node::eType GetType() const - { - return Node::MapType; - } - - virtual void InitBegin(SequenceImp * pSequenceImp) - { - } - - virtual void InitEnd(SequenceImp * pSequenceImp) - { - } - - virtual void InitBegin(MapImp * pMapImp) - { - m_Iterator = pMapImp->m_Map.begin(); - } - - virtual void InitEnd(MapImp * pMapImp) - { - m_Iterator = pMapImp->m_Map.end(); - } - - void Copy(const MapIteratorImp & it) - { - m_Iterator = it.m_Iterator; - } - - std::map::iterator m_Iterator; - - }; - - class SequenceConstIteratorImp : public IteratorImp - { - - public: - - virtual Node::eType GetType() const - { - return Node::SequenceType; - } - - virtual void InitBegin(SequenceImp * pSequenceImp) - { - m_Iterator = pSequenceImp->m_Sequence.begin(); - } - - virtual void InitEnd(SequenceImp * pSequenceImp) - { - m_Iterator = pSequenceImp->m_Sequence.end(); - } - - virtual void InitBegin(MapImp * pMapImp) - { - } - - virtual void InitEnd(MapImp * pMapImp) - { - } - - void Copy(const SequenceConstIteratorImp & it) - { - m_Iterator = it.m_Iterator; - } - - std::map::const_iterator m_Iterator; - - }; - - class MapConstIteratorImp : public IteratorImp - { - - public: - - virtual Node::eType GetType() const - { - return Node::MapType; - } - - virtual void InitBegin(SequenceImp * pSequenceImp) - { - } - - virtual void InitEnd(SequenceImp * pSequenceImp) - { - } - - virtual void InitBegin(MapImp * pMapImp) - { - m_Iterator = pMapImp->m_Map.begin(); - } - - virtual void InitEnd(MapImp * pMapImp) - { - m_Iterator = pMapImp->m_Map.end(); - } - - void Copy(const MapConstIteratorImp & it) - { - m_Iterator = it.m_Iterator; - } - - std::map::const_iterator m_Iterator; - - }; - - - // Iterator class - Iterator::Iterator() : - m_Type(None), - m_pImp(nullptr) - { - } - - Iterator::~Iterator() - { - if(m_pImp) - { - switch(m_Type) - { - case SequenceType: - delete static_cast(m_pImp); - break; - case MapType: - delete static_cast(m_pImp); - break; - default: - break; - } - - } - } - - Iterator::Iterator(const Iterator & it) : - m_Type(None), - m_pImp(nullptr) - { - *this = it; - } - - Iterator & Iterator::operator = (const Iterator & it) - { - if(m_pImp) - { - switch(m_Type) - { - case SequenceType: - delete static_cast(m_pImp); - break; - case MapType: - delete static_cast(m_pImp); - break; - default: - break; - } - m_pImp = nullptr; - m_Type = None; - } - - IteratorImp * pNewImp = nullptr; - - switch(it.m_Type) - { - case SequenceType: - m_Type = SequenceType; - pNewImp = new SequenceIteratorImp; - static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; - break; - case MapType: - m_Type = MapType; - pNewImp = new MapIteratorImp; - static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; - break; - default: - break; - } - - m_pImp = pNewImp; - return *this; - } - - std::pair Iterator::operator *() - { - switch(m_Type) - { - case SequenceType: - return { g_EmptyString, *(static_cast(m_pImp)->m_Iterator->second)}; - break; - case MapType: - return {static_cast(m_pImp)->m_Iterator->first, - *(static_cast(m_pImp)->m_Iterator->second)}; - break; - default: - break; - } - - g_NoneNode.Clear(); - return { g_EmptyString, g_NoneNode}; - } - - Iterator & Iterator::operator ++ (int dummy) - { - switch(m_Type) - { - case SequenceType: - static_cast(m_pImp)->m_Iterator++; - break; - case MapType: - static_cast(m_pImp)->m_Iterator++; - break; - default: - break; - } - return *this; - } - - Iterator & Iterator::operator -- (int dummy) - { - switch(m_Type) - { - case SequenceType: - static_cast(m_pImp)->m_Iterator--; - break; - case MapType: - static_cast(m_pImp)->m_Iterator--; - break; - default: - break; - } - return *this; - } - - bool Iterator::operator == (const Iterator & it) - { - if(m_Type != it.m_Type) - { - return false; - } - - switch(m_Type) - { - case SequenceType: - return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; - break; - case MapType: - return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; - break; - default: - break; - } - - return false; - } - - bool Iterator::operator != (const Iterator & it) - { - return !(*this == it); - } - - - // Const Iterator class - ConstIterator::ConstIterator() : - m_Type(None), - m_pImp(nullptr) - { - } - - ConstIterator::~ConstIterator() - { - if(m_pImp) - { - switch(m_Type) - { - case SequenceType: - delete static_cast(m_pImp); - break; - case MapType: - delete static_cast(m_pImp); - break; - default: - break; - } - - } - } - - ConstIterator::ConstIterator(const ConstIterator & it) : - m_Type(None), - m_pImp(nullptr) - { - *this = it; - } - - ConstIterator & ConstIterator::operator = (const ConstIterator & it) - { - if(m_pImp) - { - switch(m_Type) - { - case SequenceType: - delete static_cast(m_pImp); - break; - case MapType: - delete static_cast(m_pImp); - break; - default: - break; - } - m_pImp = nullptr; - m_Type = None; - } - - IteratorImp * pNewImp = nullptr; - - switch(it.m_Type) - { - case SequenceType: - m_Type = SequenceType; - pNewImp = new SequenceConstIteratorImp; - static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; - break; - case MapType: - m_Type = MapType; - pNewImp = new MapConstIteratorImp; - static_cast(pNewImp)->m_Iterator = static_cast(it.m_pImp)->m_Iterator; - break; - default: - break; - } - - m_pImp = pNewImp; - return *this; - } - - std::pair ConstIterator::operator *() - { - switch(m_Type) - { - case SequenceType: - return { g_EmptyString, *(static_cast(m_pImp)->m_Iterator->second)}; - break; - case MapType: - return {static_cast(m_pImp)->m_Iterator->first, - *(static_cast(m_pImp)->m_Iterator->second)}; - break; - default: - break; - } - - g_NoneNode.Clear(); - return { g_EmptyString, g_NoneNode}; - } - - ConstIterator & ConstIterator::operator ++ (int dummy) - { - switch(m_Type) - { - case SequenceType: - static_cast(m_pImp)->m_Iterator++; - break; - case MapType: - static_cast(m_pImp)->m_Iterator++; - break; - default: - break; - } - return *this; - } - - ConstIterator & ConstIterator::operator -- (int dummy) - { - switch(m_Type) - { - case SequenceType: - static_cast(m_pImp)->m_Iterator--; - break; - case MapType: - static_cast(m_pImp)->m_Iterator--; - break; - default: - break; - } - return *this; - } + * MIT License + * + * Copyright(c) 2018 Jimmie Bergmann + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files(the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions : + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ + +#include "Yaml.hpp" // NOLINT - bool ConstIterator::operator == (const ConstIterator & it) - { - if(m_Type != it.m_Type) - { - return false; - } +#include - switch(m_Type) - { - case SequenceType: - return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; - break; - case MapType: - return static_cast(m_pImp)->m_Iterator == static_cast(it.m_pImp)->m_Iterator; - break; - default: - break; - } +#include +#include +#include +#include +#include +#include - return false; +// Implementation access definitions. +#define NODE_IMP static_cast(m_pImp) +#define NODE_IMP_EXT(node) static_cast(node.m_pImp) +#define TYPE_IMP static_cast(m_pImp)->m_pImp + +#define IT_IMP static_cast(m_pImp) + +namespace Yaml { +class ReaderLine; + +// Exception message definitions. +static const char* g_ErrorInvalidCharacter = "Invalid character found."; +static const char* g_ErrorKeyMissing = "Missing key."; +static const char* g_ErrorKeyIncorrect = "Incorrect key."; +static const char* g_ErrorValueIncorrect = "Incorrect value."; +static const char* g_ErrorTabInOffset = "Tab found in offset."; +static const char* g_ErrorBlockSequenceNotAllowed = + "Sequence entries are not allowed in this context."; +static const char* g_ErrorUnexpectedDocumentEnd = + "Unexpected document end."; +static const char* g_ErrorDiffEntryNotAllowed = + "Different entry is not allowed in this context."; +static const char* g_ErrorIncorrectOffset = "Incorrect offset."; +static const char* g_ErrorSequenceError = "Error in sequence node."; +static const char* g_ErrorCannotOpenFile = "Cannot open file."; +static const char* g_ErrorIndentation = + "Space indentation is less than 2."; +static const char* g_ErrorInvalidBlockScalar = "Invalid block scalar."; +static const char* g_ErrorInvalidQuote = "Invalid quote."; +static const char* g_EmptyString = ""; +static Yaml::Node g_NoneNode; + +// Global function definitions. Implemented at end of this source file. +static std::string ExceptionMessage(const std::string &message, + ReaderLine &line); // NOLINT +static std::string ExceptionMessage( + const std::string &message, + ReaderLine &line, const size_t errorPos); // NOLINT +static std::string ExceptionMessage(const std::string &message, + const size_t errorLine, + const size_t errorPos); +static std::string ExceptionMessage(const std::string &message, + const size_t errorLine, + const std::string &data); + +static bool FindQuote( + const std::string &input, size_t &start, size_t &end, // NOLINT + size_t searchPos = 0); +static size_t FindNotCited(const std::string &input, char token, + size_t &preQuoteCount); // NOLINT +static size_t FindNotCited(const std::string &input, char token); +static bool ValidateQuote(const std::string &input); +static void CopyNode(const Node &from, Node &to); // NOLINT +static bool ShouldBeCited(const std::string &key); +static void AddEscapeTokens( + std::string &input, const std::string &tokens); // NOLINT +static void RemoveAllEscapeTokens(std::string &input); // NOLINT + +// Exception implementations +Exception::Exception(const std::string &message, const eType type) + : std::runtime_error(message), m_Type(type) {} + +Exception::eType Exception::Type() const { return m_Type; } + +const char *Exception::Message() const { return what(); } + +InternalException::InternalException(const std::string &message) + : Exception(message, InternalError) {} + +ParsingException::ParsingException(const std::string &message) + : Exception(message, ParsingError) {} + +OperationException::OperationException(const std::string &message) + : Exception(message, OperationError) {} + +class TypeImp { + public: + virtual ~TypeImp() {} + + virtual const std::string &GetData() const = 0; + virtual bool SetData(const std::string &data) = 0; + virtual size_t GetSize() const = 0; + virtual Node *GetNode(const size_t index) = 0; + virtual Node *GetNode(const std::string &key) = 0; + virtual Node *Insert(const size_t index) = 0; + virtual Node *PushFront() = 0; + virtual Node *PushBack() = 0; + virtual void Erase(const size_t index) = 0; + virtual void Erase(const std::string &key) = 0; +}; + +class SequenceImp : public TypeImp { + public: + ~SequenceImp() { + for (auto it = m_Sequence.begin(); it != m_Sequence.end(); it++) { + delete it->second; + } + } + + virtual const std::string &GetData() const { return g_EmptyString; } + + virtual bool SetData(const std::string &data) { return false; } + + virtual size_t GetSize() const { return m_Sequence.size(); } + + virtual Node *GetNode(const size_t index) { + auto it = m_Sequence.find(index); + if (it != m_Sequence.end()) { + return it->second; + } + return nullptr; + } + + virtual Node *GetNode(const std::string &key) { return nullptr; } + + virtual Node *Insert(const size_t index) { + if (m_Sequence.size() == 0) { + Node *pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; } - bool ConstIterator::operator != (const ConstIterator & it) - { - return !(*this == it); + if (index >= m_Sequence.size()) { + auto it = m_Sequence.end(); + --it; + Node *pNode = new Node; + m_Sequence.insert({it->first, pNode}); + return pNode; } + auto it = m_Sequence.cbegin(); + while (it != m_Sequence.cend()) { + m_Sequence[it->first + 1] = it->second; - // Node class - Node::Node() : - m_pImp(new NodeImp) - { + if (it->first == index) { + break; + } } - Node::Node(const Node & node) : - Node() - { - *this = node; - } + Node *pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } - Node::Node(const std::string & value) : - Node() - { - *this = value; + virtual Node *PushFront() { + for (auto it = m_Sequence.cbegin(); it != m_Sequence.cend(); it++) { + m_Sequence[it->first + 1] = it->second; } - Node::Node(const char * value) : - Node() - { - *this = value; - } + Node *pNode = new Node; + m_Sequence.insert({0, pNode}); + return pNode; + } - Node::~Node() - { - delete static_cast(m_pImp); + virtual Node *PushBack() { + size_t index = 0; + if (m_Sequence.size()) { + auto it = m_Sequence.end(); + --it; + index = it->first + 1; } - Node::eType Node::Type() const - { - return NODE_IMP->m_Type; - } + Node *pNode = new Node; + m_Sequence.insert({index, pNode}); + return pNode; + } - bool Node::IsNone() const - { - return NODE_IMP->m_Type == Node::None; + virtual void Erase(const size_t index) { + auto it = m_Sequence.find(index); + if (it == m_Sequence.end()) { + return; } + delete it->second; + m_Sequence.erase(index); + } - bool Node::IsSequence() const - { - return NODE_IMP->m_Type == Node::SequenceType; - } + virtual void Erase(const std::string &key) {} - bool Node::IsMap() const - { - return NODE_IMP->m_Type == Node::MapType; - } + std::map m_Sequence; +}; - bool Node::IsScalar() const - { - return NODE_IMP->m_Type == Node::ScalarType; +class MapImp : public TypeImp { + public: + ~MapImp() { + for (auto it = m_Map.begin(); it != m_Map.end(); it++) { + delete it->second; } + } - void Node::Clear() - { - NODE_IMP->Clear(); - } + virtual const std::string &GetData() const { return g_EmptyString; } - size_t Node::Size() const - { - if(TYPE_IMP == nullptr) - { - return 0; - } + virtual bool SetData(const std::string &data) { return false; } - return TYPE_IMP->GetSize(); - } + virtual size_t GetSize() const { return m_Map.size(); } - Node & Node::Insert(const size_t index) - { - NODE_IMP->InitSequence(); - return *TYPE_IMP->Insert(index); - } + virtual Node *GetNode(const size_t index) { return nullptr; } - Node & Node::PushFront() - { - NODE_IMP->InitSequence(); - return *TYPE_IMP->PushFront(); - } - Node & Node::PushBack() - { - NODE_IMP->InitSequence(); - return *TYPE_IMP->PushBack(); + virtual Node *GetNode(const std::string &key) { + auto it = m_Map.find(key); + if (it == m_Map.end()) { + Node *pNode = new Node; + m_Map.insert({key, pNode}); + return pNode; } + return it->second; + } - Node & Node::operator[](const size_t index) - { - NODE_IMP->InitSequence(); - Node * pNode = TYPE_IMP->GetNode(index); - if(pNode == nullptr) - { - g_NoneNode.Clear(); - return g_NoneNode; - } - return *pNode; - } + virtual Node *Insert(const size_t index) { return nullptr; } - Node & Node::operator[](const std::string & key) - { - NODE_IMP->InitMap(); - return *TYPE_IMP->GetNode(key); - } + virtual Node *PushFront() { return nullptr; } - void Node::Erase(const size_t index) - { - if(TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::SequenceType) - { - return; - } + virtual Node *PushBack() { return nullptr; } - return TYPE_IMP->Erase(index); + virtual void Erase(const size_t index) {} + + virtual void Erase(const std::string &key) { + auto it = m_Map.find(key); + if (it == m_Map.end()) { + return; } + delete it->second; + m_Map.erase(key); + } - void Node::Erase(const std::string & key) - { - if(TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::MapType) - { - return; - } + std::map m_Map; +}; - return TYPE_IMP->Erase(key); - } +class ScalarImp : public TypeImp { + public: + ~ScalarImp() {} - Node & Node::operator = (const Node & node) - { - NODE_IMP->Clear(); - CopyNode(node, *this); - return *this; - } + virtual const std::string &GetData() const { return m_Value; } - Node & Node::operator = (const std::string & value) - { - NODE_IMP->InitScalar(); - TYPE_IMP->SetData(value); - return *this; - } + virtual bool SetData(const std::string &data) { + m_Value = data; + return true; + } - Node & Node::operator = (const char * value) - { - NODE_IMP->InitScalar(); - TYPE_IMP->SetData(value ? std::string(value) : ""); - return *this; - } + virtual size_t GetSize() const { return 0; } - Iterator Node::Begin() - { - Iterator it; + virtual Node *GetNode(const size_t index) { return nullptr; } - if(TYPE_IMP != nullptr) - { - IteratorImp * pItImp = nullptr; + virtual Node *GetNode(const std::string &key) { return nullptr; } - switch(NODE_IMP->m_Type) - { - case Node::SequenceType: - it.m_Type = Iterator::SequenceType; - pItImp = new SequenceIteratorImp; - pItImp->InitBegin(static_cast(TYPE_IMP)); - break; - case Node::MapType: - it.m_Type = Iterator::MapType; - pItImp = new MapIteratorImp; - pItImp->InitBegin(static_cast(TYPE_IMP)); - break; - default: - break; - } + virtual Node *Insert(const size_t index) { return nullptr; } - it.m_pImp = pItImp; - } + virtual Node *PushFront() { return nullptr; } - return it; - } - - ConstIterator Node::Begin() const - { - ConstIterator it; - - if(TYPE_IMP != nullptr) - { - IteratorImp * pItImp = nullptr; - - switch(NODE_IMP->m_Type) - { - case Node::SequenceType: - it.m_Type = ConstIterator::SequenceType; - pItImp = new SequenceConstIteratorImp; - pItImp->InitBegin(static_cast(TYPE_IMP)); - break; - case Node::MapType: - it.m_Type = ConstIterator::MapType; - pItImp = new MapConstIteratorImp; - pItImp->InitBegin(static_cast(TYPE_IMP)); - break; - default: - break; - } + virtual Node *PushBack() { return nullptr; } - it.m_pImp = pItImp; - } + virtual void Erase(const size_t index) {} - return it; - } - - Iterator Node::End() - { - Iterator it; - - if(TYPE_IMP != nullptr) - { - IteratorImp * pItImp = nullptr; - - switch(NODE_IMP->m_Type) - { - case Node::SequenceType: - it.m_Type = Iterator::SequenceType; - pItImp = new SequenceIteratorImp; - pItImp->InitEnd(static_cast(TYPE_IMP)); - break; - case Node::MapType: - it.m_Type = Iterator::MapType; - pItImp = new MapIteratorImp; - pItImp->InitEnd(static_cast(TYPE_IMP)); - break; - default: - break; - } + virtual void Erase(const std::string &key) {} - it.m_pImp = pItImp; - } + std::string m_Value; +}; - return it; - } - - ConstIterator Node::End() const - { - ConstIterator it; - - if(TYPE_IMP != nullptr) - { - IteratorImp * pItImp = nullptr; - - switch(NODE_IMP->m_Type) - { - case Node::SequenceType: - it.m_Type = ConstIterator::SequenceType; - pItImp = new SequenceConstIteratorImp; - pItImp->InitEnd(static_cast(TYPE_IMP)); - break; - case Node::MapType: - it.m_Type = ConstIterator::MapType; - pItImp = new MapConstIteratorImp; - pItImp->InitEnd(static_cast(TYPE_IMP)); - break; - default: - break; - } +// Node implementations. +class NodeImp { + public: + NodeImp() : m_Type(Node::None), m_pImp(nullptr) {} - it.m_pImp = pItImp; - } + ~NodeImp() { Clear(); } - return it; + void Clear() { + if (m_pImp != nullptr) { + delete m_pImp; + m_pImp = nullptr; } + m_Type = Node::None; + } - const std::string & Node::AsString() const - { - if(TYPE_IMP == nullptr) - { - return g_EmptyString; - } - - return TYPE_IMP->GetData(); - } - - - - // Reader implementations - /** - * @breif Line information structure. - * - */ - class ReaderLine - { - - public: - - /** - * @breif Constructor. - * - */ - ReaderLine(const std::string & data = "", - const size_t no = 0, - const size_t offset = 0, - const Node::eType type = Node::None, - const unsigned char flags = 0) : - Data(data), - No(no), - Offset(offset), - Type(type), - Flags(flags), - NextLine(nullptr) - { - } + void InitSequence() { + if (m_Type != Node::SequenceType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new SequenceImp; + m_Type = Node::SequenceType; + } + } - enum eFlag - { - LiteralScalarFlag, ///< Literal scalar type, defined as "|". - FoldedScalarFlag, ///< Folded scalar type, defined as "<". - ScalarNewlineFlag ///< Scalar ends with a newline. - }; - - /** - * @breif Set flag. - * - */ - void SetFlag(const eFlag flag) - { - Flags |= FlagMask[static_cast(flag)]; - } + void InitMap() { + if (m_Type != Node::MapType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new MapImp; + m_Type = Node::MapType; + } + } - /** - * @breif Set flags by mask value. - * - */ - void SetFlags(const unsigned char flags) - { - Flags |= flags; - } + void InitScalar() { + if (m_Type != Node::ScalarType || m_pImp == nullptr) { + if (m_pImp) { + delete m_pImp; + } + m_pImp = new ScalarImp; + m_Type = Node::ScalarType; + } + } - /** - * @breif Unset flag. - * - */ - void UnsetFlag(const eFlag flag) - { - Flags &= ~FlagMask[static_cast(flag)]; - } + Node::eType m_Type; ///< Type of node. + TypeImp *m_pImp; ///< Imp of type. +}; - /** - * @breif Unset flags by mask value. - * - */ - void UnsetFlags(const unsigned char flags) - { - Flags &= ~flags; - } +// Iterator implementation class +class IteratorImp { + public: + virtual ~IteratorImp() {} - /** - * @breif Get flag value. - * - */ - bool GetFlag(const eFlag flag) const - { - return Flags & FlagMask[static_cast(flag)]; - } + virtual Node::eType GetType() const = 0; + virtual void InitBegin(SequenceImp *pSequenceImp) = 0; + virtual void InitEnd(SequenceImp *pSequenceImp) = 0; + virtual void InitBegin(MapImp *pMapImp) = 0; + virtual void InitEnd(MapImp *pMapImp) = 0; +}; - /** - * @breif Copy and replace scalar flags from another ReaderLine. - * - */ - void CopyScalarFlags(ReaderLine * from) - { - if (from == nullptr) - { - return; - } +class SequenceIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::SequenceType; } - unsigned char newFlags = from->Flags & (FlagMask[0] | FlagMask[1] | FlagMask[2]); - Flags |= newFlags; - } + virtual void InitBegin(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } - static const unsigned char FlagMask[3]; + virtual void InitEnd(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.end(); + } - std::string Data; ///< Data of line. - size_t No; ///< Line number. - size_t Offset; ///< Offset to first character in data. - Node::eType Type; ///< Type of line. - unsigned char Flags; ///< Flags of line. - ReaderLine * NextLine; ///< Pointer to next line. + virtual void InitBegin(MapImp *pMapImp) {} + virtual void InitEnd(MapImp *pMapImp) {} + void Copy(const SequenceIteratorImp &it) { m_Iterator = it.m_Iterator; } - }; + std::map::iterator m_Iterator; +}; - const unsigned char ReaderLine::FlagMask[3] = { 0x01, 0x02, 0x04 }; +class MapIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::MapType; } + virtual void InitBegin(SequenceImp *pSequenceImp) {} - /** - * @breif Implementation class of Yaml parsing. - * Parsing incoming stream and outputs a root node. - * - */ - class ParseImp - { + virtual void InitEnd(SequenceImp *pSequenceImp) {} - public: + virtual void InitBegin(MapImp *pMapImp) { + m_Iterator = pMapImp->m_Map.begin(); + } - /** - * @breif Default constructor. - * - */ - ParseImp() - { - } + virtual void InitEnd(MapImp *pMapImp) { m_Iterator = pMapImp->m_Map.end(); } - /** - * @breif Destructor. - * - */ - ~ParseImp() - { - ClearLines(); - } + void Copy(const MapIteratorImp &it) { m_Iterator = it.m_Iterator; } - /** - * @breif Run full parsing procedure. - * - */ - void Parse(Node & root, std::iostream & stream) - { - try - { - root.Clear(); - ReadLines(stream); - PostProcessLines(); - //Print(); - ParseRoot(root); - } - catch(Exception e) - { - root.Clear(); - throw; - } - } + std::map::iterator m_Iterator; +}; - private: +class SequenceConstIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::SequenceType; } - /** - * @breif Copy constructor. - * - */ - ParseImp(const ParseImp & copy) - { + virtual void InitBegin(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.begin(); + } - } + virtual void InitEnd(SequenceImp *pSequenceImp) { + m_Iterator = pSequenceImp->m_Sequence.end(); + } - /** - * @breif Read all lines. - * Ignoring: - * - Empty lines. - * - Comments. - * - Document start/end. - * - */ - void ReadLines(std::iostream & stream) - { - std::string line = ""; - size_t lineNo = 0; - bool documentStartFound = false; - bool foundFirstNotEmpty = false; - std::streampos streamPos = 0; - - // Read all lines, as long as the stream is ok. - while (!stream.eof() && !stream.fail()) - { - // Read line - streamPos = stream.tellg(); - std::getline(stream, line); - lineNo++; - - // Remove comment - const size_t commentPos = FindNotCited(line, '#'); - if(commentPos != std::string::npos) - { - line.resize(commentPos); - } - - // Start of document. - if (documentStartFound == false && line == "---") - { - // Erase all lines before this line. - ClearLines(); - documentStartFound = true; - continue; - } - - // End of document. - if (line == "...") - { - break; - } - else if(line == "---") - { - stream.seekg(streamPos); - break; - } - - // Remove trailing return. - if (line.size()) - { - if (line[line.size() - 1] == '\r') - { - line.resize(line.size() - 1); - } - } - - // Validate characters. - for (size_t i = 0; i < line.size(); i++) - { - if (line[i] != '\t' && (line[i] < 32 || line[i] > 125)) - { - throw ParsingException(ExceptionMessage(g_ErrorInvalidCharacter, lineNo, i + 1)); - } - } - - // Validate tabs - const size_t firstTabPos = line.find_first_of('\t'); - size_t startOffset = line.find_first_not_of(" \t"); - - // Make sure no tabs are in the very front. - if (startOffset != std::string::npos) - { - if(firstTabPos < startOffset) - { - throw ParsingException(ExceptionMessage(g_ErrorTabInOffset, lineNo, firstTabPos)); - } - - // Remove front spaces. - line = line.substr(startOffset); - } - else - { - startOffset = 0; - line = ""; - } - - // Add line. - if(foundFirstNotEmpty == false) - { - if(line.size()) - { - foundFirstNotEmpty = true; - } - else - { - continue; - } - } - - ReaderLine * pLine = new ReaderLine(line, lineNo, startOffset); - m_Lines.push_back(pLine); - } - } + virtual void InitBegin(MapImp *pMapImp) {} - /** - * @breif Run post-processing on all lines. - * Basically split lines into multiple lines if needed, to follow the parsing algorithm. - * - */ - void PostProcessLines() - { - for (auto it = m_Lines.begin(); it != m_Lines.end();) - { - // Sequence. - if (PostProcessSequenceLine(it) == true) - { - continue; - } - - // Mapping. - if (PostProcessMappingLine(it) == true) - { - continue; - } - - // Scalar. - PostProcessScalarLine(it); - } + virtual void InitEnd(MapImp *pMapImp) {} - // Set next line of all lines. - if (m_Lines.size()) - { - if (m_Lines.back()->Type != Node::ScalarType) - { - throw ParsingException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *m_Lines.back())); - } - - if (m_Lines.size() > 1) - { - auto prevEnd = m_Lines.end(); - --prevEnd; - - for (auto it = m_Lines.begin(); it != prevEnd; it++) - { - auto nextIt = it; - ++nextIt; - - (*it)->NextLine = *nextIt; - } - } - } - } + void Copy(const SequenceConstIteratorImp &it) { m_Iterator = it.m_Iterator; } - /** - * @breif Run post-processing and check for sequence. - * Split line into two lines if sequence token is not on it's own line. - * - * @return true if line is sequence, else false. - * - */ - bool PostProcessSequenceLine(std::list::iterator & it) - { - ReaderLine * pLine = *it; - - // Sequence split - if (IsSequenceStart(pLine->Data) == false) - { - return false; - } + std::map::const_iterator m_Iterator; +}; - pLine->Type = Node::SequenceType; +class MapConstIteratorImp : public IteratorImp { + public: + virtual Node::eType GetType() const { return Node::MapType; } - ClearTrailingEmptyLines(++it); + virtual void InitBegin(SequenceImp *pSequenceImp) {} - const size_t valueStart = pLine->Data.find_first_not_of(" \t", 1); - if (valueStart == std::string::npos) - { - return true; - } + virtual void InitEnd(SequenceImp *pSequenceImp) {} - // Create new line and insert - std::string newLine = pLine->Data.substr(valueStart); - it = m_Lines.insert(it, new ReaderLine(newLine, pLine->No, pLine->Offset + valueStart)); - pLine->Data = ""; + virtual void InitBegin(MapImp *pMapImp) { + m_Iterator = pMapImp->m_Map.begin(); + } - return false; - } + virtual void InitEnd(MapImp *pMapImp) { m_Iterator = pMapImp->m_Map.end(); } - /** - * @breif Run post-processing and check for mapping. - * Split line into two lines if mapping value is not on it's own line. - * - * @return true if line is mapping, else move on to scalar parsing. - * - */ - bool PostProcessMappingLine(std::list::iterator & it) - { - ReaderLine * pLine = *it; - - // Find map key. - size_t preKeyQuotes = 0; - size_t tokenPos = FindNotCited(pLine->Data, ':', preKeyQuotes); - if (tokenPos == std::string::npos) - { - return false; - } - if(preKeyQuotes > 1) - { - throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); - } + void Copy(const MapConstIteratorImp &it) { m_Iterator = it.m_Iterator; } - pLine->Type = Node::MapType; + std::map::const_iterator m_Iterator; +}; - // Get key - std::string key = pLine->Data.substr(0, tokenPos); - const size_t keyEnd = key.find_last_not_of(" \t"); - if (keyEnd == std::string::npos) - { - throw ParsingException(ExceptionMessage(g_ErrorKeyMissing, *pLine)); - } - key.resize(keyEnd + 1); +// Iterator class +Iterator::Iterator() : m_Type(None), m_pImp(nullptr) {} - // Handle cited key. - if(preKeyQuotes == 1) - { - if(key.front() != '"' || key.back() != '"') - { - throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); - } +Iterator::~Iterator() { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + } +} - key = key.substr(1, key.size() - 2); - } - RemoveAllEscapeTokens(key); - - // Get value - std::string value = ""; - size_t valueStart = std::string::npos; - if (tokenPos + 1 != pLine->Data.size()) - { - valueStart = pLine->Data.find_first_not_of(" \t", tokenPos + 1); - if (valueStart != std::string::npos) - { - value = pLine->Data.substr(valueStart); - } - } +Iterator::Iterator(const Iterator &it) : m_Type(None), m_pImp(nullptr) { + *this = it; +} - // Make sure the value is not a sequence start. - if (IsSequenceStart(value) == true) - { - throw ParsingException(ExceptionMessage(g_ErrorBlockSequenceNotAllowed, *pLine, valueStart)); - } +Iterator &Iterator::operator=(const Iterator &it) { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp *pNewImp = nullptr; + + switch (it.m_Type) { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; +} - pLine->Data = key; +std::pair Iterator::operator*() { + switch (m_Type) { + case SequenceType: + return { + g_EmptyString, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + case MapType: + return {static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return {g_EmptyString, g_NoneNode}; +} +Iterator &Iterator::operator++(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; +} - // Remove all empty lines after map key. - ClearTrailingEmptyLines(++it); +Iterator &Iterator::operator--(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; +} - // Add new empty line? - size_t newLineOffset = valueStart; - if(newLineOffset == std::string::npos) - { - if(it != m_Lines.end() && (*it)->Offset > pLine->Offset) - { - return true; - } +bool Iterator::operator==(const Iterator &it) { + if (m_Type != it.m_Type) { + return false; + } + + switch (m_Type) { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; +} - newLineOffset = tokenPos + 2; - } - else - { - newLineOffset += pLine->Offset; - } +bool Iterator::operator!=(const Iterator &it) { return !(*this == it); } - // Add new line with value. - unsigned char dummyBlockFlags = 0; - if(IsBlockScalar(value, pLine->No, dummyBlockFlags) == true) - { - newLineOffset = pLine->Offset; - } - ReaderLine * pNewLine = new ReaderLine(value, pLine->No, newLineOffset, Node::ScalarType); - it = m_Lines.insert(it, pNewLine); +// Const Iterator class +ConstIterator::ConstIterator() : m_Type(None), m_pImp(nullptr) {} - // Return false in order to handle next line(scalar value). - return false; - } +ConstIterator::~ConstIterator() { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + } +} - /** - * @breif Run post-processing and check for scalar. - * Checking for multi-line scalars. - * - * @return true if scalar search should continue, else false. - * - */ - void PostProcessScalarLine(std::list::iterator & it) - { - ReaderLine * pLine = *it; - pLine->Type = Node::ScalarType; - - size_t parentOffset = pLine->Offset; - if(pLine != m_Lines.front()) - { - std::list::iterator lastIt = it; - --lastIt; - parentOffset = (*lastIt)->Offset; - } +ConstIterator::ConstIterator(const ConstIterator &it) + : m_Type(None), m_pImp(nullptr) { + *this = it; +} - std::list::iterator lastNotEmpty = it++; - - // Find last empty lines - while(it != m_Lines.end()) - { - pLine = *it; - pLine->Type = Node::ScalarType; - if(pLine->Data.size()) - { - if(pLine->Offset <= parentOffset) - { - break; - } - else - { - lastNotEmpty = it; - } - } - ++it; - } +ConstIterator &ConstIterator::operator=(const ConstIterator &it) { + if (m_pImp) { + switch (m_Type) { + case SequenceType: + delete static_cast(m_pImp); + break; + case MapType: + delete static_cast(m_pImp); + break; + default: + break; + } + m_pImp = nullptr; + m_Type = None; + } + + IteratorImp *pNewImp = nullptr; + + switch (it.m_Type) { + case SequenceType: + m_Type = SequenceType; + pNewImp = new SequenceConstIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + m_Type = MapType; + pNewImp = new MapConstIteratorImp; + static_cast(pNewImp)->m_Iterator = + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + m_pImp = pNewImp; + return *this; +} - ClearTrailingEmptyLines(++lastNotEmpty); - } +std::pair ConstIterator::operator*() { + switch (m_Type) { + case SequenceType: + return {g_EmptyString, *(static_cast(m_pImp) + ->m_Iterator->second)}; + break; + case MapType: + return { + static_cast(m_pImp)->m_Iterator->first, + *(static_cast(m_pImp)->m_Iterator->second)}; + break; + default: + break; + } + + g_NoneNode.Clear(); + return {g_EmptyString, g_NoneNode}; +} - /** - * @breif Process root node and start of document. - * - */ - void ParseRoot(Node & root) - { - // Get first line and start type. - auto it = m_Lines.begin(); - if(it == m_Lines.end()) - { - return; - } - Node::eType type = (*it)->Type; - ReaderLine * pLine = *it; - - // Handle next line. - switch(type) - { - case Node::SequenceType: - ParseSequence(root, it); - break; - case Node::MapType: - ParseMap(root, it); - break; - case Node::ScalarType: - ParseScalar(root, it); - break; - default: - break; - } +ConstIterator &ConstIterator::operator++(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator++; + break; + case MapType: + static_cast(m_pImp)->m_Iterator++; + break; + default: + break; + } + return *this; +} - if(it != m_Lines.end()) - { - throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); - } +ConstIterator &ConstIterator::operator--(int dummy) { + switch (m_Type) { + case SequenceType: + static_cast(m_pImp)->m_Iterator--; + break; + case MapType: + static_cast(m_pImp)->m_Iterator--; + break; + default: + break; + } + return *this; +} - } +bool ConstIterator::operator==(const ConstIterator &it) { + if (m_Type != it.m_Type) { + return false; + } + + switch (m_Type) { + case SequenceType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + case MapType: + return static_cast(m_pImp)->m_Iterator == + static_cast(it.m_pImp)->m_Iterator; + break; + default: + break; + } + + return false; +} - /** - * @breif Process sequence node. - * - */ - void ParseSequence(Node & node, std::list::iterator & it) - { - ReaderLine * pNextLine = nullptr; - while(it != m_Lines.end()) - { - ReaderLine * pLine = *it; - Node & childNode = node.PushBack(); - - // Move to next line, error check. - ++it; - if(it == m_Lines.end()) - { - throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); - } - - // Handle value of map - Node::eType valueType = (*it)->Type; - switch(valueType) - { - case Node::SequenceType: - ParseSequence(childNode, it); - break; - case Node::MapType: - ParseMap(childNode, it); - break; - case Node::ScalarType: - ParseScalar(childNode, it); - break; - default: - break; - } - - // Check next line. if sequence and correct level, go on, else exit. - // If same level but but of type map = error. - if(it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) - { - break; - } - if(pNextLine->Offset > pLine->Offset) - { - throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); - } - if(pNextLine->Type != Node::SequenceType) - { - throw InternalException(ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); - } +bool ConstIterator::operator!=(const ConstIterator &it) { + return !(*this == it); +} - } - } +// Node class +Node::Node() : m_pImp(new NodeImp) {} - /** - * @breif Process map node. - * - */ - void ParseMap(Node & node, std::list::iterator & it) - { - ReaderLine * pNextLine = nullptr; - while(it != m_Lines.end()) - { - ReaderLine * pLine = *it; - Node & childNode = node[pLine->Data]; - - // Move to next line, error check. - ++it; - if(it == m_Lines.end()) - { - throw InternalException(ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); - } - - // Handle value of map - Node::eType valueType = (*it)->Type; - switch(valueType) - { - case Node::SequenceType: - ParseSequence(childNode, it); - break; - case Node::MapType: - ParseMap(childNode, it); - break; - case Node::ScalarType: - ParseScalar(childNode, it); - break; - default: - break; - } - - // Check next line. if map and correct level, go on, else exit. - // if same level but but of type map = error. - if(it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) - { - break; - } - if(pNextLine->Offset > pLine->Offset) - { - throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); - } - if(pNextLine->Type != pLine->Type) - { - throw InternalException(ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); - } +Node::Node(const Node &node) : Node() { *this = node; } - } - } +Node::Node(const std::string &value) : Node() { *this = value; } - /** - * @breif Process scalar node. - * - */ - void ParseScalar(Node & node, std::list::iterator & it) - { - std::string data = ""; - ReaderLine * pFirstLine = *it; - ReaderLine * pLine = *it; - - // Check if current line is a block scalar. - unsigned char blockFlags = 0; - bool isBlockScalar = IsBlockScalar(pLine->Data, pLine->No, blockFlags); - const bool newLineFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]); - const bool foldedFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::FoldedScalarFlag)]); - const bool literalFlag = static_cast(blockFlags & ReaderLine::FlagMask[static_cast(ReaderLine::LiteralScalarFlag)]); - size_t parentOffset = 0; - - // Find parent offset - if(it != m_Lines.begin()) - { - std::list::iterator parentIt = it; - --parentIt; - parentOffset = (*parentIt)->Offset; - } +Node::Node(const char *value) : Node() { *this = value; } - // Move to next iterator/line if current line is a block scalar. - if(isBlockScalar) - { - ++it; - if(it == m_Lines.end() || (pLine = *it)->Type != Node::ScalarType) - { - return; - } - } +Node::~Node() { delete static_cast(m_pImp); } - // Not a block scalar, cut end spaces/tabs - if(isBlockScalar == false) - { - while(1) - { - pLine = *it; - - if(parentOffset != 0 && pLine->Offset <= parentOffset) - { - throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); - } - - const size_t endOffset = pLine->Data.find_last_not_of(" \t"); - if(endOffset == std::string::npos) - { - data += "\n"; - } - else - { - data += pLine->Data.substr(0, endOffset + 1); - } - - // Move to next line - ++it; - if(it == m_Lines.end() || (*it)->Type != Node::ScalarType) - { - break; - } - - data += " "; - } - - if(ValidateQuote(data) == false) - { - throw ParsingException(ExceptionMessage(g_ErrorInvalidQuote, *pFirstLine)); - } - } - // Block scalar - else - { - pLine = *it; - size_t blockOffset = pLine->Offset; - if(blockOffset <= parentOffset) - { - throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); - } - - bool addedSpace = false; - while(it != m_Lines.end() && (*it)->Type == Node::ScalarType) - { - pLine = *it; - - const size_t endOffset = pLine->Data.find_last_not_of(" \t"); - if(endOffset != std::string::npos && pLine->Offset < blockOffset) - { - throw ParsingException(ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); - } - - if(endOffset == std::string::npos) - { - if(addedSpace) - { - data[data.size() - 1] = '\n'; - addedSpace = false; - } - else - { - data += "\n"; - } - - ++it; - continue; - } - else - { - if(blockOffset != pLine->Offset && foldedFlag) - { - if(addedSpace) - { - data[data.size() - 1] = '\n'; - addedSpace = false; - } - else - { - data += "\n"; - } - } - data += std::string(pLine->Offset - blockOffset, ' '); - data += pLine->Data; - } - - // Move to next line - ++it; - if(it == m_Lines.end() || (*it)->Type != Node::ScalarType) - { - if(newLineFlag) - { - data += "\n"; - } - break; - } - - if(foldedFlag) - { - data += " "; - addedSpace = true; - } - else if(literalFlag && endOffset != std::string::npos) - { - data += "\n"; - } - } - } +Node::eType Node::Type() const { return NODE_IMP->m_Type; } - if(data.size() && (data[0] == '"' || data[0] == '\'')) - { - data = data.substr(1, data.size() - 2 ); - } +bool Node::IsNone() const { return NODE_IMP->m_Type == Node::None; } - node = data; - } +bool Node::IsSequence() const { return NODE_IMP->m_Type == Node::SequenceType; } - /** - * @breif Debug printing. - * - */ - void Print() - { - for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) - { - - ReaderLine * pLine = *it; - - // Print type - if (pLine->Type == Node::SequenceType) - { - std::cout << "seq "; - } - else if (pLine->Type == Node::MapType) - { - std::cout << "map "; - } - else if (pLine->Type == Node::ScalarType) - { - std::cout << "sca "; - } - else - { - std::cout << " "; - } - - // Print flags - if (pLine->GetFlag(ReaderLine::FoldedScalarFlag)) - { - std::cout << "f"; - } - else - { - std::cout << "-"; - } - if (pLine->GetFlag(ReaderLine::LiteralScalarFlag)) - { - std::cout << "l"; - } - else - { - std::cout << "-"; - } - if (pLine->GetFlag(ReaderLine::ScalarNewlineFlag)) - { - std::cout << "n"; - } - else - { - std::cout << "-"; - } - if (pLine->NextLine == nullptr) - { - std::cout << "e"; - } - else - { - std::cout << "-"; - } - - - std::cout << "| "; - std::cout << pLine->No << " "; - std::cout << std::string(pLine->Offset, ' '); - - if (pLine->Type == Node::ScalarType) - { - std::string scalarValue = pLine->Data; - for (size_t i = 0; (i = scalarValue.find("\n", i)) != std::string::npos;) - { - scalarValue.replace(i, 1, "\\n"); - i += 2; - } - std::cout << scalarValue << std::endl; - } - else if (pLine->Type == Node::MapType) - { - std::cout << pLine->Data + ":" << std::endl; - } - else if (pLine->Type == Node::SequenceType) - { - std::cout << "-" << std::endl; - } - else - { - std::cout << "> UNKOWN TYPE <" << std::endl; - } - } - } +bool Node::IsMap() const { return NODE_IMP->m_Type == Node::MapType; } - /** - * @breif Clear all read lines. - * - */ - void ClearLines() - { - for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) - { - delete *it; - } - m_Lines.clear(); - } +bool Node::IsScalar() const { return NODE_IMP->m_Type == Node::ScalarType; } - void ClearTrailingEmptyLines(std::list::iterator & it) - { - while(it != m_Lines.end()) - { - ReaderLine * pLine = *it; - if(pLine->Data.size() == 0) - { - delete *it; - it = m_Lines.erase(it); - } - else - { - return; - } +void Node::Clear() { NODE_IMP->Clear(); } - } - } +size_t Node::Size() const { + if (TYPE_IMP == nullptr) { + return 0; + } - static bool IsSequenceStart(const std::string & data) - { - if (data.size() == 0 || data[0] != '-') - { - return false; - } + return TYPE_IMP->GetSize(); +} - if (data.size() >= 2 && data[1] != ' ') - { - return false; - } +Node &Node::Insert(const size_t index) { + NODE_IMP->InitSequence(); + return *TYPE_IMP->Insert(index); +} - return true; - } +Node &Node::PushFront() { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushFront(); +} +Node &Node::PushBack() { + NODE_IMP->InitSequence(); + return *TYPE_IMP->PushBack(); +} - static bool IsBlockScalar(const std::string & data, const size_t line, unsigned char & flags) - { - flags = 0; - if(data.size() == 0) - { - return false; - } +Node &Node::operator[](const size_t index) { + NODE_IMP->InitSequence(); + Node *pNode = TYPE_IMP->GetNode(index); + if (pNode == nullptr) { + g_NoneNode.Clear(); + return g_NoneNode; + } + return *pNode; +} - if(data[0] == '|') - { - if(data.size() >= 2) - { - if(data[1] != '-' && data[1] != ' ' && data[1] != '\t') - { - throw ParsingException(ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); - } - } - else - { - flags |= ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]; - } - flags |= ReaderLine::FlagMask[static_cast(ReaderLine::LiteralScalarFlag)]; - return true; - } +Node &Node::operator[](const std::string &key) { + NODE_IMP->InitMap(); + return *TYPE_IMP->GetNode(key); +} - if(data[0] == '>') - { - if(data.size() >= 2) - { - if(data[1] != '-' && data[1] != ' ' && data[1] != '\t') - { - throw ParsingException(ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); - } - } - else - { - flags |= ReaderLine::FlagMask[static_cast(ReaderLine::ScalarNewlineFlag)]; - } - flags |= ReaderLine::FlagMask[static_cast(ReaderLine::FoldedScalarFlag)]; - return true; - } +void Node::Erase(const size_t index) { + if (TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::SequenceType) { + return; + } - return false; - } + return TYPE_IMP->Erase(index); +} - std::list m_Lines; ///< List of lines. +void Node::Erase(const std::string &key) { + if (TYPE_IMP == nullptr || NODE_IMP->m_Type != Node::MapType) { + return; + } - }; + return TYPE_IMP->Erase(key); +} - // Parsing functions - void Parse(Node & root, const char * filename) - { - std::ifstream f(filename, std::ifstream::binary); - if (f.is_open() == false) - { - throw OperationException(g_ErrorCannotOpenFile); - } +Node &Node::operator=(const Node &node) { + NODE_IMP->Clear(); + CopyNode(node, *this); + return *this; +} - f.seekg(0, f.end); - size_t fileSize = static_cast(f.tellg()); - f.seekg(0, f.beg); +Node &Node::operator=(const std::string &value) { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value); + return *this; +} - std::unique_ptr data(new char[fileSize]); - f.read(data.get(), fileSize); - f.close(); +Node &Node::operator=(const char *value) { + NODE_IMP->InitScalar(); + TYPE_IMP->SetData(value ? std::string(value) : ""); + return *this; +} - Parse(root, data.get(), fileSize); - } +Iterator Node::Begin() { + Iterator it; - void Parse(Node & root, std::iostream & stream) - { - ParseImp * pImp = nullptr; + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; - try - { - pImp = new ParseImp; - pImp->Parse(root, stream); - delete pImp; - } - catch (const Exception e) - { - delete pImp; - throw; - } + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; } - void Parse(Node & root, const std::string & string) - { - std::stringstream ss(string); - Parse(root, ss); - } + it.m_pImp = pItImp; + } - void Parse(Node & root, const char * buffer, const size_t size) - { - std::stringstream ss(std::string(buffer, size)); - Parse(root, ss); - } + return it; +} +ConstIterator Node::Begin() const { + ConstIterator it; - // Serialize configuration structure. - SerializeConfig::SerializeConfig(const size_t spaceIndentation, - const size_t scalarMaxLength, - const bool sequenceMapNewline, - const bool mapScalarNewline) : - SpaceIndentation(spaceIndentation), - ScalarMaxLength(scalarMaxLength), - SequenceMapNewline(sequenceMapNewline), - MapScalarNewline(mapScalarNewline) - { + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; + + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitBegin(static_cast(TYPE_IMP)); + break; + default: + break; } + it.m_pImp = pItImp; + } + + return it; +} - // Serialization functions - void Serialize(const Node & root, const char * filename, const SerializeConfig & config) - { - std::stringstream stream; - Serialize(root, stream, config); +Iterator Node::End() { + Iterator it; - std::ofstream f(filename); - if (f.is_open() == false) - { - throw OperationException(g_ErrorCannotOpenFile); - } + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; - f.write(stream.str().c_str(), stream.str().size()); - f.close(); + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = Iterator::SequenceType; + pItImp = new SequenceIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = Iterator::MapType; + pItImp = new MapIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; } - size_t LineFolding(const std::string & input, std::vector & folded, const size_t maxLength) - { - folded.clear(); - if(input.size() == 0) - { - return 0; - } + it.m_pImp = pItImp; + } - size_t currentPos = 0; - size_t lastPos = 0; - size_t spacePos = std::string::npos; - while(currentPos < input.size()) - { - currentPos = lastPos + maxLength; + return it; +} - if(currentPos < input.size()) - { - spacePos = input.find_first_of(' ', currentPos); - } +ConstIterator Node::End() const { + ConstIterator it; - if(spacePos == std::string::npos || currentPos >= input.size()) - { - const std::string endLine = input.substr(lastPos); - if(endLine.size()) - { - folded.push_back(endLine); - } + if (TYPE_IMP != nullptr) { + IteratorImp *pItImp = nullptr; - return folded.size(); - } + switch (NODE_IMP->m_Type) { + case Node::SequenceType: + it.m_Type = ConstIterator::SequenceType; + pItImp = new SequenceConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + case Node::MapType: + it.m_Type = ConstIterator::MapType; + pItImp = new MapConstIteratorImp; + pItImp->InitEnd(static_cast(TYPE_IMP)); + break; + default: + break; + } - folded.push_back(input.substr(lastPos, spacePos - lastPos)); + it.m_pImp = pItImp; + } - lastPos = spacePos + 1; - } + return it; +} - return folded.size(); - } - - static void SerializeLoop(const Node & node, std::iostream & stream, bool useLevel, const size_t level, const SerializeConfig & config) - { - const size_t indention = config.SpaceIndentation; - - switch(node.Type()) - { - case Node::SequenceType: - { - for(auto it = node.Begin(); it != node.End(); it++) - { - const Node & value = (*it).second; - if(value.IsNone()) - { - continue; - } - stream << std::string(level, ' ') << "- "; - useLevel = false; - if(value.IsSequence() || (value.IsMap() && config.SequenceMapNewline == true)) - { - useLevel = true; - stream << "\n"; - } - - SerializeLoop(value, stream, useLevel, level + 2, config); - } +const std::string &Node::AsString() const { + if (TYPE_IMP == nullptr) { + return g_EmptyString; + } - } - break; - case Node::MapType: - { - size_t count = 0; - for(auto it = node.Begin(); it != node.End(); it++) - { - const Node & value = (*it).second; - if(value.IsNone()) - { - continue; - } - - if(useLevel || count > 0) - { - stream << std::string(level, ' '); - } - - std::string key = (*it).first; - AddEscapeTokens(key, "\\\""); - if(ShouldBeCited(key)) - { - stream << "\"" << key << "\"" << ": "; - } - else - { - stream << key << ": "; - } - - - useLevel = false; - if(value.IsScalar() == false || (value.IsScalar() && config.MapScalarNewline)) - { - useLevel = true; - stream << "\n"; - } - - SerializeLoop(value, stream, useLevel, level + indention, config); - - useLevel = true; - count++; - } + return TYPE_IMP->GetData(); +} - } - break; - case Node::ScalarType: - { - const std::string value = node.As(); - - // Empty scalar - if(value.size() == 0) - { - stream << "\n"; - break; - } - - // Get lines of scalar. - std::string line = ""; - std::vector lines; - std::istringstream iss(value); - while (iss.eof() == false) - { - std::getline(iss, line); - lines.push_back(line); - } - - // Block scalar - const std::string & lastLine = lines.back(); - const bool endNewline = lastLine.size() == 0; - if(endNewline) - { - lines.pop_back(); - } - - // Literal - if(lines.size() > 1) - { - stream << "|"; - } - // Folded/plain - else - { - const std::string frontLine = lines.front(); - if(config.ScalarMaxLength == 0 || lines.front().size() <= config.ScalarMaxLength || - LineFolding(frontLine, lines, config.ScalarMaxLength) == 1) - { - if(useLevel) - { - stream << std::string(level, ' '); - } - - if(ShouldBeCited(value)) - { - stream << "\"" << value << "\"\n"; - break; - } - stream << value << "\n"; - break; - } - else - { - stream << ">"; - } - } - - if(endNewline == false) - { - stream << "-"; - } - stream << "\n"; - - - for(auto it = lines.begin(); it != lines.end(); it++) - { - stream << std::string(level, ' ') << (*it) << "\n"; - } - } - break; +// Reader implementations +/** + * @breif Line information structure. + * + */ +class ReaderLine { + public: + /** + * @breif Constructor. + * + */ + ReaderLine(const std::string &data = "", const size_t no = 0, + const size_t offset = 0, const Node::eType type = Node::None, + const unsigned char flags = 0) + : Data(data), + No(no), + Offset(offset), + Type(type), + Flags(flags), + NextLine(nullptr) {} + + enum eFlag { + LiteralScalarFlag, ///< Literal scalar type, defined as "|". + FoldedScalarFlag, ///< Folded scalar type, defined as "<". + ScalarNewlineFlag ///< Scalar ends with a newline. + }; + + /** + * @breif Set flag. + * + */ + void SetFlag(const eFlag flag) { + Flags |= FlagMask[static_cast(flag)]; + } + + /** + * @breif Set flags by mask value. + * + */ + void SetFlags(const unsigned char flags) { Flags |= flags; } + + /** + * @breif Unset flag. + * + */ + void UnsetFlag(const eFlag flag) { + Flags &= ~FlagMask[static_cast(flag)]; + } + + /** + * @breif Unset flags by mask value. + * + */ + void UnsetFlags(const unsigned char flags) { Flags &= ~flags; } + + /** + * @breif Get flag value. + * + */ + bool GetFlag(const eFlag flag) const { + return Flags & FlagMask[static_cast(flag)]; + } + + /** + * @breif Copy and replace scalar flags from another ReaderLine. + * + */ + void CopyScalarFlags(ReaderLine *from) { + if (from == nullptr) { + return; + } + + unsigned char newFlags = + from->Flags & (FlagMask[0] | FlagMask[1] | FlagMask[2]); + Flags |= newFlags; + } + + static const unsigned char FlagMask[3]; + + std::string Data; ///< Data of line. + size_t No; ///< Line number. + size_t Offset; ///< Offset to first character in data. + Node::eType Type; ///< Type of line. + unsigned char Flags; ///< Flags of line. + ReaderLine *NextLine; ///< Pointer to next line. +}; + +const unsigned char ReaderLine::FlagMask[3] = {0x01, 0x02, 0x04}; + +/** + * @breif Implementation class of Yaml parsing. + * Parsing incoming stream and outputs a root node. + * + */ +class ParseImp { + public: + /** + * @breif Default constructor. + * + */ + ParseImp() {} + + /** + * @breif Destructor. + * + */ + ~ParseImp() { ClearLines(); } + + /** + * @breif Run full parsing procedure. + * + */ + void Parse(Node &root, std::iostream &stream) { // NOLINT + try { + root.Clear(); + ReadLines(stream); + PostProcessLines(); + // Print(); + ParseRoot(root); + } catch (Exception e) { + root.Clear(); + throw; + } + } + + private: + /** + * @breif Copy constructor. + * + */ + ParseImp(const ParseImp ©) {} + + /** + * @breif Read all lines. + * Ignoring: + * - Empty lines. + * - Comments. + * - Document start/end. + * + */ + void ReadLines(std::iostream &stream) { + std::string line = ""; + size_t lineNo = 0; + bool documentStartFound = false; + bool foundFirstNotEmpty = false; + std::streampos streamPos = 0; + + // Read all lines, as long as the stream is ok. + while (!stream.eof() && !stream.fail()) { + // Read line + streamPos = stream.tellg(); + std::getline(stream, line); + lineNo++; + + // Remove comment + const size_t commentPos = FindNotCited(line, '#'); + if (commentPos != std::string::npos) { + line.resize(commentPos); + } + + // Start of document. + if (documentStartFound == false && line == "---") { + // Erase all lines before this line. + ClearLines(); + documentStartFound = true; + continue; + } + + // End of document. + if (line == "...") { + break; + } else if (line == "---") { + stream.seekg(streamPos); + break; + } + + // Remove trailing return. + if (line.size()) { + if (line[line.size() - 1] == '\r') { + line.resize(line.size() - 1); + } + } + + // Validate characters. + for (size_t i = 0; i < line.size(); i++) { + if (line[i] != '\t' && (line[i] < 32 || line[i] > 125)) { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidCharacter, lineNo, i + 1)); + } + } + + // Validate tabs + const size_t firstTabPos = line.find_first_of('\t'); + size_t startOffset = line.find_first_not_of(" \t"); + + // Make sure no tabs are in the very front. + if (startOffset != std::string::npos) { + if (firstTabPos < startOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorTabInOffset, lineNo, firstTabPos)); + } + + // Remove front spaces. + line = line.substr(startOffset); + } else { + startOffset = 0; + line = ""; + } + + // Add line. + if (foundFirstNotEmpty == false) { + if (line.size()) { + foundFirstNotEmpty = true; + } else { + continue; + } + } + + ReaderLine *pLine = new ReaderLine(line, lineNo, startOffset); + m_Lines.push_back(pLine); + } + } + + /** + * @breif Run post-processing on all lines. + * Basically split lines into multiple lines if needed, to follow the + * parsing algorithm. + * + */ + void PostProcessLines() { + for (auto it = m_Lines.begin(); it != m_Lines.end();) { + // Sequence. + if (PostProcessSequenceLine(it) == true) { + continue; + } + + // Mapping. + if (PostProcessMappingLine(it) == true) { + continue; + } + + // Scalar. + PostProcessScalarLine(it); + } + + // Set next line of all lines. + if (m_Lines.size()) { + if (m_Lines.back()->Type != Node::ScalarType) { + throw ParsingException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *m_Lines.back())); + } + + if (m_Lines.size() > 1) { + auto prevEnd = m_Lines.end(); + --prevEnd; + + for (auto it = m_Lines.begin(); it != prevEnd; it++) { + auto nextIt = it; + ++nextIt; + + (*it)->NextLine = *nextIt; + } + } + } + } + + /** + * @breif Run post-processing and check for sequence. + * Split line into two lines if sequence token is not on it's own line. + * + * @return true if line is sequence, else false. + * + */ + bool PostProcessSequenceLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; - default: - break; - } + // Sequence split + if (IsSequenceStart(pLine->Data) == false) { + return false; } - void Serialize(const Node & root, std::iostream & stream, const SerializeConfig & config) - { - if(config.SpaceIndentation < 2) - { - throw OperationException(g_ErrorIndentation); - } + pLine->Type = Node::SequenceType; - SerializeLoop(root, stream, false, 0, config); + ClearTrailingEmptyLines(++it); + + const size_t valueStart = pLine->Data.find_first_not_of(" \t", 1); + if (valueStart == std::string::npos) { + return true; } - void Serialize(const Node & root, std::string & string, const SerializeConfig & config) - { - std::stringstream stream; - Serialize(root, stream, config); - string = stream.str(); + // Create new line and insert + std::string newLine = pLine->Data.substr(valueStart); + it = m_Lines.insert( + it, new ReaderLine(newLine, pLine->No, pLine->Offset + valueStart)); + pLine->Data = ""; + + return false; + } + + /** + * @breif Run post-processing and check for mapping. + * Split line into two lines if mapping value is not on it's own line. + * + * @return true if line is mapping, else move on to scalar parsing. + * + */ + bool PostProcessMappingLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; + + // Find map key. + size_t preKeyQuotes = 0; + size_t tokenPos = FindNotCited(pLine->Data, ':', preKeyQuotes); + if (tokenPos == std::string::npos) { + return false; + } + if (preKeyQuotes > 1) { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + pLine->Type = Node::MapType; + + // Get key + std::string key = pLine->Data.substr(0, tokenPos); + const size_t keyEnd = key.find_last_not_of(" \t"); + if (keyEnd == std::string::npos) { + throw ParsingException(ExceptionMessage(g_ErrorKeyMissing, *pLine)); + } + key.resize(keyEnd + 1); + + // Handle cited key. + if (preKeyQuotes == 1) { + if (key.front() != '"' || key.back() != '"') { + throw ParsingException(ExceptionMessage(g_ErrorKeyIncorrect, *pLine)); + } + + key = key.substr(1, key.size() - 2); } + RemoveAllEscapeTokens(key); + + // Get value + std::string value = ""; + size_t valueStart = std::string::npos; + if (tokenPos + 1 != pLine->Data.size()) { + valueStart = pLine->Data.find_first_not_of(" \t", tokenPos + 1); + if (valueStart != std::string::npos) { + value = pLine->Data.substr(valueStart); + } + } + + // Make sure the value is not a sequence start. + if (IsSequenceStart(value) == true) { + throw ParsingException( + ExceptionMessage(g_ErrorBlockSequenceNotAllowed, *pLine, valueStart)); + } + + pLine->Data = key; + + // Remove all empty lines after map key. + ClearTrailingEmptyLines(++it); + + // Add new empty line? + size_t newLineOffset = valueStart; + if (newLineOffset == std::string::npos) { + if (it != m_Lines.end() && (*it)->Offset > pLine->Offset) { + return true; + } + + newLineOffset = tokenPos + 2; + } else { + newLineOffset += pLine->Offset; + } + + // Add new line with value. + unsigned char dummyBlockFlags = 0; + if (IsBlockScalar(value, pLine->No, dummyBlockFlags) == true) { + newLineOffset = pLine->Offset; + } + ReaderLine *pNewLine = + new ReaderLine(value, pLine->No, newLineOffset, Node::ScalarType); + it = m_Lines.insert(it, pNewLine); + + // Return false in order to handle next line(scalar value). + return false; + } + + /** + * @breif Run post-processing and check for scalar. + * Checking for multi-line scalars. + * + * @return true if scalar search should continue, else false. + * + */ + void PostProcessScalarLine( + std::list::iterator &it) { // NOLINT + ReaderLine *pLine = *it; + pLine->Type = Node::ScalarType; + + size_t parentOffset = pLine->Offset; + if (pLine != m_Lines.front()) { + std::list::iterator lastIt = it; + --lastIt; + parentOffset = (*lastIt)->Offset; + } + + std::list::iterator lastNotEmpty = it++; + + // Find last empty lines + while (it != m_Lines.end()) { + pLine = *it; + pLine->Type = Node::ScalarType; + if (pLine->Data.size()) { + if (pLine->Offset <= parentOffset) { + break; + } else { + lastNotEmpty = it; + } + } + ++it; + } + + ClearTrailingEmptyLines(++lastNotEmpty); + } + + /** + * @breif Process root node and start of document. + * + */ + void ParseRoot(Node &root) { // NOLINT + // Get first line and start type. + auto it = m_Lines.begin(); + if (it == m_Lines.end()) { + return; + } + Node::eType type = (*it)->Type; + ReaderLine *pLine = *it; + + // Handle next line. + switch (type) { + case Node::SequenceType: + ParseSequence(root, it); + break; + case Node::MapType: + ParseMap(root, it); + break; + case Node::ScalarType: + ParseScalar(root, it); + break; + default: + break; + } + + if (it != m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + } + + /** + * @breif Process sequence node. + * + */ + void ParseSequence( + Node &node, std::list::iterator &it) { // NOLINT + ReaderLine *pNextLine = nullptr; + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + Node &childNode = node.PushBack(); + + // Move to next line, error check. + ++it; + if (it == m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch (valueType) { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if sequence and correct level, go on, else exit. + // If same level but but of type map = error. + if (it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) { + break; + } + if (pNextLine->Offset > pLine->Offset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if (pNextLine->Type != Node::SequenceType) { + throw InternalException( + ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + } + } + + /** + * @breif Process map node. + * + */ + void ParseMap(Node &node, // NOLINT + std::list::iterator &it) { // NOLINT + ReaderLine *pNextLine = nullptr; + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + Node &childNode = node[pLine->Data]; + + // Move to next line, error check. + ++it; + if (it == m_Lines.end()) { + throw InternalException( + ExceptionMessage(g_ErrorUnexpectedDocumentEnd, *pLine)); + } + + // Handle value of map + Node::eType valueType = (*it)->Type; + switch (valueType) { + case Node::SequenceType: + ParseSequence(childNode, it); + break; + case Node::MapType: + ParseMap(childNode, it); + break; + case Node::ScalarType: + ParseScalar(childNode, it); + break; + default: + break; + } + + // Check next line. if map and correct level, go on, else exit. + // if same level but but of type map = error. + if (it == m_Lines.end() || ((pNextLine = *it)->Offset < pLine->Offset)) { + break; + } + if (pNextLine->Offset > pLine->Offset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pNextLine)); + } + if (pNextLine->Type != pLine->Type) { + throw InternalException( + ExceptionMessage(g_ErrorDiffEntryNotAllowed, *pNextLine)); + } + } + } + + /** + * @breif Process scalar node. + * + */ + void ParseScalar( + Node &node, std::list::iterator &it) { // NOLINT + std::string data = ""; + ReaderLine *pFirstLine = *it; + ReaderLine *pLine = *it; + + // Check if current line is a block scalar. + unsigned char blockFlags = 0; + bool isBlockScalar = IsBlockScalar(pLine->Data, pLine->No, blockFlags); + const bool newLineFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]); + const bool foldedFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::FoldedScalarFlag)]); + const bool literalFlag = + static_cast(blockFlags & ReaderLine::FlagMask[static_cast( + ReaderLine::LiteralScalarFlag)]); + size_t parentOffset = 0; + + // Find parent offset + if (it != m_Lines.begin()) { + std::list::iterator parentIt = it; + --parentIt; + parentOffset = (*parentIt)->Offset; + } + + // Move to next iterator/line if current line is a block scalar. + if (isBlockScalar) { + ++it; + if (it == m_Lines.end() || (pLine = *it)->Type != Node::ScalarType) { + return; + } + } + + // Not a block scalar, cut end spaces/tabs + if (isBlockScalar == false) { + while (1) { + pLine = *it; + + if (parentOffset != 0 && pLine->Offset <= parentOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if (endOffset == std::string::npos) { + data += "\n"; + } else { + data += pLine->Data.substr(0, endOffset + 1); + } + + // Move to next line + ++it; + if (it == m_Lines.end() || (*it)->Type != Node::ScalarType) { + break; + } + + data += " "; + } + + if (ValidateQuote(data) == false) { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidQuote, *pFirstLine)); + } + } else { // Block scalar + pLine = *it; + size_t blockOffset = pLine->Offset; + if (blockOffset <= parentOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + bool addedSpace = false; + while (it != m_Lines.end() && (*it)->Type == Node::ScalarType) { + pLine = *it; + + const size_t endOffset = pLine->Data.find_last_not_of(" \t"); + if (endOffset != std::string::npos && pLine->Offset < blockOffset) { + throw ParsingException( + ExceptionMessage(g_ErrorIncorrectOffset, *pLine)); + } + + if (endOffset == std::string::npos) { + if (addedSpace) { + data[data.size() - 1] = '\n'; + addedSpace = false; + } else { + data += "\n"; + } + + ++it; + continue; + } else { + if (blockOffset != pLine->Offset && foldedFlag) { + if (addedSpace) { + data[data.size() - 1] = '\n'; + addedSpace = false; + } else { + data += "\n"; + } + } + data += std::string(pLine->Offset - blockOffset, ' '); + data += pLine->Data; + } + + // Move to next line + ++it; + if (it == m_Lines.end() || (*it)->Type != Node::ScalarType) { + if (newLineFlag) { + data += "\n"; + } + break; + } + + if (foldedFlag) { + data += " "; + addedSpace = true; + } else if (literalFlag && endOffset != std::string::npos) { + data += "\n"; + } + } + } + + if (data.size() && (data[0] == '"' || data[0] == '\'')) { + data = data.substr(1, data.size() - 2); + } + + node = data; + } + + /** + * @breif Debug printing. + * + */ + void Print() { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) { + ReaderLine *pLine = *it; + + // Print type + if (pLine->Type == Node::SequenceType) { + std::cout << "seq "; + } else if (pLine->Type == Node::MapType) { + std::cout << "map "; + } else if (pLine->Type == Node::ScalarType) { + std::cout << "sca "; + } else { + std::cout << " "; + } + + // Print flags + if (pLine->GetFlag(ReaderLine::FoldedScalarFlag)) { + std::cout << "f"; + } else { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::LiteralScalarFlag)) { + std::cout << "l"; + } else { + std::cout << "-"; + } + if (pLine->GetFlag(ReaderLine::ScalarNewlineFlag)) { + std::cout << "n"; + } else { + std::cout << "-"; + } + if (pLine->NextLine == nullptr) { + std::cout << "e"; + } else { + std::cout << "-"; + } + + std::cout << "| "; + std::cout << pLine->No << " "; + std::cout << std::string(pLine->Offset, ' '); + + if (pLine->Type == Node::ScalarType) { + std::string scalarValue = pLine->Data; + for (size_t i = 0; + (i = scalarValue.find("\n", i)) != std::string::npos;) { + scalarValue.replace(i, 1, "\\n"); + i += 2; + } + std::cout << scalarValue << std::endl; + } else if (pLine->Type == Node::MapType) { + std::cout << pLine->Data + ":" << std::endl; + } else if (pLine->Type == Node::SequenceType) { + std::cout << "-" << std::endl; + } else { + std::cout << "> UNKOWN TYPE <" << std::endl; + } + } + } + + /** + * @breif Clear all read lines. + * + */ + void ClearLines() { + for (auto it = m_Lines.begin(); it != m_Lines.end(); it++) { + delete *it; + } + m_Lines.clear(); + } + + void ClearTrailingEmptyLines( + std::list::iterator &it) { // NOLINT + while (it != m_Lines.end()) { + ReaderLine *pLine = *it; + if (pLine->Data.size() == 0) { + delete *it; + it = m_Lines.erase(it); + } else { + return; + } + } + } + + static bool IsSequenceStart(const std::string &data) { + if (data.size() == 0 || data[0] != '-') { + return false; + } + + if (data.size() >= 2 && data[1] != ' ') { + return false; + } + + return true; + } + + static bool IsBlockScalar(const std::string &data, + const size_t line, unsigned char &flags) { // NOLINT + flags = 0; + if (data.size() == 0) { + return false; + } + + if (data[0] == '|') { + if (data.size() >= 2) { + if (data[1] != '-' && data[1] != ' ' && data[1] != '\t') { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } else { + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::LiteralScalarFlag)]; + return true; + } + + if (data[0] == '>') { + if (data.size() >= 2) { + if (data[1] != '-' && data[1] != ' ' && data[1] != '\t') { + throw ParsingException( + ExceptionMessage(g_ErrorInvalidBlockScalar, line, data)); + } + } else { + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::ScalarNewlineFlag)]; + } + flags |= ReaderLine::FlagMask[static_cast( + ReaderLine::FoldedScalarFlag)]; + return true; + } + + return false; + } + + std::list m_Lines; ///< List of lines. +}; + +// Parsing functions +void Parse(Node &root, const char *filename) { // NOLINT + std::ifstream f(filename, std::ifstream::binary); + if (f.is_open() == false) { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.seekg(0, f.end); + size_t fileSize = static_cast(f.tellg()); + f.seekg(0, f.beg); + + std::unique_ptr data(new char[fileSize]); + f.read(data.get(), fileSize); + f.close(); + + Parse(root, data.get(), fileSize); +} +void Parse(Node &root, std::iostream &stream) { // NOLINT + ParseImp *pImp = nullptr; + + try { + pImp = new ParseImp; + pImp->Parse(root, stream); + delete pImp; + } catch (const Exception e) { + delete pImp; + throw; + } +} +void Parse(Node &root, const std::string &string) { // NOLINT + std::stringstream ss(string); + Parse(root, ss); +} - // Static function implementations - std::string ExceptionMessage(const std::string & message, ReaderLine & line) - { - return message + std::string(" Line ") + std::to_string(line.No) + std::string(": ") + line.Data; - } +void Parse(Node &root, const char *buffer, const size_t size) { // NOLINT + std::stringstream ss(std::string(buffer, size)); + Parse(root, ss); +} - std::string ExceptionMessage(const std::string & message, ReaderLine & line, const size_t errorPos) - { - return message + std::string(" Line ") + std::to_string(line.No) + std::string(" column ") + std::to_string(errorPos + 1) + std::string(": ") + line.Data; - } +// Serialize configuration structure. +SerializeConfig::SerializeConfig(const size_t spaceIndentation, + const size_t scalarMaxLength, + const bool sequenceMapNewline, + const bool mapScalarNewline) + : SpaceIndentation(spaceIndentation), + ScalarMaxLength(scalarMaxLength), + SequenceMapNewline(sequenceMapNewline), + MapScalarNewline(mapScalarNewline) {} + +// Serialization functions +void Serialize(const Node &root, const char *filename, + const SerializeConfig &config) { + std::stringstream stream; + Serialize(root, stream, config); + + std::ofstream f(filename); + if (f.is_open() == false) { + throw OperationException(g_ErrorCannotOpenFile); + } + + f.write(stream.str().c_str(), stream.str().size()); + f.close(); +} - std::string ExceptionMessage(const std::string & message, const size_t errorLine, const size_t errorPos) - { - return message + std::string(" Line ") + std::to_string(errorLine) + std::string(" column ") + std::to_string(errorPos); - } +size_t LineFolding(const std::string &input, + std::vector &folded, // NOLINT + const size_t maxLength) { + folded.clear(); + if (input.size() == 0) { + return 0; + } - std::string ExceptionMessage(const std::string & message, const size_t errorLine, const std::string & data) - { - return message + std::string(" Line ") + std::to_string(errorLine) + std::string(": ") + data; + size_t currentPos = 0; + size_t lastPos = 0; + size_t spacePos = std::string::npos; + while (currentPos < input.size()) { + currentPos = lastPos + maxLength; + + if (currentPos < input.size()) { + spacePos = input.find_first_of(' ', currentPos); } - bool FindQuote(const std::string & input, size_t & start, size_t & end, size_t searchPos) - { - start = end = std::string::npos; - size_t qPos = searchPos; - bool foundStart = false; + if (spacePos == std::string::npos || currentPos >= input.size()) { + const std::string endLine = input.substr(lastPos); + if (endLine.size()) { + folded.push_back(endLine); + } - while(qPos != std::string::npos) - { - // Find first quote. - qPos = input.find_first_of("\"'", qPos); - if(qPos == std::string::npos) - { - return false; - } + return folded.size(); + } - const char token = input[qPos]; - if(token == '"' && (qPos == 0 || input[qPos-1] != '\\')) - { - // Found start quote. - if(foundStart == false) - { - start = qPos; - foundStart = true; - } - // Found end quote - else - { - end = qPos; - return true; - } - } + folded.push_back(input.substr(lastPos, spacePos - lastPos)); - // Check if it's possible for another loop. - if(qPos + 1 == input.size()) - { - return false; - } - qPos++; - } + lastPos = spacePos + 1; + } - return false; - } + return folded.size(); +} - size_t FindNotCited(const std::string & input, char token, size_t & preQuoteCount) - { - preQuoteCount = 0; - size_t tokenPos = input.find_first_of(token); - if(tokenPos == std::string::npos) - { - return std::string::npos; - } +static void SerializeLoop(const Node &node, std::iostream &stream, + bool useLevel, const size_t level, + const SerializeConfig &config) { + const size_t indention = config.SpaceIndentation; + + switch (node.Type()) { + case Node::SequenceType: { + for (auto it = node.Begin(); it != node.End(); it++) { + const Node &value = (*it).second; + if (value.IsNone()) { + continue; + } + stream << std::string(level, ' ') << "- "; + useLevel = false; + if (value.IsSequence() || + (value.IsMap() && config.SequenceMapNewline == true)) { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + 2, config); + } + } break; + case Node::MapType: { + size_t count = 0; + for (auto it = node.Begin(); it != node.End(); it++) { + const Node &value = (*it).second; + if (value.IsNone()) { + continue; + } + + if (useLevel || count > 0) { + stream << std::string(level, ' '); + } + + std::string key = (*it).first; + AddEscapeTokens(key, "\\\""); + if (ShouldBeCited(key)) { + stream << "\"" << key << "\"" + << ": "; + } else { + stream << key << ": "; + } + + useLevel = false; + if (value.IsScalar() == false || + (value.IsScalar() && config.MapScalarNewline)) { + useLevel = true; + stream << "\n"; + } + + SerializeLoop(value, stream, useLevel, level + indention, config); + + useLevel = true; + count++; + } + } break; + case Node::ScalarType: { + const std::string value = node.As(); + + // Empty scalar + if (value.size() == 0) { + stream << "\n"; + break; + } + + // Get lines of scalar. + std::string line = ""; + std::vector lines; + std::istringstream iss(value); + while (iss.eof() == false) { + std::getline(iss, line); + lines.push_back(line); + } + + // Block scalar + const std::string &lastLine = lines.back(); + const bool endNewline = lastLine.size() == 0; + if (endNewline) { + lines.pop_back(); + } + + // Literal + if (lines.size() > 1) { + stream << "|"; + } else { // Folded/plain + const std::string frontLine = lines.front(); + if (config.ScalarMaxLength == 0 || + lines.front().size() <= config.ScalarMaxLength || + LineFolding(frontLine, lines, config.ScalarMaxLength) == 1) { + if (useLevel) { + stream << std::string(level, ' '); + } + + if (ShouldBeCited(value)) { + stream << "\"" << value << "\"\n"; + break; + } + stream << value << "\n"; + break; + } else { + stream << ">"; + } + } + + if (endNewline == false) { + stream << "-"; + } + stream << "\n"; + + for (auto it = lines.begin(); it != lines.end(); it++) { + stream << std::string(level, ' ') << (*it) << "\n"; + } + } break; + + default: + break; + } +} - // Find all quotes - std::vector> quotes; +void Serialize(const Node &root, std::iostream &stream, + const SerializeConfig &config) { + if (config.SpaceIndentation < 2) { + throw OperationException(g_ErrorIndentation); + } - size_t quoteStart = 0; - size_t quoteEnd = 0; - while(FindQuote(input, quoteStart, quoteEnd, quoteEnd)) - { - quotes.push_back({quoteStart, quoteEnd}); + SerializeLoop(root, stream, false, 0, config); +} - if(quoteEnd + 1 == input.size()) - { - break; - } - quoteEnd++; - } +void Serialize(const Node &root, + std::string &string, // NOLINT + const SerializeConfig &config) { + std::stringstream stream; + Serialize(root, stream, config); + string = stream.str(); +} - if(quotes.size() == 0) - { - return tokenPos; - } +// Static function implementations +std::string ExceptionMessage(const std::string &message, + ReaderLine &line) { // NOLINT + return message + std::string(" Line ") + std::to_string(line.No) + + std::string(": ") + line.Data; +} + +std::string ExceptionMessage(const std::string &message, + ReaderLine &line, const size_t errorPos) { // NOLINT + return message + std::string(" Line ") + std::to_string(line.No) + + std::string(" column ") + std::to_string(errorPos + 1) + + std::string(": ") + line.Data; +} - size_t currentQuoteIndex = 0; - std::pair currentQuote = {0, 0}; +std::string ExceptionMessage(const std::string &message, const size_t errorLine, + const size_t errorPos) { + return message + std::string(" Line ") + std::to_string(errorLine) + + std::string(" column ") + std::to_string(errorPos); +} - while(currentQuoteIndex < quotes.size()) - { - currentQuote = quotes[currentQuoteIndex]; +std::string ExceptionMessage(const std::string &message, const size_t errorLine, + const std::string &data) { + return message + std::string(" Line ") + std::to_string(errorLine) + + std::string(": ") + data; +} - if(tokenPos < currentQuote.first) - { - return tokenPos; - } - preQuoteCount++; - if(tokenPos <= currentQuote.second) - { - // Find next token - if(tokenPos + 1 == input.size()) - { - return std::string::npos; - } - tokenPos = input.find_first_of(token, tokenPos + 1); - if(tokenPos == std::string::npos) - { - return std::string::npos; - } - } +bool FindQuote(const std::string &input, + size_t &start, size_t &end, // NOLINT + size_t searchPos) { + start = end = std::string::npos; + size_t qPos = searchPos; + bool foundStart = false; - currentQuoteIndex++; - } + while (qPos != std::string::npos) { + // Find first quote. + qPos = input.find_first_of("\"'", qPos); + if (qPos == std::string::npos) { + return false; + } - return tokenPos; + const char token = input[qPos]; + if (token == '"' && (qPos == 0 || input[qPos - 1] != '\\')) { + // Found start quote. + if (foundStart == false) { + start = qPos; + foundStart = true; + } else { + end = qPos; + return true; + } } - size_t FindNotCited(const std::string & input, char token) - { - size_t dummy = 0; - return FindNotCited(input, token, dummy); + // Check if it's possible for another loop. + if (qPos + 1 == input.size()) { + return false; } + qPos++; + } - bool ValidateQuote(const std::string & input) - { - if(input.size() == 0) - { - return true; - } + return false; +} - char token = 0; - size_t searchPos = 0; - if(input[0] == '\"' || input[0] == '\'') - { - if(input.size() == 1) - { - return false; - } - token = input[0]; - searchPos = 1; - } +size_t FindNotCited(const std::string &input, char token, + size_t &preQuoteCount) { // NOLINT + preQuoteCount = 0; + size_t tokenPos = input.find_first_of(token); + if (tokenPos == std::string::npos) { + return std::string::npos; + } - while(searchPos != std::string::npos && searchPos < input.size() - 1) - { - searchPos = input.find_first_of("\"'", searchPos + 1); - if(searchPos == std::string::npos) - { - break; - } + // Find all quotes + std::vector> quotes; - const char foundToken = input[searchPos]; - - if(input[searchPos] == '\"' || input[searchPos] == '\'') - { - if(token == 0 && input[searchPos-1] != '\\') - { - return false; - } - //if(foundToken == token) - //{ - - /*if(foundToken == token && searchPos == input.size() - 1 && input[searchPos-1] != '\\') - { - return true; - if(searchPos == input.size() - 1) - { - return true; - } - return false; - } - else */ - if(foundToken == token && input[searchPos-1] != '\\') - { - if(searchPos == input.size() - 1) - { - return true; - } - return false; - } - //} - } - } + size_t quoteStart = 0; + size_t quoteEnd = 0; + while (FindQuote(input, quoteStart, quoteEnd, quoteEnd)) { + quotes.push_back({quoteStart, quoteEnd}); - return token == 0; + if (quoteEnd + 1 == input.size()) { + break; } + quoteEnd++; + } - void CopyNode(const Node & from, Node & to) - { - const Node::eType type = from.Type(); + if (quotes.size() == 0) { + return tokenPos; + } - switch(type) - { - case Node::SequenceType: - for(auto it = from.Begin(); it != from.End(); it++) - { - const Node & currentNode = (*it).second; - Node & newNode = to.PushBack(); - CopyNode(currentNode, newNode); - } - break; - case Node::MapType: - for(auto it = from.Begin(); it != from.End(); it++) - { - const Node & currentNode = (*it).second; - Node & newNode = to[(*it).first]; - CopyNode(currentNode, newNode); - } - break; - case Node::ScalarType: - to = from.As(); - break; - case Node::None: - break; - } + size_t currentQuoteIndex = 0; + std::pair currentQuote = {0, 0}; + + while (currentQuoteIndex < quotes.size()) { + currentQuote = quotes[currentQuoteIndex]; + + if (tokenPos < currentQuote.first) { + return tokenPos; + } + preQuoteCount++; + if (tokenPos <= currentQuote.second) { + // Find next token + if (tokenPos + 1 == input.size()) { + return std::string::npos; + } + tokenPos = input.find_first_of(token, tokenPos + 1); + if (tokenPos == std::string::npos) { + return std::string::npos; + } + } + + currentQuoteIndex++; + } + + return tokenPos; +} + +size_t FindNotCited(const std::string &input, char token) { + size_t dummy = 0; + return FindNotCited(input, token, dummy); +} + +bool ValidateQuote(const std::string &input) { + if (input.size() == 0) { + return true; + } + + char token = 0; + size_t searchPos = 0; + if (input[0] == '\"' || input[0] == '\'') { + if (input.size() == 1) { + return false; } + token = input[0]; + searchPos = 1; + } - bool ShouldBeCited(const std::string & key) - { - return key.find_first_of("\":{}[],&*#?|-<>=!%@") != std::string::npos; + while (searchPos != std::string::npos && searchPos < input.size() - 1) { + searchPos = input.find_first_of("\"'", searchPos + 1); + if (searchPos == std::string::npos) { + break; } - void AddEscapeTokens(std::string & input, const std::string & tokens) - { - for(auto it = tokens.begin(); it != tokens.end(); it++) - { - const char token = *it; - const std::string replace = std::string("\\") + std::string(1, token); - size_t found = input.find_first_of(token); - while(found != std::string::npos) - { - input.replace(found, 1, replace); - found = input.find_first_of(token, found + 2); - } + const char foundToken = input[searchPos]; + + if (input[searchPos] == '\"' || input[searchPos] == '\'') { + if (token == 0 && input[searchPos - 1] != '\\') { + return false; + } + // if(foundToken == token) + //{ + + /*if(foundToken == token && searchPos == input.size() - 1 && + input[searchPos-1] != '\\') + { + return true; + if(searchPos == input.size() - 1) + { + return true; + } + return false; + } + else */ + if (foundToken == token && input[searchPos - 1] != '\\') { + if (searchPos == input.size() - 1) { + return true; } + return false; + } + //} } + } - void RemoveAllEscapeTokens(std::string & input) - { - size_t found = input.find_first_of("\\"); - while(found != std::string::npos) - { - if(found + 1 == input.size()) - { - return; - } + return token == 0; +} - std::string replace(1, input[found + 1]); - input.replace(found, 2, replace); - found = input.find_first_of("\\", found + 1); - } +void CopyNode(const Node &from, Node &to) { // NOLINT + const Node::eType type = from.Type(); + + switch (type) { + case Node::SequenceType: + for (auto it = from.Begin(); it != from.End(); it++) { + const Node ¤tNode = (*it).second; + Node &newNode = to.PushBack(); + CopyNode(currentNode, newNode); + } + break; + case Node::MapType: + for (auto it = from.Begin(); it != from.End(); it++) { + const Node ¤tNode = (*it).second; + Node &newNode = to[(*it).first]; + CopyNode(currentNode, newNode); + } + break; + case Node::ScalarType: + to = from.As(); + break; + case Node::None: + break; + } +} + +bool ShouldBeCited(const std::string &key) { + return key.find_first_of("\":{}[],&*#?|-<>=!%@") != std::string::npos; +} + +void AddEscapeTokens(std::string &input, const std::string &tokens) { // NOLINT + for (auto it = tokens.begin(); it != tokens.end(); it++) { + const char token = *it; + const std::string replace = std::string("\\") + std::string(1, token); + size_t found = input.find_first_of(token); + while (found != std::string::npos) { + input.replace(found, 1, replace); + found = input.find_first_of(token, found + 2); } + } +} +void RemoveAllEscapeTokens(std::string &input) { // NOLINT + size_t found = input.find_first_of("\\"); + while (found != std::string::npos) { + if (found + 1 == input.size()) { + return; + } + std::string replace(1, input[found + 1]); + input.replace(found, 2, replace); + found = input.find_first_of("\\", found + 1); + } } + +} // namespace Yaml diff --git a/runtime/core/utils/Yaml.hpp b/runtime/core/utils/Yaml.hpp index 586657fb2..2eb302735 100644 --- a/runtime/core/utils/Yaml.hpp +++ b/runtime/core/utils/Yaml.hpp @@ -1,3 +1,5 @@ +// Copyright (c) From https://github.com/jimmiebergmann/mini-yaml +// 2022 SoundDataConverge Co.LTD (Weiliang Chong) /* * MIT License * @@ -37,620 +39,567 @@ YAML documentation: #include #include #include +#include /** * @breif Namespace wrapping mini-yaml classes. * */ -namespace Yaml -{ - - /** - * @breif Forward declarations. - * - */ - class Node; - - - /** - * @breif Helper classes and functions - * - */ - namespace impl - { - - /** - * @breif Helper functionality, converting string to any data type. - * Strings are left untouched. - * - */ - template - struct StringConverter - { - static T Get(const std::string & data) - { - T type; - std::stringstream ss(data); - ss >> type; - return type; - } - - static T Get(const std::string & data, const T & defaultValue) - { - T type; - std::stringstream ss(data); - ss >> type; - - if(ss.fail()) - { - return defaultValue; - } - - return type; - } - }; - template<> - struct StringConverter - { - static std::string Get(const std::string & data) - { - return data; - } - - static std::string Get(const std::string & data, const std::string & defaultValue) - { - if(data.size() == 0) - { - return defaultValue; - } - return data; - } - }; - - template<> - struct StringConverter - { - static bool Get(const std::string & data) - { - std::string tmpData = data; - std::transform(tmpData.begin(), tmpData.end(), tmpData.begin(), ::tolower); - if(tmpData == "true" || tmpData == "yes" || tmpData == "1") - { - return true; - } - - return false; - } - - static bool Get(const std::string & data, const bool & defaultValue) - { - if(data.size() == 0) - { - return defaultValue; - } - - return Get(data); - } - }; +namespace Yaml { +/** +* @breif Forward declarations. +* +*/ +class Node; + + +/** +* @breif Helper classes and functions +* +*/ +namespace impl { +/** +* @breif Helper functionality, converting string to any data type. +* Strings are left untouched. +* +*/ +template +struct StringConverter { + static T Get(const std::string & data) { + T type; + std::stringstream ss(data); + ss >> type; + return type; + } + + static T Get(const std::string & data, const T & defaultValue) { + T type; + std::stringstream ss(data); + ss >> type; + + if (ss.fail()) { + return defaultValue; + } + return type; + } +}; +template<> +struct StringConverter { + static std::string Get(const std::string & data) { + return data; + } + + static std::string Get( + const std::string & data, const std::string & defaultValue) { + if (data.size() == 0) { + return defaultValue; } + return data; + } +}; + +template<> +struct StringConverter { + static bool Get(const std::string & data) { + std::string tmpData = data; + std::transform(tmpData.begin(), tmpData.end(), tmpData.begin(), ::tolower); + if (tmpData == "true" || tmpData == "yes" || tmpData == "1") { + return true; + } + + return false; + } + + static bool Get(const std::string & data, const bool & defaultValue) { + if (data.size() == 0) { + return defaultValue; + } + + return Get(data); + } +}; + +} // namespace impl - /** - * @breif Exception class. - * - */ - class Exception : public std::runtime_error - { - - public: - - /** - * @breif Enumeration of exception types. - * - */ - enum eType - { - InternalError, ///< Internal error. - ParsingError, ///< Invalid parsing data. - OperationError ///< User operation error. - }; - - /** - * @breif Constructor. - * - * @param message Exception message. - * @param type Type of exception. - * - */ - Exception(const std::string & message, const eType type); - - /** - * @breif Get type of exception. - * - */ - eType Type() const; - - /** - * @breif Get message of exception. - * - */ - const char * Message() const; - - private: - - eType m_Type; ///< Type of exception. - - }; - - - /** - * @breif Internal exception class. - * - * @see Exception - * - */ - class InternalException : public Exception - { - - public: - - /** - * @breif Constructor. - * - * @param message Exception message. - * - */ - InternalException(const std::string & message); - - }; - - - /** - * @breif Parsing exception class. - * - * @see Exception - * - */ - class ParsingException : public Exception - { - - public: - - /** - * @breif Constructor. - * - * @param message Exception message. - * - */ - ParsingException(const std::string & message); - - }; - - - /** - * @breif Operation exception class. - * - * @see Exception - * - */ - class OperationException : public Exception - { - - public: - - /** - * @breif Constructor. - * - * @param message Exception message. - * - */ - OperationException(const std::string & message); - - }; - - - /** - * @breif Iterator class. - * - */ - class Iterator - { - - public: - - friend class Node; - - /** - * @breif Default constructor. - * - */ - Iterator(); - - /** - * @breif Copy constructor. - * - */ - Iterator(const Iterator & it); - - /** - * @breif Assignment operator. - * - */ - Iterator & operator = (const Iterator & it); - - /** - * @breif Destructor. - * - */ - ~Iterator(); - - /** - * @breif Get node of iterator. - * First pair item is the key of map value, empty if type is sequence. - * - */ - std::pair operator *(); - - /** - * @breif Post-increment operator. - * - */ - Iterator & operator ++ (int); - - /** - * @breif Post-decrement operator. - * - */ - Iterator & operator -- (int); - - /** - * @breif Check if iterator is equal to other iterator. - * - */ - bool operator == (const Iterator & it); - - /** - * @breif Check if iterator is not equal to other iterator. - * - */ - bool operator != (const Iterator & it); - - private: - - enum eType - { - None, - SequenceType, - MapType - }; - - eType m_Type; ///< Type of iterator. - void * m_pImp; ///< Implementation of iterator class. - - }; - - - /** - * @breif Constant iterator class. - * - */ - class ConstIterator - { - - public: - - friend class Node; - - /** - * @breif Default constructor. - * - */ - ConstIterator(); - - /** - * @breif Copy constructor. - * - */ - ConstIterator(const ConstIterator & it); - - /** - * @breif Assignment operator. - * - */ - ConstIterator & operator = (const ConstIterator & it); - - /** - * @breif Destructor. - * - */ - ~ConstIterator(); - - /** - * @breif Get node of iterator. - * First pair item is the key of map value, empty if type is sequence. - * - */ - std::pair operator *(); - - /** - * @breif Post-increment operator. - * - */ - ConstIterator & operator ++ (int); - - /** - * @breif Post-decrement operator. - * - */ - ConstIterator & operator -- (int); - - /** - * @breif Check if iterator is equal to other iterator. - * - */ - bool operator == (const ConstIterator & it); - - /** - * @breif Check if iterator is not equal to other iterator. - * - */ - bool operator != (const ConstIterator & it); - - private: - - enum eType - { - None, - SequenceType, - MapType - }; - - eType m_Type; ///< Type of iterator. - void * m_pImp; ///< Implementation of constant iterator class. - - }; - - - /** - * @breif Node class. - * - */ - class Node - { - - public: - - friend class Iterator; - - /** - * @breif Enumeration of node types. - * - */ - enum eType - { - None, - SequenceType, - MapType, - ScalarType - }; - - /** - * @breif Default constructor. - * - */ - Node(); - - /** - * @breif Copy constructor. - * - */ - Node(const Node & node); - - /** - * @breif Assignment constructors. - * Converts node to scalar type if needed. - * - */ - Node(const std::string & value); - Node(const char * value); - - /** - * @breif Destructor. - * - */ - ~Node(); - - /** - * @breif Functions for checking type of node. - * - */ - eType Type() const; - bool IsNone() const; - bool IsSequence() const; - bool IsMap() const; - bool IsScalar() const; - - /** - * @breif Completely clear node. - * - */ - void Clear(); - - /** - * @breif Get node as given template type. - * - */ - template - T As() const - { - return impl::StringConverter::Get(AsString()); - } - - /** - * @breif Get node as given template type. - * - */ - template - T As(const T & defaultValue) const - { - return impl::StringConverter::Get(AsString(), defaultValue); - } - - /** - * @breif Get size of node. - * Nodes of type None or Scalar will return 0. - * - */ - size_t Size() const; - - // Sequence operators - - /** - * @breif Insert sequence item at given index. - * Converts node to sequence type if needed. - * Adding new item to end of sequence if index is larger than sequence size. - * - */ - Node & Insert(const size_t index); - - /** - * @breif Add new sequence index to back. - * Converts node to sequence type if needed. - * - */ - Node & PushFront(); - - /** - * @breif Add new sequence index to front. - * Converts node to sequence type if needed. - * - */ - Node & PushBack(); - - /** - * @breif Get sequence/map item. - * Converts node to sequence/map type if needed. - * - * @param index Sequence index. Returns None type Node if index is unknown. - * @param key Map key. Creates a new node if key is unknown. - * - */ - Node & operator [] (const size_t index); - Node & operator [] (const std::string & key); - - /** - * @breif Erase item. - * No action if node is not a sequence or map. - * - */ - void Erase(const size_t index); - void Erase(const std::string & key); - - /** - * @breif Assignment operators. - * - */ - Node & operator = (const Node & node); - Node & operator = (const std::string & value); - Node & operator = (const char * value); - - /** - * @breif Get start iterator. - * - */ - Iterator Begin(); - ConstIterator Begin() const; - - /** - * @breif Get end iterator. - * - */ - Iterator End(); - ConstIterator End() const; - - - private: - - /** - * @breif Get as string. If type is scalar, else empty. - * - */ - const std::string & AsString() const; - - void * m_pImp; ///< Implementation of node class. - - }; - - - /** - * @breif Parsing functions. - * Population given root node with deserialized data. - * - * @param root Root node to populate. - * @param filename Path of input file. - * @param stream Input stream. - * @param string String of input data. - * @param buffer Char array of input data. - * @param size Buffer size. - * - * @throw InternalException An internal error occurred. - * @throw ParsingException Invalid input YAML data. - * @throw OperationException If filename or buffer pointer is invalid. - * - */ - void Parse(Node & root, const char * filename); - void Parse(Node & root, std::iostream & stream); - void Parse(Node & root, const std::string & string); - void Parse(Node & root, const char * buffer, const size_t size); - - - /** - * @breif Serialization configuration structure, - * describing output behavior. - * - */ - struct SerializeConfig - { - - /** - * @breif Constructor. - * - * @param spaceIndentation Number of spaces per indentation. - * @param scalarMaxLength Maximum length of scalars. Serialized as folder scalars if exceeded. - * Ignored if equal to 0. - * @param sequenceMapNewline Put maps on a new line if parent node is a sequence. - * @param mapScalarNewline Put scalars on a new line if parent node is a map. - * - */ - SerializeConfig(const size_t spaceIndentation = 2, - const size_t scalarMaxLength = 64, - const bool sequenceMapNewline = false, - const bool mapScalarNewline = false); - - size_t SpaceIndentation; ///< Number of spaces per indentation. - size_t ScalarMaxLength; ///< Maximum length of scalars. Serialized as folder scalars if exceeded. - bool SequenceMapNewline; ///< Put maps on a new line if parent node is a sequence. - bool MapScalarNewline; ///< Put scalars on a new line if parent node is a map. - }; - - - /** - * @breif Serialization functions. - * - * @param root Root node to serialize. - * @param filename Path of output file. - * @param stream Output stream. - * @param string String of output data. - * @param config Serialization configurations. - * - * @throw InternalException An internal error occurred. - * @throw OperationException If filename or buffer pointer is invalid. - * If config is invalid. - * - */ - void Serialize(const Node & root, const char * filename, const SerializeConfig & config = {2, 64, false, false}); - void Serialize(const Node & root, std::iostream & stream, const SerializeConfig & config = {2, 64, false, false}); - void Serialize(const Node & root, std::string & string, const SerializeConfig & config = {2, 64, false, false}); - -} +/** +* @breif Exception class. +* +*/ +class Exception : public std::runtime_error { + public: + /** + * @breif Enumeration of exception types. + * + */ + enum eType { + InternalError, ///< Internal error. + ParsingError, ///< Invalid parsing data. + OperationError ///< User operation error. + }; + + /** + * @breif Constructor. + * + * @param message Exception message. + * @param type Type of exception. + * + */ + Exception(const std::string & message, const eType type); + + /** + * @breif Get type of exception. + * + */ + eType Type() const; + + /** + * @breif Get message of exception. + * + */ + const char * Message() const; + + private: + eType m_Type; ///< Type of exception. +}; + +/** +* @breif Internal exception class. +* +* @see Exception +* +*/ +class InternalException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit InternalException(const std::string & message); +}; + +/** +* @breif Parsing exception class. +* +* @see Exception +* +*/ +class ParsingException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit ParsingException(const std::string & message); +}; + +/** +* @breif Operation exception class. +* +* @see Exception +* +*/ +class OperationException : public Exception { + public: + /** + * @breif Constructor. + * + * @param message Exception message. + * + */ + explicit OperationException(const std::string & message); +}; + +/** +* @breif Iterator class. +* +*/ +class Iterator { + public: + friend class Node; + + /** + * @breif Default constructor. + * + */ + Iterator(); + + /** + * @breif Copy constructor. + * + */ + Iterator(const Iterator & it); + + /** + * @breif Assignment operator. + * + */ + Iterator & operator = (const Iterator & it); + + /** + * @breif Destructor. + * + */ + ~Iterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + Iterator & operator++ (int); + + /** + * @breif Post-decrement operator. + * + */ + Iterator & operator-- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const Iterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const Iterator & it); + + private: + enum eType { + None, + SequenceType, + MapType + }; + + eType m_Type; // Type of iterator. + void * m_pImp; // Implementation of iterator class. +}; + +/** +* @breif Constant iterator class. +* +*/ +class ConstIterator { + public: + friend class Node; + + /** + * @breif Default constructor. + * + */ + ConstIterator(); + + /** + * @breif Copy constructor. + * + */ + ConstIterator(const ConstIterator & it); + + /** + * @breif Assignment operator. + * + */ + ConstIterator & operator = (const ConstIterator & it); + + /** + * @breif Destructor. + * + */ + ~ConstIterator(); + + /** + * @breif Get node of iterator. + * First pair item is the key of map value, empty if type is sequence. + * + */ + std::pair operator *(); + + /** + * @breif Post-increment operator. + * + */ + ConstIterator & operator++ (int); + + /** + * @breif Post-decrement operator. + * + */ + ConstIterator & operator-- (int); + + /** + * @breif Check if iterator is equal to other iterator. + * + */ + bool operator == (const ConstIterator & it); + + /** + * @breif Check if iterator is not equal to other iterator. + * + */ + bool operator != (const ConstIterator & it); + + private: + enum eType { + None, + SequenceType, + MapType + }; + + eType m_Type; // Type of iterator. + void * m_pImp; // Implementation of constant iterator class. +}; + +/** +* @breif Node class. +* +*/ +class Node { + public: + friend class Iterator; + + /** + * @breif Enumeration of node types. + * + */ + enum eType { + None, + SequenceType, + MapType, + ScalarType + }; + + /** + * @breif Default constructor. + * + */ + Node(); + + /** + * @breif Copy constructor. + * + */ + Node(const Node & node); + + /** + * @breif Assignment constructors. + * Converts node to scalar type if needed. + * + */ + explicit Node(const std::string & value); + explicit Node(const char * value); + + /** + * @breif Destructor. + * + */ + ~Node(); + + /** + * @breif Functions for checking type of node. + * + */ + eType Type() const; + bool IsNone() const; + bool IsSequence() const; + bool IsMap() const; + bool IsScalar() const; + + /** + * @breif Completely clear node. + * + */ + void Clear(); + + /** + * @breif Get node as given template type. + * + */ + template + T As() const { + return impl::StringConverter::Get(AsString()); + } + + /** + * @breif Get node as given template type. + * + */ + template + T As(const T & defaultValue) const { + return impl::StringConverter::Get(AsString(), defaultValue); + } + + /** + * @breif Get size of node. + * Nodes of type None or Scalar will return 0. + * + */ + size_t Size() const; + + // Sequence operators + + /** + * @breif Insert sequence item at given index. + * Converts node to sequence type if needed. + * Adding new item to end of sequence if index is larger than sequence size. + * + */ + Node & Insert(const size_t index); + + /** + * @breif Add new sequence index to back. + * Converts node to sequence type if needed. + * + */ + Node & PushFront(); + + /** + * @breif Add new sequence index to front. + * Converts node to sequence type if needed. + * + */ + Node & PushBack(); + + /** + * @breif Get sequence/map item. + * Converts node to sequence/map type if needed. + * + * @param index Sequence index. Returns None type Node if index is unknown. + * @param key Map key. Creates a new node if key is unknown. + * + */ + Node & operator[] (const size_t index); + Node & operator[] (const std::string & key); + + /** + * @breif Erase item. + * No action if node is not a sequence or map. + * + */ + void Erase(const size_t index); + void Erase(const std::string & key); + + /** + * @breif Assignment operators. + * + */ + Node & operator = (const Node & node); + Node & operator = (const std::string & value); + Node & operator = (const char * value); + + /** + * @breif Get start iterator. + * + */ + Iterator Begin(); + ConstIterator Begin() const; + + /** + * @breif Get end iterator. + * + */ + Iterator End(); + ConstIterator End() const; + + private: + /** + * @breif Get as string. If type is scalar, else empty. + * + */ + const std::string & AsString() const; + + void * m_pImp; // Implementation of node class. +}; + + +/** +* @breif Parsing functions. +* Population given root node with deserialized data. +* +* @param root Root node to populate. +* @param filename Path of input file. +* @param stream Input stream. +* @param string String of input data. +* @param buffer Char array of input data. +* @param size Buffer size. +* +* @throw InternalException An internal error occurred. +* @throw ParsingException Invalid input YAML data. +* @throw OperationException If filename or buffer pointer is invalid. +* +*/ +void Parse(Node & root, const char * filename); // NOLINT +void Parse(Node & root, std::iostream & stream); // NOLINT +void Parse(Node & root, const std::string & string); // NOLINT +void Parse(Node & root, const char * buffer, const size_t size); // NOLINT + + +/** +* @breif Serialization configuration structure, +* describing output behavior. +* +*/ +struct SerializeConfig { + /** + * @breif Constructor. + * + * @param spaceIndentation Number of spaces per indentation. + * @param scalarMaxLength Maximum length of scalars. Serialized as folder scalars if exceeded. + * Ignored if equal to 0. + * @param sequenceMapNewline Put maps on a new line if parent node is a sequence. + * @param mapScalarNewline Put scalars on a new line if parent node is a map. + * + */ + SerializeConfig(const size_t spaceIndentation = 2, + const size_t scalarMaxLength = 64, + const bool sequenceMapNewline = false, + const bool mapScalarNewline = false); + + size_t SpaceIndentation; // Number of spaces per indentation. + // Maximum length of scalars. Serialized as folder scalars if exceeded. + size_t ScalarMaxLength; + // Put maps on a new line if parent node is a sequence. + bool SequenceMapNewline; + // Put scalars on a new line if parent node is a map. + bool MapScalarNewline; +}; + + +/** +* @breif Serialization functions. +* +* @param root Root node to serialize. +* @param filename Path of output file. +* @param stream Output stream. +* @param string String of output data. +* @param config Serialization configurations. +* +* @throw InternalException An internal error occurred. +* @throw OperationException If filename or buffer pointer is invalid. +* If config is invalid. +* +*/ +void Serialize( + const Node & root, const char * filename, + const SerializeConfig & config = {2, 64, false, false}); +void Serialize( + const Node & root, std::iostream & stream, // NOLINT + const SerializeConfig & config = {2, 64, false, false}); +void Serialize( + const Node & root, std::string & string, // NOLINT + const SerializeConfig & config = {2, 64, false, false}); + +} // namespace Yaml diff --git a/runtime/core/websocket/batch_connection_handler.h b/runtime/core/websocket/batch_connection_handler.h index fe4197c26..f6b6225a2 100644 --- a/runtime/core/websocket/batch_connection_handler.h +++ b/runtime/core/websocket/batch_connection_handler.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "boost/asio/connect.hpp" #include "boost/asio/ip/tcp.hpp" diff --git a/runtime/core/websocket/websocket_server.cc b/runtime/core/websocket/websocket_server.cc index ed7e1c211..708d93623 100644 --- a/runtime/core/websocket/websocket_server.cc +++ b/runtime/core/websocket/websocket_server.cc @@ -14,12 +14,12 @@ // limitations under the License. #include "websocket/websocket_server.h" -#include "websocket/batch_connection_handler.h" #include #include #include +#include "websocket/batch_connection_handler.h" #include "boost/json/src.hpp" #include "utils/log.h" diff --git a/runtime/core/websocket/websocket_server.h b/runtime/core/websocket/websocket_server.h index e211faf9c..5714f3251 100644 --- a/runtime/core/websocket/websocket_server.h +++ b/runtime/core/websocket/websocket_server.h @@ -85,7 +85,7 @@ class WebSocketServer { decode_config_(std::move(decode_config)), decode_resource_(std::move(decode_resource)) {} - void Start(bool run_batch=false); + void Start(bool run_batch = false); private: int port_; From cd85c840311e5b9b48d6e20fc51275c2df532596 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 4 Nov 2022 11:15:57 +0800 Subject: [PATCH 57/62] fix flake8 error --- wenet/bin/export_onnx_gpu.py | 1 - wenet/transformer/asr_model.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 77f278336..de2beb88c 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -222,7 +222,6 @@ def forward(self, r_decoder_out: B x beam x T2 x V best_index: B """ - print('self.reverse_weight ', self.reverse_weight, 'self.ctc_weight ', self.ctc_weight) B, T, F = encoder_out.shape bz = self.beam_size B2 = B * bz diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 3b55dfeaa..e3dbbcfba 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -917,8 +917,7 @@ def batch_forward_encoder( beam_log_probs_idx: B x T x beam_size """ encoder_out, encoder_mask = self.encoder( - speech, - speech_lengths, -1, -1) + speech, speech_lengths, -1, -1) encoder_out_lens = encoder_mask.squeeze(1).sum(1) encoder_out_lens = encoder_out_lens.int() ctc_log_probs = self.ctc.log_softmax(encoder_out) From e8a0a2480adf35f5288b2e947c0a7b2ec221ff9e Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 4 Nov 2022 11:28:07 +0800 Subject: [PATCH 58/62] pytorch version back to 1.10.0 --- runtime/core/cmake/libtorch.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/core/cmake/libtorch.cmake b/runtime/core/cmake/libtorch.cmake index bd4a9248f..3d178c859 100644 --- a/runtime/core/cmake/libtorch.cmake +++ b/runtime/core/cmake/libtorch.cmake @@ -1,6 +1,6 @@ if(TORCH) if(NOT ANDROID) - set(PYTORCH_VERSION "1.12.0") + set(PYTORCH_VERSION "1.10.0") if(GPU) add_definitions(-DUSE_GPU) set(CUDA_NAME "cu113") From 1e2af87bcfb6fee06aa0bf8c10df7f4b4a394634 Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 4 Nov 2022 11:57:46 +0800 Subject: [PATCH 59/62] change reference to pointer of non-const object --- runtime/core/decoder/batch_torch_asr_model.cc | 28 +++++++++---------- runtime/core/decoder/batch_torch_asr_model.h | 6 ++-- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index e6f13e928..58dea24b1 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -100,8 +100,8 @@ std::shared_ptr BatchTorchAsrModel::Copy() const { void BatchTorchAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) { + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { // 1. Prepare libtorch required data int batch_size = batch_feats.size(); int num_frames = batch_feats[0].size(); @@ -136,23 +136,23 @@ void BatchTorchAsrModel::ForwardEncoder( auto topk_scores = outputs[3].toTensor().to(at::kCPU); int num_outputs = topk_scores.size(1); int output_dim = topk_scores.size(2); - batch_topk_scores.resize(batch_size); + batch_topk_scores->resize(batch_size); for (size_t i = 0; i < batch_size; i++) { - batch_topk_scores[i].resize(num_outputs); + (*batch_topk_scores)[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; j++) { - batch_topk_scores[i][j].resize(output_dim); - memcpy(batch_topk_scores[i][j].data(), topk_scores[i][j].data_ptr(), + (*batch_topk_scores)[i][j].resize(output_dim); + memcpy((*batch_topk_scores)[i][j].data(), topk_scores[i][j].data_ptr(), sizeof(float) * output_dim); } } // copy topk_indexes auto topk_indexes = outputs[4].toTensor().to(at::kCPU); - batch_topk_indexs.resize(batch_size); + batch_topk_indexs->resize(batch_size); for (size_t i = 0; i < batch_size; ++i) { - batch_topk_indexs[i].resize(num_outputs); + (*batch_topk_indexs)[i].resize(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { - batch_topk_indexs[i][j].resize(output_dim); - memcpy(batch_topk_indexs[i][j].data(), topk_indexes[i][j].data_ptr(), + (*batch_topk_indexs)[i][j].resize(output_dim); + memcpy((*batch_topk_indexs)[i][j].data(), topk_indexes[i][j].data_ptr(), sizeof(int) * output_dim); } } @@ -161,7 +161,7 @@ void BatchTorchAsrModel::ForwardEncoder( void BatchTorchAsrModel::AttentionRescoring( const std::vector>>& batch_hyps, const std::vector>& ctc_scores, - std::vector>& attention_scores) { + std::vector>* attention_scores) { // Step 1: Prepare input for libtorch int batch_size = batch_hyps.size(); int beam_size = batch_hyps[0].size(); @@ -220,10 +220,10 @@ void BatchTorchAsrModel::AttentionRescoring( #ifdef USE_GPU c10::cuda::CUDACachingAllocator::emptyCache(); #endif - attention_scores.resize(batch_size); + attention_scores->resize(batch_size); for (size_t i = 0; i < batch_size; i++) { - attention_scores[i].resize(beam_size); - memcpy(attention_scores[i].data(), rescores[i].data_ptr(), + (*attention_scores)[i].resize(beam_size); + memcpy((*attention_scores)[i].data(), rescores[i].data_ptr(), sizeof(float) * beam_size); } } diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index 80daef9b8..5c74482e4 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -43,14 +43,14 @@ class BatchTorchAsrModel : public BatchAsrModel { void AttentionRescoring( const std::vector>>& batch_hyps, const std::vector>& ctc_scores, - std::vector>& attention_scores) override; + std::vector>* attention_scores) override; std::shared_ptr Copy() const override; void ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, - std::vector>>& batch_topk_scores, - std::vector>>& batch_topk_indexs) override; // NOLINT + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT private: std::shared_ptr model_ = nullptr; From acea6da6f0342bb289e9426803734ad405c4704d Mon Sep 17 00:00:00 2001 From: veelion Date: Fri, 4 Nov 2022 13:17:58 +0800 Subject: [PATCH 60/62] fix github action build error --- runtime/core/utils/Yaml.cpp | 38 ++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/runtime/core/utils/Yaml.cpp b/runtime/core/utils/Yaml.cpp index 4e5494183..77ed14ca5 100644 --- a/runtime/core/utils/Yaml.cpp +++ b/runtime/core/utils/Yaml.cpp @@ -47,25 +47,25 @@ namespace Yaml { class ReaderLine; // Exception message definitions. -static const char* g_ErrorInvalidCharacter = "Invalid character found."; -static const char* g_ErrorKeyMissing = "Missing key."; -static const char* g_ErrorKeyIncorrect = "Incorrect key."; -static const char* g_ErrorValueIncorrect = "Incorrect value."; -static const char* g_ErrorTabInOffset = "Tab found in offset."; -static const char* g_ErrorBlockSequenceNotAllowed = - "Sequence entries are not allowed in this context."; -static const char* g_ErrorUnexpectedDocumentEnd = - "Unexpected document end."; -static const char* g_ErrorDiffEntryNotAllowed = - "Different entry is not allowed in this context."; -static const char* g_ErrorIncorrectOffset = "Incorrect offset."; -static const char* g_ErrorSequenceError = "Error in sequence node."; -static const char* g_ErrorCannotOpenFile = "Cannot open file."; -static const char* g_ErrorIndentation = - "Space indentation is less than 2."; -static const char* g_ErrorInvalidBlockScalar = "Invalid block scalar."; -static const char* g_ErrorInvalidQuote = "Invalid quote."; -static const char* g_EmptyString = ""; +static const std::string g_ErrorInvalidCharacter = "Invalid character found."; // NOLINT +static const std::string g_ErrorKeyMissing = "Missing key."; // NOLINT +static const std::string g_ErrorKeyIncorrect = "Incorrect key."; // NOLINT +static const std::string g_ErrorValueIncorrect = "Incorrect value."; // NOLINT +static const std::string g_ErrorTabInOffset = "Tab found in offset."; // NOLINT +static const std::string g_ErrorBlockSequenceNotAllowed = // NOLINT + "Sequence entries are not allowed in this context."; // NOLINT +static const std::string g_ErrorUnexpectedDocumentEnd = // NOLINT + "Unexpected document end."; // NOLINT +static const std::string g_ErrorDiffEntryNotAllowed = // NOLINT + "Different entry is not allowed in this context."; // NOLINT +static const std::string g_ErrorIncorrectOffset = "Incorrect offset."; // NOLINT +static const std::string g_ErrorSequenceError = "Error in sequence node."; // NOLINT +static const std::string g_ErrorCannotOpenFile = "Cannot open file."; // NOLINT +static const std::string g_ErrorIndentation = // NOLINT + "Space indentation is less than 2."; // NOLINT +static const std::string g_ErrorInvalidBlockScalar = "Invalid block scalar."; // NOLINT +static const std::string g_ErrorInvalidQuote = "Invalid quote."; // NOLINT +static const std::string g_EmptyString = ""; // NOLINT static Yaml::Node g_NoneNode; // Global function definitions. Implemented at end of this source file. From 8fda52ae3d2e2282d30ca62abeb2bf743f5b5279 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 30 Nov 2022 15:57:13 +0800 Subject: [PATCH 61/62] supported GPU-compute feature(fbank) by kaldifeat --- runtime/core/bin/CMakeLists.txt | 2 +- runtime/core/cmake/kaldifeat.cmake | 32 +++++++++++ runtime/core/decoder/batch_asr_decoder.cc | 43 +++++++++++---- runtime/core/decoder/batch_asr_decoder.h | 7 +++ runtime/core/decoder/batch_asr_model.h | 6 ++ runtime/core/decoder/batch_torch_asr_model.cc | 55 +++++++++++++++++++ runtime/core/decoder/batch_torch_asr_model.h | 5 ++ runtime/core/websocket/CMakeLists.txt | 2 +- runtime/libtorch/CMakeLists.txt | 2 + 9 files changed, 142 insertions(+), 12 deletions(-) create mode 100644 runtime/core/cmake/kaldifeat.cmake diff --git a/runtime/core/bin/CMakeLists.txt b/runtime/core/bin/CMakeLists.txt index 727597de9..46252f254 100644 --- a/runtime/core/bin/CMakeLists.txt +++ b/runtime/core/bin/CMakeLists.txt @@ -2,7 +2,7 @@ add_executable(decoder_main decoder_main.cc) target_link_libraries(decoder_main PUBLIC decoder) add_executable(decoder_main_batch decoder_main_batch.cc) -target_link_libraries(decoder_main_batch PUBLIC decoder) +target_link_libraries(decoder_main_batch PUBLIC decoder kaldifeat_core) add_executable(label_checker_main label_checker_main.cc) target_link_libraries(label_checker_main PUBLIC decoder) diff --git a/runtime/core/cmake/kaldifeat.cmake b/runtime/core/cmake/kaldifeat.cmake new file mode 100644 index 000000000..544826a84 --- /dev/null +++ b/runtime/core/cmake/kaldifeat.cmake @@ -0,0 +1,32 @@ +# Copyright 2022 veelion (veelion@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(GPU) + set(kaldifeat_URL "https://github.com/csukuangfj/kaldifeat/archive/refs/tags/v1.21.zip") + set(kaldifeat_HASH "SHA256=10652d930dee12d71d04da3f5b3b1bd618fa2f1af6723eb0e70d7267bfa57fe1") + set(kaldifeat_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(kaldifeat_BUILD_PYMODULE OFF CACHE BOOL "" FORCE) + set(PYTHON_EXECUTABLE "python") + list(REMOVE_AT CMAKE_MODULE_PATH 0) # hide wenet's cmake/xx.cmake from kaldifeat's + + FetchContent_Declare(kaldifeat + URL ${kaldifeat_URL} + URL_HASH ${kaldifeat_HASH} + ) + FetchContent_MakeAvailable(kaldifeat) + include_directories( + ${kaldifeat_SOURCE_DIR} + ) + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) # use wenet's cmake/xx.cmake +endif() diff --git a/runtime/core/decoder/batch_asr_decoder.cc b/runtime/core/decoder/batch_asr_decoder.cc index d2a6137a7..892328d39 100644 --- a/runtime/core/decoder/batch_asr_decoder.cc +++ b/runtime/core/decoder/batch_asr_decoder.cc @@ -35,6 +35,7 @@ BatchAsrDecoder::BatchAsrDecoder(std::shared_ptr config, beam_size_(opts.ctc_prefix_search_opts.first_beam_size), fbank_(config->num_bins, config->sample_rate, config->frame_length, config->frame_shift), + fbank_cuda_(config->num_bins, config->sample_rate), model_(resource->batch_model->Copy()), post_processor_(resource->post_processor), symbol_table_(resource->symbol_table), @@ -109,11 +110,13 @@ void BatchAsrDecoder::FbankWorker(const std::vector& wav, int index) { << ", takes " << timer.Elapsed() << " ms."; } -void BatchAsrDecoder::Decode(const std::vector>& wavs) { - // 1. calc fbank feature of the batch of wavs +void BatchAsrDecoder::ComputeFeatureCpu( + const std::vector>& wavs, + batch_feature_t* feats, + std::vector* feats_lens) { Timer timer; - batch_feature_t batch_feats; - std::vector batch_feats_lens; + batch_feature_t& batch_feats = *feats; + std::vector& batch_feats_lens = *feats_lens; if (wavs.size() > 1) { std::vector fbank_threads; for (size_t i = 0; i < wavs.size(); i++) { @@ -155,15 +158,35 @@ void BatchAsrDecoder::Decode(const std::vector>& wavs) { } } VLOG(1) << "padding feautre takes " << timer.Elapsed() << " ms."; - } + }} - // 2. encoder forward - timer.Reset(); +void BatchAsrDecoder::Decode(const std::vector>& wavs) { + // 1. calc fbank feature of the batch of wavs std::vector>> batch_topk_scores; std::vector>> batch_topk_indexs; - model_->ForwardEncoder( - batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); - VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; + Timer timer; + bool gpu_feature = true; + if (gpu_feature) { + std::vector batch_feats_lens; + timer.Reset(); + auto batch_feats = fbank_cuda_.Compute(wavs, &batch_feats_lens); + VLOG(1) << "fbank_cuda_.Comput() takes " << timer.Elapsed() << " ms."; + timer.Reset(); + // 2. encoder forward + model_->ForwardEncoder( + batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); + VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; + + } else { + batch_feature_t batch_feats; + std::vector batch_feats_lens; + ComputeFeatureCpu(wavs, &batch_feats, &batch_feats_lens); + timer.Reset(); + // 2. encoder forward + model_->ForwardEncoder( + batch_feats, batch_feats_lens, &batch_topk_scores, &batch_topk_indexs); + VLOG(1) << "encoder forward takes " << timer.Elapsed() << " ms."; + } // 3. ctc search one by one of the batch // create batch of tct search result for attention decoding diff --git a/runtime/core/decoder/batch_asr_decoder.h b/runtime/core/decoder/batch_asr_decoder.h index 992b40061..a3ec6c3e9 100644 --- a/runtime/core/decoder/batch_asr_decoder.h +++ b/runtime/core/decoder/batch_asr_decoder.h @@ -37,6 +37,7 @@ #include "utils/utils.h" #include "frontend/fbank.h" #include "utils/json.h" +#include "frontend/fbank_cuda.h" namespace wenet { @@ -64,6 +65,12 @@ class BatchAsrDecoder { private: Fbank fbank_; + FbankCuda fbank_cuda_; + + void ComputeFeatureCpu( + const std::vector>& wavs, + batch_feature_t* batch_feats, + std::vector* batch_feats_lens); void FbankWorker(const std::vector& wav, int index); std::vector> batch_feats_; // for FbankWorker std::vector> batch_feats_lens_; // for FbankWorker diff --git a/runtime/core/decoder/batch_asr_model.h b/runtime/core/decoder/batch_asr_model.h index 53f686f3a..073d821a4 100644 --- a/runtime/core/decoder/batch_asr_model.h +++ b/runtime/core/decoder/batch_asr_model.h @@ -10,6 +10,7 @@ #include #include +#include "torch/torch.h" #include "utils/timer.h" #include "utils/utils.h" @@ -35,6 +36,11 @@ class BatchAsrModel { const std::vector& batch_feats_lens, std::vector>>* batch_topk_scores, std::vector>>* batch_topk_indexs) = 0; + virtual void ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) {}; virtual void AttentionRescoring( const std::vector>>& batch_hyps, diff --git a/runtime/core/decoder/batch_torch_asr_model.cc b/runtime/core/decoder/batch_torch_asr_model.cc index 58dea24b1..18a00dbf8 100644 --- a/runtime/core/decoder/batch_torch_asr_model.cc +++ b/runtime/core/decoder/batch_torch_asr_model.cc @@ -97,6 +97,59 @@ std::shared_ptr BatchTorchAsrModel::Copy() const { return asr_model; } +void BatchTorchAsrModel::ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) { + // 1. Prepare libtorch required data + int batch_size = batch_feats_lens.size(); + torch::Tensor feats_lens = + torch::from_blob(const_cast(batch_feats_lens.data()), + {batch_size}, torch::kInt).clone(); + // Note: math.log(1e-10) is -23.025850929940457 + auto feats = torch::nn::utils::rnn::pad_sequence(batch_feats, true, + -23.025850929940457f); + + // 2. Encoder batch forward + feats = feats.to(device_); + feats_lens = feats_lens.to(device_); + torch::NoGradGuard no_grad; + std::vector inputs = {feats, feats_lens}; + + auto outputs = + model_->get_method("batch_forward_encoder")(inputs).toTuple()->elements(); + VLOG(1) << "batch_forward_encoder done"; + CHECK_EQ(outputs.size(), 5); + encoder_out_ = outputs[0].toTensor(); // (B, Tmax, dim) + encoder_lens_ = outputs[1].toTensor(); // (B,) + + // Copy topk_scores + auto topk_scores = outputs[3].toTensor().to(at::kCPU); + int num_outputs = topk_scores.size(1); + int output_dim = topk_scores.size(2); + batch_topk_scores->resize(batch_size); + for (size_t i = 0; i < batch_size; i++) { + (*batch_topk_scores)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; j++) { + (*batch_topk_scores)[i][j].resize(output_dim); + memcpy((*batch_topk_scores)[i][j].data(), topk_scores[i][j].data_ptr(), + sizeof(float) * output_dim); + } + } + // copy topk_indexes + auto topk_indexes = outputs[4].toTensor().to(at::kCPU); + batch_topk_indexs->resize(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + (*batch_topk_indexs)[i].resize(num_outputs); + for (size_t j = 0; j < num_outputs; ++j) { + (*batch_topk_indexs)[i][j].resize(output_dim); + memcpy((*batch_topk_indexs)[i][j].data(), topk_indexes[i][j].data_ptr(), + sizeof(int) * output_dim); + } + } +} + void BatchTorchAsrModel::ForwardEncoder( const batch_feature_t& batch_feats, const std::vector& batch_feats_lens, @@ -106,6 +159,7 @@ void BatchTorchAsrModel::ForwardEncoder( int batch_size = batch_feats.size(); int num_frames = batch_feats[0].size(); const int feature_dim = batch_feats[0][0].size(); + Timer timer; torch::Tensor feats = torch::zeros({batch_size, num_frames, feature_dim}, torch::kFloat); for (size_t i = 0; i < batch_size; ++i) { @@ -116,6 +170,7 @@ void BatchTorchAsrModel::ForwardEncoder( feats[i][j] = std::move(row); } } + VLOG(1) << "feature to Tensor takes " << timer.Elapsed() << " ms."; torch::Tensor feats_lens = torch::from_blob(const_cast(batch_feats_lens.data()), {batch_size}, torch::kInt).clone(); diff --git a/runtime/core/decoder/batch_torch_asr_model.h b/runtime/core/decoder/batch_torch_asr_model.h index 5c74482e4..c1b4a6d74 100644 --- a/runtime/core/decoder/batch_torch_asr_model.h +++ b/runtime/core/decoder/batch_torch_asr_model.h @@ -51,6 +51,11 @@ class BatchTorchAsrModel : public BatchAsrModel { const std::vector& batch_feats_lens, std::vector>>* batch_topk_scores, std::vector>>* batch_topk_indexs) override; // NOLINT + void ForwardEncoder( + const std::vector& batch_feats, + const std::vector& batch_feats_lens, + std::vector>>* batch_topk_scores, + std::vector>>* batch_topk_indexs) override; // NOLINT private: std::shared_ptr model_ = nullptr; diff --git a/runtime/core/websocket/CMakeLists.txt b/runtime/core/websocket/CMakeLists.txt index 67447c42d..451866710 100644 --- a/runtime/core/websocket/CMakeLists.txt +++ b/runtime/core/websocket/CMakeLists.txt @@ -2,4 +2,4 @@ add_library(websocket STATIC websocket_client.cc websocket_server.cc ) -target_link_libraries(websocket PUBLIC decoder) +target_link_libraries(websocket PUBLIC decoder kaldifeat_core) diff --git a/runtime/libtorch/CMakeLists.txt b/runtime/libtorch/CMakeLists.txt index a02f37ac7..f2dee2023 100644 --- a/runtime/libtorch/CMakeLists.txt +++ b/runtime/libtorch/CMakeLists.txt @@ -33,6 +33,8 @@ endif() # Include all dependency if(TORCH) + include(kaldifeat) + include(FetchContent) # use wenet's, disable kaldifeat's custom: cmake/Modules/FetchContent include(libtorch) endif() if(ONNX) From e0b4e4279d7f676d7032a5010ac13090fe3c5ff7 Mon Sep 17 00:00:00 2001 From: veelion Date: Wed, 30 Nov 2022 15:58:16 +0800 Subject: [PATCH 62/62] add fbank_cuda.h --- runtime/core/frontend/fbank_cuda.h | 62 ++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 runtime/core/frontend/fbank_cuda.h diff --git a/runtime/core/frontend/fbank_cuda.h b/runtime/core/frontend/fbank_cuda.h new file mode 100644 index 000000000..3b7df06f7 --- /dev/null +++ b/runtime/core/frontend/fbank_cuda.h @@ -0,0 +1,62 @@ + +#ifndef FRONTEND_FBANK_CUDA_H_ +#define FRONTEND_FBANK_CUDA_H_ + +#include "kaldifeat/csrc/feature-fbank.h" + +namespace wenet { + +class FbankCuda { + public: + FbankCuda(int num_bins, int sample_rate) { + fbank_opts_.mel_opts.num_bins = num_bins; + fbank_opts_.frame_opts.samp_freq = sample_rate; + fbank_opts_.frame_opts.dither = 0; + fbank_opts_.frame_opts.frame_shift_ms = 10.0; + fbank_opts_.frame_opts.frame_length_ms = 25.0; + fbank_opts_.device = torch::Device(torch::kCUDA, 0); + fbank_ = std::make_shared(fbank_opts_); + device_ = torch::kCUDA; + } + + torch::Tensor Compute(torch::Tensor wave_data) { + return fbank_->ComputeFeatures(wave_data, 1.0f); + } + + std::vector Compute( + const std::vector> &wave_data, + std::vector *num_frames) { + const auto &frame_opts = fbank_->GetOptions().frame_opts; + std::vector num_frames_vec; + num_frames_vec.reserve(wave_data.size()); + + std::vector strided_vec; + strided_vec.reserve(wave_data.size()); + + for (const auto &w : wave_data) { + torch::Tensor t = torch::from_blob( + const_cast(w.data()), + {static_cast(w.size())}, torch::kFloat).to(device_); + // t = t / 32768.0; + torch::Tensor strided = kaldifeat::GetStrided(t, frame_opts); + num_frames_vec.push_back(strided.size(0)); + num_frames->push_back(strided.size(0)); + strided_vec.emplace_back(std::move(strided)); + } + + torch::Tensor strided = torch::cat(strided_vec, 0); + torch::Tensor features = fbank_->ComputeFeatures(strided, /*vtln_warp*/ 1.0f); + auto ans = features.split_with_sizes(num_frames_vec, /*dim*/ 0); + return ans; + } + + private: + kaldifeat::FbankOptions fbank_opts_; + std::shared_ptr fbank_; + torch::DeviceType device_; + +}; + +} // namespace wenet + +#endif // FRONTEND_FBANK_CUDA_H_