Skip to content

Commit

Permalink
fix semaphore feedback call (#238)
Browse files Browse the repository at this point in the history
- it should store new tags

SDCP-765
  • Loading branch information
petrjasek authored May 14, 2024
1 parent 3e71603 commit 0cbf813
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 29 deletions.
109 changes: 80 additions & 29 deletions server/cp/ai/semaphore.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import logging
import requests
import xml.etree.ElementTree as ET
from superdesk.text_checkers.ai.base import AIServiceBase
import traceback
import superdesk
import json
from typing import Any, Dict, List, Mapping, Optional, TypedDict, Union
from typing import (
Dict,
List,
Literal,
Mapping,
Optional,
TypedDict,
Union,
overload,
)


logger = logging.getLogger(__name__)
Expand All @@ -18,13 +26,32 @@
ResponseType = Mapping[str, Union[str, List[str]]]


class RequestType(TypedDict, total=False):
class SearchData(TypedDict):
searchString: str
data: Any


class OperationRequest(TypedDict):
item: RequestType
class Item(TypedDict):
guid: str
abstract: str
body_html: str
headline: str
language: str
slugline: str


class Tag(TypedDict):
altids: Dict[str, str]
description: str
name: str
original_source: str
qcode: str
scheme: str
source: str


class FeedbackData(TypedDict):
item: Item
tags: Dict[str, List[Tag]]


class Semaphore(AIServiceBase):
Expand All @@ -37,34 +64,35 @@ class Semaphore(AIServiceBase):
name = "semaphore"
label = "Semaphore autotagging service"

def __init__(self, data):
def __init__(self, app):
# SEMAPHORE_BASE_URL OR TOKEN_ENDPOINT Goes Here
self.base_url = os.getenv("SEMAPHORE_BASE_URL")
self.base_url = app.config.get("SEMAPHORE_BASE_URL")

# SEMAPHORE_ANALYZE_URL Goes Here
self.analyze_url = os.getenv("SEMAPHORE_ANALYZE_URL")
self.analyze_url = app.config.get("SEMAPHORE_ANALYZE_URL")

# SEMAPHORE_API_KEY Goes Here
self.api_key = os.getenv("SEMAPHORE_API_KEY")
self.api_key = app.config.get("SEMAPHORE_API_KEY")

# SEMAPHORE_SEARCH_URL Goes Here
self.search_url = os.getenv("SEMAPHORE_SEARCH_URL")
self.search_url = app.config.get("SEMAPHORE_SEARCH_URL")

# SEMAPHORE_GET_PARENT_URL Goes Here
self.get_parent_url = os.getenv("SEMAPHORE_GET_PARENT_URL")
self.get_parent_url = app.config.get("SEMAPHORE_GET_PARENT_URL")

# SEMAPHORE_CREATE_TAG_URL Goes Here
self.create_tag_url = os.getenv("SEMAPHORE_CREATE_TAG_URL")
self.create_tag_url = app.config.get("SEMAPHORE_CREATE_TAG_URL")

# SEMAPHORE_CREATE_TAG_TASK Goes Here
self.create_tag_task = os.getenv("SEMAPHORE_CREATE_TAG_TASK")
self.create_tag_task = app.config.get("SEMAPHORE_CREATE_TAG_TASK")

# SEMAPHORE_CREATE_TAG_QUERY Goes Here
self.create_tag_query = os.getenv("SEMAPHORE_CREATE_TAG_QUERY")
self.create_tag_query = app.config.get("SEMAPHORE_CREATE_TAG_QUERY")

def get_access_token(self):
"""Get access token for Semaphore."""
url = self.base_url
assert url

payload = f"grant_type=apikey&key={self.api_key}"
headers = {"Content-Type": "application/x-www-form-urlencoded"}
Expand Down Expand Up @@ -104,15 +132,15 @@ def fetch_parent_info(self, qcode):
return []

# Analyze2 changed name to analyze_parent_info
def analyze_parent_info(self, html_content: RequestType) -> ResponseType:
def analyze_parent_info(self, data: SearchData) -> ResponseType:
try:
if not self.base_url or not self.api_key:
logger.warning(
"Semaphore Search is not configured properly, can't analyze content"
)
return {}

query = html_content["searchString"]
query = data["searchString"]

new_url = self.search_url + query + ".json"

Expand Down Expand Up @@ -266,7 +294,7 @@ def convert_to_desired_format(input_data):
)
return {}

def create_tag_in_semaphore(self, html_content: RequestType) -> ResponseType:
def create_tag_in_semaphore(self, data: FeedbackData) -> ResponseType:
result_summary: Dict[str, List[str]] = {
"created_tags": [],
"failed_tags": [],
Expand All @@ -293,7 +321,7 @@ def create_tag_in_semaphore(self, html_content: RequestType) -> ResponseType:
"Content-Type": "application/ld+json",
}

