Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
bgajdero committed Mar 13, 2024
1 parent 888e1af commit 3e7e338
Show file tree
Hide file tree
Showing 28 changed files with 1,607 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
21 changes: 21 additions & 0 deletions PyQAWrapper.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: PyQAWrapper
channels:
- conda-forge
- anaconda
- defaults
dependencies:
- python=3.9
- nltk
- tqdm
- pandas==1.4.2
- pip
- pip:
- torch==2.0.*
- transformers==4.34.0
- pyyaml
- unidecode
- SPARQLWrapper==2.0.0
- python-dateutil
- Werkzeug==2.3.7
- requests
- requests_toolbelt
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
# QAWrapper
Question and Answer Wrapper

## Setup
`conda env create -f PyQAWrapper.yml`

## To run example
`conda activate PyQAWrapper`

`python -m src.qa0`
`python -m src.qa1`

41 changes: 41 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "QAWrapper"
version = "0.0.1"
description = "Python wrapper for QuestionAnswer Interface"
authors = [
{ name = "Bart Gajderowicz", email = "[email protected]" },
]
requires-python = ">=3.9"
readme = "README.md"


[project.urls]
repository = "https://github.com/csse-uoft/QAWrapper"

[tool.setuptools.packages.find]
where = ["src"]
include = [".*"]
exclude = ["tests"]
namespaces = false


[tool.poetry]
keywords = ['python', 'qa', 'bert']
classifiers= [
"Development Status :: 1 - Planning",
"Topic :: Database",
"Intended Audience :: Education",
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.9",
"Environment :: Console",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS :: MacOS X",
"Operating System :: Microsoft :: Windows",
]

5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import setuptools

if __name__ == "__main__":
setuptools.setup()

Binary file added src/.DS_Store
Binary file not shown.
Empty file added src/__init__.py
Empty file.
Binary file added src/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added src/__pycache__/qa.cpython-39.pyc
Binary file not shown.
Binary file added src/__pycache__/qa0.cpython-39.pyc
Binary file not shown.
Binary file added src/__pycache__/qa1.cpython-39.pyc
Binary file not shown.
Binary file added src/qa/.DS_Store
Binary file not shown.
Empty file added src/qa/__init__.py
Empty file.
Binary file added src/qa/__pycache__/__init__.cpython-39.pyc
Binary file not shown.
Binary file added src/qa/__pycache__/qa.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions src/qa/aggregate_scores/layer0_partial2_aggregate.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc
program_name,44,37,157,1424,0.543209877,0.218905473,0.312056738,0.883273165,0.293140598
client,66,15,426,1148,0.814814815,0.134146341,0.230366492,0.733534743,0.256864741
need_satisfier,56,24,619,954,0.7,0.082962963,0.148344371,0.611010284,0.133812428
outcome,39,41,172,1415,0.4875,0.184834123,0.268041237,0.872225555,0.243717974
catchment_area,20,60,66,1542,0.25,0.23255814,0.240963855,0.92535545,0.201911366
6 changes: 6 additions & 0 deletions src/qa/aggregate_scores/layer1_partial2_aggregate.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc
program name,33,47,154,1436,0.4125,0.17647058823529413,0.24719101123595508,0.8796407185628743,0.21377189897555318
client,53,28,436,1148,0.654320987654321,0.1083844580777096,0.18596491228070175,0.7213213213213213,0.1790521999048717
need satisfier,57,24,619,953,0.7037037037037037,0.08431952662721894,0.15059445178335534,0.6110102843315185,0.13608889790392295
need outcome,40,41,171,1414,0.49382716049382713,0.1895734597156398,0.273972602739726,0.8727490996398559,0.24957726830930263
catchment area,15,66,71,1542,0.18518518518518517,0.1744186046511628,0.17964071856287422,0.9191263282172373,0.13721550003321117
21 changes: 21 additions & 0 deletions src/qa/aggregate_scores/layer1_partial2_aggregate_comb.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc
catchment_area_program_name,22,55,166,1606,0.285714286,0.117021277,0.166037736,0.880475933,0.126938534
need_satisfier_program_name,30,49,160,1600,0.379746835,0.157894737,0.223048327,0.886351278,0.192414013
outlier_program_name,30,48,158,1598,0.384615385,0.159574468,0.22556391,0.887677208,0.196021955
client_program_name,34,45,153,1593,0.430379747,0.181818182,0.255639098,0.891506849,0.230004585
catchment_area_client,48,31,439,1587,0.607594937,0.098562628,0.169611307,0.77672209,0.176180698
outlier_client,38,41,450,1598,0.481012658,0.077868852,0.134038801,0.769158439,0.117515576
need_satisfier_client,52,24,439,1587,0.684210526,0.105906314,0.183421517,0.779733587,0.206274278
program_name_client,51,30,436,1584,0.62962963,0.104722793,0.179577465,0.778200857,0.188790392
client_need_satisfier,58,22,616,1573,0.725,0.086053412,0.153846154,0.718818863,0.179036769
program_name_need_satisfier,59,22,616,1573,0.728395062,0.087407407,0.156084656,0.718942731,0.181396561
catchment_area_need_satisfier,51,28,623,1580,0.64556962,0.075667656,0.135458167,0.714723926,0.145372848
outlier_need_satisfier,44,36,627,1584,0.55,0.06557377,0.117177097,0.710606722,0.1074692
catchment_area_outlier,28,50,184,1601,0.358974359,0.132075472,0.193103448,0.874396135,0.16139258
need_satisfier_outlier,36,44,174,1591,0.45,0.171428571,0.248275862,0.881842818,0.225356913
program_name_outlier,35,45,176,1593,0.4375,0.165876777,0.240549828,0.880475933,0.216291121
client_outlier,27,52,182,1599,0.341772152,0.129186603,0.1875,0.874193548,0.152986703
need_satisfier_catchment_area,9,71,76,1618,0.1125,0.105882353,0.109090909,0.917136415,0.06571339
client_catchment_area,13,67,74,1616,0.1625,0.149425287,0.155688623,0.920338983,0.114073563
outlier_catchment_area,5,75,80,1622,0.0625,0.058823529,0.060606061,0.91301908,0.01505592
program_name_catchment_area,16,65,70,1612,0.197530864,0.186046512,0.191616766,0.923425978,0.151538904
156 changes: 156 additions & 0 deletions src/qa/qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
if torch.backends.mps.is_available():
device_str = 'mps'
mps_device = torch.device(device_str)
x = torch.ones(1, device=mps_device)
print("MPS found", x)
else:
device_str = 'cpu'
print ("MPS device not found.")
print("Device found: ", device_str)
from transformers import pipeline, AutoModelForQuestionAnswering
from abc import ABC


