diff --git a/DOTS/feat.py b/DOTS/feat.py index a4bf03a..12e831c 100644 --- a/DOTS/feat.py +++ b/DOTS/feat.py @@ -111,3 +111,57 @@ def g_feat(text, top_k=3, n_topics=42): top_3_indices return df2, top_3_indices + + +def gliner_feat(text, hits,labels=['disaster','weather',"avalanche", + "biological","blizzard", + "chemical","contamination","cyber", + "drought", + "earthquake","explosion", + "fire","flood", + "heat","hurricane", + "infrastructure", + "landslide", + "nuclear", + "pandemic","power", + "radiological","riot", + "sabotage", + "terror","tornado","transport","tsunami", + "volcano"]): + from gliner import GLiNER + model = GLiNER.from_pretrained("urchade/gliner_base") + + df = pd.DataFrame(columns=['Title','URL', 'Text', 'Label', 'Date','Location']) + + for hit,article in zip(hh[:100],text[:100]): + entities = model.predict_entities(''.join(article), labels) + for entity in entities: + row = pd.DataFrame({'Title': [hit['_source']['metadata']['title']], + 'URL': [hit['_source']['metadata']['link']], + 'Text': [entity["text"]], + 'Label': [entity["label"]]}) + df = pd.concat([df, row], ignore_index=True) + df2 = df.groupby(['Title','URL','Label']).agg({'Text': ' '.join}).reset_index() + df_pivot = df2.pivot(index=['Title','URL'], columns='Label', values='Text') + df_pivot.reset_index(inplace=True) + df_pivot = df_pivot.dropna(subset=df_pivot.columns[2:],how='all') + return df_pivot + +def count_gliner(df,label='disaster'): + from collections import Counter + lst = df.Text[df.Label=='earthquake'] + counter = Counter(lst) + sorted_lst = sorted(lst, key=lambda x: -counter[x]) + return pd.unique(sorted_lst) + +def gpy_gliner(df_pivot): + g=graphistry.nodes(df_pivot.drop(['URL','Title'],axis=1)) + g2 = g.umap() # df_pivot.drop(['URL','Title'],axis=1),**topic_model) + g2 = g2.dbscan() # min_dist=1, min_samples=3) + # g3 = g2.transform_dbscan(df_pivot.drop(['URL','Title'],axis=1),return_graph=False) + df2=pd.DataFrame(g2.get_matrix()) + + max_index_per_row = df2.idxmax(axis=1) + top_3_indices = max_index_per_row.value_counts().index[:10] + top_3_indices + return top_3_indices, g2, df2 diff --git a/demo/pipe_g_feat.ipynb b/demo/pipe_g_feat.ipynb index 2f10a88..c7e9afc 100644 --- a/demo/pipe_g_feat.ipynb +++ b/demo/pipe_g_feat.ipynb @@ -40,23 +40,6 @@ "from DOTS.ingestion_utils import safe_iter_pull, iter_pull, reduce_newlines, scrape_selenium_headless" ] }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n", - "\n", - "from langchain_community.document_loaders import TextLoader\n", - "from langchain_community.vectorstores import OpenSearchVectorSearch\n", - "from langchain_openai import OpenAIEmbeddings\n", - "from langchain_text_splitters import CharacterTextSplitter" - ] - }, { "cell_type": "code", "execution_count": 4, @@ -139,7 +122,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## GLiNER" + "## GLiNER featurize" + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "cannot import name 'gliner_feat' from 'DOTS.feat' (/Users/apple/WRK/dcolinmorgan/dots/DOTS/feat.py)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[220], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mDOTS\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mfeat\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m gliner_feat\n", + "\u001b[0;31mImportError\u001b[0m: cannot import name 'gliner_feat' from 'DOTS.feat' (/Users/apple/WRK/dcolinmorgan/dots/DOTS/feat.py)" + ] + } + ], + "source": [ + "from DOTS.feat import gliner_feat" ] }, { @@ -187,21 +191,37 @@ }, { "cell_type": "code", - "execution_count": 153, + "execution_count": 186, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1464\n" + "252\n" ] } ], "source": [ + "%%time\n", "import pandas as pd\n", - "\n", - "labels=['disaster', 'flood', 'weather', 'hurricane', 'fire', 'date', 'location']#, 'politician', 'corporation', 'government_agency', 'military']\n", + "# labels https://github.com/graphistry/graphistrygpt/blob/4b761af233c52109fb6bca2aec4e54d8c0ffa4ad/graphistrygpt/plugins/dt/models.py#L108\n", + "# labels=['disaster', 'flood', 'weather', 'hurricane', 'fire', 'date', 'location']#, 'politician', 'corporation', 'government_agency', 'military']\n", + "labels=[ \"avalanche\",\n", + " \"biological\",\"blizzard\",\n", + " \"chemical\",\"contamination\",\"cyber\",\n", + " \"drought\",\n", + " \"earthquake\",\"explosion\",\n", + " \"fire\",\"flood\",\n", + " \"heat\",\"hurricane\",\n", + " \"infrastructure\",\n", + " \"landslide\",\n", + " \"nuclear\",\n", + " \"pandemic\",\"power\",\n", + " \"radiological\",\"riot\",\n", + " \"sabotage\",\n", + " \"terror\",\"tornado\",\"transport\",\"tsunami\",\n", + " \"volcano\"]\n", "df = pd.DataFrame(columns=['Title','URL', 'Text', 'Label', 'Date','Location'])\n", "\n", "for hit,article in zip(hh[:100],text[:100]):\n", @@ -217,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 155, + "execution_count": 187, "metadata": {}, "outputs": [], "source": [ @@ -226,26 +246,403 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 196, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['biological', 'cyber', 'earthquake', 'explosion', 'fire', 'flood',\n", + " 'heat', 'infrastructure', 'landslide', 'power', 'riot', 'terror',\n", + " 'tornado', 'transport'],\n", + " dtype='object', name='Label')" + ] + }, + "execution_count": 196, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_pivot.columns[2:]" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
LabelTitleURLbiologicalcyberearthquakeexplosionfirefloodheatinfrastructurelandslidepowerriotterrortornadotransport
01 dead, 2 children wounded after 13 y/o goes o...https://news.google.com/rss/articles/CBMiX2h0d...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNmass shooting mass shootingNaNambulances
11 dead, 2 injured in fire at northwest Indiana...https://news.google.com/rss/articles/CBMiogFod...NaNbrowserNaNNaNfireNaNNaNMount Zion Suburban ApartmentsNaNNaNNaNNaNNaNNaN
210th suspect held in Moscow terror attack - Th...https://news.google.com/rss/articles/CBMieWh0d...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNmassacreNaNNaNNaN
32 people rescued from Humber River following l...https://news.google.com/rss/articles/CBMibmh0d...NaNNaNNaNNaNNaNNaNNaNHumber River culvertslandslide landslide landslideNaNNaNNaNNaNtruck car
422 terrorists arrested during 232 IBOs conduct...https://news.google.com/rss/articles/CBMiYWh0d...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNAl-QaedaNaNNaN
...................................................
64Unexploded WW2 bomb detonated on Jersey's east...https://news.google.com/rss/articles/CBMiOmh0d...NaNNaNNaNcontrolled explosion Guatemala landfill fireRussian strikes Guatemala landfill fireNaNNaNSeymour Tower Eiffel TowerNaNNaNNaNNaNNaNNaN
65Utah Man Arrested for Assaulting Officers Duri...https://news.google.com/rss/articles/CBMieWh0d...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaNcivil disorderNaNNaNNaN
66Video shows flames shooting from Philadelphia ...https://news.google.com/rss/articles/CBMifmh0d...NaNNaNNaNNaNfire fire fireNaNNaNNaNNaNNaNNaNNaNNaNNaN
67Video: Storm expected to bring heavy snow, gus...https://news.google.com/rss/articles/CBMiUGh0d...NaNNaNNaNNaNNaNNaNNaNNaNNaNpower outagesNaNNaNNaNNaN
68WFFD respond to truck fire, Old Iowa Park Road...https://news.google.com/rss/articles/CBMiXWh0d...NaNNaNNaNNaNtruck fireNaNNaNgas meterNaNNaNNaNNaNNaNNaN
\n", + "