manual_tags = extract_manual_tags(html_content["data"])
manual_tags = extract_manual_tags(data)

for item in manual_tags:
# print(item)
Expand Down Expand Up @@ -356,20 +384,38 @@ def create_tag_in_semaphore(self, html_content: RequestType) -> ResponseType:

return result_summary

@overload
def data_operation( # noqa: E704
self,
verb: str,
operation: Literal["feedback"],
name: Optional[str],
data: FeedbackData,
) -> ResponseType: ...

@overload
def data_operation( # noqa: E704
self,
verb: str,
operation: Literal["search"],
name: Optional[str],
data: SearchData,
) -> ResponseType: ...

def data_operation(
self,
verb: str,
operation: str,
operation: Literal["search", "feedback"],
name: Optional[str],
data: OperationRequest,
data,
) -> ResponseType:
if operation == "feedback":
return self.analyze(data["item"])
return self.create_tag_in_semaphore(data)
if operation == "search":
return self.search(data)
return {}

def search(self, data) -> ResponseType:
def search(self, data: SearchData) -> ResponseType:
try:
print(
"----------------------------------------------------------------------"
Expand All @@ -394,7 +440,7 @@ def search(self, data) -> ResponseType:
pass
return {}

def analyze(self, html_content: RequestType, tags=None) -> ResponseType:
def analyze(self, item: Item, tags=None) -> ResponseType:
try:
if not self.base_url or not self.api_key:
logger.warning(
Expand All @@ -403,7 +449,7 @@ def analyze(self, html_content: RequestType, tags=None) -> ResponseType:
return {}

# Convert HTML to XML
xml_payload = self.html_to_xml(html_content)
xml_payload = self.html_to_xml(item)

payload = {"XML_INPUT": xml_payload}

Expand All @@ -417,6 +463,7 @@ def analyze(self, html_content: RequestType, tags=None) -> ResponseType:
except Exception as e:
traceback.print_exc()
logger.error(f"An error occurred while making the request: {str(e)}")
raise

root = response.text

Expand Down Expand Up @@ -564,7 +611,7 @@ def adjust_score(score, existing_scores):
logger.error(f"An error occurred. We are in analyze exception: {str(e)}")
return {}

def html_to_xml(self, html_content) -> str:
def html_to_xml(self, html_content: Item) -> str:
def clean_html_content(input_str):
# Remove full HTML tags using regular expressions
your_string = input_str.replace("<p>", "")
Expand Down Expand Up @@ -605,8 +652,8 @@ def clean_html_content(input_str):
return xml_output


def extract_manual_tags(data):
manual_tags = []
def extract_manual_tags(data: FeedbackData) -> List[Tag]:
manual_tags: List[Tag] = []

if "tags" in data:
# Loop through each tag type (like 'subject', 'person', etc.)
Expand Down Expand Up @@ -642,7 +689,11 @@ def replace_qcodes(output_data):
)

# Create a mapping from semaphore_id to qcode
semaphore_to_qcode = {item["semaphore_id"]: item["qcode"] for item in cv["items"]}
semaphore_to_qcode = {
item["semaphore_id"]: item["qcode"]
for item in cv["items"]
if item.get("semaphore_id")
}

# Define a function to replace qcodes in a given list
def replace_in_list(data_list):
Expand Down
9 changes: 9 additions & 0 deletions server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,12 @@
ONCLUSIVE_SERVER_TIMEZONE = os.environ.get("ONCLUSIVE_SERVER_TIMEZONE", "Europe/London")

PLANNING_JSON_ASSIGNED_INFO_EXTENDED = True

SEMAPHORE_BASE_URL = os.getenv("SEMAPHORE_BASE_URL")
SEMAPHORE_ANALYZE_URL = os.getenv("SEMAPHORE_ANALYZE_URL")
SEMAPHORE_API_KEY = os.getenv("SEMAPHORE_API_KEY")
SEMAPHORE_SEARCH_URL = os.getenv("SEMAPHORE_SEARCH_URL")
SEMAPHORE_GET_PARENT_URL = os.getenv("SEMAPHORE_GET_PARENT_URL")
SEMAPHORE_CREATE_TAG_URL = os.getenv("SEMAPHORE_CREATE_TAG_URL")
SEMAPHORE_CREATE_TAG_TASK = os.getenv("SEMAPHORE_CREATE_TAG_TASK")
SEMAPHORE_CREATE_TAG_QUERY = os.getenv("SEMAPHORE_CREATE_TAG_QUERY")

0 comments on commit 0cbf813

Please sign in to comment.