diff --git a/convokit/coordination/coordination.py b/convokit/coordination/coordination.py index 1d524479..a181dc35 100644 --- a/convokit/coordination/coordination.py +++ b/convokit/coordination/coordination.py @@ -486,9 +486,10 @@ def _scores_over_utterances( target = utt1.speaker if speaker == target: continue - speaker, target = Coordination._annot_speaker( - speaker, utt2, split_by_attribs - ), Coordination._annot_speaker(target, utt1, split_by_attribs) + speaker, target = ( + Coordination._annot_speaker(speaker, utt2, split_by_attribs), + Coordination._annot_speaker(target, utt1, split_by_attribs), + ) speaker_filter = speaker_utterance_selector(utt2, utt1) target_filter = target_utterance_selector(utt2, utt1) diff --git a/convokit/expected_context_framework/demos/parliament_demo.ipynb b/convokit/expected_context_framework/demos/parliament_demo.ipynb index 7c674b0a..b437f903 100644 --- a/convokit/expected_context_framework/demos/parliament_demo.ipynb +++ b/convokit/expected_context_framework/demos/parliament_demo.ipynb @@ -18,7 +18,8 @@ "outputs": [], "source": [ "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -56,14 +57,14 @@ "metadata": {}, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# PARL_CORPUS_PATH = download('parliament-corpus', data_dir=DATA_DIR)\n", "\n", "# OPTION 2: READ PREVIOUSLY-DOWNLOADED CORPUS FROM DISK\n", "# UNCOMMENT THIS LINE AND REPLACE WITH THE DIRECTORY WHERE THE TENNIS-CORPUS IS LOCATED\n", - "PARL_CORPUS_PATH = ''" + "PARL_CORPUS_PATH = \"\"" ] }, { @@ -107,7 +108,7 @@ "metadata": {}, "outputs": [], "source": [ - "parl_corpus.load_info('utterance',['arcs','q_arcs'])" + "parl_corpus.load_info(\"utterance\", [\"arcs\", \"q_arcs\"])" ] }, { @@ -123,7 +124,10 @@ "metadata": {}, "outputs": [], "source": [ - "from convokit.expected_context_framework import ColNormedTfidfTransformer, ExpectedContextModelTransformer" + "from convokit.expected_context_framework import (\n", + " ColNormedTfidfTransformer,\n", + " ExpectedContextModelTransformer,\n", + ")" ] }, { @@ -151,9 +155,12 @@ } ], "source": [ - "q_tfidf_obj = ColNormedTfidfTransformer(input_field='q_arcs', output_field='q_arc_tfidf',\n", - " min_df=100, max_df=.1, binary=False)\n", - "q_tfidf_obj.fit(parl_corpus, selector=lambda x: x.meta['is_question'] and x.meta['pair_has_features'])\n", + "q_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"q_arcs\", output_field=\"q_arc_tfidf\", min_df=100, max_df=0.1, binary=False\n", + ")\n", + "q_tfidf_obj.fit(\n", + " parl_corpus, selector=lambda x: x.meta[\"is_question\"] and x.meta[\"pair_has_features\"]\n", + ")\n", "print(len(q_tfidf_obj.get_vocabulary()))" ] }, @@ -171,9 +178,10 @@ } ], "source": [ - "a_tfidf_obj = ColNormedTfidfTransformer(input_field='arcs', output_field='arc_tfidf',\n", - " min_df=100, max_df=.1, binary=False)\n", - "a_tfidf_obj.fit(parl_corpus, selector=lambda x: x.meta['is_answer'] and x.meta['pair_has_features'])\n", + "a_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"arcs\", output_field=\"arc_tfidf\", min_df=100, max_df=0.1, binary=False\n", + ")\n", + "a_tfidf_obj.fit(parl_corpus, selector=lambda x: x.meta[\"is_answer\"] and x.meta[\"pair_has_features\"])\n", "print(len(a_tfidf_obj.get_vocabulary()))" ] }, @@ -183,8 +191,12 @@ "metadata": {}, "outputs": [], "source": [ - "_ = q_tfidf_obj.transform(parl_corpus, selector=lambda x: x.meta['is_question'] and x.meta['pair_has_features'])\n", - "_ = a_tfidf_obj.transform(parl_corpus, selector=lambda x: x.meta['is_answer'] and x.meta['pair_has_features'])" + "_ = q_tfidf_obj.transform(\n", + " parl_corpus, selector=lambda x: x.meta[\"is_question\"] and x.meta[\"pair_has_features\"]\n", + ")\n", + "_ = a_tfidf_obj.transform(\n", + " parl_corpus, selector=lambda x: x.meta[\"is_answer\"] and x.meta[\"pair_has_features\"]\n", + ")" ] }, { @@ -205,10 +217,15 @@ "outputs": [], "source": [ "q_ec = ExpectedContextModelTransformer(\n", - " context_field='next_id', output_prefix='fw', \n", - " vect_field='q_arc_tfidf', context_vect_field='arc_tfidf',\n", - " n_svd_dims=25, n_clusters=8,\n", - " random_state=1000, cluster_random_state=1000)" + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " vect_field=\"q_arc_tfidf\",\n", + " context_vect_field=\"arc_tfidf\",\n", + " n_svd_dims=25,\n", + " n_clusters=8,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -224,8 +241,11 @@ "metadata": {}, "outputs": [], "source": [ - "q_ec.fit(parl_corpus, selector=lambda x: x.meta['is_question'] and (x.meta.get('q_arc_tfidf__n_feats',0)>0),\n", - " context_selector=lambda x: x.meta['is_answer'] and (x.meta.get('arc_tfidf__n_feats',0)>0))" + "q_ec.fit(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_question\"] and (x.meta.get(\"q_arc_tfidf__n_feats\", 0) > 0),\n", + " context_selector=lambda x: x.meta[\"is_answer\"] and (x.meta.get(\"arc_tfidf__n_feats\", 0) > 0),\n", + ")" ] }, { @@ -729,8 +749,18 @@ "metadata": {}, "outputs": [], "source": [ - "q_ec.set_cluster_names(['shared_concern', 'request_assurance', 'prompt_comment', 'agreement', \n", - " 'issue_update', 'question_premises', 'accept_propose', 'demand_account'])" + "q_ec.set_cluster_names(\n", + " [\n", + " \"shared_concern\",\n", + " \"request_assurance\",\n", + " \"prompt_comment\",\n", + " \"agreement\",\n", + " \"issue_update\",\n", + " \"question_premises\",\n", + " \"accept_propose\",\n", + " \"demand_account\",\n", + " ]\n", + ")" ] }, { @@ -860,10 +890,16 @@ "metadata": {}, "outputs": [], "source": [ - "_ = q_ec.transform(parl_corpus, selector=lambda x: x.meta['is_question'] and (x.meta.get('q_arc_tfidf__n_feats',0)>0))\n", + "_ = q_ec.transform(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_question\"] and (x.meta.get(\"q_arc_tfidf__n_feats\", 0) > 0),\n", + ")\n", "\n", "# this call derives representations of answers, following our method.\n", - "_ = q_ec.transform_context_utts(parl_corpus, selector=lambda x: x.meta['is_answer'] and (x.meta.get('arc_tfidf__n_feats',0)>0)) " + "_ = q_ec.transform_context_utts(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_answer\"] and (x.meta.get(\"arc_tfidf__n_feats\", 0) > 0),\n", + ")" ] }, { @@ -887,7 +923,7 @@ } ], "source": [ - "ut_eg_id = '2010-03-25c.364.5'\n", + "ut_eg_id = \"2010-03-25c.364.5\"\n", "eg_ut = parl_corpus.get_utterance(ut_eg_id)\n", "print(eg_ut.text)" ] @@ -916,7 +952,7 @@ } ], "source": [ - "eg_ut.meta['fw_clustering.cluster'], eg_ut.meta['fw_clustering.cluster_dist']" + "eg_ut.meta[\"fw_clustering.cluster\"], eg_ut.meta[\"fw_clustering.cluster_dist\"]" ] }, { @@ -947,7 +983,7 @@ } ], "source": [ - "parl_corpus.get_vectors('fw_repr',[ut_eg_id])\n", + "parl_corpus.get_vectors(\"fw_repr\", [ut_eg_id])\n", "# technical note: for an explanation of why there are only 24 dimensions, instead of 25, see the `snip_first_dim` parameter in the documentation" ] }, @@ -975,7 +1011,7 @@ } ], "source": [ - "eg_ut.meta['fw_range']" + "eg_ut.meta[\"fw_range\"]" ] }, { @@ -996,10 +1032,23 @@ "outputs": [], "source": [ "for ut in parl_corpus.iter_utterances():\n", - " ut.meta['speaker'] = ut.speaker.id\n", - "utt_meta_df = parl_corpus.get_attribute_table('utterance',\n", - " ['fw_clustering.cluster','govt', 'govt_coarse','is_question','is_answer',\n", - " 'is_incumbent','is_oppn','speaker','party', 'tenure','next_id'])" + " ut.meta[\"speaker\"] = ut.speaker.id\n", + "utt_meta_df = parl_corpus.get_attribute_table(\n", + " \"utterance\",\n", + " [\n", + " \"fw_clustering.cluster\",\n", + " \"govt\",\n", + " \"govt_coarse\",\n", + " \"is_question\",\n", + " \"is_answer\",\n", + " \"is_incumbent\",\n", + " \"is_oppn\",\n", + " \"speaker\",\n", + " \"party\",\n", + " \"tenure\",\n", + " \"next_id\",\n", + " ],\n", + ")" ] }, { @@ -1015,10 +1064,13 @@ "metadata": {}, "outputs": [], "source": [ - "utt_meta_sub = utt_meta_df[((utt_meta_df.is_incumbent == True) | (utt_meta_df.is_oppn == True))\n", - " & (utt_meta_df.speaker != '') & (utt_meta_df.party.notnull())\n", - " & (utt_meta_df.govt_coarse != 'thatcher+major')\n", - " & (utt_meta_df['fw_clustering.cluster'].notnull())].copy()" + "utt_meta_sub = utt_meta_df[\n", + " ((utt_meta_df.is_incumbent == True) | (utt_meta_df.is_oppn == True))\n", + " & (utt_meta_df.speaker != \"\")\n", + " & (utt_meta_df.party.notnull())\n", + " & (utt_meta_df.govt_coarse != \"thatcher+major\")\n", + " & (utt_meta_df[\"fw_clustering.cluster\"].notnull())\n", + "].copy()" ] }, { @@ -1045,7 +1097,9 @@ " val_false = sum((col == val) & ~bool_col)\n", " nval_true = sum((col != val) & bool_col)\n", " nval_false = sum((col != val) & ~bool_col)\n", - " log_odds_entries.append({'val': val, 'log_odds': np.log((val_true/val_false)/(nval_true/nval_false))})\n", + " log_odds_entries.append(\n", + " {\"val\": val, \"log_odds\": np.log((val_true / val_false) / (nval_true / nval_false))}\n", + " )\n", " return log_odds_entries" ] }, @@ -1055,12 +1109,13 @@ "metadata": {}, "outputs": [], "source": [ - "log_odds_party = []\n", + "log_odds_party = []\n", "for cname in q_ec.get_cluster_names():\n", - " entry = compute_log_odds(utt_meta_sub['fw_clustering.cluster'],utt_meta_sub['is_incumbent'],\n", - " val_subset=[cname])\n", + " entry = compute_log_odds(\n", + " utt_meta_sub[\"fw_clustering.cluster\"], utt_meta_sub[\"is_incumbent\"], val_subset=[cname]\n", + " )\n", " log_odds_party += entry\n", - "log_odds_party_df = pd.DataFrame(log_odds_party).set_index('val')" + "log_odds_party_df = pd.DataFrame(log_odds_party).set_index(\"val\")" ] }, { @@ -1069,7 +1124,7 @@ "metadata": {}, "outputs": [], "source": [ - "type_order = log_odds_party_df.sort_values('log_odds').index\n" + "type_order = log_odds_party_df.sort_values(\"log_odds\").index" ] }, { @@ -1078,9 +1133,16 @@ "metadata": {}, "outputs": [], "source": [ - "display_names = ['Demand for account', 'Questioning premises', 'Prompt for comment',\n", - " 'Accept and propose', 'Req. for assurance', 'Issue update', \n", - " 'Shared concerns', 'Agreement']" + "display_names = [\n", + " \"Demand for account\",\n", + " \"Questioning premises\",\n", + " \"Prompt for comment\",\n", + " \"Accept and propose\",\n", + " \"Req. for assurance\",\n", + " \"Issue update\",\n", + " \"Shared concerns\",\n", + " \"Agreement\",\n", + "]" ] }, { @@ -1090,6 +1152,7 @@ "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -1112,21 +1175,21 @@ } ], "source": [ - "fig, ax = plt.subplots(figsize=(4,6))\n", + "fig, ax = plt.subplots(figsize=(4, 6))\n", "ax.set_xlim(-1.5, 1.5)\n", - "ax.set_ylim(-.5,7.5)\n", - "for i,cname in enumerate(type_order):\n", + "ax.set_ylim(-0.5, 7.5)\n", + "for i, cname in enumerate(type_order):\n", " log_odds = log_odds_party_df.loc[cname].log_odds\n", - " ax.scatter([log_odds], [i], color='black',s=49)\n", - " ax.plot([-1.25,1.25],[i,i],'--', color='grey', linewidth=.5)\n", - "ax.plot([0,0],[-2,8], color='grey', linewidth=1)\n", + " ax.scatter([log_odds], [i], color=\"black\", s=49)\n", + " ax.plot([-1.25, 1.25], [i, i], \"--\", color=\"grey\", linewidth=0.5)\n", + "ax.plot([0, 0], [-2, 8], color=\"grey\", linewidth=1)\n", "ax.invert_yaxis()\n", "ax.set_yticks(np.arange(len(type_order)))\n", "ax.set_yticklabels(display_names, fontsize=14)\n", - "ax.set_xticklabels([-1.5,-1,-.5,0,.5,1,1.5], fontsize=14)\n", - "plt.rc('xtick',labelsize=12)\n", - "plt.rc('ytick',labelsize=12)\n", - "ax.set_xlabel('log odds ratio', fontsize=16)\n", + "ax.set_xticklabels([-1.5, -1, -0.5, 0, 0.5, 1, 1.5], fontsize=14)\n", + "plt.rc(\"xtick\", labelsize=12)\n", + "plt.rc(\"ytick\", labelsize=12)\n", + "ax.set_xlabel(\"log odds ratio\", fontsize=16)\n", "None" ] }, @@ -1157,8 +1220,19 @@ "metadata": {}, "outputs": [], "source": [ - "med_tenures = pd.concat([utt_meta_sub[utt_meta_sub.is_incumbent].groupby('fw_clustering.cluster').tenure.median().rename('govt'),\n", - "utt_meta_sub[~utt_meta_sub.is_incumbent].groupby('fw_clustering.cluster').tenure.median().rename('oppn')], axis=1)\n", + "med_tenures = pd.concat(\n", + " [\n", + " utt_meta_sub[utt_meta_sub.is_incumbent]\n", + " .groupby(\"fw_clustering.cluster\")\n", + " .tenure.median()\n", + " .rename(\"govt\"),\n", + " utt_meta_sub[~utt_meta_sub.is_incumbent]\n", + " .groupby(\"fw_clustering.cluster\")\n", + " .tenure.median()\n", + " .rename(\"oppn\"),\n", + " ],\n", + " axis=1,\n", + ")\n", "med_in_tenure = utt_meta_sub[utt_meta_sub.is_incumbent].tenure.median()\n", "med_op_tenure = utt_meta_sub[~utt_meta_sub.is_incumbent].tenure.median()" ] @@ -1284,23 +1358,25 @@ } ], "source": [ - "fig, ax = plt.subplots(figsize=(4,6))\n", + "fig, ax = plt.subplots(figsize=(4, 6))\n", "ax.set_xlim(2, 13)\n", - "ax.set_ylim(-.5,7.5)\n", - "for i,cname in enumerate(type_order):\n", - " ax.scatter([med_tenures.loc[cname].govt],[i-.05], s=49, color='blue')\n", - " ax.scatter([med_tenures.loc[cname].oppn],[i+.05], s=49, color='red', facecolor='white',marker='s')\n", - " ax.plot([.5,14.5],[i,i],'--', color='grey', linewidth=.5)\n", - "ax.plot([med_in_tenure, med_in_tenure],[-2,8], color='blue',linewidth=1)\n", - "ax.plot([med_op_tenure, med_op_tenure],[-2,8], '--', color='red', linewidth=1)\n", + "ax.set_ylim(-0.5, 7.5)\n", + "for i, cname in enumerate(type_order):\n", + " ax.scatter([med_tenures.loc[cname].govt], [i - 0.05], s=49, color=\"blue\")\n", + " ax.scatter(\n", + " [med_tenures.loc[cname].oppn], [i + 0.05], s=49, color=\"red\", facecolor=\"white\", marker=\"s\"\n", + " )\n", + " ax.plot([0.5, 14.5], [i, i], \"--\", color=\"grey\", linewidth=0.5)\n", + "ax.plot([med_in_tenure, med_in_tenure], [-2, 8], color=\"blue\", linewidth=1)\n", + "ax.plot([med_op_tenure, med_op_tenure], [-2, 8], \"--\", color=\"red\", linewidth=1)\n", "ax.invert_yaxis()\n", - "ax.set_xticks([5,10])\n", - "ax.set_xticklabels([5,10], fontsize=14)\n", + "ax.set_xticks([5, 10])\n", + "ax.set_xticklabels([5, 10], fontsize=14)\n", "ax.set_yticks(np.arange(8))\n", "ax.set_yticklabels(display_names, fontsize=14)\n", - "ax.set_xlabel('median tenure', fontsize=16)\n", - "plt.rc('xtick',labelsize=12)\n", - "plt.rc('ytick',labelsize=12)" + "ax.set_xlabel(\"median tenure\", fontsize=16)\n", + "plt.rc(\"xtick\", labelsize=12)\n", + "plt.rc(\"ytick\", labelsize=12)" ] }, { @@ -1340,10 +1416,15 @@ "outputs": [], "source": [ "a_ec = ExpectedContextModelTransformer(\n", - " context_field='reply_to', output_prefix='bk', \n", - " vect_field='arc_tfidf', context_vect_field='q_arc_tfidf',\n", - " n_svd_dims=15, n_clusters=5,\n", - " random_state=1000, cluster_random_state=1000)" + " context_field=\"reply_to\",\n", + " output_prefix=\"bk\",\n", + " vect_field=\"arc_tfidf\",\n", + " context_vect_field=\"q_arc_tfidf\",\n", + " n_svd_dims=15,\n", + " n_clusters=5,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -1352,8 +1433,12 @@ "metadata": {}, "outputs": [], "source": [ - "a_ec.fit(parl_corpus, selector=lambda x: x.meta['is_answer'] and (x.meta.get('arc_tfidf__n_feats',0)>0),\n", - " context_selector=lambda x: x.meta['is_question'] and (x.meta.get('q_arc_tfidf__n_feats',0)>0))" + "a_ec.fit(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_answer\"] and (x.meta.get(\"arc_tfidf__n_feats\", 0) > 0),\n", + " context_selector=lambda x: x.meta[\"is_question\"]\n", + " and (x.meta.get(\"q_arc_tfidf__n_feats\", 0) > 0),\n", + ")" ] }, { @@ -1670,7 +1755,7 @@ "metadata": {}, "outputs": [], "source": [ - "a_ec.set_cluster_names(['progress_report', 'statement', 'endorsement', 'comment', 'commitment'])" + "a_ec.set_cluster_names([\"progress_report\", \"statement\", \"endorsement\", \"comment\", \"commitment\"])" ] }, { @@ -1776,7 +1861,10 @@ "metadata": {}, "outputs": [], "source": [ - "_ = a_ec.transform(parl_corpus, selector=lambda x: x.meta['is_answer'] and (x.meta.get('arc_tfidf__n_feats',0)>0))" + "_ = a_ec.transform(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_answer\"] and (x.meta.get(\"arc_tfidf__n_feats\", 0) > 0),\n", + ")" ] }, { @@ -1787,12 +1875,17 @@ }, "outputs": [], "source": [ - "a_utt_meta_df = parl_corpus.get_attribute_table('utterance',\n", - " ['bk_clustering.cluster'])\n", - "a_utt_meta_sub = a_utt_meta_df.join(utt_meta_df[((utt_meta_df.is_incumbent == True) | (utt_meta_df.is_oppn == True))\n", - " & (utt_meta_df.speaker != '') & (utt_meta_df.party.notnull())\n", - " & (utt_meta_df.govt_coarse != 'thatcher+major')].set_index('next_id'), how='inner')\n", - "a_utt_meta_sub = a_utt_meta_sub[a_utt_meta_sub['bk_clustering.cluster'].notnull()]" + "a_utt_meta_df = parl_corpus.get_attribute_table(\"utterance\", [\"bk_clustering.cluster\"])\n", + "a_utt_meta_sub = a_utt_meta_df.join(\n", + " utt_meta_df[\n", + " ((utt_meta_df.is_incumbent == True) | (utt_meta_df.is_oppn == True))\n", + " & (utt_meta_df.speaker != \"\")\n", + " & (utt_meta_df.party.notnull())\n", + " & (utt_meta_df.govt_coarse != \"thatcher+major\")\n", + " ].set_index(\"next_id\"),\n", + " how=\"inner\",\n", + ")\n", + "a_utt_meta_sub = a_utt_meta_sub[a_utt_meta_sub[\"bk_clustering.cluster\"].notnull()]" ] }, { @@ -1821,12 +1914,13 @@ "metadata": {}, "outputs": [], "source": [ - "log_odds_party_answer = []\n", + "log_odds_party_answer = []\n", "for cname in a_ec.get_cluster_names():\n", - " entry = compute_log_odds(a_utt_meta_sub['bk_clustering.cluster'],a_utt_meta_sub['is_incumbent'],\n", - " val_subset=[cname])\n", + " entry = compute_log_odds(\n", + " a_utt_meta_sub[\"bk_clustering.cluster\"], a_utt_meta_sub[\"is_incumbent\"], val_subset=[cname]\n", + " )\n", " log_odds_party_answer += entry\n", - "log_odds_party_answer_df = pd.DataFrame(log_odds_party_answer).set_index('val')" + "log_odds_party_answer_df = pd.DataFrame(log_odds_party_answer).set_index(\"val\")" ] }, { @@ -1835,7 +1929,7 @@ "metadata": {}, "outputs": [], "source": [ - "a_type_order = log_odds_party_answer_df.sort_values('log_odds').index" + "a_type_order = log_odds_party_answer_df.sort_values(\"log_odds\").index" ] }, { @@ -1844,7 +1938,7 @@ "metadata": {}, "outputs": [], "source": [ - "a_display_names = ['Statement', 'Comment', 'Progress report', 'Commitment', 'Endorsement']" + "a_display_names = [\"Statement\", \"Comment\", \"Progress report\", \"Commitment\", \"Endorsement\"]" ] }, { @@ -1866,20 +1960,20 @@ } ], "source": [ - "fig, ax = plt.subplots(figsize=(3,4))\n", + "fig, ax = plt.subplots(figsize=(3, 4))\n", "ax.set_xlim(-1.5, 1.5)\n", - "ax.set_ylim(-.5,4.5)\n", - "for i,cname in enumerate(a_type_order):\n", + "ax.set_ylim(-0.5, 4.5)\n", + "for i, cname in enumerate(a_type_order):\n", " log_odds = log_odds_party_answer_df.loc[cname].log_odds\n", - " ax.scatter([log_odds], [i], color='black',s=49) \n", - " ax.plot([-1.25,1.25],[i,i],'--', color='grey', linewidth=.5)\n", - "ax.plot([0,0],[-2,5], color='grey', linewidth=1)\n", + " ax.scatter([log_odds], [i], color=\"black\", s=49)\n", + " ax.plot([-1.25, 1.25], [i, i], \"--\", color=\"grey\", linewidth=0.5)\n", + "ax.plot([0, 0], [-2, 5], color=\"grey\", linewidth=1)\n", "ax.invert_yaxis()\n", "ax.set_yticks(np.arange(len(a_type_order)))\n", "ax.set_yticklabels(a_display_names, fontsize=14)\n", - "ax.set_xlabel('log odds ratio', fontsize=16)\n", - "ax.set_xticks([-1,0,1])\n", - "ax.set_xticklabels([-1,0,1], fontsize=14)\n", + "ax.set_xlabel(\"log odds ratio\", fontsize=16)\n", + "ax.set_xticks([-1, 0, 1])\n", + "ax.set_xticklabels([-1, 0, 1], fontsize=14)\n", "None" ] }, @@ -1905,8 +1999,7 @@ "metadata": {}, "outputs": [], "source": [ - "utt_range_df = parl_corpus.get_attribute_table('utterance',\n", - " ['fw_clustering.cluster','fw_range'])\n", + "utt_range_df = parl_corpus.get_attribute_table(\"utterance\", [\"fw_clustering.cluster\", \"fw_range\"])\n", "utt_range_df = utt_range_df[utt_range_df.fw_range.notnull()].copy()" ] }, @@ -1916,7 +2009,9 @@ "metadata": {}, "outputs": [], "source": [ - "fw_range_distrs = utt_range_df.groupby('fw_clustering.cluster').fw_range.describe().sort_values('50%')\n", + "fw_range_distrs = (\n", + " utt_range_df.groupby(\"fw_clustering.cluster\").fw_range.describe().sort_values(\"50%\")\n", + ")\n", "med_range_full = utt_range_df.fw_range.median()" ] }, @@ -1939,23 +2034,23 @@ } ], "source": [ - "fig, ax = plt.subplots(figsize=(4,6))\n", - "ax.set_xlim(.7, .9)\n", - "ax.set_ylim(-.5,7.5)\n", - "for i,cname in enumerate(type_order):\n", - " med_range = fw_range_distrs.loc[cname]['50%']\n", - " left = fw_range_distrs.loc[cname]['25%']\n", - " right = fw_range_distrs.loc[cname]['75%']\n", - " ax.scatter([med_range], [i], color='black',s=49)\n", - " ax.plot([left,right],[i,i], color='black')\n", - " ax.plot([-1.25,1.25],[i,i],'--', color='grey', linewidth=.5)\n", - "ax.plot([med_range_full,med_range_full],[-2,8], '--', color='grey', linewidth=1)\n", + "fig, ax = plt.subplots(figsize=(4, 6))\n", + "ax.set_xlim(0.7, 0.9)\n", + "ax.set_ylim(-0.5, 7.5)\n", + "for i, cname in enumerate(type_order):\n", + " med_range = fw_range_distrs.loc[cname][\"50%\"]\n", + " left = fw_range_distrs.loc[cname][\"25%\"]\n", + " right = fw_range_distrs.loc[cname][\"75%\"]\n", + " ax.scatter([med_range], [i], color=\"black\", s=49)\n", + " ax.plot([left, right], [i, i], color=\"black\")\n", + " ax.plot([-1.25, 1.25], [i, i], \"--\", color=\"grey\", linewidth=0.5)\n", + "ax.plot([med_range_full, med_range_full], [-2, 8], \"--\", color=\"grey\", linewidth=1)\n", "ax.invert_yaxis()\n", "ax.set_yticks(np.arange(len(type_order)))\n", "ax.set_yticklabels(display_names, fontsize=14)\n", - "ax.set_xlabel('$\\overrightarrow{\\Sigma}$', fontsize=16)\n", - "plt.rc('xtick',labelsize=14)\n", - "plt.rc('ytick',labelsize=14)\n" + "ax.set_xlabel(\"$\\overrightarrow{\\Sigma}$\", fontsize=16)\n", + "plt.rc(\"xtick\", labelsize=14)\n", + "plt.rc(\"ytick\", labelsize=14)" ] }, { @@ -2025,13 +2120,21 @@ "metadata": {}, "outputs": [], "source": [ - "q_pipe = ExpectedContextModelPipeline(context_field='next_id', output_prefix='fw',\n", - " text_field='q_arcs', context_text_field='arcs', share_tfidf_models=False,\n", - " text_pipe=parliament_arc_pipeline(), \n", - " tfidf_params={'binary': False, 'min_df': 100, 'max_df': .1}, \n", - " min_terms=1,\n", - " n_svd_dims=25, n_clusters=8, cluster_on='utts',\n", - " random_state=1000, cluster_random_state=1000)" + "q_pipe = ExpectedContextModelPipeline(\n", + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " text_field=\"q_arcs\",\n", + " context_text_field=\"arcs\",\n", + " share_tfidf_models=False,\n", + " text_pipe=parliament_arc_pipeline(),\n", + " tfidf_params={\"binary\": False, \"min_df\": 100, \"max_df\": 0.1},\n", + " min_terms=1,\n", + " n_svd_dims=25,\n", + " n_clusters=8,\n", + " cluster_on=\"utts\",\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -2040,9 +2143,11 @@ "metadata": {}, "outputs": [], "source": [ - "q_pipe.fit(parl_corpus,\n", - " selector=lambda x: x.meta['is_question'] and x.meta['pair_has_features'],\n", - " context_selector=lambda x: x.meta['is_answer'] and x.meta['pair_has_features'])" + "q_pipe.fit(\n", + " parl_corpus,\n", + " selector=lambda x: x.meta[\"is_question\"] and x.meta[\"pair_has_features\"],\n", + " context_selector=lambda x: x.meta[\"is_answer\"] and x.meta[\"pair_has_features\"],\n", + ")" ] }, { @@ -2321,8 +2426,18 @@ "metadata": {}, "outputs": [], "source": [ - "q_pipe.set_cluster_names(['shared_concern', 'request_assurance', 'prompt_comment', 'agreement', \n", - " 'issue_update', 'question_premises', 'accept_propose', 'demand_account'])" + "q_pipe.set_cluster_names(\n", + " [\n", + " \"shared_concern\",\n", + " \"request_assurance\",\n", + " \"prompt_comment\",\n", + " \"agreement\",\n", + " \"issue_update\",\n", + " \"question_premises\",\n", + " \"accept_propose\",\n", + " \"demand_account\",\n", + " ]\n", + ")" ] }, { @@ -2339,7 +2454,8 @@ "outputs": [], "source": [ "new_ut = q_pipe.transform_utterance(\n", - " 'Can the Minister please explain why the reopening was delayed?')" + " \"Can the Minister please explain why the reopening was delayed?\"\n", + ")" ] }, { @@ -2356,7 +2472,7 @@ } ], "source": [ - "print('question type:', new_ut.meta['fw_clustering.cluster'])" + "print(\"question type:\", new_ut.meta[\"fw_clustering.cluster\"])" ] }, { @@ -2408,7 +2524,7 @@ "source": [ "# note that different versions of SpaCy may produce different outputs, since the\n", "# dependency parses may change from version to version\n", - "new_ut.meta['fw_repr']" + "new_ut.meta[\"fw_repr\"]" ] } ], diff --git a/convokit/expected_context_framework/demos/scotus_orientation_demo.ipynb b/convokit/expected_context_framework/demos/scotus_orientation_demo.ipynb index 0de52a3b..cb047f4c 100644 --- a/convokit/expected_context_framework/demos/scotus_orientation_demo.ipynb +++ b/convokit/expected_context_framework/demos/scotus_orientation_demo.ipynb @@ -21,7 +21,8 @@ "outputs": [], "source": [ "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -74,7 +75,7 @@ }, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# SCOTUS_CORPUS_PATH = download('supreme-corpus', data_dir=DATA_DIR)\n", @@ -129,7 +130,7 @@ }, "outputs": [], "source": [ - "scotus_corpus.load_info('utterance',['arcs','tokens'])" + "scotus_corpus.load_info(\"utterance\", [\"arcs\", \"tokens\"])" ] }, { @@ -148,8 +149,10 @@ "outputs": [], "source": [ "from convokit.text_processing import TextProcessor\n", - "wordcounter = TextProcessor(input_field='tokens', output_field='wordcount',\n", - " proc_fn=lambda x: len(x.split()))\n", + "\n", + "wordcounter = TextProcessor(\n", + " input_field=\"tokens\", output_field=\"wordcount\", proc_fn=lambda x: len(x.split())\n", + ")\n", "scotus_corpus = wordcounter.transform(scotus_corpus)" ] }, @@ -169,7 +172,7 @@ "outputs": [], "source": [ "for ut in scotus_corpus.iter_utterances(selector=lambda x: x.reply_to is not None):\n", - " scotus_corpus.get_utterance(ut.reply_to).meta['next_id'] = ut.id" + " scotus_corpus.get_utterance(ut.reply_to).meta[\"next_id\"] = ut.id" ] }, { @@ -217,20 +220,24 @@ "outputs": [], "source": [ "for ut in scotus_corpus.iter_utterances():\n", - " ut.meta['is_valid_context'] = (ut.meta['speaker_type'] == 'A')\\\n", - " and (ut.meta['arcs'] != '')\\\n", - " and (ut.meta['wordcount'] >= min_wc_context)\\\n", - " and (ut.meta['wordcount'] <= max_wc_context) \n", + " ut.meta[\"is_valid_context\"] = (\n", + " (ut.meta[\"speaker_type\"] == \"A\")\n", + " and (ut.meta[\"arcs\"] != \"\")\n", + " and (ut.meta[\"wordcount\"] >= min_wc_context)\n", + " and (ut.meta[\"wordcount\"] <= max_wc_context)\n", + " )\n", "for ut in scotus_corpus.iter_utterances():\n", - " if ('next_id' not in ut.meta) or (ut.reply_to is None): \n", - " ut.meta['is_valid_utt'] = False\n", + " if (\"next_id\" not in ut.meta) or (ut.reply_to is None):\n", + " ut.meta[\"is_valid_utt\"] = False\n", " else:\n", - " ut.meta['is_valid_utt'] = (ut.meta['speaker_type'] == 'J')\\\n", - " and (ut.meta['arcs'] != '')\\\n", - " and (ut.meta['wordcount'] >= min_wc)\\\n", - " and (ut.meta['wordcount'] <= max_wc)\\\n", - " and scotus_corpus.get_utterance(ut.meta['next_id']).meta['is_valid_context']\\\n", - " and scotus_corpus.get_utterance(ut.reply_to).meta['is_valid_context']" + " ut.meta[\"is_valid_utt\"] = (\n", + " (ut.meta[\"speaker_type\"] == \"J\")\n", + " and (ut.meta[\"arcs\"] != \"\")\n", + " and (ut.meta[\"wordcount\"] >= min_wc)\n", + " and (ut.meta[\"wordcount\"] <= max_wc)\n", + " and scotus_corpus.get_utterance(ut.meta[\"next_id\"]).meta[\"is_valid_context\"]\n", + " and scotus_corpus.get_utterance(ut.reply_to).meta[\"is_valid_context\"]\n", + " )" ] }, { @@ -257,7 +264,7 @@ } ], "source": [ - "sum(ut.meta['is_valid_utt'] for ut in scotus_corpus.iter_utterances())" + "sum(ut.meta[\"is_valid_utt\"] for ut in scotus_corpus.iter_utterances())" ] }, { @@ -277,7 +284,7 @@ } ], "source": [ - "sum(ut.meta['is_valid_context'] for ut in scotus_corpus.iter_utterances())" + "sum(ut.meta[\"is_valid_context\"] for ut in scotus_corpus.iter_utterances())" ] }, { @@ -324,10 +331,16 @@ }, "outputs": [], "source": [ - "j_tfidf_obj = ColNormedTfidfTransformer(input_field='arcs', output_field='j_tfidf', binary=True, \n", - " min_df=250, max_df=1., max_features=2000)\n", - "_ = j_tfidf_obj.fit(scotus_corpus, selector=lambda x: x.meta['is_valid_utt'])\n", - "_ = j_tfidf_obj.transform(scotus_corpus, selector=lambda x: x.meta['is_valid_utt'])" + "j_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"arcs\",\n", + " output_field=\"j_tfidf\",\n", + " binary=True,\n", + " min_df=250,\n", + " max_df=1.0,\n", + " max_features=2000,\n", + ")\n", + "_ = j_tfidf_obj.fit(scotus_corpus, selector=lambda x: x.meta[\"is_valid_utt\"])\n", + "_ = j_tfidf_obj.transform(scotus_corpus, selector=lambda x: x.meta[\"is_valid_utt\"])" ] }, { @@ -338,10 +351,16 @@ }, "outputs": [], "source": [ - "a_tfidf_obj = ColNormedTfidfTransformer(input_field='arcs', output_field='a_tfidf', binary=True, \n", - " min_df=250, max_df=1., max_features=2000)\n", - "_ = a_tfidf_obj.fit(scotus_corpus, selector=lambda x: x.meta['is_valid_context'])\n", - "_ = a_tfidf_obj.transform(scotus_corpus, selector=lambda x: x.meta['is_valid_context'])" + "a_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"arcs\",\n", + " output_field=\"a_tfidf\",\n", + " binary=True,\n", + " min_df=250,\n", + " max_df=1.0,\n", + " max_features=2000,\n", + ")\n", + "_ = a_tfidf_obj.fit(scotus_corpus, selector=lambda x: x.meta[\"is_valid_context\"])\n", + "_ = a_tfidf_obj.transform(scotus_corpus, selector=lambda x: x.meta[\"is_valid_context\"])" ] }, { @@ -361,10 +380,14 @@ }, "outputs": [], "source": [ - "dual_context_model = DualContextWrapper(context_fields=['reply_to','next_id'], output_prefixes=['bk','fw'],\n", - " vect_field='j_tfidf', context_vect_field='a_tfidf', \n", - " n_svd_dims=15,\n", - " random_state=1000)" + "dual_context_model = DualContextWrapper(\n", + " context_fields=[\"reply_to\", \"next_id\"],\n", + " output_prefixes=[\"bk\", \"fw\"],\n", + " vect_field=\"j_tfidf\",\n", + " context_vect_field=\"a_tfidf\",\n", + " n_svd_dims=15,\n", + " random_state=1000,\n", + ")" ] }, { @@ -375,8 +398,11 @@ }, "outputs": [], "source": [ - "dual_context_model.fit(scotus_corpus, selector=lambda x: x.meta['is_valid_utt'],\n", - " context_selector=lambda x: x.meta['is_valid_context'])" + "dual_context_model.fit(\n", + " scotus_corpus,\n", + " selector=lambda x: x.meta[\"is_valid_utt\"],\n", + " context_selector=lambda x: x.meta[\"is_valid_context\"],\n", + ")" ] }, { @@ -748,10 +774,10 @@ } ], "source": [ - "print('\\nhigh orientation')\n", - "display(term_df.sort_values('orn')[['orn']].tail(20))\n", - "print('low orientation')\n", - "display(term_df.sort_values('orn')[['orn']].head(20))" + "print(\"\\nhigh orientation\")\n", + "display(term_df.sort_values(\"orn\")[[\"orn\"]].tail(20))\n", + "print(\"low orientation\")\n", + "display(term_df.sort_values(\"orn\")[[\"orn\"]].head(20))" ] }, { @@ -792,19 +818,24 @@ "outputs": [], "source": [ "sentence_utts = []\n", - "for ut in scotus_corpus.iter_utterances(selector=lambda x: x.meta['is_valid_utt']):\n", - " sents = ut.meta['arcs'].split('\\n')\n", - " tok_sents = ut.meta['tokens'].split('\\n')\n", + "for ut in scotus_corpus.iter_utterances(selector=lambda x: x.meta[\"is_valid_utt\"]):\n", + " sents = ut.meta[\"arcs\"].split(\"\\n\")\n", + " tok_sents = ut.meta[\"tokens\"].split(\"\\n\")\n", " for i, (sent, tok_sent) in enumerate(zip(sents, tok_sents)):\n", - " utt_id = ut.id + '_' + '%03d' % i\n", + " utt_id = ut.id + \"_\" + \"%03d\" % i\n", " speaker = ut.speaker\n", " text = tok_sent\n", - " meta = {'arcs': sent, 'utt_id': ut.id, 'speaker': ut.speaker.id}\n", - " sentence_utts.append(Utterance(\n", - " id=utt_id, speaker=speaker, text=text,\n", - " reply_to=ut.reply_to, conversation_id=ut.conversation_id,\n", - " meta=meta\n", - " ))" + " meta = {\"arcs\": sent, \"utt_id\": ut.id, \"speaker\": ut.speaker.id}\n", + " sentence_utts.append(\n", + " Utterance(\n", + " id=utt_id,\n", + " speaker=speaker,\n", + " text=text,\n", + " reply_to=ut.reply_to,\n", + " conversation_id=ut.conversation_id,\n", + " meta=meta,\n", + " )\n", + " )" ] }, { @@ -853,7 +884,9 @@ "outputs": [], "source": [ "_ = j_tfidf_obj.transform(sentence_corpus)\n", - "_ = dual_context_model.transform(sentence_corpus, selector=lambda x: x.meta['j_tfidf__n_feats'] >= 1)\n" + "_ = dual_context_model.transform(\n", + " sentence_corpus, selector=lambda x: x.meta[\"j_tfidf__n_feats\"] >= 1\n", + ")" ] }, { @@ -870,9 +903,9 @@ } ], "source": [ - "ut_eg_id = '20030__1_029_000'\n", + "ut_eg_id = \"20030__1_029_000\"\n", "eg_ut = sentence_corpus.get_utterance(ut_eg_id)\n", - "print(eg_ut.speaker.meta['name'], ':',eg_ut.text)" + "print(eg_ut.speaker.meta[\"name\"], \":\", eg_ut.text)" ] }, { @@ -892,7 +925,7 @@ } ], "source": [ - "eg_ut.meta['orn']" + "eg_ut.meta[\"orn\"]" ] }, { @@ -919,10 +952,13 @@ }, "outputs": [], "source": [ - "sent_df = sentence_corpus.get_attribute_table('utterance',['orn','j_tfidf__n_feats'])\n", - "text_df = pd.DataFrame([{'id': ut._id, 'text': ut.text, 'speaker': ut.speaker.meta['name']}\n", - " for ut in sentence_corpus.iter_utterances()\n", - "]).set_index('id')\n", + "sent_df = sentence_corpus.get_attribute_table(\"utterance\", [\"orn\", \"j_tfidf__n_feats\"])\n", + "text_df = pd.DataFrame(\n", + " [\n", + " {\"id\": ut._id, \"text\": ut.text, \"speaker\": ut.speaker.meta[\"name\"]}\n", + " for ut in sentence_corpus.iter_utterances()\n", + " ]\n", + ").set_index(\"id\")\n", "sent_df = sent_df.join(text_df)" ] }, @@ -963,10 +999,12 @@ }, "outputs": [], "source": [ - "low_subset = sent_df[(sent_df.j_tfidf__n_feats >= 30)\n", - " & (sent_df.orn < sent_df.orn.quantile(.1))].sample(10,random_state=9)\n", - "high_subset = sent_df[(sent_df.j_tfidf__n_feats >= 30)\n", - " & (sent_df.orn > sent_df.orn.quantile(.9))].sample(10,random_state=9)" + "low_subset = sent_df[\n", + " (sent_df.j_tfidf__n_feats >= 30) & (sent_df.orn < sent_df.orn.quantile(0.1))\n", + "].sample(10, random_state=9)\n", + "high_subset = sent_df[\n", + " (sent_df.j_tfidf__n_feats >= 30) & (sent_df.orn > sent_df.orn.quantile(0.9))\n", + "].sample(10, random_state=9)" ] }, { @@ -1019,9 +1057,9 @@ } ], "source": [ - "for id, row in high_subset.sort_values('orn', ascending=False).iterrows():\n", - " print(id,row.speaker, 'orientation:',row.orn)\n", - " print('>', row.text)\n", + "for id, row in high_subset.sort_values(\"orn\", ascending=False).iterrows():\n", + " print(id, row.speaker, \"orientation:\", row.orn)\n", + " print(\">\", row.text)\n", " print()" ] }, @@ -1068,9 +1106,9 @@ } ], "source": [ - "for id, row in low_subset.sort_values('orn').iterrows():\n", - " print(id,row.speaker, 'orientation:',row.orn)\n", - " print('>', row.text)\n", + "for id, row in low_subset.sort_values(\"orn\").iterrows():\n", + " print(id, row.speaker, \"orientation:\", row.orn)\n", + " print(\">\", row.text)\n", " print()" ] }, @@ -1140,11 +1178,16 @@ }, "outputs": [], "source": [ - "dual_pipe = DualContextPipeline(context_fields=['reply_to','next_id'], \n", - " output_prefixes=['bk','fw'], share_tfidf_models=False,\n", - " text_field='arcs', text_pipe=scotus_arc_pipeline(), \n", - " tfidf_params={'binary': True, 'min_df': 250, 'max_features': 2000}, \n", - " n_svd_dims=15, random_state=1000)" + "dual_pipe = DualContextPipeline(\n", + " context_fields=[\"reply_to\", \"next_id\"],\n", + " output_prefixes=[\"bk\", \"fw\"],\n", + " share_tfidf_models=False,\n", + " text_field=\"arcs\",\n", + " text_pipe=scotus_arc_pipeline(),\n", + " tfidf_params={\"binary\": True, \"min_df\": 250, \"max_features\": 2000},\n", + " n_svd_dims=15,\n", + " random_state=1000,\n", + ")" ] }, { @@ -1155,9 +1198,11 @@ }, "outputs": [], "source": [ - "dual_pipe.fit(scotus_corpus,\n", - " selector=lambda x: x.meta['is_valid_utt'],\n", - " context_selector=lambda x: x.meta['is_valid_context'])" + "dual_pipe.fit(\n", + " scotus_corpus,\n", + " selector=lambda x: x.meta[\"is_valid_utt\"],\n", + " context_selector=lambda x: x.meta[\"is_valid_context\"],\n", + ")" ] }, { @@ -1482,10 +1527,10 @@ } ], "source": [ - "print('\\nhigh orientation')\n", - "display(term_df_new.sort_values('orn')[['orn']].tail(20))\n", - "print('low orientation')\n", - "display(term_df_new.sort_values('orn')[['orn']].head(20))" + "print(\"\\nhigh orientation\")\n", + "display(term_df_new.sort_values(\"orn\")[[\"orn\"]].tail(20))\n", + "print(\"low orientation\")\n", + "display(term_df_new.sort_values(\"orn\")[[\"orn\"]].head(20))" ] }, { @@ -1503,7 +1548,7 @@ }, "outputs": [], "source": [ - "eg_ut_new = dual_pipe.transform_utterance('What is the difference between these statutes?')" + "eg_ut_new = dual_pipe.transform_utterance(\"What is the difference between these statutes?\")" ] }, { @@ -1520,7 +1565,7 @@ } ], "source": [ - "print('orientation:', eg_ut_new.meta['orn'])" + "print(\"orientation:\", eg_ut_new.meta[\"orn\"])" ] }, { diff --git a/convokit/expected_context_framework/demos/switchboard_exploration_demo.ipynb b/convokit/expected_context_framework/demos/switchboard_exploration_demo.ipynb index d330b462..c233d41a 100644 --- a/convokit/expected_context_framework/demos/switchboard_exploration_demo.ipynb +++ b/convokit/expected_context_framework/demos/switchboard_exploration_demo.ipynb @@ -22,7 +22,8 @@ "outputs": [], "source": [ "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -79,7 +80,7 @@ }, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# SW_CORPUS_PATH = download('switchboard-processed-corpus', data_dir=DATA_DIR)\n", @@ -127,7 +128,7 @@ }, "outputs": [], "source": [ - "utt_eg_id = '3496-79'" + "utt_eg_id = \"3496-79\"" ] }, { @@ -154,7 +155,7 @@ } ], "source": [ - "sw_corpus.get_utterance(utt_eg_id).meta['alpha_text']" + "sw_corpus.get_utterance(utt_eg_id).meta[\"alpha_text\"]" ] }, { @@ -194,19 +195,19 @@ "source": [ "topic_counts = defaultdict(set)\n", "for ut in sw_corpus.iter_utterances():\n", - " topic = sw_corpus.get_conversation(ut.conversation_id).meta['topic']\n", - " for x in set(ut.meta['alpha_text'].lower().split()):\n", + " topic = sw_corpus.get_conversation(ut.conversation_id).meta[\"topic\"]\n", + " for x in set(ut.meta[\"alpha_text\"].lower().split()):\n", " topic_counts[x].add(topic)\n", "topic_counts = {x: len(y) for x, y in topic_counts.items()}\n", "\n", "word_convo_counts = defaultdict(set)\n", "for ut in sw_corpus.iter_utterances():\n", - " for x in set(ut.meta['alpha_text'].lower().split()):\n", + " for x in set(ut.meta[\"alpha_text\"].lower().split()):\n", " word_convo_counts[x].add(ut.conversation_id)\n", - "word_convo_counts = {x: len(y) for x, y in word_convo_counts.items()}\n", + "word_convo_counts = {x: len(y) for x, y in word_convo_counts.items()}\n", "\n", - "min_topic_words = set(x for x,y in topic_counts.items() if y >= 33)\n", - "min_convo_words = set(x for x,y in word_convo_counts.items() if y >= 200)\n", + "min_topic_words = set(x for x, y in topic_counts.items() if y >= 33)\n", + "min_convo_words = set(x for x, y in word_convo_counts.items() if y >= 200)\n", "vocab = sorted(min_topic_words.intersection(min_convo_words))" ] }, @@ -247,7 +248,10 @@ }, "outputs": [], "source": [ - "from convokit.expected_context_framework import ColNormedTfidfTransformer, ExpectedContextModelTransformer" + "from convokit.expected_context_framework import (\n", + " ColNormedTfidfTransformer,\n", + " ExpectedContextModelTransformer,\n", + ")" ] }, { @@ -267,7 +271,9 @@ }, "outputs": [], "source": [ - "tfidf_obj = ColNormedTfidfTransformer(input_field='alpha_text', output_field='col_normed_tfidf', binary=True, vocabulary=vocab)\n", + "tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"alpha_text\", output_field=\"col_normed_tfidf\", binary=True, vocabulary=vocab\n", + ")\n", "_ = tfidf_obj.fit(sw_corpus)\n", "_ = tfidf_obj.transform(sw_corpus)" ] @@ -296,10 +302,16 @@ }, "outputs": [], "source": [ - "ec_fw = ExpectedContextModelTransformer(context_field='next_id', output_prefix='fw', \n", - " vect_field='col_normed_tfidf', context_vect_field='col_normed_tfidf', \n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000)" + "ec_fw = ExpectedContextModelTransformer(\n", + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " vect_field=\"col_normed_tfidf\",\n", + " context_vect_field=\"col_normed_tfidf\",\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -317,8 +329,11 @@ }, "outputs": [], "source": [ - "ec_fw.fit(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5, \n", - " context_selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>= 5)" + "ec_fw.fit(\n", + " sw_corpus,\n", + " selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + " context_selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + ")" ] }, { @@ -338,13 +353,22 @@ }, "outputs": [], "source": [ - "ec_bk = ExpectedContextModelTransformer(context_field='reply_to', output_prefix='bk', \n", - " vect_field='col_normed_tfidf', context_vect_field='col_normed_tfidf', \n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000,\n", - " model=ec_fw)\n", - "ec_bk.fit(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5, \n", - " context_selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>= 5)" + "ec_bk = ExpectedContextModelTransformer(\n", + " context_field=\"reply_to\",\n", + " output_prefix=\"bk\",\n", + " vect_field=\"col_normed_tfidf\",\n", + " context_vect_field=\"col_normed_tfidf\",\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + " model=ec_fw,\n", + ")\n", + "ec_bk.fit(\n", + " sw_corpus,\n", + " selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + " context_selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + ")" ] }, { @@ -665,8 +689,8 @@ }, "outputs": [], "source": [ - "ec_fw.set_cluster_names(['commentary','personal'])\n", - "ec_bk.set_cluster_names(['personal', 'commentary'])" + "ec_fw.set_cluster_names([\"commentary\", \"personal\"])\n", + "ec_bk.set_cluster_names([\"personal\", \"commentary\"])" ] }, { @@ -757,9 +781,13 @@ }, "outputs": [], "source": [ - "term_df = pd.DataFrame({'index': ec_fw.get_terms(),\n", - " 'fw_range': ec_fw.get_term_ranges(),\n", - " 'bk_range': ec_bk.get_term_ranges()}).set_index('index')" + "term_df = pd.DataFrame(\n", + " {\n", + " \"index\": ec_fw.get_terms(),\n", + " \"fw_range\": ec_fw.get_term_ranges(),\n", + " \"bk_range\": ec_bk.get_term_ranges(),\n", + " }\n", + ").set_index(\"index\")" ] }, { @@ -865,10 +893,8 @@ }, "outputs": [], "source": [ - "term_df['orn'] = term_df.bk_range - term_df.fw_range\n", - "term_df['shift'] = paired_distances(\n", - " ec_fw.ec_model.term_reprs, ec_bk.ec_model.term_reprs\n", - " )" + "term_df[\"orn\"] = term_df.bk_range - term_df.fw_range\n", + "term_df[\"shift\"] = paired_distances(ec_fw.ec_model.term_reprs, ec_bk.ec_model.term_reprs)" ] }, { @@ -1271,15 +1297,15 @@ } ], "source": [ - "k=10\n", - "print('low orientation')\n", - "display(term_df.sort_values('orn').head(k)[['orn']])\n", - "print('high orientation')\n", - "display(term_df.sort_values('orn').tail(k)[['orn']])\n", - "print('\\nlow shift')\n", - "display(term_df.sort_values('shift').head(k)[['shift']])\n", - "print('high shift')\n", - "display(term_df.sort_values('shift').tail(k)[['shift']])" + "k = 10\n", + "print(\"low orientation\")\n", + "display(term_df.sort_values(\"orn\").head(k)[[\"orn\"]])\n", + "print(\"high orientation\")\n", + "display(term_df.sort_values(\"orn\").tail(k)[[\"orn\"]])\n", + "print(\"\\nlow shift\")\n", + "display(term_df.sort_values(\"shift\").head(k)[[\"shift\"]])\n", + "print(\"high shift\")\n", + "display(term_df.sort_values(\"shift\").tail(k)[[\"shift\"]])" ] }, { @@ -1308,8 +1334,8 @@ }, "outputs": [], "source": [ - "_ = ec_fw.transform(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5)\n", - "_ = ec_bk.transform(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5)" + "_ = ec_fw.transform(sw_corpus, selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5)\n", + "_ = ec_bk.transform(sw_corpus, selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5)" ] }, { @@ -1364,8 +1390,8 @@ ], "source": [ "eg_ut = sw_corpus.get_utterance(utt_eg_id)\n", - "print('Forwards range:', eg_ut.meta['fw_range'])\n", - "print('Backwards range:', eg_ut.meta['bk_range'])" + "print(\"Forwards range:\", eg_ut.meta[\"fw_range\"])\n", + "print(\"Backwards range:\", eg_ut.meta[\"bk_range\"])" ] }, { @@ -1390,8 +1416,8 @@ } ], "source": [ - "print('Forwards cluster:', eg_ut.meta['fw_clustering.cluster'])\n", - "print('Backwards cluster:', eg_ut.meta['bk_clustering.cluster'])" + "print(\"Forwards cluster:\", eg_ut.meta[\"fw_clustering.cluster\"])\n", + "print(\"Backwards cluster:\", eg_ut.meta[\"bk_clustering.cluster\"])" ] }, { @@ -1409,8 +1435,10 @@ }, "outputs": [], "source": [ - "for ut in sw_corpus.iter_utterances(selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5):\n", - " ut.meta['orn'] = ut.meta['bk_range'] - ut.meta['fw_range']" + "for ut in sw_corpus.iter_utterances(\n", + " selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5\n", + "):\n", + " ut.meta[\"orn\"] = ut.meta[\"bk_range\"] - ut.meta[\"fw_range\"]" ] }, { @@ -1428,9 +1456,9 @@ }, "outputs": [], "source": [ - "utt_shifts = paired_distances(sw_corpus.get_vectors('fw_repr'), sw_corpus.get_vectors('bk_repr'))\n", - "for id, shift in zip(sw_corpus.get_vector_matrix('fw_repr').ids, utt_shifts):\n", - " sw_corpus.get_utterance(id).meta['shift'] = shift" + "utt_shifts = paired_distances(sw_corpus.get_vectors(\"fw_repr\"), sw_corpus.get_vectors(\"bk_repr\"))\n", + "for id, shift in zip(sw_corpus.get_vector_matrix(\"fw_repr\").ids, utt_shifts):\n", + " sw_corpus.get_utterance(id).meta[\"shift\"] = shift" ] }, { @@ -1448,8 +1476,8 @@ } ], "source": [ - "print('shift:', eg_ut.meta['shift'])\n", - "print('orientation:', eg_ut.meta['orn'])" + "print(\"shift:\", eg_ut.meta[\"shift\"])\n", + "print(\"orientation:\", eg_ut.meta[\"orn\"])" ] }, { @@ -1494,10 +1522,10 @@ }, "outputs": [], "source": [ - "df = sw_corpus.get_attribute_table('utterance',\n", - " ['bk_clustering.cluster', 'fw_clustering.cluster',\n", - " 'orn', 'shift', 'tags'])\n", - "df = df[df['bk_clustering.cluster'].notnull()]" + "df = sw_corpus.get_attribute_table(\n", + " \"utterance\", [\"bk_clustering.cluster\", \"fw_clustering.cluster\", \"orn\", \"shift\", \"tags\"]\n", + ")\n", + "df = df[df[\"bk_clustering.cluster\"].notnull()]" ] }, { @@ -1515,9 +1543,9 @@ }, "outputs": [], "source": [ - "tag_subset = ['aa', 'b', 'ba', 'h', 'ny', 'qw', 'qy', 'sd', 'sv'] \n", + "tag_subset = [\"aa\", \"b\", \"ba\", \"h\", \"ny\", \"qw\", \"qy\", \"sd\", \"sv\"]\n", "for tag in tag_subset:\n", - " df['has_' + tag] = df.tags.apply(lambda x: tag in x.split())" + " df[\"has_\" + tag] = df.tags.apply(lambda x: tag in x.split())" ] }, { @@ -1546,7 +1574,9 @@ " val_false = sum((col == val) & ~bool_col)\n", " nval_true = sum((col != val) & bool_col)\n", " nval_false = sum((col != val) & ~bool_col)\n", - " log_odds_entries.append({'val': val, 'log_odds': np.log((val_true/val_false)/(nval_true/nval_false))})\n", + " log_odds_entries.append(\n", + " {\"val\": val, \"log_odds\": np.log((val_true / val_false) / (nval_true / nval_false))}\n", + " )\n", " return log_odds_entries" ] }, @@ -1560,10 +1590,10 @@ "source": [ "bk_log_odds = []\n", "for tag in tag_subset:\n", - " entry = compute_log_odds(df['bk_clustering.cluster'],df['has_' + tag], ['commentary'])[0]\n", - " entry['tag'] = tag\n", + " entry = compute_log_odds(df[\"bk_clustering.cluster\"], df[\"has_\" + tag], [\"commentary\"])[0]\n", + " entry[\"tag\"] = tag\n", " bk_log_odds.append(entry)\n", - "bk_log_odds_df = pd.DataFrame(bk_log_odds).set_index('tag').sort_values('log_odds')[['log_odds']]" + "bk_log_odds_df = pd.DataFrame(bk_log_odds).set_index(\"tag\").sort_values(\"log_odds\")[[\"log_odds\"]]" ] }, { @@ -1576,10 +1606,10 @@ "source": [ "fw_log_odds = []\n", "for tag in tag_subset:\n", - " entry = compute_log_odds(df['fw_clustering.cluster'],df['has_' + tag], ['commentary'])[0]\n", - " entry['tag'] = tag\n", + " entry = compute_log_odds(df[\"fw_clustering.cluster\"], df[\"has_\" + tag], [\"commentary\"])[0]\n", + " entry[\"tag\"] = tag\n", " fw_log_odds.append(entry)\n", - "fw_log_odds_df = pd.DataFrame(fw_log_odds).set_index('tag').sort_values('log_odds')[['log_odds']]" + "fw_log_odds_df = pd.DataFrame(fw_log_odds).set_index(\"tag\").sort_values(\"log_odds\")[[\"log_odds\"]]" ] }, { @@ -1724,10 +1754,10 @@ } ], "source": [ - "print('forwards types vs labels')\n", + "print(\"forwards types vs labels\")\n", "display(fw_log_odds_df.T)\n", - "print('--------------------------')\n", - "print('backwards types vs labels')\n", + "print(\"--------------------------\")\n", + "print(\"backwards types vs labels\")\n", "display(bk_log_odds_df.T)" ] }, @@ -1779,14 +1809,17 @@ " s = np.sqrt(((n1 - 1) * s1 + (n2 - 1) * s2) / (n1 + n2 - 2))\n", " u1, u2 = np.mean(d1), np.mean(d2)\n", " return (u1 - u2) / s\n", + "\n", + "\n", "def get_pstars(p):\n", - " if p < 0.001:\n", - " return '***'\n", + " if p < 0.001:\n", + " return \"***\"\n", " elif p < 0.01:\n", - " return '**'\n", + " return \"**\"\n", " elif p < 0.05:\n", - " return '*'\n", - " else: return ''" + " return \"*\"\n", + " else:\n", + " return \"\"" ] }, { @@ -1797,17 +1830,16 @@ }, "outputs": [], "source": [ - "stat_col = 'orn'\n", + "stat_col = \"orn\"\n", "entries = []\n", "for tag in tag_subset:\n", - " has = df[df['has_' + tag]][stat_col]\n", - " hasnt = df[~df['has_' + tag]][stat_col]\n", - " entry = {'tag': tag, 'pval': stats.mannwhitneyu(has, hasnt)[1],\n", - " 'cd': cohend(has, hasnt)}\n", - " entry['ps'] = get_pstars(entry['pval'] * len(tag_subset))\n", + " has = df[df[\"has_\" + tag]][stat_col]\n", + " hasnt = df[~df[\"has_\" + tag]][stat_col]\n", + " entry = {\"tag\": tag, \"pval\": stats.mannwhitneyu(has, hasnt)[1], \"cd\": cohend(has, hasnt)}\n", + " entry[\"ps\"] = get_pstars(entry[\"pval\"] * len(tag_subset))\n", " entries.append(entry)\n", - "orn_stat_df = pd.DataFrame(entries).set_index('tag').sort_values('cd')\n", - "orn_stat_df = orn_stat_df[np.abs(orn_stat_df.cd) >= .1]" + "orn_stat_df = pd.DataFrame(entries).set_index(\"tag\").sort_values(\"cd\")\n", + "orn_stat_df = orn_stat_df[np.abs(orn_stat_df.cd) >= 0.1]" ] }, { @@ -1818,17 +1850,16 @@ }, "outputs": [], "source": [ - "stat_col = 'shift'\n", + "stat_col = \"shift\"\n", "entries = []\n", "for tag in tag_subset:\n", - " has = df[df['has_' + tag]][stat_col]\n", - " hasnt = df[~df['has_' + tag]][stat_col]\n", - " entry = {'tag': tag, 'pval': stats.mannwhitneyu(has, hasnt)[1],\n", - " 'cd': cohend(has, hasnt)}\n", - " entry['ps'] = get_pstars(entry['pval'] * len(tag_subset))\n", + " has = df[df[\"has_\" + tag]][stat_col]\n", + " hasnt = df[~df[\"has_\" + tag]][stat_col]\n", + " entry = {\"tag\": tag, \"pval\": stats.mannwhitneyu(has, hasnt)[1], \"cd\": cohend(has, hasnt)}\n", + " entry[\"ps\"] = get_pstars(entry[\"pval\"] * len(tag_subset))\n", " entries.append(entry)\n", - "shift_stat_df = pd.DataFrame(entries).set_index('tag').sort_values('cd')\n", - "shift_stat_df = shift_stat_df[np.abs(shift_stat_df.cd) >= .1]" + "shift_stat_df = pd.DataFrame(entries).set_index(\"tag\").sort_values(\"cd\")\n", + "shift_stat_df = shift_stat_df[np.abs(shift_stat_df.cd) >= 0.1]" ] }, { @@ -1999,10 +2030,10 @@ } ], "source": [ - "print('orientation vs labels')\n", + "print(\"orientation vs labels\")\n", "display(orn_stat_df.T)\n", - "print('--------------------------')\n", - "print('shift vs labels')\n", + "print(\"--------------------------\")\n", + "print(\"shift vs labels\")\n", "display(shift_stat_df.T)" ] }, @@ -2044,7 +2075,7 @@ }, "outputs": [], "source": [ - "FW_MODEL_PATH = os.path.join(SW_CORPUS_PATH, 'fw')" + "FW_MODEL_PATH = os.path.join(SW_CORPUS_PATH, \"fw\")" ] }, { @@ -2102,9 +2133,16 @@ }, "outputs": [], "source": [ - "ec_fw_new = ExpectedContextModelTransformer('next_id', 'fw_new', 'col_normed_tfidf', 'col_normed_tfidf', \n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000)" + "ec_fw_new = ExpectedContextModelTransformer(\n", + " \"next_id\",\n", + " \"fw_new\",\n", + " \"col_normed_tfidf\",\n", + " \"col_normed_tfidf\",\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -2133,7 +2171,9 @@ }, "outputs": [], "source": [ - "_ = ec_fw_new.transform(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5)" + "_ = ec_fw_new.transform(\n", + " sw_corpus, selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5\n", + ")" ] }, { @@ -2153,7 +2193,7 @@ } ], "source": [ - "np.allclose(sw_corpus.get_vectors('fw_repr'), sw_corpus.get_vectors('fw_new_repr'))" + "np.allclose(sw_corpus.get_vectors(\"fw_repr\"), sw_corpus.get_vectors(\"fw_new_repr\"))" ] }, { @@ -2198,7 +2238,7 @@ "source": [ "# see `demo_text_pipelines.py` in this demo's directory for details\n", "# in short, this pipeline will either output the `alpha_text` metadata field\n", - "# of an utterance, or write the utterance's `text` attribute into the `alpha_text` \n", + "# of an utterance, or write the utterance's `text` attribute into the `alpha_text`\n", "# metadata field\n", "from demo_text_pipelines import switchboard_text_pipeline" ] @@ -2224,13 +2264,19 @@ }, "outputs": [], "source": [ - "fw_pipe = ExpectedContextModelPipeline(context_field='next_id', output_prefix='fw',\n", - " text_field='alpha_text',\n", - " text_pipe=switchboard_text_pipeline(), \n", - " tfidf_params={'binary': True, 'vocabulary': vocab}, \n", - " min_terms=5,\n", - " n_svd_dims=15, n_clusters=2, cluster_on='utts',\n", - " random_state=1000, cluster_random_state=1000)" + "fw_pipe = ExpectedContextModelPipeline(\n", + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " text_field=\"alpha_text\",\n", + " text_pipe=switchboard_text_pipeline(),\n", + " tfidf_params={\"binary\": True, \"vocabulary\": vocab},\n", + " min_terms=5,\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " cluster_on=\"utts\",\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -2261,14 +2307,20 @@ }, "outputs": [], "source": [ - "bk_pipe = ExpectedContextModelPipeline(context_field='reply_to', output_prefix='bk',\n", - " text_field='alpha_text',\n", - " text_pipe=switchboard_text_pipeline(), \n", - " tfidf_params={'binary': True, 'vocabulary': vocab}, \n", - " min_terms=5,\n", - " ec_model=fw_pipe,\n", - " n_svd_dims=15, n_clusters=2, cluster_on='utts',\n", - " random_state=1000, cluster_random_state=1000)" + "bk_pipe = ExpectedContextModelPipeline(\n", + " context_field=\"reply_to\",\n", + " output_prefix=\"bk\",\n", + " text_field=\"alpha_text\",\n", + " text_pipe=switchboard_text_pipeline(),\n", + " tfidf_params={\"binary\": True, \"vocabulary\": vocab},\n", + " min_terms=5,\n", + " ec_model=fw_pipe,\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " cluster_on=\"utts\",\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -2297,8 +2349,8 @@ }, "outputs": [], "source": [ - "fw_pipe.set_cluster_names(['commentary','personal'])\n", - "bk_pipe.set_cluster_names(['personal', 'commentary'])" + "fw_pipe.set_cluster_names([\"commentary\", \"personal\"])\n", + "bk_pipe.set_cluster_names([\"personal\", \"commentary\"])" ] }, { @@ -2316,7 +2368,7 @@ }, "outputs": [], "source": [ - "eg_ut_new = fw_pipe.transform_utterance('How old were you when you left ?')\n", + "eg_ut_new = fw_pipe.transform_utterance(\"How old were you when you left ?\")\n", "eg_ut_new = bk_pipe.transform_utterance(eg_ut_new)" ] }, @@ -2357,7 +2409,7 @@ } ], "source": [ - "eg_ut_new.meta['fw_repr']" + "eg_ut_new.meta[\"fw_repr\"]" ] }, { @@ -2378,10 +2430,10 @@ ], "source": [ "# note these attributes have the exact same values as those of eg_ut, computed above\n", - "print('Forwards range:', eg_ut_new.meta['fw_range'])\n", - "print('Backwards range:', eg_ut_new.meta['bk_range'])\n", - "print('Forwards cluster:', eg_ut_new.meta['fw_clustering.cluster'])\n", - "print('Backwards cluster:', eg_ut_new.meta['bk_clustering.cluster'])" + "print(\"Forwards range:\", eg_ut_new.meta[\"fw_range\"])\n", + "print(\"Backwards range:\", eg_ut_new.meta[\"bk_range\"])\n", + "print(\"Forwards cluster:\", eg_ut_new.meta[\"fw_clustering.cluster\"])\n", + "print(\"Backwards cluster:\", eg_ut_new.meta[\"bk_clustering.cluster\"])" ] } ], diff --git a/convokit/expected_context_framework/demos/switchboard_exploration_dual_demo.ipynb b/convokit/expected_context_framework/demos/switchboard_exploration_dual_demo.ipynb index 7e7ed112..cd7da4c2 100644 --- a/convokit/expected_context_framework/demos/switchboard_exploration_dual_demo.ipynb +++ b/convokit/expected_context_framework/demos/switchboard_exploration_dual_demo.ipynb @@ -23,7 +23,8 @@ "outputs": [], "source": [ "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -80,7 +81,7 @@ }, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# SW_CORPUS_PATH = download('switchboard-processed-corpus', data_dir=DATA_DIR)\n", @@ -128,7 +129,7 @@ }, "outputs": [], "source": [ - "utt_eg_id = '3496-79'" + "utt_eg_id = \"3496-79\"" ] }, { @@ -155,7 +156,7 @@ } ], "source": [ - "sw_corpus.get_utterance(utt_eg_id).meta['alpha_text']" + "sw_corpus.get_utterance(utt_eg_id).meta[\"alpha_text\"]" ] }, { @@ -195,19 +196,19 @@ "source": [ "topic_counts = defaultdict(set)\n", "for ut in sw_corpus.iter_utterances():\n", - " topic = sw_corpus.get_conversation(ut.conversation_id).meta['topic']\n", - " for x in set(ut.meta['alpha_text'].lower().split()):\n", + " topic = sw_corpus.get_conversation(ut.conversation_id).meta[\"topic\"]\n", + " for x in set(ut.meta[\"alpha_text\"].lower().split()):\n", " topic_counts[x].add(topic)\n", "topic_counts = {x: len(y) for x, y in topic_counts.items()}\n", "\n", "word_convo_counts = defaultdict(set)\n", "for ut in sw_corpus.iter_utterances():\n", - " for x in set(ut.meta['alpha_text'].lower().split()):\n", + " for x in set(ut.meta[\"alpha_text\"].lower().split()):\n", " word_convo_counts[x].add(ut.conversation_id)\n", - "word_convo_counts = {x: len(y) for x, y in word_convo_counts.items()}\n", + "word_convo_counts = {x: len(y) for x, y in word_convo_counts.items()}\n", "\n", - "min_topic_words = set(x for x,y in topic_counts.items() if y >= 33)\n", - "min_convo_words = set(x for x,y in word_convo_counts.items() if y >= 200)\n", + "min_topic_words = set(x for x, y in topic_counts.items() if y >= 33)\n", + "min_convo_words = set(x for x, y in word_convo_counts.items() if y >= 200)\n", "vocab = sorted(min_topic_words.intersection(min_convo_words))" ] }, @@ -268,7 +269,9 @@ }, "outputs": [], "source": [ - "tfidf_obj = ColNormedTfidfTransformer(input_field='alpha_text', output_field='col_normed_tfidf', binary=True, vocabulary=vocab)\n", + "tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"alpha_text\", output_field=\"col_normed_tfidf\", binary=True, vocabulary=vocab\n", + ")\n", "_ = tfidf_obj.fit(sw_corpus)\n", "_ = tfidf_obj.transform(sw_corpus)" ] @@ -296,10 +299,16 @@ }, "outputs": [], "source": [ - "dual_context_model = DualContextWrapper(context_fields=['reply_to','next_id'], output_prefixes=['bk','fw'],\n", - " vect_field='col_normed_tfidf', context_vect_field='col_normed_tfidf', \n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000)" + "dual_context_model = DualContextWrapper(\n", + " context_fields=[\"reply_to\", \"next_id\"],\n", + " output_prefixes=[\"bk\", \"fw\"],\n", + " vect_field=\"col_normed_tfidf\",\n", + " context_vect_field=\"col_normed_tfidf\",\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -317,8 +326,11 @@ }, "outputs": [], "source": [ - "dual_context_model.fit(sw_corpus,selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5, \n", - " context_selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>= 5)" + "dual_context_model.fit(\n", + " sw_corpus,\n", + " selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + " context_selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5,\n", + ")" ] }, { @@ -625,8 +637,8 @@ }, "outputs": [], "source": [ - "dual_context_model.ec_models[0].set_cluster_names(['personal', 'commentary'])\n", - "dual_context_model.ec_models[1].set_cluster_names(['commentary', 'personal'])" + "dual_context_model.ec_models[0].set_cluster_names([\"personal\", \"commentary\"])\n", + "dual_context_model.ec_models[1].set_cluster_names([\"commentary\", \"personal\"])" ] }, { @@ -1159,15 +1171,15 @@ } ], "source": [ - "k=10\n", - "print('low orientation')\n", - "display(term_df.sort_values('orn').head(k)[['orn']])\n", - "print('high orientation')\n", - "display(term_df.sort_values('orn').tail(k)[['orn']])\n", - "print('\\nlow shift')\n", - "display(term_df.sort_values('shift').head(k)[['shift']])\n", - "print('high shift')\n", - "display(term_df.sort_values('shift').tail(k)[['shift']])" + "k = 10\n", + "print(\"low orientation\")\n", + "display(term_df.sort_values(\"orn\").head(k)[[\"orn\"]])\n", + "print(\"high orientation\")\n", + "display(term_df.sort_values(\"orn\").tail(k)[[\"orn\"]])\n", + "print(\"\\nlow shift\")\n", + "display(term_df.sort_values(\"shift\").head(k)[[\"shift\"]])\n", + "print(\"high shift\")\n", + "display(term_df.sort_values(\"shift\").tail(k)[[\"shift\"]])" ] }, { @@ -1196,7 +1208,9 @@ }, "outputs": [], "source": [ - "_ = dual_context_model.transform(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5)" + "_ = dual_context_model.transform(\n", + " sw_corpus, selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5\n", + ")" ] }, { @@ -1251,8 +1265,8 @@ ], "source": [ "eg_ut = sw_corpus.get_utterance(utt_eg_id)\n", - "print('Forwards range:', eg_ut.meta['fw_range'])\n", - "print('Backwards range:', eg_ut.meta['bk_range'])" + "print(\"Forwards range:\", eg_ut.meta[\"fw_range\"])\n", + "print(\"Backwards range:\", eg_ut.meta[\"bk_range\"])" ] }, { @@ -1277,8 +1291,8 @@ } ], "source": [ - "print('Forwards cluster:', eg_ut.meta['fw_clustering.cluster'])\n", - "print('Backwards cluster:', eg_ut.meta['bk_clustering.cluster'])" + "print(\"Forwards cluster:\", eg_ut.meta[\"fw_clustering.cluster\"])\n", + "print(\"Backwards cluster:\", eg_ut.meta[\"bk_clustering.cluster\"])" ] }, { @@ -1303,8 +1317,8 @@ } ], "source": [ - "print('shift:', eg_ut.meta['shift'])\n", - "print('orientation:', eg_ut.meta['orn'])" + "print(\"shift:\", eg_ut.meta[\"shift\"])\n", + "print(\"orientation:\", eg_ut.meta[\"orn\"])" ] }, { @@ -1335,10 +1349,10 @@ }, "outputs": [], "source": [ - "df = sw_corpus.get_attribute_table('utterance',\n", - " ['bk_clustering.cluster', 'fw_clustering.cluster',\n", - " 'orn', 'shift', 'tags'])\n", - "df = df[df['bk_clustering.cluster'].notnull()]" + "df = sw_corpus.get_attribute_table(\n", + " \"utterance\", [\"bk_clustering.cluster\", \"fw_clustering.cluster\", \"orn\", \"shift\", \"tags\"]\n", + ")\n", + "df = df[df[\"bk_clustering.cluster\"].notnull()]" ] }, { @@ -1356,9 +1370,9 @@ }, "outputs": [], "source": [ - "tag_subset = ['aa', 'b', 'ba', 'h', 'ny', 'qw', 'qy', 'sd', 'sv'] \n", + "tag_subset = [\"aa\", \"b\", \"ba\", \"h\", \"ny\", \"qw\", \"qy\", \"sd\", \"sv\"]\n", "for tag in tag_subset:\n", - " df['has_' + tag] = df.tags.apply(lambda x: tag in x.split())" + " df[\"has_\" + tag] = df.tags.apply(lambda x: tag in x.split())" ] }, { @@ -1387,7 +1401,9 @@ " val_false = sum((col == val) & ~bool_col)\n", " nval_true = sum((col != val) & bool_col)\n", " nval_false = sum((col != val) & ~bool_col)\n", - " log_odds_entries.append({'val': val, 'log_odds': np.log((val_true/val_false)/(nval_true/nval_false))})\n", + " log_odds_entries.append(\n", + " {\"val\": val, \"log_odds\": np.log((val_true / val_false) / (nval_true / nval_false))}\n", + " )\n", " return log_odds_entries" ] }, @@ -1401,10 +1417,10 @@ "source": [ "bk_log_odds = []\n", "for tag in tag_subset:\n", - " entry = compute_log_odds(df['bk_clustering.cluster'],df['has_' + tag], ['commentary'])[0]\n", - " entry['tag'] = tag\n", + " entry = compute_log_odds(df[\"bk_clustering.cluster\"], df[\"has_\" + tag], [\"commentary\"])[0]\n", + " entry[\"tag\"] = tag\n", " bk_log_odds.append(entry)\n", - "bk_log_odds_df = pd.DataFrame(bk_log_odds).set_index('tag').sort_values('log_odds')[['log_odds']]" + "bk_log_odds_df = pd.DataFrame(bk_log_odds).set_index(\"tag\").sort_values(\"log_odds\")[[\"log_odds\"]]" ] }, { @@ -1417,10 +1433,10 @@ "source": [ "fw_log_odds = []\n", "for tag in tag_subset:\n", - " entry = compute_log_odds(df['fw_clustering.cluster'],df['has_' + tag], ['commentary'])[0]\n", - " entry['tag'] = tag\n", + " entry = compute_log_odds(df[\"fw_clustering.cluster\"], df[\"has_\" + tag], [\"commentary\"])[0]\n", + " entry[\"tag\"] = tag\n", " fw_log_odds.append(entry)\n", - "fw_log_odds_df = pd.DataFrame(fw_log_odds).set_index('tag').sort_values('log_odds')[['log_odds']]" + "fw_log_odds_df = pd.DataFrame(fw_log_odds).set_index(\"tag\").sort_values(\"log_odds\")[[\"log_odds\"]]" ] }, { @@ -1565,10 +1581,10 @@ } ], "source": [ - "print('forwards types vs labels')\n", + "print(\"forwards types vs labels\")\n", "display(fw_log_odds_df.T)\n", - "print('--------------------------')\n", - "print('backwards types vs labels')\n", + "print(\"--------------------------\")\n", + "print(\"backwards types vs labels\")\n", "display(bk_log_odds_df.T)" ] }, @@ -1620,14 +1636,17 @@ " s = np.sqrt(((n1 - 1) * s1 + (n2 - 1) * s2) / (n1 + n2 - 2))\n", " u1, u2 = np.mean(d1), np.mean(d2)\n", " return (u1 - u2) / s\n", + "\n", + "\n", "def get_pstars(p):\n", - " if p < 0.001:\n", - " return '***'\n", + " if p < 0.001:\n", + " return \"***\"\n", " elif p < 0.01:\n", - " return '**'\n", + " return \"**\"\n", " elif p < 0.05:\n", - " return '*'\n", - " else: return ''" + " return \"*\"\n", + " else:\n", + " return \"\"" ] }, { @@ -1638,17 +1657,16 @@ }, "outputs": [], "source": [ - "stat_col = 'orn'\n", + "stat_col = \"orn\"\n", "entries = []\n", "for tag in tag_subset:\n", - " has = df[df['has_' + tag]][stat_col]\n", - " hasnt = df[~df['has_' + tag]][stat_col]\n", - " entry = {'tag': tag, 'pval': stats.mannwhitneyu(has, hasnt)[1],\n", - " 'cd': cohend(has, hasnt)}\n", - " entry['ps'] = get_pstars(entry['pval'] * len(tag_subset))\n", + " has = df[df[\"has_\" + tag]][stat_col]\n", + " hasnt = df[~df[\"has_\" + tag]][stat_col]\n", + " entry = {\"tag\": tag, \"pval\": stats.mannwhitneyu(has, hasnt)[1], \"cd\": cohend(has, hasnt)}\n", + " entry[\"ps\"] = get_pstars(entry[\"pval\"] * len(tag_subset))\n", " entries.append(entry)\n", - "orn_stat_df = pd.DataFrame(entries).set_index('tag').sort_values('cd')\n", - "orn_stat_df = orn_stat_df[np.abs(orn_stat_df.cd) >= .1]" + "orn_stat_df = pd.DataFrame(entries).set_index(\"tag\").sort_values(\"cd\")\n", + "orn_stat_df = orn_stat_df[np.abs(orn_stat_df.cd) >= 0.1]" ] }, { @@ -1659,17 +1677,16 @@ }, "outputs": [], "source": [ - "stat_col = 'shift'\n", + "stat_col = \"shift\"\n", "entries = []\n", "for tag in tag_subset:\n", - " has = df[df['has_' + tag]][stat_col]\n", - " hasnt = df[~df['has_' + tag]][stat_col]\n", - " entry = {'tag': tag, 'pval': stats.mannwhitneyu(has, hasnt)[1],\n", - " 'cd': cohend(has, hasnt)}\n", - " entry['ps'] = get_pstars(entry['pval'] * len(tag_subset))\n", + " has = df[df[\"has_\" + tag]][stat_col]\n", + " hasnt = df[~df[\"has_\" + tag]][stat_col]\n", + " entry = {\"tag\": tag, \"pval\": stats.mannwhitneyu(has, hasnt)[1], \"cd\": cohend(has, hasnt)}\n", + " entry[\"ps\"] = get_pstars(entry[\"pval\"] * len(tag_subset))\n", " entries.append(entry)\n", - "shift_stat_df = pd.DataFrame(entries).set_index('tag').sort_values('cd')\n", - "shift_stat_df = shift_stat_df[np.abs(shift_stat_df.cd) >= .1]" + "shift_stat_df = pd.DataFrame(entries).set_index(\"tag\").sort_values(\"cd\")\n", + "shift_stat_df = shift_stat_df[np.abs(shift_stat_df.cd) >= 0.1]" ] }, { @@ -1840,10 +1857,10 @@ } ], "source": [ - "print('orientation vs labels')\n", + "print(\"orientation vs labels\")\n", "display(orn_stat_df.T)\n", - "print('--------------------------')\n", - "print('shift vs labels')\n", + "print(\"--------------------------\")\n", + "print(\"shift vs labels\")\n", "display(shift_stat_df.T)" ] }, @@ -1889,7 +1906,7 @@ }, "outputs": [], "source": [ - "DUAL_MODEL_PATH = os.path.join(SW_CORPUS_PATH, 'dual_model')" + "DUAL_MODEL_PATH = os.path.join(SW_CORPUS_PATH, \"dual_model\")" ] }, { @@ -1942,11 +1959,17 @@ }, "outputs": [], "source": [ - "dual_model_new = DualContextWrapper(context_fields=['reply_to','next_id'], output_prefixes=['bk_new','fw_new'],\n", - " vect_field='col_normed_tfidf', context_vect_field='col_normed_tfidf', \n", - " wrapper_output_prefix='new',\n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000)" + "dual_model_new = DualContextWrapper(\n", + " context_fields=[\"reply_to\", \"next_id\"],\n", + " output_prefixes=[\"bk_new\", \"fw_new\"],\n", + " vect_field=\"col_normed_tfidf\",\n", + " context_vect_field=\"col_normed_tfidf\",\n", + " wrapper_output_prefix=\"new\",\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -1957,7 +1980,7 @@ }, "outputs": [], "source": [ - "dual_model_new.load(DUAL_MODEL_PATH, model_dirs=['bk','fw'])" + "dual_model_new.load(DUAL_MODEL_PATH, model_dirs=[\"bk\", \"fw\"])" ] }, { @@ -1975,7 +1998,9 @@ }, "outputs": [], "source": [ - "_ = dual_model_new.transform(sw_corpus, selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5)" + "_ = dual_model_new.transform(\n", + " sw_corpus, selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5\n", + ")" ] }, { @@ -2015,7 +2040,7 @@ } ], "source": [ - "np.allclose(sw_corpus.get_vectors('bk_new_repr'), sw_corpus.get_vectors('bk_repr'))" + "np.allclose(sw_corpus.get_vectors(\"bk_new_repr\"), sw_corpus.get_vectors(\"bk_repr\"))" ] }, { @@ -2035,7 +2060,7 @@ } ], "source": [ - "np.allclose(sw_corpus.get_vectors('fw_new_repr'), sw_corpus.get_vectors('fw_repr'))" + "np.allclose(sw_corpus.get_vectors(\"fw_new_repr\"), sw_corpus.get_vectors(\"fw_repr\"))" ] }, { @@ -2046,9 +2071,11 @@ }, "outputs": [], "source": [ - "for ut in sw_corpus.iter_utterances(selector=lambda x: x.meta.get('col_normed_tfidf__n_feats',0)>=5):\n", - " assert ut.meta['orn'] == ut.meta['new_orn']\n", - " assert ut.meta['shift'] == ut.meta['new_shift']" + "for ut in sw_corpus.iter_utterances(\n", + " selector=lambda x: x.meta.get(\"col_normed_tfidf__n_feats\", 0) >= 5\n", + "):\n", + " assert ut.meta[\"orn\"] == ut.meta[\"new_orn\"]\n", + " assert ut.meta[\"shift\"] == ut.meta[\"new_shift\"]" ] }, { @@ -2093,7 +2120,7 @@ "source": [ "# see `demo_text_pipelines.py` in this demo's directory for details\n", "# in short, this pipeline will either output the `alpha_text` metadata field\n", - "# of an utterance, or write the utterance's `text` attribute into the `alpha_text` \n", + "# of an utterance, or write the utterance's `text` attribute into the `alpha_text`\n", "# metadata field\n", "from demo_text_pipelines import switchboard_text_pipeline" ] @@ -2119,13 +2146,18 @@ }, "outputs": [], "source": [ - "pipe_obj = DualContextPipeline(context_fields=['reply_to','next_id'], \n", - " output_prefixes=['bk','fw'],\n", - " text_field='alpha_text', text_pipe=switchboard_text_pipeline(), \n", - " tfidf_params={'binary': True, 'vocabulary': vocab}, \n", - " min_terms=5,\n", - " n_svd_dims=15, n_clusters=2,\n", - " random_state=1000, cluster_random_state=1000)" + "pipe_obj = DualContextPipeline(\n", + " context_fields=[\"reply_to\", \"next_id\"],\n", + " output_prefixes=[\"bk\", \"fw\"],\n", + " text_field=\"alpha_text\",\n", + " text_pipe=switchboard_text_pipeline(),\n", + " tfidf_params={\"binary\": True, \"vocabulary\": vocab},\n", + " min_terms=5,\n", + " n_svd_dims=15,\n", + " n_clusters=2,\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -2162,7 +2194,7 @@ }, "outputs": [], "source": [ - "eg_ut_new = pipe_obj.transform_utterance('How old were you when you left ?')" + "eg_ut_new = pipe_obj.transform_utterance(\"How old were you when you left ?\")" ] }, { @@ -2183,8 +2215,8 @@ ], "source": [ "# note these attributes have the exact same values as those of eg_ut, computed above\n", - "print('shift:', eg_ut_new.meta['shift'])\n", - "print('orientation:', eg_ut_new.meta['orn'])" + "print(\"shift:\", eg_ut_new.meta[\"shift\"])\n", + "print(\"orientation:\", eg_ut_new.meta[\"orn\"])" ] }, { diff --git a/convokit/expected_context_framework/demos/wiki_awry_demo.ipynb b/convokit/expected_context_framework/demos/wiki_awry_demo.ipynb index 1f06b867..504e32e4 100644 --- a/convokit/expected_context_framework/demos/wiki_awry_demo.ipynb +++ b/convokit/expected_context_framework/demos/wiki_awry_demo.ipynb @@ -32,7 +32,8 @@ "outputs": [], "source": [ "import warnings\n", - "warnings.filterwarnings('ignore')" + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -85,7 +86,7 @@ }, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# WIKI_CORPUS_PATH = download('wiki-corpus', data_dir=DATA_DIR)\n", @@ -140,7 +141,7 @@ }, "outputs": [], "source": [ - "wiki_corpus.load_info('utterance',['arcs_censored'])" + "wiki_corpus.load_info(\"utterance\", [\"arcs_censored\"])" ] }, { @@ -159,8 +160,10 @@ "outputs": [], "source": [ "from convokit.text_processing import TextProcessor\n", - "join_arcs = TextProcessor(input_field='arcs_censored', output_field='arcs',\n", - " proc_fn=lambda sents: '\\n'.join(sents))\n", + "\n", + "join_arcs = TextProcessor(\n", + " input_field=\"arcs_censored\", output_field=\"arcs\", proc_fn=lambda sents: \"\\n\".join(sents)\n", + ")\n", "wiki_corpus = join_arcs.transform(wiki_corpus)" ] }, @@ -180,7 +183,7 @@ "outputs": [], "source": [ "for ut in wiki_corpus.iter_utterances(selector=lambda x: x.reply_to is not None):\n", - " wiki_corpus.get_utterance(ut.reply_to).meta['next_id'] = ut.id" + " wiki_corpus.get_utterance(ut.reply_to).meta[\"next_id\"] = ut.id" ] }, { @@ -207,7 +210,10 @@ }, "outputs": [], "source": [ - "from convokit.expected_context_framework import ColNormedTfidfTransformer, ExpectedContextModelTransformer" + "from convokit.expected_context_framework import (\n", + " ColNormedTfidfTransformer,\n", + " ExpectedContextModelTransformer,\n", + ")" ] }, { @@ -229,11 +235,15 @@ }, "outputs": [], "source": [ - "first_tfidf_obj = ColNormedTfidfTransformer(input_field='arcs', output_field='first_tfidf', binary=True, min_df=50)\n", - "_ = first_tfidf_obj.fit(wiki_corpus, selector=lambda x: x.meta.get('next_id',None) is not None)\n", + "first_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"arcs\", output_field=\"first_tfidf\", binary=True, min_df=50\n", + ")\n", + "_ = first_tfidf_obj.fit(wiki_corpus, selector=lambda x: x.meta.get(\"next_id\", None) is not None)\n", "_ = first_tfidf_obj.transform(wiki_corpus)\n", "\n", - "second_tfidf_obj = ColNormedTfidfTransformer(input_field='arcs', output_field='second_tfidf', binary=True, min_df=50)\n", + "second_tfidf_obj = ColNormedTfidfTransformer(\n", + " input_field=\"arcs\", output_field=\"second_tfidf\", binary=True, min_df=50\n", + ")\n", "_ = second_tfidf_obj.fit(wiki_corpus, selector=lambda x: x.reply_to is not None)\n", "_ = second_tfidf_obj.transform(wiki_corpus)" ] @@ -258,10 +268,16 @@ "outputs": [], "source": [ "ec_fw = ExpectedContextModelTransformer(\n", - " context_field='next_id', output_prefix='fw', \n", - " vect_field='first_tfidf', context_vect_field='second_tfidf', \n", - " n_svd_dims=25, n_clusters=6, cluster_on='terms',\n", - " random_state=1000, cluster_random_state=1000)" + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " vect_field=\"first_tfidf\",\n", + " context_vect_field=\"second_tfidf\",\n", + " n_svd_dims=25,\n", + " n_clusters=6,\n", + " cluster_on=\"terms\",\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -279,10 +295,13 @@ }, "outputs": [], "source": [ - "ec_fw.fit(wiki_corpus, selector=lambda x: (x.meta.get('first_tfidf__n_feats',0)>=1)\n", - " and (x.meta.get('next_id',None) is not None), \n", - " context_selector=lambda x: (x.meta.get('second_tfidf__n_feats',0)>= 1)\n", - " and (x.reply_to is not None))" + "ec_fw.fit(\n", + " wiki_corpus,\n", + " selector=lambda x: (x.meta.get(\"first_tfidf__n_feats\", 0) >= 1)\n", + " and (x.meta.get(\"next_id\", None) is not None),\n", + " context_selector=lambda x: (x.meta.get(\"second_tfidf__n_feats\", 0) >= 1)\n", + " and (x.reply_to is not None),\n", + ")" ] }, { @@ -664,7 +683,7 @@ } ], "source": [ - "ec_fw.print_clusters(k=10,corpus=wiki_corpus,max_chars=200)" + "ec_fw.print_clusters(k=10, corpus=wiki_corpus, max_chars=200)" ] }, { @@ -689,9 +708,9 @@ }, "outputs": [], "source": [ - "ec_fw.set_cluster_names(['casual', 'coordination', \n", - " 'procedures', 'contention',\n", - " 'editing', 'moderation'])" + "ec_fw.set_cluster_names(\n", + " [\"casual\", \"coordination\", \"procedures\", \"contention\", \"editing\", \"moderation\"]\n", + ")" ] }, { @@ -818,7 +837,7 @@ }, "outputs": [], "source": [ - "# OPTION 1: DOWNLOAD CORPUS \n", + "# OPTION 1: DOWNLOAD CORPUS\n", "# UNCOMMENT THESE LINES TO DOWNLOAD CORPUS\n", "# DATA_DIR = ''\n", "# AWRY_CORPUS_PATH = download('conversations-gone-awry-corpus', data_dir=DATA_DIR)\n", @@ -847,7 +866,9 @@ }, "outputs": [], "source": [ - "awry_corpus = awry_corpus.filter_conversations_by(lambda convo: convo.meta['annotation_year'] == '2018')\n", + "awry_corpus = awry_corpus.filter_conversations_by(\n", + " lambda convo: convo.meta[\"annotation_year\"] == \"2018\"\n", + ")\n", "# here we filter to consider only the conversations from the original paper" ] }, @@ -885,7 +906,7 @@ }, "outputs": [], "source": [ - "awry_corpus.load_info('utterance',['parsed'])" + "awry_corpus.load_info(\"utterance\", [\"parsed\"])" ] }, { @@ -897,6 +918,7 @@ "outputs": [], "source": [ "from demo_text_pipelines import wiki_arc_pipeline\n", + "\n", "# see `demo_text_pipelines.py` in this demo's directory for details\n", "# in short, this pipeline will compute the dependency-parse arcs we use as input features,\n", "# but will skip over utterances for which these attributes already exist\n", @@ -955,10 +977,15 @@ }, "outputs": [], "source": [ - "cluster_assign_df = awry_corpus.get_attribute_table('utterance',['fw_clustering.cluster_id_'])\n", + "cluster_assign_df = awry_corpus.get_attribute_table(\"utterance\", [\"fw_clustering.cluster_id_\"])\n", "type_assignments = np.zeros((len(cluster_assign_df), 6))\n", - "type_assignments[np.arange(len(cluster_assign_df)),cluster_assign_df['fw_clustering.cluster_id_'].values.astype(int)] = 1\n", - "cluster_assign_df = pd.DataFrame(columns=np.arange(6), index=cluster_assign_df.index, data=type_assignments)\n", + "type_assignments[\n", + " np.arange(len(cluster_assign_df)),\n", + " cluster_assign_df[\"fw_clustering.cluster_id_\"].values.astype(int),\n", + "] = 1\n", + "cluster_assign_df = pd.DataFrame(\n", + " columns=np.arange(6), index=cluster_assign_df.index, data=type_assignments\n", + ")\n", "cluster_assign_df.columns = ec_fw.get_cluster_names()" ] }, @@ -1122,7 +1149,9 @@ " convo_ids.append(comment.root)\n", " timestamps.append(comment.timestamp)\n", " page_ids.append(conversation.meta[\"page_id\"])\n", - "comment_df = pd.DataFrame({\"conversation_id\": convo_ids, \"timestamp\": timestamps, \"page_id\": page_ids}, index=comment_ids)\n", + "comment_df = pd.DataFrame(\n", + " {\"conversation_id\": convo_ids, \"timestamp\": timestamps, \"page_id\": page_ids}, index=comment_ids\n", + ")\n", "\n", "# we'll do our construction using awry conversation ID's as the reference key\n", "awry_convo_ids = set()\n", @@ -1130,14 +1159,21 @@ "good_convo_map = {}\n", "page_id_map = {}\n", "for conversation in awry_corpus.iter_conversations():\n", - " if conversation.meta[\"conversation_has_personal_attack\"] and conversation.id not in awry_convo_ids:\n", + " if (\n", + " conversation.meta[\"conversation_has_personal_attack\"]\n", + " and conversation.id not in awry_convo_ids\n", + " ):\n", " awry_convo_ids.add(conversation.id)\n", " good_convo_map[conversation.id] = conversation.meta[\"pair_id\"]\n", " page_id_map[conversation.id] = conversation.meta[\"page_id\"]\n", "awry_convo_ids = list(awry_convo_ids)\n", - "pairs_df = pd.DataFrame({\"bad_conversation_id\": awry_convo_ids,\n", - " \"conversation_id\": [good_convo_map[cid] for cid in awry_convo_ids],\n", - " \"page_id\": [page_id_map[cid] for cid in awry_convo_ids]})\n", + "pairs_df = pd.DataFrame(\n", + " {\n", + " \"bad_conversation_id\": awry_convo_ids,\n", + " \"conversation_id\": [good_convo_map[cid] for cid in awry_convo_ids],\n", + " \"page_id\": [page_id_map[cid] for cid in awry_convo_ids],\n", + " }\n", + ")\n", "# finally, we will augment the pairs dataframe with the IDs of the first and second comment for both\n", "# the bad and good conversation. This will come in handy for constructing the feature matrix.\n", "first_ids = []\n", @@ -1146,14 +1182,22 @@ "second_ids_bad = []\n", "for row in pairs_df.itertuples():\n", " # \"first two\" is defined in terms of time of posting\n", - " comments_sorted = comment_df[comment_df.conversation_id==row.conversation_id].sort_values(by=\"timestamp\")\n", + " comments_sorted = comment_df[comment_df.conversation_id == row.conversation_id].sort_values(\n", + " by=\"timestamp\"\n", + " )\n", " first_ids.append(comments_sorted.iloc[0].name)\n", " second_ids.append(comments_sorted.iloc[1].name)\n", - " comments_sorted_bad = comment_df[comment_df.conversation_id==row.bad_conversation_id].sort_values(by=\"timestamp\")\n", + " comments_sorted_bad = comment_df[\n", + " comment_df.conversation_id == row.bad_conversation_id\n", + " ].sort_values(by=\"timestamp\")\n", " first_ids_bad.append(comments_sorted_bad.iloc[0].name)\n", " second_ids_bad.append(comments_sorted_bad.iloc[1].name)\n", - "pairs_df = pairs_df.assign(first_id=first_ids, second_id=second_ids, \n", - " bad_first_id=first_ids_bad, bad_second_id=second_ids_bad)" + "pairs_df = pairs_df.assign(\n", + " first_id=first_ids,\n", + " second_id=second_ids,\n", + " bad_first_id=first_ids_bad,\n", + " bad_second_id=second_ids_bad,\n", + ")" ] }, { @@ -1164,11 +1208,19 @@ }, "outputs": [], "source": [ - "tox_first_comment_features =pairs_df[['bad_first_id']].join(cluster_assign_df, how='left', on='bad_first_id')[cluster_assign_df.columns]\n", - "ntox_first_comment_features =pairs_df[['first_id']].join(cluster_assign_df, how='left', on='first_id')[cluster_assign_df.columns]\n", + "tox_first_comment_features = pairs_df[[\"bad_first_id\"]].join(\n", + " cluster_assign_df, how=\"left\", on=\"bad_first_id\"\n", + ")[cluster_assign_df.columns]\n", + "ntox_first_comment_features = pairs_df[[\"first_id\"]].join(\n", + " cluster_assign_df, how=\"left\", on=\"first_id\"\n", + ")[cluster_assign_df.columns]\n", "\n", - "tox_second_comment_features =pairs_df[['bad_second_id']].join(cluster_assign_df, how='left', on='bad_second_id')[cluster_assign_df.columns]\n", - "ntox_second_comment_features =pairs_df[['second_id']].join(cluster_assign_df, how='left', on='second_id')[cluster_assign_df.columns]" + "tox_second_comment_features = pairs_df[[\"bad_second_id\"]].join(\n", + " cluster_assign_df, how=\"left\", on=\"bad_second_id\"\n", + ")[cluster_assign_df.columns]\n", + "ntox_second_comment_features = pairs_df[[\"second_id\"]].join(\n", + " cluster_assign_df, how=\"left\", on=\"second_id\"\n", + ")[cluster_assign_df.columns]" ] }, { @@ -1207,36 +1259,51 @@ "outputs": [], "source": [ "def get_p_stars(x):\n", - " if x < .001: return '***'\n", - " elif x < .01: return '**'\n", - " elif x < .05: return '*'\n", - " else: return ''\n", - "def compare_tox(df_ntox, df_tox, min_n=0):\n", + " if x < 0.001:\n", + " return \"***\"\n", + " elif x < 0.01:\n", + " return \"**\"\n", + " elif x < 0.05:\n", + " return \"*\"\n", + " else:\n", + " return \"\"\n", + "\n", + "\n", + "def compare_tox(df_ntox, df_tox, min_n=0):\n", " cols = df_ntox.columns\n", - " num_feats_in_tox = df_tox[cols].sum().astype(int).rename('num_feat_tox')\n", - " num_nfeats_in_tox = (1 - df_tox[cols]).sum().astype(int).rename('num_nfeat_tox')\n", - " num_feats_in_ntox = df_ntox[cols].sum().astype(int).rename('num_feat_ntox')\n", - " num_nfeats_in_ntox = (1 - df_ntox[cols]).sum().astype(int).rename('num_nfeat_ntox')\n", - " prop_tox = df_tox[cols].mean().rename('prop_tox')\n", - " ref_prop_ntox = df_ntox[cols].mean().rename('prop_ntox')\n", + " num_feats_in_tox = df_tox[cols].sum().astype(int).rename(\"num_feat_tox\")\n", + " num_nfeats_in_tox = (1 - df_tox[cols]).sum().astype(int).rename(\"num_nfeat_tox\")\n", + " num_feats_in_ntox = df_ntox[cols].sum().astype(int).rename(\"num_feat_ntox\")\n", + " num_nfeats_in_ntox = (1 - df_ntox[cols]).sum().astype(int).rename(\"num_nfeat_ntox\")\n", + " prop_tox = df_tox[cols].mean().rename(\"prop_tox\")\n", + " ref_prop_ntox = df_ntox[cols].mean().rename(\"prop_ntox\")\n", " n_tox = len(df_tox)\n", - " df = pd.concat([\n", - " num_feats_in_tox, \n", - " num_nfeats_in_tox,\n", - " num_feats_in_ntox,\n", - " num_nfeats_in_ntox,\n", - " prop_tox,\n", - " ref_prop_ntox,\n", - " ], axis=1)\n", - " df['num_total'] = df.num_feat_tox + df.num_feat_ntox\n", - " df['log_odds'] = np.log(df.num_feat_tox) - np.log(df.num_nfeat_tox) \\\n", - " + np.log(df.num_nfeat_ntox) - np.log(df.num_feat_ntox)\n", - " df['abs_log_odds'] = np.abs(df.log_odds)\n", - " df['binom_p'] = df.apply(lambda x: stats.binom_test(x.num_feat_tox, n_tox, x.prop_ntox), axis=1)#*5\n", + " df = pd.concat(\n", + " [\n", + " num_feats_in_tox,\n", + " num_nfeats_in_tox,\n", + " num_feats_in_ntox,\n", + " num_nfeats_in_ntox,\n", + " prop_tox,\n", + " ref_prop_ntox,\n", + " ],\n", + " axis=1,\n", + " )\n", + " df[\"num_total\"] = df.num_feat_tox + df.num_feat_ntox\n", + " df[\"log_odds\"] = (\n", + " np.log(df.num_feat_tox)\n", + " - np.log(df.num_nfeat_tox)\n", + " + np.log(df.num_nfeat_ntox)\n", + " - np.log(df.num_feat_ntox)\n", + " )\n", + " df[\"abs_log_odds\"] = np.abs(df.log_odds)\n", + " df[\"binom_p\"] = df.apply(\n", + " lambda x: stats.binom_test(x.num_feat_tox, n_tox, x.prop_ntox), axis=1\n", + " ) # *5\n", " df = df[df.num_total >= min_n]\n", - " df['p'] = df['binom_p'].apply(lambda x: '%.3f' % x)\n", - " df['pstars'] = df['binom_p'].apply(get_p_stars)\n", - " return df.sort_values('log_odds', ascending=False)" + " df[\"p\"] = df[\"binom_p\"].apply(lambda x: \"%.3f\" % x)\n", + " df[\"pstars\"] = df[\"binom_p\"].apply(get_p_stars)\n", + " return df.sort_values(\"log_odds\", ascending=False)" ] }, { @@ -1276,6 +1343,7 @@ "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline" ] }, @@ -1287,84 +1355,104 @@ }, "outputs": [], "source": [ - "# we are now ready to plot these comparisons. the following (rather intimidating) helper function \n", + "# we are now ready to plot these comparisons. the following (rather intimidating) helper function\n", "# produces a nicely-formatted plot:\n", - "def draw_figure(ax, first_cmp, second_cmp, title='', prompt_types=6, min_log_odds=.2, min_freq=50,xlim=.85):\n", + "def draw_figure(\n", + " ax, first_cmp, second_cmp, title=\"\", prompt_types=6, min_log_odds=0.2, min_freq=50, xlim=0.85\n", + "):\n", "\n", " # selecting and sorting the features to plot, given minimum effect sizes and statistical significance\n", - " frequent_feats = first_cmp[first_cmp.num_total >= min_freq].index.union(second_cmp[second_cmp.num_total >= min_freq].index)\n", - " lrg_effect_feats = first_cmp[(first_cmp.abs_log_odds >= .2)\n", - " & (first_cmp.binom_p < .05)].index.union(second_cmp[(second_cmp.abs_log_odds >= .2)\n", - " & (second_cmp.binom_p < .05)].index)\n", - "# feats_to_include = frequent_feats.intersection(lrg_effect_feats)\n", + " frequent_feats = first_cmp[first_cmp.num_total >= min_freq].index.union(\n", + " second_cmp[second_cmp.num_total >= min_freq].index\n", + " )\n", + " lrg_effect_feats = first_cmp[\n", + " (first_cmp.abs_log_odds >= 0.2) & (first_cmp.binom_p < 0.05)\n", + " ].index.union(second_cmp[(second_cmp.abs_log_odds >= 0.2) & (second_cmp.binom_p < 0.05)].index)\n", + " # feats_to_include = frequent_feats.intersection(lrg_effect_feats)\n", " feats_to_include = first_cmp.index\n", " feat_order = sorted(feats_to_include, key=lambda x: first_cmp.loc[x].log_odds, reverse=True)\n", "\n", " # parameters determining the look of the figure\n", - " colors = ['blue', 'grey']\n", - " shapes = ['^', 's'] \n", - " eps = .02\n", - " star_eps = .035\n", + " colors = [\"blue\", \"grey\"]\n", + " shapes = [\"^\", \"s\"]\n", + " eps = 0.02\n", + " star_eps = 0.035\n", " xlim = xlim\n", - " min_log = .2\n", + " min_log = 0.2\n", " gap_prop = 2\n", " label_size = 14\n", - " title_size=18\n", + " title_size = 18\n", " radius = 256\n", " features = feat_order\n", " ax.invert_yaxis()\n", - " ax.plot([0,0], [0, len(features)/gap_prop], color='black')\n", - " \n", - " # for each figure we plot the point according to effect size in the first and second comment, \n", + " ax.plot([0, 0], [0, len(features) / gap_prop], color=\"black\")\n", + "\n", + " # for each figure we plot the point according to effect size in the first and second comment,\n", " # and add axis labels denoting statistical significance\n", " yticks = []\n", " yticklabels = []\n", " for f_idx, feat in enumerate(features):\n", - " curr_y = (f_idx + .5)/gap_prop\n", + " curr_y = (f_idx + 0.5) / gap_prop\n", " yticks.append(curr_y)\n", " try:\n", - " \n", + "\n", " first_p = first_cmp.loc[feat].binom_p\n", - " second_p = second_cmp.loc[feat].binom_p \n", + " second_p = second_cmp.loc[feat].binom_p\n", " if first_cmp.loc[feat].abs_log_odds < min_log:\n", " first_face = \"white\"\n", " elif first_p >= 0.05:\n", - " first_face = 'white'\n", + " first_face = \"white\"\n", " else:\n", " first_face = colors[0]\n", " if second_cmp.loc[feat].abs_log_odds < min_log:\n", " second_face = \"white\"\n", " elif second_p >= 0.05:\n", - " second_face = 'white'\n", + " second_face = \"white\"\n", " else:\n", " second_face = colors[1]\n", - " ax.plot([-1 * xlim, xlim], [curr_y, curr_y], '--', color='grey', zorder=0, linewidth=.5)\n", - " \n", - " ax.scatter([first_cmp.loc[feat].log_odds], [curr_y + eps], s=radius, edgecolor=colors[0], marker=shapes[0],\n", - " zorder=20, facecolors=first_face)\n", - " ax.scatter([second_cmp.loc[feat].log_odds], [curr_y + eps], s=radius, edgecolor=colors[1], marker=shapes[1], \n", - " zorder=10, facecolors=second_face)\n", - " \n", + " ax.plot(\n", + " [-1 * xlim, xlim], [curr_y, curr_y], \"--\", color=\"grey\", zorder=0, linewidth=0.5\n", + " )\n", + "\n", + " ax.scatter(\n", + " [first_cmp.loc[feat].log_odds],\n", + " [curr_y + eps],\n", + " s=radius,\n", + " edgecolor=colors[0],\n", + " marker=shapes[0],\n", + " zorder=20,\n", + " facecolors=first_face,\n", + " )\n", + " ax.scatter(\n", + " [second_cmp.loc[feat].log_odds],\n", + " [curr_y + eps],\n", + " s=radius,\n", + " edgecolor=colors[1],\n", + " marker=shapes[1],\n", + " zorder=10,\n", + " facecolors=second_face,\n", + " )\n", + "\n", " first_pstr_len = len(get_p_stars(first_p))\n", " second_pstr_len = len(get_p_stars(second_p))\n", - " p_str = np.array([' '] * 8)\n", + " p_str = np.array([\" \"] * 8)\n", " if first_pstr_len > 0:\n", - " p_str[:first_pstr_len] = '*'\n", + " p_str[:first_pstr_len] = \"*\"\n", " if second_pstr_len > 0:\n", - " p_str[-second_pstr_len:] = '⁺'\n", - " \n", - " feat_str = str(feat) + '\\n' + ''.join(p_str)\n", + " p_str[-second_pstr_len:] = \"⁺\"\n", + "\n", + " feat_str = str(feat) + \"\\n\" + \"\".join(p_str)\n", " yticklabels.append(feat_str)\n", " except Exception as e:\n", - " yticklabels.append('')\n", - " \n", + " yticklabels.append(\"\")\n", + "\n", " # add the axis labels\n", - " ax.set_xlabel('log-odds ratio', fontsize=28)\n", - " ax.set_xticks([-xlim-.05, -.5, 0, .5, xlim])\n", - " ax.set_xticklabels(['on-track', -.5, 0, .5, 'awry'], fontsize=24)\n", + " ax.set_xlabel(\"log-odds ratio\", fontsize=28)\n", + " ax.set_xticks([-xlim - 0.05, -0.5, 0, 0.5, xlim])\n", + " ax.set_xticklabels([\"on-track\", -0.5, 0, 0.5, \"awry\"], fontsize=24)\n", " ax.set_yticks(yticks)\n", " ax.set_yticklabels(yticklabels, fontsize=32)\n", - " ax.tick_params(axis='both', which='both', bottom='off', top='off',left='off')\n", + " ax.tick_params(axis=\"both\", which=\"both\", bottom=\"off\", top=\"off\", left=\"off\")\n", " return feat_order" ] }, @@ -1385,8 +1473,8 @@ } ], "source": [ - "f, ax = plt.subplots(1,1, figsize=(10,10))\n", - "_ = draw_figure(ax, first_comparisons, second_comparisons, '')" + "f, ax = plt.subplots(1, 1, figsize=(10, 10))\n", + "_ = draw_figure(ax, first_comparisons, second_comparisons, \"\")" ] }, { @@ -1443,13 +1531,20 @@ }, "outputs": [], "source": [ - "fw_pipe = ExpectedContextModelPipeline(context_field='next_id', output_prefix='fw',\n", - " text_field='arcs', share_tfidf_models=False,\n", - " text_pipe=wiki_arc_pipeline(), \n", - " tfidf_params={'binary': True, 'min_df': 50}, \n", - " min_terms=1,\n", - " n_svd_dims=25, n_clusters=6, cluster_on='terms',\n", - " random_state=1000, cluster_random_state=1000)" + "fw_pipe = ExpectedContextModelPipeline(\n", + " context_field=\"next_id\",\n", + " output_prefix=\"fw\",\n", + " text_field=\"arcs\",\n", + " share_tfidf_models=False,\n", + " text_pipe=wiki_arc_pipeline(),\n", + " tfidf_params={\"binary\": True, \"min_df\": 50},\n", + " min_terms=1,\n", + " n_svd_dims=25,\n", + " n_clusters=6,\n", + " cluster_on=\"terms\",\n", + " random_state=1000,\n", + " cluster_random_state=1000,\n", + ")" ] }, { @@ -1460,9 +1555,11 @@ }, "outputs": [], "source": [ - "fw_pipe.fit(wiki_corpus,\n", - " selector=lambda x: x.meta.get('next_id',None) is not None,\n", - " context_selector=lambda x: x.reply_to is not None)" + "fw_pipe.fit(\n", + " wiki_corpus,\n", + " selector=lambda x: x.meta.get(\"next_id\", None) is not None,\n", + " context_selector=lambda x: x.reply_to is not None,\n", + ")" ] }, { @@ -1676,9 +1773,9 @@ }, "outputs": [], "source": [ - "fw_pipe.set_cluster_names(['casual', 'coordination', \n", - " 'procedures', 'contention',\n", - " 'editing', 'moderation'])" + "fw_pipe.set_cluster_names(\n", + " [\"casual\", \"coordination\", \"procedures\", \"contention\", \"editing\", \"moderation\"]\n", + ")" ] }, { @@ -1696,7 +1793,7 @@ }, "outputs": [], "source": [ - "new_ut = fw_pipe.transform_utterance('Let me help you out with that')" + "new_ut = fw_pipe.transform_utterance(\"Let me help you out with that\")" ] }, { @@ -1713,7 +1810,7 @@ } ], "source": [ - "print('type:', new_ut.meta['fw_clustering.cluster'])" + "print(\"type:\", new_ut.meta[\"fw_clustering.cluster\"])" ] }, { @@ -1765,7 +1862,7 @@ "source": [ "# note that different versions of SpaCy may produce different outputs, since the\n", "# dependency parses may change from version to version\n", - "new_ut.meta['fw_repr']" + "new_ut.meta[\"fw_repr\"]" ] }, { diff --git a/convokit/fighting_words/demos/fightingwords_demo.ipynb b/convokit/fighting_words/demos/fightingwords_demo.ipynb index 2f4b37ff..733f26c2 100644 --- a/convokit/fighting_words/demos/fightingwords_demo.ipynb +++ b/convokit/fighting_words/demos/fightingwords_demo.ipynb @@ -24,7 +24,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('reddit-corpus-small'))" + "corpus = Corpus(filename=download(\"reddit-corpus-small\"))" ] }, { @@ -48,7 +48,7 @@ } ], "source": [ - "fw = FightingWords(ngram_range=(1,1))" + "fw = FightingWords(ngram_range=(1, 1))" ] }, { @@ -78,8 +78,11 @@ } ], "source": [ - "fw.fit(corpus, class1_func=lambda utt: utt.meta['subreddit'] == 'Christianity', \n", - " class2_func=lambda utt: utt.meta['subreddit'] == \"atheism\",)" + "fw.fit(\n", + " corpus,\n", + " class1_func=lambda utt: utt.meta[\"subreddit\"] == \"Christianity\",\n", + " class2_func=lambda utt: utt.meta[\"subreddit\"] == \"atheism\",\n", + ")" ] }, { @@ -101,7 +104,7 @@ } ], "source": [ - "df = fw.summarize(corpus, plot=True, class1_name='r/Christianity', class2_name='r/atheism')" + "df = fw.summarize(corpus, plot=True, class1_name=\"r/Christianity\", class2_name=\"r/atheism\")" ] }, { @@ -938,7 +941,7 @@ } ], "source": [ - "fw.get_zscore('education')" + "fw.get_zscore(\"education\")" ] }, { @@ -958,7 +961,7 @@ } ], "source": [ - "fw.get_zscore('morals')" + "fw.get_zscore(\"morals\")" ] }, { @@ -978,7 +981,7 @@ } ], "source": [ - "fw.transform(corpus, config={'annot_method': 'top_k', 'top_k': 10})" + "fw.transform(corpus, config={\"annot_method\": \"top_k\", \"top_k\": 10})" ] }, { @@ -1038,10 +1041,10 @@ "source": [ "for utt in corpus.iter_utterances():\n", " if utt.meta[\"subreddit\"] in [\"atheism\", \"Christianity\"]:\n", - " if len(utt.meta['fighting_words_class1']) > 0:\n", - " print(utt.meta['subreddit'])\n", - " print(utt.meta['fighting_words_class1'])\n", - " print(utt.meta['fighting_words_class2'])\n", + " if len(utt.meta[\"fighting_words_class1\"]) > 0:\n", + " print(utt.meta[\"subreddit\"])\n", + " print(utt.meta[\"fighting_words_class1\"])\n", + " print(utt.meta[\"fighting_words_class2\"])\n", " print(utt.text)\n", " break" ] diff --git a/convokit/forecaster/CRAFT/demos/craft_demo.ipynb b/convokit/forecaster/CRAFT/demos/craft_demo.ipynb index 2514dd40..079d9fc8 100644 --- a/convokit/forecaster/CRAFT/demos/craft_demo.ipynb +++ b/convokit/forecaster/CRAFT/demos/craft_demo.ipynb @@ -45,7 +45,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('conversations-gone-awry-corpus'))" + "corpus = Corpus(filename=download(\"conversations-gone-awry-corpus\"))" ] }, { @@ -125,16 +125,18 @@ "metadata": {}, "outputs": [], "source": [ - "MAX_LENGTH = 80 # this constant controls the maximum number of tokens to consider; it must be set to 80 since that's what CRAFT was trained one.\n", - "forecaster = Forecaster(forecaster_model = craft_model,\n", - " forecast_mode = \"past\",\n", - " convo_structure=\"linear\",\n", - " text_func = lambda utt: utt.meta[\"tokens\"][:(MAX_LENGTH-1)],\n", - " label_func = lambda utt: int(utt.meta['comment_has_personal_attack']),\n", - " forecast_attribute_name=\"prediction\", forecast_prob_attribute_name=\"pred_score\",\n", - " use_last_only = False,\n", - " skip_broken_convos=False\n", - " )" + "MAX_LENGTH = 80 # this constant controls the maximum number of tokens to consider; it must be set to 80 since that's what CRAFT was trained one.\n", + "forecaster = Forecaster(\n", + " forecaster_model=craft_model,\n", + " forecast_mode=\"past\",\n", + " convo_structure=\"linear\",\n", + " text_func=lambda utt: utt.meta[\"tokens\"][: (MAX_LENGTH - 1)],\n", + " label_func=lambda utt: int(utt.meta[\"comment_has_personal_attack\"]),\n", + " forecast_attribute_name=\"prediction\",\n", + " forecast_prob_attribute_name=\"pred_score\",\n", + " use_last_only=False,\n", + " skip_broken_convos=False,\n", + ")" ] }, { @@ -242,7 +244,11 @@ "# comments, but rather the \"section header\" (something akin to a conversation title in Wikipedia talk pages). Since they\n", "# are not real comments, we do not want to include them in forecasting. We use the ignore_utterances parameter to\n", "# specify this behavior.\n", - "forecaster.transform(corpus, selector=lambda convo: convo.meta[\"split\"] in ['test'], ignore_utterances=lambda utt: utt.meta['is_section_header'])" + "forecaster.transform(\n", + " corpus,\n", + " selector=lambda convo: convo.meta[\"split\"] in [\"test\"],\n", + " ignore_utterances=lambda utt: utt.meta[\"is_section_header\"],\n", + ")" ] }, { @@ -276,15 +282,15 @@ "metadata": {}, "outputs": [], "source": [ - "FORECAST_THRESH = 0.570617 # Threshold learned on a validation set. Try playing with this to see how it affects the precision-recall tradeoff!\n", + "FORECAST_THRESH = 0.570617 # Threshold learned on a validation set. Try playing with this to see how it affects the precision-recall tradeoff!\n", "preds = []\n", "labels = []\n", "# Iterate at a conversation level and consolidate predictions for each conversation\n", - "for convo in corpus.iter_conversations(selector=lambda c: c.meta['split'] == 'test'):\n", - " labels.append(int(convo.meta['conversation_has_personal_attack']))\n", + "for convo in corpus.iter_conversations(selector=lambda c: c.meta[\"split\"] == \"test\"):\n", + " labels.append(int(convo.meta[\"conversation_has_personal_attack\"]))\n", " prediction = 0\n", " for utt in convo.iter_utterances():\n", - " if utt.meta['pred_score'] is not None and utt.meta['pred_score'] > FORECAST_THRESH:\n", + " if utt.meta[\"pred_score\"] is not None and utt.meta[\"pred_score\"] > FORECAST_THRESH:\n", " prediction = 1\n", " preds.append(prediction)\n", "preds = np.asarray(preds)\n", @@ -307,9 +313,13 @@ "source": [ "# Compute accuracy, precision, recall, F1, and false positive rate\n", "acc = np.mean(preds == labels)\n", - "precision, recall, f1, _ = precision_recall_fscore_support(preds, labels, average='binary')\n", - "fpr = np.mean(preds[labels==0])\n", - "print(\"Accuracy = {:.2%}, Precision = {:.2%}, Recall = {:.2%}, FPR = {:.2%}, F1 = {:.2%}\".format(acc, precision, recall, fpr, f1))" + "precision, recall, f1, _ = precision_recall_fscore_support(preds, labels, average=\"binary\")\n", + "fpr = np.mean(preds[labels == 0])\n", + "print(\n", + " \"Accuracy = {:.2%}, Precision = {:.2%}, Recall = {:.2%}, FPR = {:.2%}, F1 = {:.2%}\".format(\n", + " acc, precision, recall, fpr, f1\n", + " )\n", + ")" ] }, { @@ -326,20 +336,24 @@ "metadata": {}, "outputs": [], "source": [ - "comments_until_derail = {} # store the \"number of comments until derailment\" metric for each conversation\n", + "comments_until_derail = (\n", + " {}\n", + ") # store the \"number of comments until derailment\" metric for each conversation\n", "\n", - "for convo in corpus.iter_conversations(selector=lambda c: c.meta['split'] == 'test' and c.meta['conversation_has_personal_attack']):\n", + "for convo in corpus.iter_conversations(\n", + " selector=lambda c: c.meta[\"split\"] == \"test\" and c.meta[\"conversation_has_personal_attack\"]\n", + "):\n", " # filter out the section header as usual\n", - " utts = [utt for utt in convo.iter_utterances() if not utt.meta['is_section_header']]\n", + " utts = [utt for utt in convo.iter_utterances() if not utt.meta[\"is_section_header\"]]\n", " # by construction, the last comment is the one with the personal attack\n", " derail_idx = len(utts) - 1\n", " # now scan the utterances in order until we find the first derailment prediction (if any)\n", " for idx in range(1, len(utts)):\n", - " if utts[idx].meta['pred_score'] > FORECAST_THRESH:\n", + " if utts[idx].meta[\"pred_score\"] > FORECAST_THRESH:\n", " # recall that the forecast_score meta field specifies what CRAFT thought this comment would look like BEFORE it\n", - " # saw this comment. So the actual CRAFT forecast is made during the previous comment; we account for this by \n", + " # saw this comment. So the actual CRAFT forecast is made during the previous comment; we account for this by\n", " # subtracting 1 from idx\n", - " comments_until_derail[convo.id] = derail_idx - (idx-1)\n", + " comments_until_derail[convo.id] = derail_idx - (idx - 1)\n", " break" ] }, @@ -364,12 +378,14 @@ "source": [ "# visualize the distribution of \"number of comments until derailment\" as a histogram (reproducing Figure 4 from the paper)\n", "comments_until_derail_vals = np.asarray(list(comments_until_derail.values()))\n", - "plt.rcParams['figure.figsize'] = (10.0, 5.0)\n", - "plt.rcParams['font.size'] = 24\n", - "plt.hist(comments_until_derail_vals, bins=range(1, np.max(comments_until_derail_vals)), density=True)\n", - "plt.xlim(1,10)\n", - "plt.xticks(np.arange(1,10)+0.5, np.arange(1,10))\n", - "plt.yticks(np.arange(0,0.25,0.05), np.arange(0,25,5))\n", + "plt.rcParams[\"figure.figsize\"] = (10.0, 5.0)\n", + "plt.rcParams[\"font.size\"] = 24\n", + "plt.hist(\n", + " comments_until_derail_vals, bins=range(1, np.max(comments_until_derail_vals)), density=True\n", + ")\n", + "plt.xlim(1, 10)\n", + "plt.xticks(np.arange(1, 10) + 0.5, np.arange(1, 10))\n", + "plt.yticks(np.arange(0, 0.25, 0.05), np.arange(0, 25, 5))\n", "plt.xlabel(\"Number of comments elapsed\")\n", "plt.ylabel(\"% of conversations\")\n", "plt.show()" diff --git a/convokit/forecaster/CRAFT/demos/craft_demo_new.ipynb b/convokit/forecaster/CRAFT/demos/craft_demo_new.ipynb index 2f86af97..43a7bfcf 100644 --- a/convokit/forecaster/CRAFT/demos/craft_demo_new.ipynb +++ b/convokit/forecaster/CRAFT/demos/craft_demo_new.ipynb @@ -84,15 +84,17 @@ "metadata": {}, "outputs": [], "source": [ - "forecaster = Forecaster(forecaster_model = craft_model,\n", - " forecast_mode = \"future\",\n", - " convo_structure=\"linear\",\n", - " text_func = lambda utt: utt.meta[\"tokens\"][:(MAX_LENGTH-1)],\n", - " label_func = lambda utt: int(utt.meta['comment_has_personal_attack']),\n", - " forecast_attribute_name=\"prediction\", forecast_prob_attribute_name=\"pred_score\",\n", - " use_last_only = True,\n", - " skip_broken_convos=False\n", - " )" + "forecaster = Forecaster(\n", + " forecaster_model=craft_model,\n", + " forecast_mode=\"future\",\n", + " convo_structure=\"linear\",\n", + " text_func=lambda utt: utt.meta[\"tokens\"][: (MAX_LENGTH - 1)],\n", + " label_func=lambda utt: int(utt.meta[\"comment_has_personal_attack\"]),\n", + " forecast_attribute_name=\"prediction\",\n", + " forecast_prob_attribute_name=\"pred_score\",\n", + " use_last_only=True,\n", + " skip_broken_convos=False,\n", + ")" ] }, { @@ -208,8 +210,11 @@ } ], "source": [ - "forecaster.transform(corpus, selector=lambda convo: convo.meta[\"split\"] == \"train\",\n", - " ignore_utterances=lambda utt: utt.meta[\"is_section_header\"])" + "forecaster.transform(\n", + " corpus,\n", + " selector=lambda convo: convo.meta[\"split\"] == \"train\",\n", + " ignore_utterances=lambda utt: utt.meta[\"is_section_header\"],\n", + ")" ] }, { diff --git a/convokit/forecaster/CRAFT/demos/craft_demo_original.ipynb b/convokit/forecaster/CRAFT/demos/craft_demo_original.ipynb index 85a91c26..cbe588db 100644 --- a/convokit/forecaster/CRAFT/demos/craft_demo_original.ipynb +++ b/convokit/forecaster/CRAFT/demos/craft_demo_original.ipynb @@ -40,6 +40,7 @@ "import itertools\n", "from urllib.request import urlretrieve\n", "from convokit import download, Corpus\n", + "\n", "%matplotlib inline" ] }, @@ -115,14 +116,20 @@ "\n", " def __init__(self, name, word2index=None, index2word=None):\n", " self.name = name\n", - " self.trimmed = False if not word2index else True # if a precomputed vocab is specified assume the user wants to use it as-is\n", + " self.trimmed = (\n", + " False if not word2index else True\n", + " ) # if a precomputed vocab is specified assume the user wants to use it as-is\n", " self.word2index = word2index if word2index else {\"UNK\": UNK_token}\n", " self.word2count = {}\n", - " self.index2word = index2word if index2word else {PAD_token: \"PAD\", SOS_token: \"SOS\", EOS_token: \"EOS\", UNK_token: \"UNK\"}\n", + " self.index2word = (\n", + " index2word\n", + " if index2word\n", + " else {PAD_token: \"PAD\", SOS_token: \"SOS\", EOS_token: \"EOS\", UNK_token: \"UNK\"}\n", + " )\n", " self.num_words = 4 if not index2word else len(index2word) # Count SOS, EOS, PAD, UNK\n", "\n", " def addSentence(self, sentence):\n", - " for word in sentence.split(' '):\n", + " for word in sentence.split(\" \"):\n", " self.addWord(word)\n", "\n", " def addWord(self, word):\n", @@ -146,19 +153,22 @@ " if v >= min_count:\n", " keep_words.append(k)\n", "\n", - " print('keep_words {} / {} = {:.4f}'.format(\n", - " len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)\n", - " ))\n", + " print(\n", + " \"keep_words {} / {} = {:.4f}\".format(\n", + " len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)\n", + " )\n", + " )\n", "\n", " # Reinitialize dictionaries\n", " self.word2index = {\"UNK\": UNK_token}\n", " self.word2count = {}\n", " self.index2word = {PAD_token: \"PAD\", SOS_token: \"SOS\", EOS_token: \"EOS\", UNK_token: \"UNK\"}\n", - " self.num_words = 4 # Count default tokens\n", + " self.num_words = 4 # Count default tokens\n", "\n", " for word in keep_words:\n", " self.addWord(word)\n", "\n", + "\n", "# Create a Voc object from precomputed data structures\n", "def loadPrecomputedVoc(corpus_name, word2index_url, index2word_url):\n", " # load the word-to-index lookup map\n", @@ -182,26 +192,26 @@ "source": [ "# Helper functions for preprocessing and tokenizing text\n", "\n", + "\n", "# Turn a Unicode string to plain ASCII, thanks to\n", "# https://stackoverflow.com/a/518232/2809427\n", "def unicodeToAscii(s):\n", - " return ''.join(\n", - " c for c in unicodedata.normalize('NFD', s)\n", - " if unicodedata.category(c) != 'Mn'\n", - " )\n", + " return \"\".join(c for c in unicodedata.normalize(\"NFD\", s) if unicodedata.category(c) != \"Mn\")\n", + "\n", "\n", "# Tokenize the string using NLTK\n", "def tokenize(text):\n", - " tokenizer = nltk.tokenize.RegexpTokenizer(pattern=r'\\w+|[^\\w\\s]')\n", + " tokenizer = nltk.tokenize.RegexpTokenizer(pattern=r\"\\w+|[^\\w\\s]\")\n", " # simplify the problem space by considering only ASCII data\n", " cleaned_text = unicodeToAscii(text.lower())\n", "\n", " # if the resulting string is empty, nothing else to do\n", " if not cleaned_text.strip():\n", " return []\n", - " \n", + "\n", " return tokenizer.tokenize(cleaned_text)\n", "\n", + "\n", "# Given a ConvoKit conversation, preprocess each utterance's text by tokenizing and truncating.\n", "# Returns the processed dialog entry where text has been replaced with a list of\n", "# tokens, each no longer than MAX_LENGTH - 1 (to leave space for the EOS token)\n", @@ -209,20 +219,27 @@ " processed = []\n", " for utterance in dialog.iter_utterances():\n", " # skip the section header, which does not contain conversational content\n", - " if utterance.meta['is_section_header']:\n", + " if utterance.meta[\"is_section_header\"]:\n", " continue\n", " tokens = tokenize(utterance.text)\n", " # replace out-of-vocabulary tokens\n", " for i in range(len(tokens)):\n", " if tokens[i] not in voc.word2index:\n", " tokens[i] = \"UNK\"\n", - " processed.append({\"tokens\": tokens, \"is_attack\": int(utterance.meta['comment_has_personal_attack']), \"id\": utterance.id})\n", + " processed.append(\n", + " {\n", + " \"tokens\": tokens,\n", + " \"is_attack\": int(utterance.meta[\"comment_has_personal_attack\"]),\n", + " \"id\": utterance.id,\n", + " }\n", + " )\n", " return processed\n", "\n", + "\n", "# Load context-reply pairs from the Corpus, optionally filtering to only conversations\n", "# from the specified split (train, val, or test).\n", "# Each conversation, which has N comments (not including the section header) will\n", - "# get converted into N-1 comment-reply pairs, one pair for each reply \n", + "# get converted into N-1 comment-reply pairs, one pair for each reply\n", "# (the first comment does not reply to anything).\n", "# Each comment-reply pair is a tuple consisting of the conversational context\n", "# (that is, all comments prior to the reply), the reply itself, the label (that\n", @@ -233,14 +250,14 @@ " pairs = []\n", " for convo in corpus.iter_conversations():\n", " # consider only conversations in the specified split of the data\n", - " if split is None or convo.meta['split'] == split:\n", + " if split is None or convo.meta[\"split\"] == split:\n", " dialog = processDialog(voc, convo)\n", " for idx in range(1, len(dialog)):\n", - " reply = dialog[idx][\"tokens\"][:(MAX_LENGTH-1)]\n", + " reply = dialog[idx][\"tokens\"][: (MAX_LENGTH - 1)]\n", " label = dialog[idx][\"is_attack\"]\n", " comment_id = dialog[idx][\"id\"]\n", " # gather as context all utterances preceding the reply\n", - " context = [u[\"tokens\"][:(MAX_LENGTH-1)] for u in dialog[:idx]]\n", + " context = [u[\"tokens\"][: (MAX_LENGTH - 1)] for u in dialog[:idx]]\n", " pairs.append((context, reply, label, comment_id))\n", " return pairs" ] @@ -257,12 +274,15 @@ "source": [ "# Helper functions for turning dialog and text sequences into tensors, and manipulating those tensors\n", "\n", + "\n", "def indexesFromSentence(voc, sentence):\n", " return [voc.word2index[word] for word in sentence] + [EOS_token]\n", "\n", + "\n", "def zeroPadding(l, fillvalue=PAD_token):\n", " return list(itertools.zip_longest(*l, fillvalue=fillvalue))\n", "\n", + "\n", "def binaryMatrix(l, value=PAD_token):\n", " m = []\n", " for i, seq in enumerate(l):\n", @@ -274,11 +294,14 @@ " m[i].append(1)\n", " return m\n", "\n", + "\n", "# Takes a batch of dialogs (lists of lists of tokens) and converts it into a\n", "# batch of utterances (lists of tokens) sorted by length, while keeping track of\n", "# the information needed to reconstruct the original batch of dialogs\n", "def dialogBatch2UtteranceBatch(dialog_batch):\n", - " utt_tuples = [] # will store tuples of (utterance, original position in batch, original position in dialog)\n", + " utt_tuples = (\n", + " []\n", + " ) # will store tuples of (utterance, original position in batch, original position in dialog)\n", " for batch_idx in range(len(dialog_batch)):\n", " dialog = dialog_batch[batch_idx]\n", " for dialog_idx in range(len(dialog)):\n", @@ -292,6 +315,7 @@ " dialog_indices = [u[2] for u in utt_tuples]\n", " return utt_batch, batch_indices, dialog_indices\n", "\n", + "\n", "# Returns padded input sequence tensor and lengths\n", "def inputVar(l, voc):\n", " indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]\n", @@ -300,6 +324,7 @@ " padVar = torch.LongTensor(padList)\n", " return padVar, lengths\n", "\n", + "\n", "# Returns padded target sequence tensor, padding mask, and max target length\n", "def outputVar(l, voc):\n", " indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]\n", @@ -310,6 +335,7 @@ " padVar = torch.LongTensor(padList)\n", " return padVar, mask, max_target_len\n", "\n", + "\n", "# Returns all items for a given batch of pairs\n", "def batch2TrainData(voc, pair_batch, already_sorted=False):\n", " if not already_sorted:\n", @@ -325,7 +351,19 @@ " inp, utt_lengths = inputVar(input_utterances, voc)\n", " output, mask, max_target_len = outputVar(output_batch, voc)\n", " label_batch = torch.FloatTensor(label_batch) if label_batch[0] is not None else None\n", - " return inp, dialog_lengths, utt_lengths, batch_indices, dialog_indices, label_batch, id_batch, output, mask, max_target_len\n", + " return (\n", + " inp,\n", + " dialog_lengths,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " label_batch,\n", + " id_batch,\n", + " output,\n", + " mask,\n", + " max_target_len,\n", + " )\n", + "\n", "\n", "def batchIterator(voc, source_data, batch_size, shuffle=True):\n", " cur_idx = 0\n", @@ -336,7 +374,7 @@ " cur_idx = 0\n", " if shuffle:\n", " random.shuffle(source_data)\n", - " batch = source_data[cur_idx:(cur_idx+batch_size)]\n", + " batch = source_data[cur_idx : (cur_idx + batch_size)]\n", " # the true batch size may be smaller than the given batch size if there is not enough data left\n", " true_batch_size = len(batch)\n", " # ensure that the dialogs in this batch are sorted by length, as expected by the padding module\n", @@ -346,7 +384,7 @@ " batch_labels = [x[2] for x in batch]\n", " # convert batch to tensors\n", " batch_tensors = batch2TrainData(voc, batch, already_sorted=True)\n", - " yield (batch_tensors, batch_dialogs, batch_labels, true_batch_size) \n", + " yield (batch_tensors, batch_dialogs, batch_labels, true_batch_size)\n", " cur_idx += batch_size" ] }, @@ -483,7 +521,9 @@ ], "source": [ "# Inspect the Voc object to make sure it loaded correctly\n", - "print(voc.num_words) # expected vocab size is 50004: it was built using a fixed vocab size of 50k plus 4 spots for special tokens PAD, SOS, EOS, and UNK.\n", + "print(\n", + " voc.num_words\n", + ") # expected vocab size is 50004: it was built using a fixed vocab size of 50k plus 4 spots for special tokens PAD, SOS, EOS, and UNK.\n", "print(list(voc.word2index.items())[:10])\n", "print(list(voc.index2word.items())[:10])" ] @@ -587,7 +627,7 @@ } ], "source": [ - "for token_list in uttid_to_test_pair['201082648.33321.33321'][0]:\n", + "for token_list in uttid_to_test_pair[\"201082648.33321.33321\"][0]:\n", " print(token_list)\n", " print()" ] @@ -616,6 +656,7 @@ "source": [ "class EncoderRNN(nn.Module):\n", " \"\"\"This module represents the utterance encoder component of CRAFT, responsible for creating vector representations of utterances\"\"\"\n", + "\n", " def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):\n", " super(EncoderRNN, self).__init__()\n", " self.n_layers = n_layers\n", @@ -624,8 +665,13 @@ "\n", " # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'\n", " # because our input size is a word embedding with number of features == hidden_size\n", - " self.gru = nn.GRU(hidden_size, hidden_size, n_layers,\n", - " dropout=(0 if n_layers == 1 else dropout), bidirectional=True)\n", + " self.gru = nn.GRU(\n", + " hidden_size,\n", + " hidden_size,\n", + " n_layers,\n", + " dropout=(0 if n_layers == 1 else dropout),\n", + " bidirectional=True,\n", + " )\n", "\n", " def forward(self, input_seq, input_lengths, hidden=None):\n", " # Convert word indexes to embeddings\n", @@ -637,21 +683,28 @@ " # Unpack padding\n", " outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)\n", " # Sum bidirectional GRU outputs\n", - " outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]\n", + " outputs = outputs[:, :, : self.hidden_size] + outputs[:, :, self.hidden_size :]\n", " # Return output and final hidden state\n", " return outputs, hidden\n", "\n", + "\n", "class ContextEncoderRNN(nn.Module):\n", " \"\"\"This module represents the context encoder component of CRAFT, responsible for creating an order-sensitive vector representation of conversation context\"\"\"\n", + "\n", " def __init__(self, hidden_size, n_layers=1, dropout=0):\n", " super(ContextEncoderRNN, self).__init__()\n", " self.n_layers = n_layers\n", " self.hidden_size = hidden_size\n", - " \n", + "\n", " # only unidirectional GRU for context encoding\n", - " self.gru = nn.GRU(hidden_size, hidden_size, n_layers,\n", - " dropout=(0 if n_layers == 1 else dropout), bidirectional=False)\n", - " \n", + " self.gru = nn.GRU(\n", + " hidden_size,\n", + " hidden_size,\n", + " n_layers,\n", + " dropout=(0 if n_layers == 1 else dropout),\n", + " bidirectional=False,\n", + " )\n", + "\n", " def forward(self, input_seq, input_lengths, hidden=None):\n", " # Pack padded batch of sequences for RNN module\n", " packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths)\n", @@ -662,13 +715,15 @@ " # return output and final hidden state\n", " return outputs, hidden\n", "\n", + "\n", "class SingleTargetClf(nn.Module):\n", " \"\"\"This module represents the CRAFT classifier head, which takes the context encoding and uses it to make a forecast\"\"\"\n", + "\n", " def __init__(self, hidden_size, dropout=0.1):\n", " super(SingleTargetClf, self).__init__()\n", - " \n", + "\n", " self.hidden_size = hidden_size\n", - " \n", + "\n", " # initialize classifier\n", " self.layer1 = nn.Linear(hidden_size, hidden_size)\n", " self.layer1_act = nn.LeakyReLU()\n", @@ -676,7 +731,7 @@ " self.layer2_act = nn.LeakyReLU()\n", " self.clf = nn.Linear(hidden_size // 2, 1)\n", " self.dropout = nn.Dropout(p=dropout)\n", - " \n", + "\n", " def forward(self, encoder_outputs, encoder_input_lengths):\n", " # from stackoverflow (https://stackoverflow.com/questions/50856936/taking-the-last-state-from-bilstm-bigru-in-pytorch)\n", " # First we unsqueeze seqlengths two times so it has the same number of\n", @@ -684,11 +739,11 @@ " # (batch_size) -> (1, batch_size, 1)\n", " lengths = encoder_input_lengths.unsqueeze(0).unsqueeze(2)\n", " # Then we expand it accordingly\n", - " # (1, batch_size, 1) -> (1, batch_size, hidden_size) \n", + " # (1, batch_size, 1) -> (1, batch_size, hidden_size)\n", " lengths = lengths.expand((1, -1, encoder_outputs.size(2)))\n", "\n", " # take only the last state of the encoder for each batch\n", - " last_outputs = torch.gather(encoder_outputs, 0, lengths-1).squeeze()\n", + " last_outputs = torch.gather(encoder_outputs, 0, lengths - 1).squeeze()\n", " # forward pass through hidden layers\n", " layer1_out = self.layer1_act(self.layer1(self.dropout(last_outputs)))\n", " layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out)))\n", @@ -696,51 +751,70 @@ " logits = self.clf(self.dropout(layer2_out)).squeeze()\n", " return logits\n", "\n", + "\n", "class Predictor(nn.Module):\n", " \"\"\"This helper module encapsulates the CRAFT pipeline, defining the logic of passing an input through each consecutive sub-module.\"\"\"\n", + "\n", " def __init__(self, encoder, context_encoder, classifier):\n", " super(Predictor, self).__init__()\n", " self.encoder = encoder\n", " self.context_encoder = context_encoder\n", " self.classifier = classifier\n", - " \n", - " def forward(self, input_batch, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, max_length):\n", + "\n", + " def forward(\n", + " self,\n", + " input_batch,\n", + " dialog_lengths,\n", + " dialog_lengths_list,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " batch_size,\n", + " max_length,\n", + " ):\n", " # Forward input through encoder model\n", " _, utt_encoder_hidden = self.encoder(input_batch, utt_lengths)\n", - " \n", + "\n", " # Convert utterance encoder final states to batched dialogs for use by context encoder\n", - " context_encoder_input = makeContextEncoderInput(utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices)\n", - " \n", + " context_encoder_input = makeContextEncoderInput(\n", + " utt_encoder_hidden, dialog_lengths_list, batch_size, batch_indices, dialog_indices\n", + " )\n", + "\n", " # Forward pass through context encoder\n", - " context_encoder_outputs, context_encoder_hidden = self.context_encoder(context_encoder_input, dialog_lengths)\n", - " \n", + " context_encoder_outputs, context_encoder_hidden = self.context_encoder(\n", + " context_encoder_input, dialog_lengths\n", + " )\n", + "\n", " # Forward pass through classifier to get prediction logits\n", " logits = self.classifier(context_encoder_outputs, dialog_lengths)\n", - " \n", + "\n", " # Apply sigmoid activation\n", " predictions = F.sigmoid(logits)\n", " return predictions\n", "\n", - "def makeContextEncoderInput(utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices):\n", + "\n", + "def makeContextEncoderInput(\n", + " utt_encoder_hidden, dialog_lengths, batch_size, batch_indices, dialog_indices\n", + "):\n", " \"\"\"The utterance encoder takes in utterances in combined batches, with no knowledge of which ones go where in which conversation.\n", - " Its output is therefore also unordered. We correct this by using the information computed during tensor conversion to regroup\n", - " the utterance vectors into their proper conversational order.\"\"\"\n", + " Its output is therefore also unordered. We correct this by using the information computed during tensor conversion to regroup\n", + " the utterance vectors into their proper conversational order.\"\"\"\n", " # first, sum the forward and backward encoder states\n", - " utt_encoder_summed = utt_encoder_hidden[-2,:,:] + utt_encoder_hidden[-1,:,:]\n", + " utt_encoder_summed = utt_encoder_hidden[-2, :, :] + utt_encoder_hidden[-1, :, :]\n", " # we now have hidden state of shape [utterance_batch_size, hidden_size]\n", " # split it into a list of [hidden_size,] x utterance_batch_size\n", " last_states = [t.squeeze() for t in utt_encoder_summed.split(1, dim=0)]\n", - " \n", + "\n", " # create a placeholder list of tensors to group the states by source dialog\n", " states_dialog_batched = [[None for _ in range(dialog_lengths[i])] for i in range(batch_size)]\n", - " \n", + "\n", " # group the states by source dialog\n", " for hidden_state, batch_idx, dialog_idx in zip(last_states, batch_indices, dialog_indices):\n", " states_dialog_batched[batch_idx][dialog_idx] = hidden_state\n", - " \n", + "\n", " # stack each dialog into a tensor of shape [dialog_length, hidden_size]\n", " states_dialog_batched = [torch.stack(d) for d in states_dialog_batched]\n", - " \n", + "\n", " # finally, condense all the dialog tensors into a single zero-padded tensor\n", " # of shape [max_dialog_length, batch_size, hidden_size]\n", " return torch.nn.utils.rnn.pad_sequence(states_dialog_batched)" @@ -768,36 +842,77 @@ }, "outputs": [], "source": [ - "def _evaluate_batch(encoder, context_encoder, predictor, voc, input_batch, dialog_lengths, \n", - " dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, device, max_length=MAX_LENGTH):\n", + "def _evaluate_batch(\n", + " encoder,\n", + " context_encoder,\n", + " predictor,\n", + " voc,\n", + " input_batch,\n", + " dialog_lengths,\n", + " dialog_lengths_list,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " batch_size,\n", + " device,\n", + " max_length=MAX_LENGTH,\n", + "):\n", " # Set device options\n", " input_batch = input_batch.to(device)\n", " dialog_lengths = dialog_lengths.to(device)\n", " utt_lengths = utt_lengths.to(device)\n", " # Predict future attack using predictor\n", - " scores = predictor(input_batch, dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices, batch_size, max_length)\n", + " scores = predictor(\n", + " input_batch,\n", + " dialog_lengths,\n", + " dialog_lengths_list,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " batch_size,\n", + " max_length,\n", + " )\n", " predictions = (scores > 0.5).float()\n", " return predictions, scores\n", "\n", + "\n", "def _evaluate_dataset(dataset, encoder, context_encoder, predictor, voc, batch_size, device):\n", " # create a batch iterator for the given data\n", " batch_iterator = batchIterator(voc, dataset, batch_size, shuffle=False)\n", " # find out how many iterations we will need to cover the whole dataset\n", " n_iters = len(dataset) // batch_size + int(len(dataset) % batch_size > 0)\n", - " output_df = {\n", - " \"id\": [],\n", - " \"prediction\": [],\n", - " \"score\": []\n", - " }\n", - " for iteration in range(1, n_iters+1):\n", + " output_df = {\"id\": [], \"prediction\": [], \"score\": []}\n", + " for iteration in range(1, n_iters + 1):\n", " batch, batch_dialogs, _, true_batch_size = next(batch_iterator)\n", " # Extract fields from batch\n", - " input_variable, dialog_lengths, utt_lengths, batch_indices, dialog_indices, labels, convo_ids, target_variable, mask, max_target_len = batch\n", + " (\n", + " input_variable,\n", + " dialog_lengths,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " labels,\n", + " convo_ids,\n", + " target_variable,\n", + " mask,\n", + " max_target_len,\n", + " ) = batch\n", " dialog_lengths_list = [len(x) for x in batch_dialogs]\n", " # run the model\n", - " predictions, scores = _evaluate_batch(encoder, context_encoder, predictor, voc, input_variable,\n", - " dialog_lengths, dialog_lengths_list, utt_lengths, batch_indices, dialog_indices,\n", - " true_batch_size, device)\n", + " predictions, scores = _evaluate_batch(\n", + " encoder,\n", + " context_encoder,\n", + " predictor,\n", + " voc,\n", + " input_variable,\n", + " dialog_lengths,\n", + " dialog_lengths_list,\n", + " utt_lengths,\n", + " batch_indices,\n", + " dialog_indices,\n", + " true_batch_size,\n", + " device,\n", + " )\n", "\n", " # format the output as a dataframe (which we can later re-join with the corpus)\n", " for i in range(true_batch_size):\n", @@ -807,8 +922,10 @@ " output_df[\"id\"].append(convo_id)\n", " output_df[\"prediction\"].append(pred)\n", " output_df[\"score\"].append(score)\n", - " \n", - " print(\"Iteration: {}; Percent complete: {:.1f}%\".format(iteration, iteration / n_iters * 100))\n", + "\n", + " print(\n", + " \"Iteration: {}; Percent complete: {:.1f}%\".format(iteration, iteration / n_iters * 100)\n", + " )\n", "\n", " return pd.DataFrame(output_df).set_index(\"id\")" ] @@ -938,7 +1055,7 @@ "random.seed(2019)\n", "\n", "# Tell torch to use GPU. Note that if you are running this notebook in a non-GPU environment, you can change 'cuda' to 'cpu' to get the code to run.\n", - "device = torch.device('cpu')\n", + "device = torch.device(\"cpu\")\n", "\n", "print(\"Loading saved parameters...\")\n", "if not os.path.isfile(\"model.tar\"):\n", @@ -948,14 +1065,14 @@ "# checkpoint = torch.load(\"model.tar\")\n", "# If running in a non-GPU environment, you need to tell PyTorch to convert the parameters to CPU tensor format.\n", "# To do so, replace the previous line with the following:\n", - "checkpoint = torch.load(\"model.tar\", map_location=torch.device('cpu'))\n", - "encoder_sd = checkpoint['en']\n", - "context_sd = checkpoint['ctx']\n", - "attack_clf_sd = checkpoint['atk_clf']\n", - "embedding_sd = checkpoint['embedding']\n", - "voc.__dict__ = checkpoint['voc_dict']\n", - "\n", - "print('Building encoders, decoder, and classifier...')\n", + "checkpoint = torch.load(\"model.tar\", map_location=torch.device(\"cpu\"))\n", + "encoder_sd = checkpoint[\"en\"]\n", + "context_sd = checkpoint[\"ctx\"]\n", + "attack_clf_sd = checkpoint[\"atk_clf\"]\n", + "embedding_sd = checkpoint[\"embedding\"]\n", + "voc.__dict__ = checkpoint[\"voc_dict\"]\n", + "\n", + "print(\"Building encoders, decoder, and classifier...\")\n", "# Initialize word embeddings\n", "embedding = nn.Embedding(voc.num_words, hidden_size)\n", "embedding.load_state_dict(embedding_sd)\n", @@ -971,7 +1088,7 @@ "encoder = encoder.to(device)\n", "context_encoder = context_encoder.to(device)\n", "attack_clf = attack_clf.to(device)\n", - "print('Models built and ready to go!')\n", + "print(\"Models built and ready to go!\")\n", "\n", "# Set dropout layers to eval mode\n", "encoder.eval()\n", @@ -982,7 +1099,9 @@ "predictor = Predictor(encoder, context_encoder, attack_clf)\n", "\n", "# Run the pipeline!\n", - "forecasts_df = _evaluate_dataset(test_pairs, encoder, context_encoder, predictor, voc, batch_size, device)" + "forecasts_df = _evaluate_dataset(\n", + " test_pairs, encoder, context_encoder, predictor, voc, batch_size, device\n", + ")" ] }, { @@ -1197,10 +1316,10 @@ "# prior to actually seeing this utterance, that this utterance *would be* a derailment\".\n", "for convo in corpus.iter_conversations():\n", " # only consider test set conversations (we did not make predictions for the other ones)\n", - " if convo.meta['split'] == \"test\":\n", + " if convo.meta[\"split\"] == \"test\":\n", " for utt in convo.iter_utterances():\n", " if utt.id in forecasts_df.index:\n", - " utt.meta['forecast_score'] = forecasts_df.loc[utt.id].score" + " utt.meta[\"forecast_score\"] = forecasts_df.loc[utt.id].score" ] }, { @@ -1236,20 +1355,23 @@ "# set up to not look at the last comment, meaning that all forecasts we obtained are forecasts made prior to derailment. This simplifies\n", "# the computation of forecast accuracy as we now do not need to explicitly consider when a forecast was made.\n", "\n", - "conversational_forecasts_df = {\n", - " \"convo_id\": [],\n", - " \"label\": [],\n", - " \"score\": [],\n", - " \"prediction\": []\n", - "}\n", + "conversational_forecasts_df = {\"convo_id\": [], \"label\": [], \"score\": [], \"prediction\": []}\n", "\n", "for convo in corpus.iter_conversations():\n", - " if convo.meta['split'] == \"test\":\n", - " conversational_forecasts_df['convo_id'].append(convo.id)\n", - " conversational_forecasts_df['label'].append(int(convo.meta['conversation_has_personal_attack']))\n", - " forecast_scores = [utt.meta['forecast_score'] for utt in convo.iter_utterances() if 'forecast_score' in utt.meta]\n", - " conversational_forecasts_df['score'] = np.max(forecast_scores)\n", - " conversational_forecasts_df['prediction'].append(int(np.max(forecast_scores) > FORECAST_THRESH))\n", + " if convo.meta[\"split\"] == \"test\":\n", + " conversational_forecasts_df[\"convo_id\"].append(convo.id)\n", + " conversational_forecasts_df[\"label\"].append(\n", + " int(convo.meta[\"conversation_has_personal_attack\"])\n", + " )\n", + " forecast_scores = [\n", + " utt.meta[\"forecast_score\"]\n", + " for utt in convo.iter_utterances()\n", + " if \"forecast_score\" in utt.meta\n", + " ]\n", + " conversational_forecasts_df[\"score\"] = np.max(forecast_scores)\n", + " conversational_forecasts_df[\"prediction\"].append(\n", + " int(np.max(forecast_scores) > FORECAST_THRESH)\n", + " )\n", "\n", "conversational_forecasts_df = pd.DataFrame(conversational_forecasts_df).set_index(\"convo_id\")\n", "print((conversational_forecasts_df.label == conversational_forecasts_df.prediction).mean())" @@ -1281,14 +1403,15 @@ "source": [ "# in addition to accuracy, we can also consider applying other metrics at the conversation level, such as precision/recall\n", "def get_pr_stats(preds, labels):\n", - " tp = ((labels==1)&(preds==1)).sum()\n", - " fp = ((labels==0)&(preds==1)).sum()\n", - " tn = ((labels==0)&(preds==0)).sum()\n", - " fn = ((labels==1)&(preds==0)).sum()\n", + " tp = ((labels == 1) & (preds == 1)).sum()\n", + " fp = ((labels == 0) & (preds == 1)).sum()\n", + " tn = ((labels == 0) & (preds == 0)).sum()\n", + " fn = ((labels == 1) & (preds == 0)).sum()\n", " print(\"Precision = {0:.4f}, recall = {1:.4f}\".format(tp / (tp + fp), tp / (tp + fn)))\n", " print(\"False positive rate =\", fp / (fp + tn))\n", " print(\"F1 =\", 2 / (((tp + fp) / tp) + ((tp + fn) / tp)))\n", "\n", + "\n", "get_pr_stats(conversational_forecasts_df.prediction, conversational_forecasts_df.label)" ] }, @@ -1374,23 +1497,25 @@ }, "outputs": [], "source": [ - "comments_until_derail = {} # store the \"number of comments until derailment\" metric for each conversation\n", - "time_until_derail = {} # store the \"time until derailment\" metric for each conversation\n", + "comments_until_derail = (\n", + " {}\n", + ") # store the \"number of comments until derailment\" metric for each conversation\n", + "time_until_derail = {} # store the \"time until derailment\" metric for each conversation\n", "\n", "for convo in corpus.iter_conversations():\n", - " if convo.meta['split'] == \"test\" and convo.meta['conversation_has_personal_attack']:\n", + " if convo.meta[\"split\"] == \"test\" and convo.meta[\"conversation_has_personal_attack\"]:\n", " # filter out the section header as usual\n", - " utts = [utt for utt in convo.iter_utterances() if not utt.meta['is_section_header']]\n", + " utts = [utt for utt in convo.iter_utterances() if not utt.meta[\"is_section_header\"]]\n", " # by construction, the last comment is the one with the personal attack\n", " derail_idx = len(utts) - 1\n", " # now scan the utterances in order until we find the first derailment prediction (if any)\n", " for idx in range(1, len(utts)):\n", - " if utts[idx].meta['forecast_score'] > FORECAST_THRESH:\n", + " if utts[idx].meta[\"forecast_score\"] > FORECAST_THRESH:\n", " # recall that the forecast_score meta field specifies what CRAFT thought this comment would look like BEFORE it\n", - " # saw this comment. So the actual CRAFT forecast is made during the previous comment; we account for this by \n", + " # saw this comment. So the actual CRAFT forecast is made during the previous comment; we account for this by\n", " # subtracting 1 from idx\n", - " comments_until_derail[convo.id] = derail_idx - (idx-1)\n", - " time_until_derail[convo.id] = utts[derail_idx].timestamp - utts[(idx-1)].timestamp\n", + " comments_until_derail[convo.id] = derail_idx - (idx - 1)\n", + " time_until_derail[convo.id] = utts[derail_idx].timestamp - utts[(idx - 1)].timestamp\n", " break" ] }, @@ -1418,7 +1543,12 @@ "source": [ "# compute some quick statistics about the distribution of the \"number of comments until derailment\" metric\n", "comments_until_derail_vals = np.asarray(list(comments_until_derail.values()))\n", - "print(np.min(comments_until_derail_vals), np.max(comments_until_derail_vals), np.median(comments_until_derail_vals), np.mean(comments_until_derail_vals))" + "print(\n", + " np.min(comments_until_derail_vals),\n", + " np.max(comments_until_derail_vals),\n", + " np.median(comments_until_derail_vals),\n", + " np.mean(comments_until_derail_vals),\n", + ")" ] }, { @@ -1446,7 +1576,12 @@ "# compute some quick statistics about the distribution of the \"time until derailment\" metric\n", "# note that since timestamps are in seconds, we convert to hours by dividing by 3600, to make it more human readable\n", "time_until_derail_vals = np.asarray(list(time_until_derail.values())) / 3600\n", - "print(np.min(time_until_derail_vals), np.max(time_until_derail_vals), np.median(time_until_derail_vals), np.mean(time_until_derail_vals))" + "print(\n", + " np.min(time_until_derail_vals),\n", + " np.max(time_until_derail_vals),\n", + " np.median(time_until_derail_vals),\n", + " np.mean(time_until_derail_vals),\n", + ")" ] }, { @@ -1477,12 +1612,14 @@ ], "source": [ "# visualize the distribution of \"number of comments until derailment\" as a histogram (reproducing Figure 4 from the paper)\n", - "plt.rcParams['figure.figsize'] = (10.0, 5.0)\n", - "plt.rcParams['font.size'] = 24\n", - "plt.hist(comments_until_derail_vals, bins=range(1, np.max(comments_until_derail_vals)), density=True)\n", - "plt.xlim(1,10)\n", - "plt.xticks(np.arange(1,10)+0.5, np.arange(1,10))\n", - "plt.yticks(np.arange(0,0.25,0.05), np.arange(0,25,5))\n", + "plt.rcParams[\"figure.figsize\"] = (10.0, 5.0)\n", + "plt.rcParams[\"font.size\"] = 24\n", + "plt.hist(\n", + " comments_until_derail_vals, bins=range(1, np.max(comments_until_derail_vals)), density=True\n", + ")\n", + "plt.xlim(1, 10)\n", + "plt.xticks(np.arange(1, 10) + 0.5, np.arange(1, 10))\n", + "plt.yticks(np.arange(0, 0.25, 0.05), np.arange(0, 25, 5))\n", "plt.xlabel(\"Number of comments elapsed\")\n", "plt.ylabel(\"% of conversations\")\n", "plt.show()" diff --git a/convokit/forecaster/CRAFT/demos/craft_demo_training.ipynb b/convokit/forecaster/CRAFT/demos/craft_demo_training.ipynb index 441f18cb..7ebce0f7 100644 --- a/convokit/forecaster/CRAFT/demos/craft_demo_training.ipynb +++ b/convokit/forecaster/CRAFT/demos/craft_demo_training.ipynb @@ -54,9 +54,7 @@ } ], "source": [ - "craft_model = CRAFTModel(device_type=\"cpu\", options={'validation_size': 0.2,\n", - " 'train_epochs': 5\n", - " })" + "craft_model = CRAFTModel(device_type=\"cpu\", options={\"validation_size\": 0.2, \"train_epochs\": 5})" ] }, { @@ -65,15 +63,17 @@ "metadata": {}, "outputs": [], "source": [ - "forecaster = Forecaster(forecaster_model = craft_model,\n", - " forecast_mode = 'past',\n", - " convo_structure=\"linear\",\n", - " text_func = lambda utt: utt.meta[\"tokens\"][:(MAX_LENGTH-1)],\n", - " label_func = lambda utt: int(utt.meta['comment_has_personal_attack']),\n", - " forecast_attribute_name=\"prediction\", forecast_prob_attribute_name=\"pred_score\",\n", - " use_last_only = True,\n", - " skip_broken_convos=False\n", - " )" + "forecaster = Forecaster(\n", + " forecaster_model=craft_model,\n", + " forecast_mode=\"past\",\n", + " convo_structure=\"linear\",\n", + " text_func=lambda utt: utt.meta[\"tokens\"][: (MAX_LENGTH - 1)],\n", + " label_func=lambda utt: int(utt.meta[\"comment_has_personal_attack\"]),\n", + " forecast_attribute_name=\"prediction\",\n", + " forecast_prob_attribute_name=\"pred_score\",\n", + " use_last_only=True,\n", + " skip_broken_convos=False,\n", + ")" ] }, { @@ -200,8 +200,11 @@ } ], "source": [ - "forecaster.fit(corpus, selector = lambda convo: convo.meta[\"split\"] == \"train\",\n", - " ignore_utterances = lambda utt: utt.meta[\"is_section_header\"])" + "forecaster.fit(\n", + " corpus,\n", + " selector=lambda convo: convo.meta[\"split\"] == \"train\",\n", + " ignore_utterances=lambda utt: utt.meta[\"is_section_header\"],\n", + ")" ] }, { diff --git a/convokit/forecaster/tests/cumulativeBoW_demo.ipynb b/convokit/forecaster/tests/cumulativeBoW_demo.ipynb index cbfd7f18..6b08b5ec 100644 --- a/convokit/forecaster/tests/cumulativeBoW_demo.ipynb +++ b/convokit/forecaster/tests/cumulativeBoW_demo.ipynb @@ -23,7 +23,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('subreddit-Cornell'))" + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))" ] }, { @@ -58,7 +58,7 @@ "metadata": {}, "outputs": [], "source": [ - "convo = corpus.get_conversation('o31u0')" + "convo = corpus.get_conversation(\"o31u0\")" ] }, { @@ -224,7 +224,7 @@ "source": [ "# Adding a 'y' feature to fit to\n", "for utt in corpus.iter_utterances():\n", - " utt.add_meta('pos_score', int(utt.meta['score'] > 0))" + " utt.add_meta(\"pos_score\", int(utt.meta[\"score\"] > 0))" ] }, { @@ -243,7 +243,7 @@ } ], "source": [ - "forecaster = Forecaster(label_func=lambda utt: utt.meta['pos_score'], skip_broken_convos=True)" + "forecaster = Forecaster(label_func=lambda utt: utt.meta[\"pos_score\"], skip_broken_convos=True)" ] }, { @@ -330,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_utterance('dpn8e4v')" + "corpus.get_utterance(\"dpn8e4v\")" ] }, { @@ -339,7 +339,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_utterance('dpn8e4v').root" + "corpus.get_utterance(\"dpn8e4v\").root" ] }, { @@ -348,7 +348,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_conversation(corpus.get_utterance('dpn8e4v').root).print_conversation_structure()" + "corpus.get_conversation(corpus.get_utterance(\"dpn8e4v\").root).print_conversation_structure()" ] }, { @@ -364,7 +364,9 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_conversation(corpus.get_utterance('dpn8e4v').root).print_conversation_structure(lambda utt: str(utt.meta['forecast']))" + "corpus.get_conversation(corpus.get_utterance(\"dpn8e4v\").root).print_conversation_structure(\n", + " lambda utt: str(utt.meta[\"forecast\"])\n", + ")" ] }, { @@ -380,7 +382,9 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_conversation(corpus.get_utterance('dpn8e4v').root).print_conversation_structure(lambda utt: str(utt.meta['pos_score']))" + "corpus.get_conversation(corpus.get_utterance(\"dpn8e4v\").root).print_conversation_structure(\n", + " lambda utt: str(utt.meta[\"pos_score\"])\n", + ")" ] }, { @@ -389,8 +393,8 @@ "metadata": {}, "outputs": [], "source": [ - "forecasts = [utt.meta['forecast'] for utt in corpus.iter_utterances()]\n", - "actual = [utt.meta['pos_score'] for utt in corpus.iter_utterances()]" + "forecasts = [utt.meta[\"forecast\"] for utt in corpus.iter_utterances()]\n", + "actual = [utt.meta[\"pos_score\"] for utt in corpus.iter_utterances()]" ] }, { @@ -399,7 +403,9 @@ "metadata": {}, "outputs": [], "source": [ - "y_true_pred = [(forecast, actual) for forecast, actual in zip(forecasts, actual) if forecast is not None]" + "y_true_pred = [\n", + " (forecast, actual) for forecast, actual in zip(forecasts, actual) if forecast is not None\n", + "]" ] }, { diff --git a/convokit/model/corpus.py b/convokit/model/corpus.py index 7ec3883e..de99b0f3 100644 --- a/convokit/model/corpus.py +++ b/convokit/model/corpus.py @@ -406,29 +406,89 @@ def has_speaker(self, speaker_id: str) -> bool: """ return speaker_id in self.speakers - def random_utterance(self) -> Utterance: + def random_utterance(self, selector: Optional[Callable[[Utterance], bool]] = None) -> Utterance: """ - Get a random Utterance from the Corpus + Get a random Utterance from the Corpus, with an optional selector that filters for Utterances that should be considered. - :return: a random Utterance + :param selector: a (lambda) function that takes an Utterance and returns True or False (i.e. consider / not consider). + By default, the selector considers all Utterances in the Corpus. + :return: a random Utterance that in the Corpus that is considered based on the selector. """ - return random.choice(list(self.utterances.values())) + count = 0 + selected_utterance = None - def random_conversation(self) -> Conversation: - """ - Get a random Conversation from the Corpus + if selector == None: + return random.choice(list(self.utterances.values())) + # Iterate over utterances directly from the generator + for utterance in self.iter_utterances(): + # Apply the filter function if provided + if selector(utterance): + count += 1 + # Reservoir sampling: Replace the current selection with decreasing probability + if random.randint(1, count) == 1: + selected_utterance = utterance + if selected_utterance is None: + raise ValueError("No matching Utterance found in the Corpus.") + + return selected_utterance + + def random_conversation( + self, selector: Optional[Callable[[Conversation], bool]] = None + ) -> Conversation: + """ + Get a random Conversation from the Corpus, with an optional selector that filters for Conversations that should be considered. + + :param selecter: a (lamda) function that takes a Conversation and returns True or False.(i.e. consider / not consider). + By default, the selector considers all Conversations in the Corpus. + :return: a random Conversation that in the Corpus that is considered based on the selector + """ + count = 0 + selected_conversation = None + # if selector is not provided return random object + if selector == None: + return random.choice(list(self.conversations.values())) + + # Iterate over conversations + for conversation in self.iter_conversations(): + # Apply the filter function if provided + if selector(conversation): + count += 1 + # Reservoir sampling: Replace the current selection with decreasing probability + if random.randint(1, count) == 1: + selected_conversation = conversation + if selected_conversation is None: + raise ValueError("No matching Conversation found in the Corpus.") + + return selected_conversation + + def random_speaker(self, selector: Optional[Callable[[Speaker], bool]] = None) -> Speaker: + """ + Get a random Speaker from the Corpus, with an optional selector that filters for Speakers that should be considered. + + :param selector: a (lambda) function that takes an Speaker and returns True or False (i.e. consider / not consider). + By default, the selector considers all Speakers in the Corpus. + :return: a random Speaker that in the Corpus that is considered based on the selector. + """ + count = 0 + selected_speaker = None + # if selector is not provided return random object + if selector == None: + return random.choice(list(self.speakers.values())) + + # iterate over speakers + for speaker in self.iter_speakers(): - :return: a random Conversation - """ - return random.choice(list(self.conversations.values())) + # Apply the filter function if provided + if selector(speaker): + count += 1 + # Reservoir sampling: Replace the current selection with decreasing probability + if random.randint(1, count) == 1: + selected_speaker = speaker - def random_speaker(self) -> Speaker: - """ - Get a random Speaker from the Corpus + if selected_speaker is None: + raise ValueError("No matching Speaker found in the Corpus.") - :return: a random Speaker - """ - return random.choice(list(self.speakers.values())) + return selected_speaker def iter_utterances( self, selector: Optional[Callable[[Utterance], bool]] = lambda utt: True diff --git a/convokit/ranker/demos/ranker_demo.ipynb b/convokit/ranker/demos/ranker_demo.ipynb index aa045f56..d9dc1deb 100644 --- a/convokit/ranker/demos/ranker_demo.ipynb +++ b/convokit/ranker/demos/ranker_demo.ipynb @@ -24,7 +24,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('subreddit-Cornell'))" + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))" ] }, { @@ -187,7 +187,7 @@ "# Sanity check of (rank, score) pairings\n", "utt_sample = list(corpus.iter_utterances())[:10]\n", "\n", - "sorted([(utt.meta['rank'], utt.meta['score']) for utt in utt_sample], key=lambda x: x[0]) " + "sorted([(utt.meta[\"rank\"], utt.meta[\"score\"]) for utt in utt_sample], key=lambda x: x[0])" ] }, { @@ -203,11 +203,12 @@ "metadata": {}, "outputs": [], "source": [ - "ranker = convokit.Ranker(obj_type=\"speaker\", \n", - " score_func=lambda user: len(list(user.iter_utterances())), \n", - " score_attribute_name=\"num_utts\",\n", - " rank_attribute_name=\"num_utts_rank\"\n", - " )" + "ranker = convokit.Ranker(\n", + " obj_type=\"speaker\",\n", + " score_func=lambda user: len(list(user.iter_utterances())),\n", + " score_attribute_name=\"num_utts\",\n", + " rank_attribute_name=\"num_utts_rank\",\n", + ")" ] }, { @@ -343,7 +344,10 @@ "# Sanity check of (rank, score) pairings\n", "speaker_sample = list(corpus.iter_speakers())[:10]\n", "\n", - "sorted([(spkr.meta['num_utts_rank'], spkr.meta['num_utts']) for spkr in speaker_sample], key=lambda x: x[0]) " + "sorted(\n", + " [(spkr.meta[\"num_utts_rank\"], spkr.meta[\"num_utts\"]) for spkr in speaker_sample],\n", + " key=lambda x: x[0],\n", + ")" ] }, { diff --git a/convokit/redirection/redirectionDemo.ipynb b/convokit/redirection/redirectionDemo.ipynb index 98c2dabe..5d96de6e 100644 --- a/convokit/redirection/redirectionDemo.ipynb +++ b/convokit/redirection/redirectionDemo.ipynb @@ -49,7 +49,11 @@ "from convokit.redirection.likelihoodModel import LikelihoodModel\n", "from convokit.redirection.gemmaLikelihoodModel import GemmaLikelihoodModel\n", "from convokit.redirection.redirection import Redirection\n", - "from convokit.redirection.config import DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG, DEFAULT_TRAIN_CONFIG\n", + "from convokit.redirection.config import (\n", + " DEFAULT_BNB_CONFIG,\n", + " DEFAULT_LORA_CONFIG,\n", + " DEFAULT_TRAIN_CONFIG,\n", + ")\n", "import random\n", "from sklearn.model_selection import train_test_split\n", "import numpy as np\n", @@ -76,7 +80,7 @@ "# corpus = Corpus(DATA_DIR)\n", "\n", "# Otherwise download the corpus\n", - "corpus = Corpus(filename=download('supreme-corpus'))\n", + "corpus = Corpus(filename=download(\"supreme-corpus\"))\n", "corpus.print_summary_stats()" ] }, @@ -97,7 +101,7 @@ "source": [ "convos = [convo for convo in corpus.iter_conversations()]\n", "sample_convos = random.sample(convos, 50)\n", - "print(len(sample_convos))\n" + "print(len(sample_convos))" ] }, { @@ -108,11 +112,11 @@ "outputs": [], "source": [ "for convo in sample_convos:\n", - " for utt in convo.iter_utterances():\n", - " if utt.speaker.id.startswith(\"j_\"):\n", - " utt.meta[\"role\"] = \"justice\"\n", - " else:\n", - " utt.meta[\"role\"] = \"lawyer\"" + " for utt in convo.iter_utterances():\n", + " if utt.speaker.id.startswith(\"j_\"):\n", + " utt.meta[\"role\"] = \"justice\"\n", + " else:\n", + " utt.meta[\"role\"] = \"lawyer\"" ] }, { @@ -135,11 +139,11 @@ "print(len(train_convos), len(val_convos), len(test_convos))\n", "\n", "for convo in train_convos:\n", - " convo.meta[\"train\"] = True\n", - "for convo in val_convos: \n", - " convo.meta[\"val\"] = True \n", + " convo.meta[\"train\"] = True\n", + "for convo in val_convos:\n", + " convo.meta[\"val\"] = True\n", "for convo in test_convos:\n", - " convo.meta[\"test\"] = True " + " convo.meta[\"test\"] = True" ] }, { @@ -159,14 +163,13 @@ "metadata": {}, "outputs": [], "source": [ - "gemma_likelihood_model = \\\n", - " GemmaLikelihoodModel(\n", - " hf_token = \"TODO: ADD HUGGINGFACE AUTH TOKEN\",\n", - " model_id = \"google/gemma-2b\", \n", - " train_config = DEFAULT_TRAIN_CONFIG,\n", - " bnb_config = DEFAULT_BNB_CONFIG,\n", - " lora_config = DEFAULT_LORA_CONFIG,\n", - " )" + "gemma_likelihood_model = GemmaLikelihoodModel(\n", + " hf_token=\"TODO: ADD HUGGINGFACE AUTH TOKEN\",\n", + " model_id=\"google/gemma-2b\",\n", + " train_config=DEFAULT_TRAIN_CONFIG,\n", + " bnb_config=DEFAULT_BNB_CONFIG,\n", + " lora_config=DEFAULT_LORA_CONFIG,\n", + ")" ] }, { @@ -234,13 +237,12 @@ "metadata": {}, "outputs": [], "source": [ - "redirection = \\\n", - " Redirection(\n", - " likelihood_model = gemma_likelihood_model,\n", - " redirection_attribute_name = \"redirection\"\n", - "# previous_context_selector = , \n", - "# future_context_selector = ,\n", - " )" + "redirection = Redirection(\n", + " likelihood_model=gemma_likelihood_model,\n", + " redirection_attribute_name=\"redirection\",\n", + " # previous_context_selector = ,\n", + " # future_context_selector = ,\n", + ")" ] }, { @@ -258,10 +260,11 @@ "metadata": {}, "outputs": [], "source": [ - "redirection.fit(corpus, \n", - " train_selector=lambda convo: \"train\" in convo.meta, \n", - " val_selector=lambda convo: \"val\" in convo.meta\n", - " )" + "redirection.fit(\n", + " corpus,\n", + " train_selector=lambda convo: \"train\" in convo.meta,\n", + " val_selector=lambda convo: \"val\" in convo.meta,\n", + ")" ] }, { @@ -336,13 +339,13 @@ "justice_utts = []\n", "lawyer_utts = []\n", "\n", - "for convo in test_convos: \n", - " for utt in convo.iter_utterances():\n", - " if \"redirection\" in utt.meta:\n", - " if utt.meta[\"role\"] == \"justice\":\n", - " justice_utts.append(utt)\n", - " else:\n", - " lawyer_utts.append(utt)\n", + "for convo in test_convos:\n", + " for utt in convo.iter_utterances():\n", + " if \"redirection\" in utt.meta:\n", + " if utt.meta[\"role\"] == \"justice\":\n", + " justice_utts.append(utt)\n", + " else:\n", + " lawyer_utts.append(utt)\n", "\n", "justice_utts = sorted(justice_utts, key=lambda utt: utt.meta[\"redirection\"])\n", "lawyer_utts = sorted(lawyer_utts, key=lambda utt: utt.meta[\"redirection\"])\n", @@ -351,14 +354,14 @@ "lawyer_threshold = int(len(lawyer_utts) * 0.20)\n", "\n", "for utt in justice_utts[:justice_threshold]:\n", - " utt.meta['type'] = \"justice_low\"\n", + " utt.meta[\"type\"] = \"justice_low\"\n", "for utt in justice_utts[-justice_threshold:]:\n", - " utt.meta['type'] = \"justice_high\"\n", + " utt.meta[\"type\"] = \"justice_high\"\n", "\n", "for utt in lawyer_utts[:lawyer_threshold]:\n", - " utt.meta['type'] = \"lawyer_low\"\n", + " utt.meta[\"type\"] = \"lawyer_low\"\n", "for utt in lawyer_utts[-lawyer_threshold:]:\n", - " utt.meta['type'] = \"lawyer_high\"" + " utt.meta[\"type\"] = \"lawyer_high\"" ] }, { @@ -376,11 +379,14 @@ "metadata": {}, "outputs": [], "source": [ - "fw_justice = FightingWords(ngram_range=(2,2))\n", - "class1 = 'justice_high'\n", - "class2 = 'justice_low'\n", - "fw_justice.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, \n", - " class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)\n", + "fw_justice = FightingWords(ngram_range=(2, 2))\n", + "class1 = \"justice_high\"\n", + "class2 = \"justice_low\"\n", + "fw_justice.fit(\n", + " corpus,\n", + " class1_func=lambda utt: \"type\" in utt.meta and utt.meta[\"type\"] == class1,\n", + " class2_func=lambda utt: \"type\" in utt.meta and utt.meta[\"type\"] == class2,\n", + ")\n", "justice = fw_justice.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)\n", "justice.head(20)" ] @@ -418,11 +424,14 @@ "metadata": {}, "outputs": [], "source": [ - "fw_lawyer = FightingWords(ngram_range=(2,2))\n", - "class1 = 'lawyer_high'\n", - "class2 = 'lawyer_low'\n", - "fw_lawyer.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, \n", - " class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)\n", + "fw_lawyer = FightingWords(ngram_range=(2, 2))\n", + "class1 = \"lawyer_high\"\n", + "class2 = \"lawyer_low\"\n", + "fw_lawyer.fit(\n", + " corpus,\n", + " class1_func=lambda utt: \"type\" in utt.meta and utt.meta[\"type\"] == class1,\n", + " class2_func=lambda utt: \"type\" in utt.meta and utt.meta[\"type\"] == class2,\n", + ")\n", "lawyer = fw_lawyer.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)\n", "lawyer.head(20)" ] @@ -454,7 +463,7 @@ "source": [ "convo_justices = []\n", "convo_lawyers = []\n", - "for convo in test_convos: \n", + "for convo in test_convos:\n", " justice = []\n", " lawyer = []\n", " for utt in convo.iter_utterances():\n", @@ -465,7 +474,7 @@ " lawyer.append(utt.meta[\"redirection\"])\n", " convo_justices.append(np.mean(justice))\n", " convo_lawyers.append(np.mean(lawyer))\n", - " \n", + "\n", "print(\"Average justice:\", np.mean(convo_justices))\n", "print(\"Average lawyer:\", np.mean(convo_lawyers))\n", "stat, p_value = wilcoxon(convo_justices, convo_lawyers)\n", diff --git a/convokit/surprise/demos/surprise_demo.ipynb b/convokit/surprise/demos/surprise_demo.ipynb index 92946d6c..0cf249f9 100644 --- a/convokit/surprise/demos/surprise_demo.ipynb +++ b/convokit/surprise/demos/surprise_demo.ipynb @@ -49,7 +49,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('subreddit-Cornell'))" + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))" ] }, { @@ -86,7 +86,9 @@ "metadata": {}, "outputs": [], "source": [ - "SPEAKER_BLACKLIST = ['[deleted]', 'DeltaBot', 'AutoModerator']\n", + "SPEAKER_BLACKLIST = [\"[deleted]\", \"DeltaBot\", \"AutoModerator\"]\n", + "\n", + "\n", "def utterance_is_valid(utterance):\n", " return utterance.speaker.id not in SPEAKER_BLACKLIST and utterance.text" ] @@ -117,7 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "speaker_activities = corpus.get_attribute_table('speaker', ['n_convos'])" + "speaker_activities = corpus.get_attribute_table(\"speaker\", [\"n_convos\"])" ] }, { @@ -219,7 +221,7 @@ } ], "source": [ - "speaker_activities.sort_values('n_convos', ascending=False).head(10)" + "speaker_activities.sort_values(\"n_convos\", ascending=False).head(10)" ] }, { @@ -228,7 +230,7 @@ "metadata": {}, "outputs": [], "source": [ - "top_speakers = speaker_activities.sort_values('n_convos', ascending=False).head(100).index" + "top_speakers = speaker_activities.sort_values(\"n_convos\", ascending=False).head(100).index" ] }, { @@ -239,7 +241,10 @@ "source": [ "import itertools\n", "\n", - "subset_utts = [list(corpus.get_speaker(speaker).iter_utterances(selector=utterance_is_valid)) for speaker in top_speakers]\n", + "subset_utts = [\n", + " list(corpus.get_speaker(speaker).iter_utterances(selector=utterance_is_valid))\n", + " for speaker in top_speakers\n", + "]\n", "subset_corpus = Corpus(utterances=list(itertools.chain(*subset_utts)))" ] }, @@ -287,9 +292,9 @@ "source": [ "import spacy\n", "\n", - "spacy_nlp = spacy.load('en_core_web_sm', disable=['ner','parser', 'tagger', 'lemmatizer'])\n", + "spacy_nlp = spacy.load(\"en_core_web_sm\", disable=[\"ner\", \"parser\", \"tagger\", \"lemmatizer\"])\n", "for utt in subset_corpus.iter_utterances():\n", - " utt.meta['joined_tokens'] = [t.text.lower() for t in spacy_nlp(utt.text)]" + " utt.meta[\"joined_tokens\"] = [t.text.lower() for t in spacy_nlp(utt.text)]" ] }, { @@ -298,7 +303,14 @@ "metadata": {}, "outputs": [], "source": [ - "surp = Surprise(tokenizer=lambda x: x, model_key_selector=lambda utt: '_'.join([utt.speaker.id, utt.conversation_id]), target_sample_size=100, context_sample_size=1000, n_samples=50, smooth=True)" + "surp = Surprise(\n", + " tokenizer=lambda x: x,\n", + " model_key_selector=lambda utt: \"_\".join([utt.speaker.id, utt.conversation_id]),\n", + " target_sample_size=100,\n", + " context_sample_size=1000,\n", + " n_samples=50,\n", + " smooth=True,\n", + ")" ] }, { @@ -316,7 +328,20 @@ } ], "source": [ - "surp = surp.fit(subset_corpus, text_func=lambda utt: [list(itertools.chain(*[u.meta['joined_tokens'] for u in utt.speaker.iter_utterances() if u.conversation_id != utt.conversation_id]))])" + "surp = surp.fit(\n", + " subset_corpus,\n", + " text_func=lambda utt: [\n", + " list(\n", + " itertools.chain(\n", + " *[\n", + " u.meta[\"joined_tokens\"]\n", + " for u in utt.speaker.iter_utterances()\n", + " if u.conversation_id != utt.conversation_id\n", + " ]\n", + " )\n", + " )\n", + " ],\n", + ")" ] }, { @@ -344,7 +369,7 @@ } ], "source": [ - "transformed_corpus = surp.transform(subset_corpus, obj_type='speaker')" + "transformed_corpus = surp.transform(subset_corpus, obj_type=\"speaker\")" ] }, { @@ -363,10 +388,16 @@ "source": [ "import pandas as pd\n", "from functools import reduce\n", - "def combine_dicts(x,y):\n", + "\n", + "\n", + "def combine_dicts(x, y):\n", " x.update(y)\n", " return x\n", - "surprise_scores = reduce(combine_dicts, transformed_corpus.get_speakers_dataframe()['meta.surprise'].values)\n", + "\n", + "\n", + "surprise_scores = reduce(\n", + " combine_dicts, transformed_corpus.get_speakers_dataframe()[\"meta.surprise\"].values\n", + ")\n", "suprise_series = pd.Series(surprise_scores).dropna()" ] }, diff --git a/convokit/surprise/demos/tennis_demo.ipynb b/convokit/surprise/demos/tennis_demo.ipynb index a5012807..a3e9135d 100644 --- a/convokit/surprise/demos/tennis_demo.ipynb +++ b/convokit/surprise/demos/tennis_demo.ipynb @@ -36,8 +36,8 @@ "metadata": {}, "outputs": [], "source": [ - "PATH = '/home/axl4' # replace with your path to tennis_data directory\n", - "data_dir = f'{PATH}/tennis_data/'" + "PATH = \"/home/axl4\" # replace with your path to tennis_data directory\n", + "data_dir = f\"{PATH}/tennis_data/\"" ] }, { @@ -46,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "corpus_speakers = {'COMMENTATOR': Speaker(id = 'COMMENTATOR', meta = {})}" + "corpus_speakers = {\"COMMENTATOR\": Speaker(id=\"COMMENTATOR\", meta={})}" ] }, { @@ -55,7 +55,7 @@ "metadata": {}, "outputs": [], "source": [ - "with open(data_dir + 'text_commentaries.json', 'r') as f:\n", + "with open(data_dir + \"text_commentaries.json\", \"r\") as f:\n", " commentaries = json.load(f)" ] }, @@ -76,9 +76,17 @@ "utterances = []\n", "count = 0\n", "for c in tqdm(commentaries):\n", - " idx = 'c{}'.format(count)\n", - " meta = {'player_gender': c['gender'], 'scoreline': c['scoreline']}\n", - " utterances.append(Utterance(id=idx, speaker=corpus_speakers['COMMENTATOR'], conversation_id=idx, text=c['commentary'], meta=meta))\n", + " idx = \"c{}\".format(count)\n", + " meta = {\"player_gender\": c[\"gender\"], \"scoreline\": c[\"scoreline\"]}\n", + " utterances.append(\n", + " Utterance(\n", + " id=idx,\n", + " speaker=corpus_speakers[\"COMMENTATOR\"],\n", + " conversation_id=idx,\n", + " text=c[\"commentary\"],\n", + " meta=meta,\n", + " )\n", + " )\n", " count += 1" ] }, @@ -113,7 +121,7 @@ } ], "source": [ - "interview_corpus = Corpus(filename=download('tennis-corpus'))" + "interview_corpus = Corpus(filename=download(\"tennis-corpus\"))" ] }, { @@ -148,8 +156,11 @@ "metadata": {}, "outputs": [], "source": [ - "for utt in interview_corpus.iter_utterances(selector=lambda u: u.meta['is_question']):\n", - " utt.add_meta('player_gender', utt.get_conversation().get_utterance(utt.id.replace('q', 'a')).get_speaker().meta['gender'])" + "for utt in interview_corpus.iter_utterances(selector=lambda u: u.meta[\"is_question\"]):\n", + " utt.add_meta(\n", + " \"player_gender\",\n", + " utt.get_conversation().get_utterance(utt.id.replace(\"q\", \"a\")).get_speaker().meta[\"gender\"],\n", + " )" ] }, { @@ -169,10 +180,18 @@ "source": [ "from nltk import word_tokenize\n", "\n", + "\n", "def tokenizer(text):\n", " return list(filter(lambda w: w.isalnum(), word_tokenize(text.lower())))\n", "\n", - "surp = Surprise(model_key_selector=lambda utt: 'corpus', tokenizer=tokenizer, target_sample_size=10, context_sample_size=None, n_samples=3)" + "\n", + "surp = Surprise(\n", + " model_key_selector=lambda utt: \"corpus\",\n", + " tokenizer=tokenizer,\n", + " target_sample_size=10,\n", + " context_sample_size=None,\n", + " n_samples=3,\n", + ")" ] }, { @@ -207,7 +226,10 @@ } ], "source": [ - "surp.fit(game_commentary_corpus, text_func=lambda utt: [' '.join([u.text for u in game_commentary_corpus.iter_utterances()])])" + "surp.fit(\n", + " game_commentary_corpus,\n", + " text_func=lambda utt: [\" \".join([u.text for u in game_commentary_corpus.iter_utterances()])],\n", + ")" ] }, { @@ -226,9 +248,18 @@ "import itertools\n", "\n", "SAMPLE = True\n", - "SAMPLE_SIZE = 10000 # edit this to change the number of interview questions to calculate surprise for\n", + "SAMPLE_SIZE = (\n", + " 10000 # edit this to change the number of interview questions to calculate surprise for\n", + ")\n", "\n", - "subset_utts = [interview_corpus.get_utterance(utt) for utt in interview_corpus.get_utterances_dataframe(selector=lambda utt: utt.meta['is_question']).sample(SAMPLE_SIZE).index]\n", + "subset_utts = [\n", + " interview_corpus.get_utterance(utt)\n", + " for utt in interview_corpus.get_utterances_dataframe(\n", + " selector=lambda utt: utt.meta[\"is_question\"]\n", + " )\n", + " .sample(SAMPLE_SIZE)\n", + " .index\n", + "]\n", "subset_corpus = Corpus(utterances=subset_utts) if SAMPLE else interview_corpus" ] }, @@ -263,7 +294,7 @@ } ], "source": [ - "surp.transform(subset_corpus, obj_type='utterance', selector=lambda utt: utt.meta['is_question'])" + "surp.transform(subset_corpus, obj_type=\"utterance\", selector=lambda utt: utt.meta[\"is_question\"])" ] }, { @@ -280,7 +311,7 @@ "metadata": {}, "outputs": [], "source": [ - "utterances = subset_corpus.get_utterances_dataframe(selector=lambda utt: utt.meta['is_question'])" + "utterances = subset_corpus.get_utterances_dataframe(selector=lambda utt: utt.meta[\"is_question\"])" ] }, { @@ -302,7 +333,9 @@ "source": [ "import pandas as pd\n", "\n", - "female_qs = pd.to_numeric(utterances[utterances['meta.player_gender'] == 'F']['meta.surprise']).dropna()\n", + "female_qs = pd.to_numeric(\n", + " utterances[utterances[\"meta.player_gender\"] == \"F\"][\"meta.surprise\"]\n", + ").dropna()\n", "female_qs.median()" ] }, @@ -323,7 +356,9 @@ } ], "source": [ - "male_qs = pd.to_numeric(utterances[utterances['meta.player_gender'] == 'M']['meta.surprise']).dropna()\n", + "male_qs = pd.to_numeric(\n", + " utterances[utterances[\"meta.player_gender\"] == \"M\"][\"meta.surprise\"]\n", + ").dropna()\n", "male_qs.median()" ] }, @@ -454,7 +489,12 @@ "metadata": {}, "outputs": [], "source": [ - "gender_models_surp = Surprise(model_key_selector=lambda utt: utt.meta['player_gender'], target_sample_size=10, context_sample_size=5000, surprise_attr_name='surprise_gender_model')" + "gender_models_surp = Surprise(\n", + " model_key_selector=lambda utt: utt.meta[\"player_gender\"],\n", + " target_sample_size=10,\n", + " context_sample_size=5000,\n", + " surprise_attr_name=\"surprise_gender_model\",\n", + ")" ] }, { @@ -482,7 +522,7 @@ } ], "source": [ - "gender_models_surp.fit(interview_corpus, selector=lambda utt: utt.meta['is_question'])" + "gender_models_surp.fit(interview_corpus, selector=lambda utt: utt.meta[\"is_question\"])" ] }, { @@ -518,7 +558,13 @@ } ], "source": [ - "gender_models_surp.transform(subset_corpus, obj_type='utterance', group_and_models=lambda utt: (utt.id, ['M', 'F']), group_model_attr_key=lambda _, m: m, selector=lambda utt: utt.meta['is_question'])" + "gender_models_surp.transform(\n", + " subset_corpus,\n", + " obj_type=\"utterance\",\n", + " group_and_models=lambda utt: (utt.id, [\"M\", \"F\"]),\n", + " group_model_attr_key=lambda _, m: m,\n", + " selector=lambda utt: utt.meta[\"is_question\"],\n", + ")" ] }, { @@ -535,7 +581,7 @@ "metadata": {}, "outputs": [], "source": [ - "utterances = subset_corpus.get_utterances_dataframe(selector=lambda utt: utt.meta['is_question'])" + "utterances = subset_corpus.get_utterances_dataframe(selector=lambda utt: utt.meta[\"is_question\"])" ] }, { @@ -555,7 +601,9 @@ } ], "source": [ - "utterances[utterances['meta.player_gender'] == 'F']['meta.surprise_gender_model'].map(lambda x: x['M']).dropna().mean()" + "utterances[utterances[\"meta.player_gender\"] == \"F\"][\"meta.surprise_gender_model\"].map(\n", + " lambda x: x[\"M\"]\n", + ").dropna().mean()" ] }, { @@ -575,7 +623,9 @@ } ], "source": [ - "utterances[utterances['meta.player_gender'] == 'F']['meta.surprise_gender_model'].map(lambda x: x['F']).dropna().mean()" + "utterances[utterances[\"meta.player_gender\"] == \"F\"][\"meta.surprise_gender_model\"].map(\n", + " lambda x: x[\"F\"]\n", + ").dropna().mean()" ] }, { @@ -595,7 +645,9 @@ } ], "source": [ - "utterances[utterances['meta.player_gender'] == 'M']['meta.surprise_gender_model'].map(lambda x: x['M']).dropna().mean()" + "utterances[utterances[\"meta.player_gender\"] == \"M\"][\"meta.surprise_gender_model\"].map(\n", + " lambda x: x[\"M\"]\n", + ").dropna().mean()" ] }, { @@ -615,7 +667,9 @@ } ], "source": [ - "utterances[utterances['meta.player_gender'] == 'M']['meta.surprise_gender_model'].map(lambda x: x['F']).dropna().mean()" + "utterances[utterances[\"meta.player_gender\"] == \"M\"][\"meta.surprise_gender_model\"].map(\n", + " lambda x: x[\"F\"]\n", + ").dropna().mean()" ] }, { diff --git a/convokit/tests/notebook_testers/convokitIndex_issues_demo.ipynb b/convokit/tests/notebook_testers/convokitIndex_issues_demo.ipynb index 212183fb..fc704139 100644 --- a/convokit/tests/notebook_testers/convokitIndex_issues_demo.ipynb +++ b/convokit/tests/notebook_testers/convokitIndex_issues_demo.ipynb @@ -7,7 +7,8 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('../..')" + "\n", + "os.chdir(\"../..\")" ] }, { @@ -26,7 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "utterances = [Utterance(id=str(i), speaker=User(id='speaker'+str(i))) for i in range(10)]" + "utterances = [Utterance(id=str(i), speaker=User(id=\"speaker\" + str(i))) for i in range(10)]" ] }, { @@ -157,7 +158,7 @@ "outputs": [], "source": [ "for utt in corpus.iter_utterances():\n", - " utt.meta['good_meta'] = 1" + " utt.meta[\"good_meta\"] = 1" ] }, { @@ -190,8 +191,8 @@ "metadata": {}, "outputs": [], "source": [ - "for utt in corpus.iter_utterances(): # annotate first utt\n", - " utt.meta['okay_meta'] = 1\n", + "for utt in corpus.iter_utterances(): # annotate first utt\n", + " utt.meta[\"okay_meta\"] = 1\n", " break" ] }, @@ -225,10 +226,10 @@ "metadata": {}, "outputs": [], "source": [ - "idx = 1 ## \n", - "for utt in corpus.iter_utterances(): # annotate second utt\n", + "idx = 1 ##\n", + "for utt in corpus.iter_utterances(): # annotate second utt\n", " if idx == 2:\n", - " utt.meta['okay_meta2'] = 1\n", + " utt.meta[\"okay_meta2\"] = 1\n", " idx += 1" ] }, @@ -262,11 +263,11 @@ "metadata": {}, "outputs": [], "source": [ - "for idx, utt in enumerate(corpus.iter_utterances()): # annotate alternating utts\n", + "for idx, utt in enumerate(corpus.iter_utterances()): # annotate alternating utts\n", " if idx % 2:\n", - " utt.meta['bad_meta'] = 1\n", + " utt.meta[\"bad_meta\"] = 1\n", " else:\n", - " utt.meta['bad_meta'] = None\n" + " utt.meta[\"bad_meta\"] = None" ] }, { @@ -318,7 +319,7 @@ "outputs": [], "source": [ "for utt in corpus.iter_utterances():\n", - " utt.meta['to_be_deleted'] = 1" + " utt.meta[\"to_be_deleted\"] = 1" ] }, { @@ -351,7 +352,7 @@ "metadata": {}, "outputs": [], "source": [ - "del corpus.random_utterance().meta['to_be_deleted']" + "del corpus.random_utterance().meta[\"to_be_deleted\"]" ] }, { @@ -378,7 +379,7 @@ ], "source": [ "for utt in corpus.iter_utterances():\n", - " print(utt.meta.get('to_be_deleted', None))" + " print(utt.meta.get(\"to_be_deleted\", None))" ] }, { diff --git a/convokit/tests/notebook_testers/exclude_meta_tests.ipynb b/convokit/tests/notebook_testers/exclude_meta_tests.ipynb index 62c435b0..ce619b81 100644 --- a/convokit/tests/notebook_testers/exclude_meta_tests.ipynb +++ b/convokit/tests/notebook_testers/exclude_meta_tests.ipynb @@ -7,7 +7,8 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('../../..')\n", + "\n", + "os.chdir(\"../../..\")\n", "import convokit" ] }, @@ -34,7 +35,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('subreddit-lol'))" + "corpus = Corpus(filename=download(\"subreddit-lol\"))" ] }, { @@ -79,10 +80,13 @@ } ], "source": [ - "corpus2 = Corpus(filename=download('subreddit-lol'), exclude_conversation_meta=['subreddit'],\n", - " exclude_speaker_meta=['num_posts'],\n", - " exclude_utterance_meta=['score'],\n", - " exclude_overall_meta=['num_posts'])" + "corpus2 = Corpus(\n", + " filename=download(\"subreddit-lol\"),\n", + " exclude_conversation_meta=[\"subreddit\"],\n", + " exclude_speaker_meta=[\"num_posts\"],\n", + " exclude_utterance_meta=[\"score\"],\n", + " exclude_overall_meta=[\"num_posts\"],\n", + ")" ] }, { diff --git a/convokit/tests/notebook_testers/reindex_conversations_example.ipynb b/convokit/tests/notebook_testers/reindex_conversations_example.ipynb index 7132649b..18cf2f57 100644 --- a/convokit/tests/notebook_testers/reindex_conversations_example.ipynb +++ b/convokit/tests/notebook_testers/reindex_conversations_example.ipynb @@ -7,7 +7,8 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('..')" + "\n", + "os.chdir(\"..\")" ] }, { @@ -70,28 +71,24 @@ " 4 5 6 7 8 9\n", "10 11\n", "\"\"\"\n", - "corpus = Corpus(utterances = [\n", - " Utterance(id=\"0\", reply_to=None, root=\"0\", user=User(name=\"alice\"), timestamp=0),\n", - "\n", - " Utterance(id=\"2\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=2),\n", - " Utterance(id=\"1\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=1),\n", - " Utterance(id=\"3\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=3),\n", - "\n", - " Utterance(id=\"4\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", - " Utterance(id=\"5\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", - " Utterance(id=\"6\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=6),\n", - "\n", - " Utterance(id=\"7\", reply_to=\"2\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", - " Utterance(id=\"8\", reply_to=\"2\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", - "\n", - " Utterance(id=\"9\", reply_to=\"3\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", - "\n", - " Utterance(id=\"10\", reply_to=\"4\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", - "\n", - " Utterance(id=\"11\", reply_to=\"9\", root=\"0\", user=User(name=\"alice\"), timestamp=10),\n", "\n", - " Utterance(id=\"other\", reply_to=None, root=\"other\", user=User(name=\"alice\"), timestamp=99)\n", - "])" + "corpus = Corpus(\n", + " utterances=[\n", + " Utterance(id=\"0\", reply_to=None, root=\"0\", user=User(name=\"alice\"), timestamp=0),\n", + " Utterance(id=\"2\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=2),\n", + " Utterance(id=\"1\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=1),\n", + " Utterance(id=\"3\", reply_to=\"0\", root=\"0\", user=User(name=\"alice\"), timestamp=3),\n", + " Utterance(id=\"4\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", + " Utterance(id=\"5\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", + " Utterance(id=\"6\", reply_to=\"1\", root=\"0\", user=User(name=\"alice\"), timestamp=6),\n", + " Utterance(id=\"7\", reply_to=\"2\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", + " Utterance(id=\"8\", reply_to=\"2\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", + " Utterance(id=\"9\", reply_to=\"3\", root=\"0\", user=User(name=\"alice\"), timestamp=4),\n", + " Utterance(id=\"10\", reply_to=\"4\", root=\"0\", user=User(name=\"alice\"), timestamp=5),\n", + " Utterance(id=\"11\", reply_to=\"9\", root=\"0\", user=User(name=\"alice\"), timestamp=10),\n", + " Utterance(id=\"other\", reply_to=None, root=\"other\", user=User(name=\"alice\"), timestamp=99),\n", + " ]\n", + ")" ] }, { @@ -107,8 +104,8 @@ "metadata": {}, "outputs": [], "source": [ - "corpus.get_conversation(\"0\").meta['hey'] = 'jude'\n", - "corpus.meta['foo'] = 'bar'" + "corpus.get_conversation(\"0\").meta[\"hey\"] = \"jude\"\n", + "corpus.meta[\"foo\"] = \"bar\"" ] }, { diff --git a/convokit/text_processing/demo/cleaning_text.ipynb b/convokit/text_processing/demo/cleaning_text.ipynb index 846ca438..64bc4fb5 100644 --- a/convokit/text_processing/demo/cleaning_text.ipynb +++ b/convokit/text_processing/demo/cleaning_text.ipynb @@ -7,7 +7,8 @@ "outputs": [], "source": [ "import os\n", - "os.chdir('../../..')" + "\n", + "os.chdir(\"../../..\")" ] }, { @@ -58,7 +59,7 @@ } ], "source": [ - "corpus = Corpus(filename=download('subreddit-Cornell'))" + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))" ] }, { @@ -97,7 +98,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').text" + "corpus.get_utterance(\"15enm8\").text" ] }, { @@ -151,7 +152,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').text" + "corpus.get_utterance(\"15enm8\").text" ] }, { @@ -179,7 +180,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').meta" + "corpus.get_utterance(\"15enm8\").meta" ] }, { @@ -222,7 +223,7 @@ ], "source": [ "cleaner = TextCleaner(replace_text=True, save_original=True, verbosity=10000)\n", - "corpus = Corpus(filename=download('subreddit-Cornell'))\n", + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))\n", "cleaner.transform(corpus)" ] }, @@ -243,7 +244,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').text" + "corpus.get_utterance(\"15enm8\").text" ] }, { @@ -263,7 +264,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').meta['original']" + "corpus.get_utterance(\"15enm8\").meta[\"original\"]" ] }, { @@ -306,7 +307,7 @@ ], "source": [ "cleaner = TextCleaner(replace_text=False, verbosity=10000)\n", - "corpus = Corpus(filename=download('subreddit-Cornell'))\n", + "corpus = Corpus(filename=download(\"subreddit-Cornell\"))\n", "cleaner.transform(corpus)" ] }, @@ -327,7 +328,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').text" + "corpus.get_utterance(\"15enm8\").text" ] }, { @@ -347,7 +348,7 @@ } ], "source": [ - "corpus.get_utterance('15enm8').meta['cleaned']" + "corpus.get_utterance(\"15enm8\").meta[\"cleaned\"]" ] }, {