From 3d60e3e6b1c4a3602db88e1c2bbe0f67b186ff5f Mon Sep 17 00:00:00 2001 From: Tilman Kerl Date: Sat, 16 Dec 2023 19:13:34 +0100 Subject: [PATCH] fix param passing for payload --- chat_doc/inference/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chat_doc/inference/chat.py b/chat_doc/inference/chat.py index b4c41c8..90c3780 100644 --- a/chat_doc/inference/chat.py +++ b/chat_doc/inference/chat.py @@ -44,6 +44,7 @@ def _postprocess(self, prediction: str) -> str: cleaned_pred = prediction.split("###")[1].split("\n")[1].strip() # replace all hyperlinks with "Llama Hospital" using regex + # regex from https://gist.github.com/gruber/8891611 cleaned_pred = re.sub( r'(?i)\b((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’])|(?:(? str: def predict(self, input_text: str, history: str = "", qa=False) -> str: prompt = self.template.create_prompt(input_text=input_text, history=history) - prediction = self.model.predict(self._payload(prompt), qa=qa)[0]["generated_text"] + prediction = self.model.predict(self._payload(prompt, qa=qa))[0]["generated_text"] if qa: return self._postprocess_qa(prediction)