class QAModel:
"""
A class that initiates a transformer pipeline and deepset roBERTa squad2 Q&A model.
...
Attributes
----------
model_name : str
The name of the QA model
nlp : transformers.pipeline
A transformer pipeline from the given model name
model: transformers.AutoModelForQuestionAnswering
A transformer qa model with the given model name
"""
model_name = "deepset/roberta-base-squad2"
# model_name = "deepset/deberta-v3-large-squad2"
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device=device_str)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
# DebertaForQuestionAnswering



class QA(ABC):
"""
An abstract class that defines default use cases for the QA knowledge source.
...
Attributes
----------
correct_entity_list : list
list of entity names that are correct
context : str
text to be the context of QA input
entity : str
type of entity that questions should target
QAs : NoneType, dict
formatted questions and answers
qa_info : dict
top QA results
display_info : dict
QA results formatted for user display
Methods
----------
set_questions()
Set questions on QAs from the question template
set_answers()
Run the model and add the outputted information to QAs
rank_answers()
Sort QAs by ranking, where the default is by score
top_answers()
Add the top ranked answer for each entity to qa_info
run_qa()
Do all from setting the question to selecting the top answers
get_display_info()
Format top answers to be displayed to users
fetch_info()
Fetch all original records of this Q&A
"""

AGG_PARTIAL = None # overridden by qa0 and qa1
ENTITIES = [] # overridden by qa0 and qa1
def __init__(self, context, entity):
"""Instantiate only if the entity is valid. Raise exception otherwise.
:param context: text to be the context of QA input
"""
if entity not in self.ENTITIES:
raise Exception(f"Invalid entity type: {entity}")
self.context = context
self.QAs = None
self.qa_info = {}
self.display_info = {}
self.entity = entity

def set_questions(self):
"""
Set questions on QAs from the question template
"""
raise NotImplementedError

def set_answers(self, ner):
"""
Run the model and add the outputted information to QAs
"""
raise NotImplementedError

def get_agg_score(self, entity):
raise NotImplementedError


def rank_answers(self):
"""
Sort QAs by ranking, where the default is by score
"""
for entity in self.QAs:
self.QAs[entity] = sorted(self.QAs[entity], key=lambda d: d['score'])

def top_answers(self):
"""Add the top ranked answer for each entity to qa_info
:return: top answers
"""
self.rank_answers()
self.qa_info = {key: self.QAs[key][0] for key in self.QAs}
return self.qa_info

def run_qa(self, ner=False):
"""Do all from setting the question to selecting the top answers
:return: top answers
"""
print('in run_qa 1', ner)
self.set_questions()
print('in run_qa 2', ner)
self.set_answers(ner)
print('in run_qa 3', ner)
self.qa_info = self.top_answers()
print('in run_qa 4', ner, self.qa_info)
return self.qa_info

def get_display_info(self):
"""Format top answers to be displayed to users
:return: formatted QA outputs to be displayed to users
"""
display_info = {}
for entity in self.qa_info:
display_info[entity] = self.qa_info[entity]["answer"]
self.display_info = display_info
return display_info

def fetch_info(self):
"""Fetch all original records of this Q&A
:return: dictionary of the context and QAs
"""
return {"context": self.context, "QAs": self.QAs}
Loading

0 comments on commit 3e7e338

Please sign in to comment.