-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
1,607 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
name: PyQAWrapper | ||
channels: | ||
- conda-forge | ||
- anaconda | ||
- defaults | ||
dependencies: | ||
- python=3.9 | ||
- nltk | ||
- tqdm | ||
- pandas==1.4.2 | ||
- pip | ||
- pip: | ||
- torch==2.0.* | ||
- transformers==4.34.0 | ||
- pyyaml | ||
- unidecode | ||
- SPARQLWrapper==2.0.0 | ||
- python-dateutil | ||
- Werkzeug==2.3.7 | ||
- requests | ||
- requests_toolbelt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,12 @@ | ||
# QAWrapper | ||
Question and Answer Wrapper | ||
|
||
## Setup | ||
`conda env create -f PyQAWrapper.yml` | ||
|
||
## To run example | ||
`conda activate PyQAWrapper` | ||
|
||
`python -m src.qa0` | ||
`python -m src.qa1` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "QAWrapper" | ||
version = "0.0.1" | ||
description = "Python wrapper for QuestionAnswer Interface" | ||
authors = [ | ||
{ name = "Bart Gajderowicz", email = "[email protected]" }, | ||
] | ||
requires-python = ">=3.9" | ||
readme = "README.md" | ||
|
||
|
||
[project.urls] | ||
repository = "https://github.com/csse-uoft/QAWrapper" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
include = [".*"] | ||
exclude = ["tests"] | ||
namespaces = false | ||
|
||
|
||
[tool.poetry] | ||
keywords = ['python', 'qa', 'bert'] | ||
classifiers= [ | ||
"Development Status :: 1 - Planning", | ||
"Topic :: Database", | ||
"Intended Audience :: Education", | ||
"Intended Audience :: Developers", | ||
"Intended Audience :: Information Technology", | ||
"Intended Audience :: Science/Research", | ||
"Programming Language :: Python :: 3.9", | ||
"Environment :: Console", | ||
"Operating System :: POSIX :: Linux", | ||
"Operating System :: MacOS :: MacOS X", | ||
"Operating System :: Microsoft :: Windows", | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import setuptools | ||
|
||
if __name__ == "__main__": | ||
setuptools.setup() | ||
|
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc | ||
program_name,44,37,157,1424,0.543209877,0.218905473,0.312056738,0.883273165,0.293140598 | ||
client,66,15,426,1148,0.814814815,0.134146341,0.230366492,0.733534743,0.256864741 | ||
need_satisfier,56,24,619,954,0.7,0.082962963,0.148344371,0.611010284,0.133812428 | ||
outcome,39,41,172,1415,0.4875,0.184834123,0.268041237,0.872225555,0.243717974 | ||
catchment_area,20,60,66,1542,0.25,0.23255814,0.240963855,0.92535545,0.201911366 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc | ||
program name,33,47,154,1436,0.4125,0.17647058823529413,0.24719101123595508,0.8796407185628743,0.21377189897555318 | ||
client,53,28,436,1148,0.654320987654321,0.1083844580777096,0.18596491228070175,0.7213213213213213,0.1790521999048717 | ||
need satisfier,57,24,619,953,0.7037037037037037,0.08431952662721894,0.15059445178335534,0.6110102843315185,0.13608889790392295 | ||
need outcome,40,41,171,1414,0.49382716049382713,0.1895734597156398,0.273972602739726,0.8727490996398559,0.24957726830930263 | ||
catchment area,15,66,71,1542,0.18518518518518517,0.1744186046511628,0.17964071856287422,0.9191263282172373,0.13721550003321117 |
21 changes: 21 additions & 0 deletions
21
src/qa/aggregate_scores/layer1_partial2_aggregate_comb.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
,TP,FP,FN,TN,precision,recall,f1,accuracy,mcc | ||
catchment_area_program_name,22,55,166,1606,0.285714286,0.117021277,0.166037736,0.880475933,0.126938534 | ||
need_satisfier_program_name,30,49,160,1600,0.379746835,0.157894737,0.223048327,0.886351278,0.192414013 | ||
outlier_program_name,30,48,158,1598,0.384615385,0.159574468,0.22556391,0.887677208,0.196021955 | ||
client_program_name,34,45,153,1593,0.430379747,0.181818182,0.255639098,0.891506849,0.230004585 | ||
catchment_area_client,48,31,439,1587,0.607594937,0.098562628,0.169611307,0.77672209,0.176180698 | ||
outlier_client,38,41,450,1598,0.481012658,0.077868852,0.134038801,0.769158439,0.117515576 | ||
need_satisfier_client,52,24,439,1587,0.684210526,0.105906314,0.183421517,0.779733587,0.206274278 | ||
program_name_client,51,30,436,1584,0.62962963,0.104722793,0.179577465,0.778200857,0.188790392 | ||
client_need_satisfier,58,22,616,1573,0.725,0.086053412,0.153846154,0.718818863,0.179036769 | ||
program_name_need_satisfier,59,22,616,1573,0.728395062,0.087407407,0.156084656,0.718942731,0.181396561 | ||
catchment_area_need_satisfier,51,28,623,1580,0.64556962,0.075667656,0.135458167,0.714723926,0.145372848 | ||
outlier_need_satisfier,44,36,627,1584,0.55,0.06557377,0.117177097,0.710606722,0.1074692 | ||
catchment_area_outlier,28,50,184,1601,0.358974359,0.132075472,0.193103448,0.874396135,0.16139258 | ||
need_satisfier_outlier,36,44,174,1591,0.45,0.171428571,0.248275862,0.881842818,0.225356913 | ||
program_name_outlier,35,45,176,1593,0.4375,0.165876777,0.240549828,0.880475933,0.216291121 | ||
client_outlier,27,52,182,1599,0.341772152,0.129186603,0.1875,0.874193548,0.152986703 | ||
need_satisfier_catchment_area,9,71,76,1618,0.1125,0.105882353,0.109090909,0.917136415,0.06571339 | ||
client_catchment_area,13,67,74,1616,0.1625,0.149425287,0.155688623,0.920338983,0.114073563 | ||
outlier_catchment_area,5,75,80,1622,0.0625,0.058823529,0.060606061,0.91301908,0.01505592 | ||
program_name_catchment_area,16,65,70,1612,0.197530864,0.186046512,0.191616766,0.923425978,0.151538904 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import torch | ||
if torch.backends.mps.is_available(): | ||
device_str = 'mps' | ||
mps_device = torch.device(device_str) | ||
x = torch.ones(1, device=mps_device) | ||
print("MPS found", x) | ||
else: | ||
device_str = 'cpu' | ||
print ("MPS device not found.") | ||
print("Device found: ", device_str) | ||
from transformers import pipeline, AutoModelForQuestionAnswering | ||
from abc import ABC | ||
|
||
|
||
class QAModel: | ||
""" | ||
A class that initiates a transformer pipeline and deepset roBERTa squad2 Q&A model. | ||
... | ||
Attributes | ||
---------- | ||
model_name : str | ||
The name of the QA model | ||
nlp : transformers.pipeline | ||
A transformer pipeline from the given model name | ||
model: transformers.AutoModelForQuestionAnswering | ||
A transformer qa model with the given model name | ||
""" | ||
model_name = "deepset/roberta-base-squad2" | ||
# model_name = "deepset/deberta-v3-large-squad2" | ||
nlp = pipeline('question-answering', model=model_name, tokenizer=model_name, device=device_str) | ||
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | ||
# DebertaForQuestionAnswering | ||
|
||
|
||
|
||
class QA(ABC): | ||
""" | ||
An abstract class that defines default use cases for the QA knowledge source. | ||
... | ||
Attributes | ||
---------- | ||
correct_entity_list : list | ||
list of entity names that are correct | ||
context : str | ||
text to be the context of QA input | ||
entity : str | ||
type of entity that questions should target | ||
QAs : NoneType, dict | ||
formatted questions and answers | ||
qa_info : dict | ||
top QA results | ||
display_info : dict | ||
QA results formatted for user display | ||
Methods | ||
---------- | ||
set_questions() | ||
Set questions on QAs from the question template | ||
set_answers() | ||
Run the model and add the outputted information to QAs | ||
rank_answers() | ||
Sort QAs by ranking, where the default is by score | ||
top_answers() | ||
Add the top ranked answer for each entity to qa_info | ||
run_qa() | ||
Do all from setting the question to selecting the top answers | ||
get_display_info() | ||
Format top answers to be displayed to users | ||
fetch_info() | ||
Fetch all original records of this Q&A | ||
""" | ||
|
||
AGG_PARTIAL = None # overridden by qa0 and qa1 | ||
ENTITIES = [] # overridden by qa0 and qa1 | ||
def __init__(self, context, entity): | ||
"""Instantiate only if the entity is valid. Raise exception otherwise. | ||
:param context: text to be the context of QA input | ||
""" | ||
if entity not in self.ENTITIES: | ||
raise Exception(f"Invalid entity type: {entity}") | ||
self.context = context | ||
self.QAs = None | ||
self.qa_info = {} | ||
self.display_info = {} | ||
self.entity = entity | ||
|
||
def set_questions(self): | ||
""" | ||
Set questions on QAs from the question template | ||
""" | ||
raise NotImplementedError | ||
|
||
def set_answers(self, ner): | ||
""" | ||
Run the model and add the outputted information to QAs | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_agg_score(self, entity): | ||
raise NotImplementedError | ||
|
||
|
||
def rank_answers(self): | ||
""" | ||
Sort QAs by ranking, where the default is by score | ||
""" | ||
for entity in self.QAs: | ||
self.QAs[entity] = sorted(self.QAs[entity], key=lambda d: d['score']) | ||
|
||
def top_answers(self): | ||
"""Add the top ranked answer for each entity to qa_info | ||
:return: top answers | ||
""" | ||
self.rank_answers() | ||
self.qa_info = {key: self.QAs[key][0] for key in self.QAs} | ||
return self.qa_info | ||
|
||
def run_qa(self, ner=False): | ||
"""Do all from setting the question to selecting the top answers | ||
:return: top answers | ||
""" | ||
print('in run_qa 1', ner) | ||
self.set_questions() | ||
print('in run_qa 2', ner) | ||
self.set_answers(ner) | ||
print('in run_qa 3', ner) | ||
self.qa_info = self.top_answers() | ||
print('in run_qa 4', ner, self.qa_info) | ||
return self.qa_info | ||
|
||
def get_display_info(self): | ||
"""Format top answers to be displayed to users | ||
:return: formatted QA outputs to be displayed to users | ||
""" | ||
display_info = {} | ||
for entity in self.qa_info: | ||
display_info[entity] = self.qa_info[entity]["answer"] | ||
self.display_info = display_info | ||
return display_info | ||
|
||
def fetch_info(self): | ||
"""Fetch all original records of this Q&A | ||
:return: dictionary of the context and QAs | ||
""" | ||
return {"context": self.context, "QAs": self.QAs} |
Oops, something went wrong.