diff --git a/DOTS/feat.py b/DOTS/feat.py index a4bf03a..12e831c 100644 --- a/DOTS/feat.py +++ b/DOTS/feat.py @@ -111,3 +111,57 @@ def g_feat(text, top_k=3, n_topics=42): top_3_indices return df2, top_3_indices + + +def gliner_feat(text, hits,labels=['disaster','weather',"avalanche", + "biological","blizzard", + "chemical","contamination","cyber", + "drought", + "earthquake","explosion", + "fire","flood", + "heat","hurricane", + "infrastructure", + "landslide", + "nuclear", + "pandemic","power", + "radiological","riot", + "sabotage", + "terror","tornado","transport","tsunami", + "volcano"]): + from gliner import GLiNER + model = GLiNER.from_pretrained("urchade/gliner_base") + + df = pd.DataFrame(columns=['Title','URL', 'Text', 'Label', 'Date','Location']) + + for hit,article in zip(hh[:100],text[:100]): + entities = model.predict_entities(''.join(article), labels) + for entity in entities: + row = pd.DataFrame({'Title': [hit['_source']['metadata']['title']], + 'URL': [hit['_source']['metadata']['link']], + 'Text': [entity["text"]], + 'Label': [entity["label"]]}) + df = pd.concat([df, row], ignore_index=True) + df2 = df.groupby(['Title','URL','Label']).agg({'Text': ' '.join}).reset_index() + df_pivot = df2.pivot(index=['Title','URL'], columns='Label', values='Text') + df_pivot.reset_index(inplace=True) + df_pivot = df_pivot.dropna(subset=df_pivot.columns[2:],how='all') + return df_pivot + +def count_gliner(df,label='disaster'): + from collections import Counter + lst = df.Text[df.Label=='earthquake'] + counter = Counter(lst) + sorted_lst = sorted(lst, key=lambda x: -counter[x]) + return pd.unique(sorted_lst) + +def gpy_gliner(df_pivot): + g=graphistry.nodes(df_pivot.drop(['URL','Title'],axis=1)) + g2 = g.umap() # df_pivot.drop(['URL','Title'],axis=1),**topic_model) + g2 = g2.dbscan() # min_dist=1, min_samples=3) + # g3 = g2.transform_dbscan(df_pivot.drop(['URL','Title'],axis=1),return_graph=False) + df2=pd.DataFrame(g2.get_matrix()) + + max_index_per_row = df2.idxmax(axis=1) + top_3_indices = max_index_per_row.value_counts().index[:10] + top_3_indices + return top_3_indices, g2, df2 diff --git a/demo/pipe_g_feat.ipynb b/demo/pipe_g_feat.ipynb index 2f10a88..c7e9afc 100644 --- a/demo/pipe_g_feat.ipynb +++ b/demo/pipe_g_feat.ipynb @@ -40,23 +40,6 @@ "from DOTS.ingestion_utils import safe_iter_pull, iter_pull, reduce_newlines, scrape_selenium_headless" ] }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n", - "\n", - "from langchain_community.document_loaders import TextLoader\n", - "from langchain_community.vectorstores import OpenSearchVectorSearch\n", - "from langchain_openai import OpenAIEmbeddings\n", - "from langchain_text_splitters import CharacterTextSplitter" - ] - }, { "cell_type": "code", "execution_count": 4, @@ -139,7 +122,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## GLiNER" + "## GLiNER featurize" + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "cannot import name 'gliner_feat' from 'DOTS.feat' (/Users/apple/WRK/dcolinmorgan/dots/DOTS/feat.py)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[220], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mDOTS\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfeat\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m gliner_feat\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'gliner_feat' from 'DOTS.feat' (/Users/apple/WRK/dcolinmorgan/dots/DOTS/feat.py)" + ] + } + ], + "source": [ + "from DOTS.feat import gliner_feat" ] }, { @@ -187,21 +191,37 @@ }, { "cell_type": "code", - "execution_count": 153, + "execution_count": 186, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1464\n" + "252\n" ] } ], "source": [ + "%%time\n", "import pandas as pd\n", - "\n", - "labels=['disaster', 'flood', 'weather', 'hurricane', 'fire', 'date', 'location']#, 'politician', 'corporation', 'government_agency', 'military']\n", + "# labels https://github.com/graphistry/graphistrygpt/blob/4b761af233c52109fb6bca2aec4e54d8c0ffa4ad/graphistrygpt/plugins/dt/models.py#L108\n", + "# labels=['disaster', 'flood', 'weather', 'hurricane', 'fire', 'date', 'location']#, 'politician', 'corporation', 'government_agency', 'military']\n", + "labels=[ \"avalanche\",\n", + " \"biological\",\"blizzard\",\n", + " \"chemical\",\"contamination\",\"cyber\",\n", + " \"drought\",\n", + " \"earthquake\",\"explosion\",\n", + " \"fire\",\"flood\",\n", + " \"heat\",\"hurricane\",\n", + " \"infrastructure\",\n", + " \"landslide\",\n", + " \"nuclear\",\n", + " \"pandemic\",\"power\",\n", + " \"radiological\",\"riot\",\n", + " \"sabotage\",\n", + " \"terror\",\"tornado\",\"transport\",\"tsunami\",\n", + " \"volcano\"]\n", "df = pd.DataFrame(columns=['Title','URL', 'Text', 'Label', 'Date','Location'])\n", "\n", "for hit,article in zip(hh[:100],text[:100]):\n", @@ -217,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 187, "metadata": {}, "outputs": [], "source": [ @@ -226,26 +246,403 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 196, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['biological', 'cyber', 'earthquake', 'explosion', 'fire', 'flood',\n", + " 'heat', 'infrastructure', 'landslide', 'power', 'riot', 'terror',\n", + " 'tornado', 'transport'],\n", + " dtype='object', name='Label')" + ] + }, + "execution_count": 196, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_pivot.columns[2:]" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Label | \n", + "Title | \n", + "URL | \n", + "biological | \n", + "cyber | \n", + "earthquake | \n", + "explosion | \n", + "fire | \n", + "flood | \n", + "heat | \n", + "infrastructure | \n", + "landslide | \n", + "power | \n", + "riot | \n", + "terror | \n", + "tornado | \n", + "transport | \n", + "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "1 dead, 2 children wounded after 13 y/o goes o... | \n", + "https://news.google.com/rss/articles/CBMiX2h0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "mass shooting mass shooting | \n", + "NaN | \n", + "ambulances | \n", + "
1 | \n", + "1 dead, 2 injured in fire at northwest Indiana... | \n", + "https://news.google.com/rss/articles/CBMiogFod... | \n", + "NaN | \n", + "browser | \n", + "NaN | \n", + "NaN | \n", + "fire | \n", + "NaN | \n", + "NaN | \n", + "Mount Zion Suburban Apartments | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
2 | \n", + "10th suspect held in Moscow terror attack - Th... | \n", + "https://news.google.com/rss/articles/CBMieWh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "massacre | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
3 | \n", + "2 people rescued from Humber River following l... | \n", + "https://news.google.com/rss/articles/CBMibmh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "Humber River culverts | \n", + "landslide landslide landslide | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "truck car | \n", + "
4 | \n", + "22 terrorists arrested during 232 IBOs conduct... | \n", + "https://news.google.com/rss/articles/CBMiYWh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "Al-Qaeda | \n", + "NaN | \n", + "NaN | \n", + "
... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "... | \n", + "
64 | \n", + "Unexploded WW2 bomb detonated on Jersey's east... | \n", + "https://news.google.com/rss/articles/CBMiOmh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "controlled explosion Guatemala landfill fire | \n", + "Russian strikes Guatemala landfill fire | \n", + "NaN | \n", + "NaN | \n", + "Seymour Tower Eiffel Tower | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
65 | \n", + "Utah Man Arrested for Assaulting Officers Duri... | \n", + "https://news.google.com/rss/articles/CBMieWh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "civil disorder | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
66 | \n", + "Video shows flames shooting from Philadelphia ... | \n", + "https://news.google.com/rss/articles/CBMifmh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "fire fire fire | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
67 | \n", + "Video: Storm expected to bring heavy snow, gus... | \n", + "https://news.google.com/rss/articles/CBMiUGh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "power outages | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
68 | \n", + "WFFD respond to truck fire, Old Iowa Park Road... | \n", + "https://news.google.com/rss/articles/CBMiXWh0d... | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "truck fire | \n", + "NaN | \n", + "NaN | \n", + "gas meter | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "NaN | \n", + "
69 rows × 16 columns
\n", + "