69 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + "Label Title \\\n", + "0 1 dead, 2 children wounded after 13 y/o goes o... \n", + "1 1 dead, 2 injured in fire at northwest Indiana... \n", + "2 10th suspect held in Moscow terror attack - Th... \n", + "3 2 people rescued from Humber River following l... \n", + "4 22 terrorists arrested during 232 IBOs conduct... \n", + ".. ... \n", + "64 Unexploded WW2 bomb detonated on Jersey's east... \n", + "65 Utah Man Arrested for Assaulting Officers Duri... \n", + "66 Video shows flames shooting from Philadelphia ... \n", + "67 Video: Storm expected to bring heavy snow, gus... \n", + "68 WFFD respond to truck fire, Old Iowa Park Road... \n", + "\n", + "Label URL biological cyber \\\n", + "0 https://news.google.com/rss/articles/CBMiX2h0d... NaN NaN \n", + "1 https://news.google.com/rss/articles/CBMiogFod... NaN browser \n", + "2 https://news.google.com/rss/articles/CBMieWh0d... NaN NaN \n", + "3 https://news.google.com/rss/articles/CBMibmh0d... NaN NaN \n", + "4 https://news.google.com/rss/articles/CBMiYWh0d... NaN NaN \n", + ".. ... ... ... \n", + "64 https://news.google.com/rss/articles/CBMiOmh0d... NaN NaN \n", + "65 https://news.google.com/rss/articles/CBMieWh0d... NaN NaN \n", + "66 https://news.google.com/rss/articles/CBMifmh0d... NaN NaN \n", + "67 https://news.google.com/rss/articles/CBMiUGh0d... NaN NaN \n", + "68 https://news.google.com/rss/articles/CBMiXWh0d... NaN NaN \n", + "\n", + "Label earthquake explosion \\\n", + "0 NaN NaN \n", + "1 NaN NaN \n", + "2 NaN NaN \n", + "3 NaN NaN \n", + "4 NaN NaN \n", + ".. ... ... \n", + "64 NaN controlled explosion Guatemala landfill fire \n", + "65 NaN NaN \n", + "66 NaN NaN \n", + "67 NaN NaN \n", + "68 NaN NaN \n", + "\n", + "Label fire flood heat \\\n", + "0 NaN NaN NaN \n", + "1 fire NaN NaN \n", + "2 NaN NaN NaN \n", + "3 NaN NaN NaN \n", + "4 NaN NaN NaN \n", + ".. ... ... ... \n", + "64 Russian strikes Guatemala landfill fire NaN NaN \n", + "65 NaN NaN NaN \n", + "66 fire fire fire NaN NaN \n", + "67 NaN NaN NaN \n", + "68 truck fire NaN NaN \n", + "\n", + "Label infrastructure landslide \\\n", + "0 NaN NaN \n", + "1 Mount Zion Suburban Apartments NaN \n", + "2 NaN NaN \n", + "3 Humber River culverts landslide landslide landslide \n", + "4 NaN NaN \n", + ".. ... ... \n", + "64 Seymour Tower Eiffel Tower NaN \n", + "65 NaN NaN \n", + "66 NaN NaN \n", + "67 NaN NaN \n", + "68 gas meter NaN \n", + "\n", + "Label power riot terror tornado \\\n", + "0 NaN NaN mass shooting mass shooting NaN \n", + "1 NaN NaN NaN NaN \n", + "2 NaN massacre NaN NaN \n", + "3 NaN NaN NaN NaN \n", + "4 NaN NaN Al-Qaeda NaN \n", + ".. ... ... ... ... \n", + "64 NaN NaN NaN NaN \n", + "65 NaN civil disorder NaN NaN \n", + "66 NaN NaN NaN NaN \n", + "67 power outages NaN NaN NaN \n", + "68 NaN NaN NaN NaN \n", + "\n", + "Label transport \n", + "0 ambulances \n", + "1 NaN \n", + "2 NaN \n", + "3 truck car \n", + "4 NaN \n", + ".. ... \n", + "64 NaN \n", + "65 NaN \n", + "66 NaN \n", + "67 NaN \n", + "68 NaN \n", + "\n", + "[69 rows x 16 columns]" + ] + }, + "execution_count": 197, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df_pivot = df2.pivot(index=['Title','URL'], columns='Label', values='Text')\n", "df_pivot.reset_index(inplace=True)\n", - "df_pivot = df_pivot.dropna(subset=['disaster','fire','flood','weather'],how='all')\n", + "df_pivot = df_pivot.dropna(subset=df_pivot.columns[2:],how='all')\n", "df_pivot" ] }, { "cell_type": "code", - "execution_count": 182, + "execution_count": 198, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['landslide' 'World War Two' 'gas leak' 'famine' 'snow event' 'flooding']\n" + "['spring Chinook' 'smolts']\n" ] } ], @@ -254,7 +651,7 @@ "from collections import Counter\n", "\n", "# Your list\n", - "lst = df.Text[df.Label=='disaster']\n", + "lst = df.Text[df.Label=='biological']\n", "\n", "# Count the frequency of each element\n", "counter = Counter(lst)\n", @@ -267,31 +664,14 @@ }, { "cell_type": "code", - "execution_count": 185, + "execution_count": 199, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['snow' 'rain' 'tornadoes' 'showers' 'sunshine' 'wind' 'sleet' 'storm'\n", - " 'jet stream' 'winds' 'severe weather' 'straight-line winds' 'tornado'\n", - " 'wet snow' 'power outages' 'Wind gusts' 'wind gusts' 'gusty winds'\n", - " 'Polar air' 'storms' 'Rain' 'floodwaters' 'snow bomb' 'rainfall'\n", - " 'snowfall' 'flood waters' 'solar eclipse' 'tornado watches' 'heavy rain'\n", - " 'wet flakes' 'large hail' 'severe thunderstorms' 'cold front'\n", - " '60 mph will be the primary concern. However a tornado'\n", - " 'tornado warning\\xa0means residents should take shelter immediately because a tornado'\n", - " 'Tornadoes' 'EF2 tornado' 'straight line winds' 'EF-0 tornado' 'ENE wind'\n", - " 'wintry mix' 'Tropical air' 'Atlantic weather systems' 'sub-tropical air'\n", - " 'Jet stream' 'water tables' 'tropical air' 'flood water'\n", - " 'Wind gusts of 60-70+ mph will be likely. Tree damage' 'Wind Advisories'\n", - " 'rain showers' 'lightning' 'flooding' 'thunderstorms' 'hail'\n", - " 'Spring sunshine' 'wintry weather' 'Spring sunshineRead' 'temperatures'\n", - " 'windy' 'mild' 'Showers' 'geopolitical climate' 'Straight line winds'\n", - " 'clouds' 'S' 'ESE' 'WSW' 'water damage' 'ice'\n", - " 'wind gusts could hit 80 kilometres an hour.The rain'\n", - " 'daytime temperatures' 'icy']\n" + "['3.7 magnitude earthquake' '3.3 quake' '2.3-magnitude earthquake']\n" ] } ], @@ -300,7 +680,7 @@ "from collections import Counter\n", "\n", "# Your list\n", - "lst = df.Text[df.Label=='weather']\n", + "lst = df.Text[df.Label=='earthquake']\n", "\n", "# Count the frequency of each element\n", "counter = Counter(lst)\n", @@ -313,15 +693,16 @@ }, { "cell_type": "code", - "execution_count": 183, + "execution_count": 200, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['fire' 'six-alarm fire' 'Israeli fire' 'Fire' 'truck fire'\n", - " 'multi-alarm fire' 'brush fire']\n" + "['fire' 'six-alarm fire' 'Russian strikes' 'Guatemala landfill fire'\n", + " 'Israeli fire' 'Fire' 'flames' 'Blaze' 'truck fire' 'multi-alarm fire'\n", + " 'brush fire']\n" ] } ], @@ -343,7 +724,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 201, "metadata": {}, "outputs": [], "source": [ @@ -364,41 +745,38 @@ }, { "cell_type": "code", - "execution_count": 178, + "execution_count": 203, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "* Ignoring target column of shape (38, 0) in UMAP fit, as it is not one dimensional" + "* Ignoring target column of shape (69, 0) in UMAP fit, as it is not one dimensional" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1.07 s, sys: 240 ms, total: 1.31 s\n", - "Wall time: 1.15 s\n" + "CPU times: user 1.13 s, sys: 244 ms, total: 1.37 s\n", + "Wall time: 1.16 s\n" ] }, { "data": { "text/plain": [ - "Index(['disaster_nan', 'disaster_landslide flooding landslide landslide',\n", - " 'disaster_famine', 'disaster_snow event', 'disaster_World War Two',\n", - " 'disaster_gas leak'],\n", - " dtype='object')" + "Index(['biological_nan', 'cyber_nan'], dtype='object')" ] }, - "execution_count": 178, + "execution_count": 203, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", - "g=graphistry.nodes(df_pivot.drop(['URL','Title','date'],axis=1))\n", + "g=graphistry.nodes(df_pivot.drop(['URL','Title'],axis=1))\n", "g2 = g.umap()#df_pivot.drop(['URL','Title'],axis=1),**topic_model)\n", "g2 = g2.dbscan() #min_dist=1, min_samples=3)\n", "# g3 = g2.transform_dbscan(df_pivot.drop(['URL','Title'],axis=1),return_graph=False)\n", @@ -411,26 +789,26 @@ }, { "cell_type": "code", - "execution_count": 179, + "execution_count": 204, "metadata": {}, "outputs": [], "source": [ - "import seaborn as sns\n", - "import matplotlib.colors as mcolors\n", - "palette = sns.color_palette(\"hls\", 10)\n", - "hex_palette = [mcolors.rgb2hex(color) for color in palette]" + "# import seaborn as sns\n", + "# import matplotlib.colors as mcolors\n", + "# palette = sns.color_palette(\"hls\", 10)\n", + "# hex_palette = [mcolors.rgb2hex(color) for color in palette]" ] }, { "cell_type": "code", - "execution_count": 180, + "execution_count": 205, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", - " \n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 210, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "g2.plot()" ] @@ -572,34 +1089,52 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## featurize" + "## distilroberta featurize" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 212, "metadata": {}, "outputs": [], "source": [ "from DOTS.feat import featurize_stories\n", - "import graphistry\n", - "import umap" + "import graphistry" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 213, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['severe weather', 'falling rain', 'kindergarten']" + ] + }, + "execution_count": 213, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "featurize_stories(str(articles[711]), top_k = 3, max_len=512)" + "featurize_stories(str(articles[211]), top_k = 3, max_len=512)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 214, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 622/622 [16:14<00:00, 1.57s/it]\n" + ] + } + ], "source": [ "rank_articles=[]\n", "from tqdm import tqdm\n", @@ -615,16 +1150,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "os.getcwd()" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 215, "metadata": {}, "outputs": [], "source": [ @@ -633,7 +1159,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 216, "metadata": {}, "outputs": [], "source": [ @@ -653,7 +1179,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 217, "metadata": {}, "outputs": [], "source": [ @@ -663,9 +1189,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 218, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "* Ignoring target column of shape (539, 0) in UMAP fit, as it is not one dimensional" + ] + }, + { + "data": { + "text/plain": [ + "Index(['2: weather, severe, death', '2: thunderstorms, storms, widespread',\n", + " '2: belotserkovsky, impoverishment, urmărește',\n", + " '2: youssef, hussein, county', '0: firefighters, night, fires',\n", + " '2: dahdouh, pourahmadi, damage',\n", + " '2: firefighters, firefighter, daughters',\n", + " '2: afternoon, aftershocks, enter',\n", + " '2: authorities, mississippi, puppies',\n", + " '1: israeli, warplanes, rezidențiale'],\n", + " dtype='object')" + ] + }, + "execution_count": 218, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# import pandas as pd\n", "data=pd.DataFrame(flattened_list) # each ranked feature is a row\n", @@ -692,9 +1244,37 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 219, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 219, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "g22.plot() # .encode_point_color('_dbscan',palette=[\"hotpink\", \"dodgerblue\"],as_continuous=True).plot()" ] diff --git a/main.py b/main.py index e7c6286..44548dd 100644 --- a/main.py +++ b/main.py @@ -25,12 +25,12 @@ def _input(): # Setup argument parser parser = argparse.ArgumentParser(description='Process OS data for dynamic features.') parser.add_argument('-n', type=int, default=100, help='Number of data items to get') - parser.add_argument('-f', type=int, default=5, help='Number of features per item to get') + # parser.add_argument('-f', type=int, default=5, help='Number of features per item to get') parser.add_argument('-o', type=str, default='dots_feats.csv', help='Output file name') # parser.add_argument('-p', type=int, default=1, help='Parallelize requests') # parser.add_argument('-t', type=int, default=1, help='Scroll Timeout in minutes, if using "d=1" large data set') - parser.add_argument('-d', type=int, default=2, help='0 for OS, 1 for test_gnews, 2 for lobstr') - # parser.add_argument('-e', type=datetime, default=20231231, help='end date') + parser.add_argument('-d', type=int, default=1, help='0 for OS, 1 for test_gnews, 2 for lobstr') + parser.add_argument('-e', type=int, default=1, help='0 for distilroberta, 1 for pyg, 2 for gliner') args, unknown = parser.parse_known_args() return args @@ -81,7 +81,7 @@ def main(args): RR = dataloader else: RR = articles - if f == 0: + if args.e == 0: for j,i in tqdm(enumerate(RR), total=len(RR), desc="featurizing articles"): try: @@ -101,12 +101,16 @@ def main(args): with open('DOTS/output/full_'+dname+args.o, 'a', newline='') as file: writer = csv.writer(file) writer.writerows(rank_articles) - elif f == 1: + elif args.e == 1: df2, top_3_indices = g_feat(articles, top_k=3, n_topics=42) with open('DOTS/output/g_feats_'+dname+args.o, 'a', newline='') as file: writer = csv.writer(file) writer.writerows(top_3_indices) df2.to_csv('DOTS/output/g_full_'+dname+args.o) + elif args.e == 2: + df2 = g_feat(articles, hits) + df2.to_csv('DOTS/output/gliner_full_'+dname+args.o,sep='\t') + # flattened_list = [item for sublist in rank_articles for item in sublist] # import pandas as pd diff --git a/setup.py b/setup.py index a0f8f08..aace10a 100755 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ 'selenium', 'webdriver_manager,' 'undetected_chromedriver', + 'gliner', ]