Skip to content

Commit ad3488a

Browse files
committed
A simpler titlematch algo
[skip ci]
1 parent ec1aa9b commit ad3488a

1 file changed

Lines changed: 180 additions & 0 deletions

File tree

‎scripts/local_gdrive.py‎

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sqlite3
44
from pathlib import Path
5+
import re
56
from typing import List, Dict, Any, Optional
67
from time import sleep
78

@@ -14,6 +15,8 @@
1415
import threading
1516
from functools import wraps
1617

18+
from rapidfuzz import fuzz, process
19+
1720
def UTC_NOW():
1821
now_utc = datetime.now(timezone.utc)
1922
return now_utc.isoformat(timespec='milliseconds').replace('+00:00', 'Z')
@@ -76,6 +79,9 @@ def __init__(self, db_path: str | Path):
7679
Args:
7780
db_path: The file path for the SQLite database.
7881
"""
82+
self.MIN_TITLE_LEN = 20 # do ratio matches on strings this long
83+
self.MIN_PARTIAL_LEN = 35 # do partial matches when really long
84+
self._clear_title_cache()
7985
self.db_path = db_path
8086
self.conn = sqlite3.connect(db_path, check_same_thread=False)
8187
self._lock = threading.RLock()
@@ -533,6 +539,115 @@ def files_exactly_named(self, name: str) -> List[Dict[str, Any]]:
533539
def files_originally_named_exactly(self, name: str) -> List[Dict[str, Any]]:
534540
return self.sql_query("original_name = ?", (name,))
535541

542+
def _clear_title_cache(self):
543+
self.title_cache = []
544+
self.title_cache_ids = []
545+
self.partial_title_cache = []
546+
self.partial_title_cache_ids = []
547+
548+
def _add_string_to_title_cache(self, gid: str, needle: str):
549+
if len(needle) < self.MIN_TITLE_LEN:
550+
return
551+
self.title_cache.append(needle)
552+
self.title_cache_ids.append(gid)
553+
if len(needle) >= self.MIN_PARTIAL_LEN:
554+
self.partial_title_cache.append(needle)
555+
self.partial_title_cache_ids.append(gid)
556+
557+
def _add_name_to_title_cache(self, gid: str, name: str):
558+
if name[-4:].lower() != '.pdf':
559+
return # for now we're only interested in pdfs
560+
name = name[:-4]
561+
self._add_string_to_title_cache(gid, name)
562+
if ' - ' in name:
563+
# In our naming convention, " - " comes before the authors
564+
name = name.split(' - ')[0]
565+
self._add_string_to_title_cache(gid, name)
566+
if '(' in name or ']' in name:
567+
# parens hold non-title metadata
568+
name = re.sub(r'\s*[\(\[][^)]*[\)\]]', '', name)
569+
self._add_string_to_title_cache(gid, name)
570+
parts = name.split('_ ')
571+
if len(parts) == 2:
572+
self._add_string_to_title_cache(gid, parts[0])
573+
self._add_string_to_title_cache(gid, parts[1])
574+
575+
def rebuild_title_cache(self):
576+
self._clear_title_cache()
577+
with self._lock:
578+
self.cursor.execute(
579+
"SELECT id, name FROM drive_items WHERE owner = 1 AND mime_type = ? AND shortcut_target IS NULL",
580+
('application/pdf',),
581+
)
582+
pdf_files = self.cursor.fetchall()
583+
for file in pdf_files:
584+
self._add_name_to_title_cache(file['id'], file['name'])
585+
586+
def _title_match(self, needle: str) -> str | None:
587+
"""Returns the Google ID of the match or None"""
588+
if not self.title_cache:
589+
self.rebuild_title_cache()
590+
match = process.extractOne(
591+
needle,
592+
self.partial_title_cache if len(needle) >= self.MIN_PARTIAL_LEN else self.title_cache,
593+
score_cutoff=81, # see: calculate_ideal_title_match_threshhold()
594+
scorer=fuzz.partial_ratio if len(needle) >= self.MIN_PARTIAL_LEN else fuzz.ratio,
595+
)
596+
if not match:
597+
return None
598+
if len(needle) >= self.MIN_PARTIAL_LEN:
599+
return self.partial_title_cache_ids[match[2]]
600+
return self.title_cache_ids[match[2]]
601+
602+
def max_score_for_title(self, needle: str, gid: str = None) -> float:
603+
# Has to match the logic of probable_pdf_id_for_title
604+
scores = []
605+
scores.append(self._max_score_for_title(needle, gid=gid))
606+
if ': ' in needle:
607+
for part in needle.split(': '):
608+
scores.append(self._max_score_for_title(needle, gid=gid))
609+
return max(scores, default=0)
610+
611+
def _max_score_for_title(self, needle: str, gid: str = None) -> float:
612+
"""If gid is provided, only consider the strings for that id"""
613+
if len(needle) < self.MIN_TITLE_LEN:
614+
return 0
615+
comparitor = fuzz.partial_ratio
616+
id_list = self.partial_title_cache_ids
617+
haystack = self.partial_title_cache
618+
if len(needle) < self.MIN_PARTIAL_LEN:
619+
comparitor = fuzz.ratio
620+
haystack = self.title_cache
621+
id_list = self.title_cache_ids
622+
return max(
623+
[comparitor(
624+
needle,
625+
haystack[idx],
626+
)
627+
for idx, iid in enumerate(id_list)
628+
if iid == gid or gid is None],
629+
default=0,
630+
)
631+
632+
def probable_pdf_id_for_title(self, title: str) -> str | None:
633+
"""Does a fuzzy check for PDF files matching the title
634+
If it found a match, returns the Drive ID of the file
635+
636+
Note: this is about 85% accurate.
637+
About 15% of the time it retrieves a different PDF with a similar title
638+
"""
639+
match = self._title_match(title.replace(':', '_'))
640+
if match:
641+
return match
642+
if ': ' in title:
643+
for part in title.split(': '):
644+
if len(part) < self.MIN_TITLE_LEN:
645+
continue
646+
match = self._title_match(part)
647+
if match:
648+
return match
649+
return None
650+
536651
@locked
537652
def find_duplicate_md5s(self) -> List[str]:
538653
"""
@@ -604,6 +719,7 @@ def rename_file(self, file_id: str, new_name: str):
604719
with self._lock:
605720
self.cursor.execute("UPDATE drive_items SET name = ? WHERE id = ?", (new_name, file_id))
606721
self.conn.commit()
722+
self._clear_title_cache() # maybe be more intelligent if needed
607723

