Skip to content

Commit

Permalink
full comp demo
Browse files Browse the repository at this point in the history
  • Loading branch information
dcolinmorgan committed Apr 12, 2024
1 parent 36e5f91 commit 7080061
Show file tree
Hide file tree
Showing 4 changed files with 770 additions and 131 deletions.
54 changes: 54 additions & 0 deletions DOTS/feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,57 @@ def g_feat(text, top_k=3, n_topics=42):
top_3_indices

return df2, top_3_indices


def gliner_feat(text, hits,labels=['disaster','weather',"avalanche",
"biological","blizzard",
"chemical","contamination","cyber",
"drought",
"earthquake","explosion",
"fire","flood",
"heat","hurricane",
"infrastructure",
"landslide",
"nuclear",
"pandemic","power",
"radiological","riot",
"sabotage",
"terror","tornado","transport","tsunami",
"volcano"]):
from gliner import GLiNER
model = GLiNER.from_pretrained("urchade/gliner_base")

df = pd.DataFrame(columns=['Title','URL', 'Text', 'Label', 'Date','Location'])

for hit,article in zip(hh[:100],text[:100]):
entities = model.predict_entities(''.join(article), labels)
for entity in entities:
row = pd.DataFrame({'Title': [hit['_source']['metadata']['title']],
'URL': [hit['_source']['metadata']['link']],
'Text': [entity["text"]],
'Label': [entity["label"]]})
df = pd.concat([df, row], ignore_index=True)
df2 = df.groupby(['Title','URL','Label']).agg({'Text': ' '.join}).reset_index()
df_pivot = df2.pivot(index=['Title','URL'], columns='Label', values='Text')
df_pivot.reset_index(inplace=True)
df_pivot = df_pivot.dropna(subset=df_pivot.columns[2:],how='all')
return df_pivot

def count_gliner(df,label='disaster'):
from collections import Counter
lst = df.Text[df.Label=='earthquake']
counter = Counter(lst)
sorted_lst = sorted(lst, key=lambda x: -counter[x])
return pd.unique(sorted_lst)

def gpy_gliner(df_pivot):
g=graphistry.nodes(df_pivot.drop(['URL','Title'],axis=1))
g2 = g.umap() # df_pivot.drop(['URL','Title'],axis=1),**topic_model)
g2 = g2.dbscan() # min_dist=1, min_samples=3)
# g3 = g2.transform_dbscan(df_pivot.drop(['URL','Title'],axis=1),return_graph=False)
df2=pd.DataFrame(g2.get_matrix())

max_index_per_row = df2.idxmax(axis=1)
top_3_indices = max_index_per_row.value_counts().index[:10]
top_3_indices
return top_3_indices, g2, df2
Loading

0 comments on commit 7080061

Please sign in to comment.