608724
def create_folder(self, folder_name: str, parent_id: str) -> str:
609725
"""Creates a new folder with the name and parent and rets the new id"""
@@ -667,6 +783,7 @@ def upload_file(self, fp: Path, filename=None, folder_id=None) -> str | None:
667783
)
668784
)
669785
)
786+
self._add_name_to_title_cache(ret, filename or fp.name)
670787
return ret
671788

672789
######
@@ -688,3 +805,66 @@ def __enter__(self):
688805
def __exit__(self, exc_type, exc_value, traceback):
689806
"""Context manager exit. Ensures connection is closed."""
690807
self.close()
808+
809+
810+
def calculate_ideal_title_match_threshhold(minc: int):
811+
import website
812+
from yaspin import yaspin
813+
with yaspin(text="Loading website..."):
814+
website.load()
815+
import gdrive
816+
articles = [c for c in website.content
817+
if c.category == 'articles' and
818+
c.drive_links and
819+
c.formats[0] == 'pdf'
820+
]
821+
avs = [c for c in website.content
822+
if c.category == "av" and 'pdf' not in c.get("formats", [])
823+
]
824+
gdrive.gcache.MIN_TITLE_LEN = minc
825+
print(f"Set MINC={minc}")
826+
with yaspin(text="Building title cache..."):
827+
gdrive.gcache.rebuild_title_cache()
828+
with yaspin(text="Calculating article scores..."):
829+
article_scores = [
830+
gdrive.gcache.max_score_for_title(
831+
c.title,
832+
gdrive.link_to_id(c.drive_links[0]),
833+
)
834+
for c in articles
835+
]
836+
print(f"Found {sum(s > 0.5 for s in article_scores)}/{len(articles)} articles")
837+
with yaspin(text=f"Scoring {len(avs)} AV items..."):
838+
av_scores = [
839+
gdrive.gcache.max_score_for_title(
840+
c.title
841+
)
842+
for c in avs
843+
]
844+
print(f"Found {sum(s > 0.5 for s in av_scores)}/{len(avs)} AVs")
845+
import numpy as np
846+
y_scores = np.concatenate([article_scores, av_scores])
847+
y_true = np.concatenate([
848+
np.ones(len(articles)), # we should ideally find all these
849+
np.zeros(len(avs)), # we should ideally not find any of these
850+
])
851+
from sklearn.metrics import roc_curve
852+
fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores)
853+
from sklearn.metrics import precision_recall_curve, average_precision_score
854+
precision, recall, pr_thresholds = precision_recall_curve(y_true, y_scores)
855+
ap = average_precision_score(y_true, y_scores)
856+
f1 = 2 * precision * recall / (precision + recall + 1e-12)
857+
best_idx = np.argmax(f1)
858+
best_threshold_pr = pr_thresholds[best_idx]
859+
print(f"Best threshhold value to balance precision and recall is {best_threshold_pr} (F1={f1[best_idx]})")
860+
j_scores = tpr - fpr
861+
best_idx = np.argmax(j_scores)
862+
best_threshold_roc = roc_thresholds[best_idx]
863+
print(f"Best threshold for J score is {best_threshold_roc} (J={j_scores[best_idx]})")
864+
865+
if __name__ == "__main__":
866+
import argparse
867+
parser = argparse.ArgumentParser()
868+
parser.add_argument("--minc", type=int, default=20)
869+
args = parser.parse_args()
870+
calculate_ideal_title_match_threshhold(args.minc)

0 commit comments

Comments
 (0)