|
| 1 | +#!/bin/python3 |
| 2 | + |
| 3 | +from gdrive_base import DRIVE_LINK, link_to_id |
| 4 | +import website |
| 5 | +from yaspin import yaspin |
| 6 | +import pickle |
| 7 | +import re |
| 8 | +import joblib |
| 9 | +import random |
| 10 | +from rapidfuzz import fuzz |
| 11 | +from itertools import chain |
| 12 | +from sklearn.neural_network import MLPClassifier |
| 13 | +from typing import Iterable |
| 14 | +from functools import cache |
| 15 | + |
| 16 | +parentheses = re.compile(r'\s*[\(\[][^)]*[\)\]]') |
| 17 | +CLASSIFIER_FILE = 'titlematch.classifier' |
| 18 | +classifier: MLPClassifier |
| 19 | +classifier = None |
| 20 | + |
| 21 | +def probability_filename_matches( |
| 22 | + filename: str | Iterable[str], |
| 23 | + work_title: str | Iterable[str], |
| 24 | + first_author: str | Iterable[str], |
| 25 | +) -> float | list[float]: |
| 26 | + """Returns the match probability between `filename`(s) and the work(s) |
| 27 | + |
| 28 | + At most one of (title, author) and (filename) can be Iterable. |
| 29 | + It won't fill a whole 2D matrix |
| 30 | + |
| 31 | + Returns: |
| 32 | + 0 -> 1 P(match) |
| 33 | + The optimal cutoff for Balanced Accuracy (=87.3%) is 0.6629""" |
| 34 | + assert type(work_title) == type(first_author), "work_title and first_author must be the same type" |
| 35 | + if isinstance(work_title, str) and isinstance(filename, str): |
| 36 | + parsed_name = split_file_name(filename) |
| 37 | + features = extract_feature_vector_for_item_parsed_name_pair( |
| 38 | + work_title, |
| 39 | + first_author, |
| 40 | + parsed_name, |
| 41 | + ) |
| 42 | + return float(classifier.predict_proba([features])[0][1]) |
| 43 | + if isinstance(filename, str): |
| 44 | + parsed_name = split_file_name(filename) |
| 45 | + features = [] |
| 46 | + for title, author in zip(work_title, first_author): |
| 47 | + features.append(extract_feature_vector_for_item_parsed_name_pair( |
| 48 | + title, |
| 49 | + author, |
| 50 | + parsed_name, |
| 51 | + )) |
| 52 | + return [float(ps[1]) for ps in classifier.predict_proba(features)] |
| 53 | + if isinstance(first_author, str): |
| 54 | + features = [] |
| 55 | + for fname in filename: |
| 56 | + features.append(extract_feature_vector_for_item_parsed_name_pair( |
| 57 | + work_title, |
| 58 | + first_author, |
| 59 | + split_file_name(fname), |
| 60 | + )) |
| 61 | + return [float(ps[1]) for ps in classifier.predict_proba(features)] |
| 62 | + raise ValueError("Unknown type combination") |
| 63 | + |
| 64 | +@cache |
| 65 | +def split_file_name(filename: str) -> tuple[str, str, str]: |
| 66 | + """Returns guessed (title, subtitle, author) strings |
| 67 | + Based on the naive assumption of a "Title_ Subtitle - Author.pdf" name |
| 68 | + """ |
| 69 | + ret = ['','',''] |
| 70 | + if filename.lower().endswith('.pdf'): |
| 71 | + filename = filename[:-4] |
| 72 | + filename = parentheses.sub('', filename) |
| 73 | + filename = filename.replace('_-_', ' - ') |
| 74 | + if ' - ' in filename: |
| 75 | + auth_split = filename.split(' - ') |
| 76 | + ret[2] = auth_split[-1] |
| 77 | + # treat multiple ' - 's as :s that became _s |
| 78 | + filename = '_ '.join(auth_split[:-1]) |
| 79 | + filename = filename.replace(': ', '_ ') |
| 80 | + if '_ ' in filename: |
| 81 | + ret[0] = filename.split('_ ')[0] |
| 82 | + ret[1] = filename[len(ret[0])+2:] |
| 83 | + else: |
| 84 | + ret[0] = filename |
| 85 | + return tuple(ret) |
| 86 | + |
| 87 | +def extract_feature_vector_for_item_parsed_name_pair( |
| 88 | + true_title: str, |
| 89 | + first_author: str, |
| 90 | + split_file_name: tuple[str, str, str], # from above |
| 91 | +) -> tuple[float, int, int, float, int, int, float, int]: |
| 92 | + assert split_file_name[0], f"No title in {split_file_name}" |
| 93 | + if ': ' in true_title: |
| 94 | + title = true_title.split(': ') |
| 95 | + subtitle = ': '.join(title[1:]) |
| 96 | + title = title[0] |
| 97 | + if split_file_name[1]: |
| 98 | + return ( |
| 99 | + fuzz.partial_ratio(split_file_name[0], title), |
| 100 | + len(split_file_name[0]), |
| 101 | + len(title), |
| 102 | + fuzz.partial_ratio(split_file_name[1], subtitle), |
| 103 | + len(split_file_name[1]), |
| 104 | + len(subtitle), |
| 105 | + fuzz.token_sort_ratio(split_file_name[2], first_author), |
| 106 | + len(split_file_name[2]), |
| 107 | + ) |
| 108 | + return ( |
| 109 | + fuzz.partial_ratio(split_file_name[0], title), |
| 110 | + len(split_file_name[0]), |
| 111 | + len(title), |
| 112 | + fuzz.partial_ratio(split_file_name[0], subtitle), |
| 113 | + len(split_file_name[0]), |
| 114 | + len(subtitle), |
| 115 | + fuzz.token_sort_ratio(split_file_name[2], first_author), |
| 116 | + len(split_file_name[2]), |
| 117 | + ) |
| 118 | + # else there is no : in the true_title |
| 119 | + if split_file_name[1]: |
| 120 | + # But this file thinks there should be a subtitle |
| 121 | + return ( |
| 122 | + fuzz.partial_ratio(split_file_name[0], true_title), |
| 123 | + len(split_file_name[0]), |
| 124 | + len(true_title), |
| 125 | + fuzz.partial_ratio(split_file_name[1], true_title), |
| 126 | + len(split_file_name[1]), |
| 127 | + 0, |
| 128 | + fuzz.token_sort_ratio(split_file_name[2], first_author), |
| 129 | + len(split_file_name[2]), |
| 130 | + ) |
| 131 | + # else no subtitle and not expecting one either |
| 132 | + return ( |
| 133 | + fuzz.partial_ratio(split_file_name[0], true_title), |
| 134 | + len(split_file_name[0]), |
| 135 | + len(true_title), |
| 136 | + 100.0, # '' == '' Perfect match! |
| 137 | + 0, |
| 138 | + 0, |
| 139 | + fuzz.token_sort_ratio(split_file_name[2], first_author), |
| 140 | + len(split_file_name[2]), |
| 141 | + ) |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + print("Welcome to the titlematch.py trainer") |
| 146 | + from gdrive import gcache, gcache_folder |
| 147 | + from sklearn.model_selection import GridSearchCV |
| 148 | + import heapq |
| 149 | + from tqdm import tqdm |
| 150 | + |
| 151 | + with yaspin(text="Loading website..."): |
| 152 | + website.load() |
| 153 | + print("Website loaded") |
| 154 | + |
| 155 | + disk_memorizor = joblib.Memory(gcache_folder, verbose=0) |
| 156 | + |
| 157 | + website_content_with_pdfs = [ |
| 158 | + c for c in website.content if |
| 159 | + c.formats[0] == 'pdf' and c.get('drive_links') |
| 160 | + and str(c.drive_links[0]).startswith(DRIVE_LINK.split('{}')[0]) and |
| 161 | + c.get('authors') |
| 162 | + ] |
| 163 | + print(f"Found {len(website_content_with_pdfs)} content items with PDFs") |
| 164 | + drive_file_names = [] |
| 165 | + for item in website_content_with_pdfs: |
| 166 | + drive_id = link_to_id(item['drive_links'][0]) |
| 167 | + drive_file = gcache.get_item(drive_id) |
| 168 | + assert drive_file is not None, f"No file found in gcache for {DRIVE_LINK.format(drive_id)}" |
| 169 | + assert drive_file['name'].lower().endswith('.pdf'), f"File is called 'pdf' by the website: {DRIVE_LINK.format(drive_id)}" |
| 170 | + assert drive_file['name'] not in drive_file_names, f"Multiple files found with name = \"{drive_file['name']}\"" |
| 171 | + drive_file_names.append(drive_file['name']) |
| 172 | + |
| 173 | + print("Loading the full feature vector matrix...") |
| 174 | + |
| 175 | + @disk_memorizor.cache() |
| 176 | + def build_full_feature_vector_matrix_for_items( |
| 177 | + content_paths: list[str], |
| 178 | + drive_file_names: list[str], |
| 179 | + ): |
| 180 | + parsed_file_names = [split_file_name(fn) for fn in drive_file_names] |
| 181 | + # expand from possibly-pickled IDs |
| 182 | + website_content = { |
| 183 | + c.content_path: c for c in website.content |
| 184 | + } |
| 185 | + website_content = [ |
| 186 | + website_content[cpath] for cpath in content_paths |
| 187 | + ] |
| 188 | + ret = [] |
| 189 | + print("Building the full training data feature matrix...", flush=True) |
| 190 | + for item in tqdm(website_content, unit='i'): |
| 191 | + row = [] |
| 192 | + for parsed_name in parsed_file_names: |
| 193 | + row.append( |
| 194 | + extract_feature_vector_for_item_parsed_name_pair( |
| 195 | + item.title, |
| 196 | + website.normalized_author_name(item.authors[0]), |
| 197 | + parsed_name, |
| 198 | + ) |
| 199 | + ) |
| 200 | + ret.append(row) |
| 201 | + return ret |
| 202 | + |
| 203 | + full_feature_vector_matrix = build_full_feature_vector_matrix_for_items( |
| 204 | + [c.content_path for c in website_content_with_pdfs], # squish to IDs for pickling |
| 205 | + drive_file_names, |
| 206 | + ) |
| 207 | + print("Selecting samples for X and y...") |
| 208 | + y = [] |
| 209 | + X = [] |
| 210 | + for row_i, row in enumerate(full_feature_vector_matrix): |
| 211 | + X.append(row[row_i]) # self-similarity features |
| 212 | + y.append(1) # I am myself |
| 213 | + # Now to find a few negative examples (don't just append all) |
| 214 | + # We pick randomly among the highest title, subtitle, and author scores |
| 215 | + # Along with three others completely at random |
| 216 | + highest_titles = [] |
| 217 | + highest_title_score = 0 |
| 218 | + highest_subtitles = [] |
| 219 | + highest_subtitle_score = 0 |
| 220 | + highest_authors = [] |
| 221 | + highest_author_score = 0 |
| 222 | + for col_j, col in chain(enumerate(row[:row_i]), enumerate(row[row_i+1:], start=row_i+1)): |
| 223 | + if col[0] == highest_title_score: |
| 224 | + highest_titles.append(col_j) |
| 225 | + if col[0] > highest_title_score: |
| 226 | + highest_title_score = col[0] |
| 227 | + highest_titles = [col_j] |
| 228 | + if col[3] == highest_subtitle_score: |
| 229 | + highest_subtitles.append(col_j) |
| 230 | + if col[3] > highest_subtitle_score: |
| 231 | + highest_subtitle_score = col[3] |
| 232 | + highest_subtitles = [col_j] |
| 233 | + if col[6] == highest_author_score: |
| 234 | + highest_authors.append(col_j) |
| 235 | + if col[6] > highest_author_score: |
| 236 | + highest_author_score = col[6] |
| 237 | + highest_authors = [col_j] |
| 238 | + to_take = set() |
| 239 | + to_take.add(random.choice(highest_titles)) |
| 240 | + random.shuffle(highest_subtitles) |
| 241 | + random.shuffle(highest_authors) |
| 242 | + while len(highest_subtitles) or len(highest_authors): |
| 243 | + if len(highest_subtitles): |
| 244 | + choice = highest_subtitles.pop() |
| 245 | + if choice not in to_take: |
| 246 | + to_take.add(choice) |
| 247 | + highest_subtitles = [] |
| 248 | + if len(highest_authors): |
| 249 | + choice = highest_authors.pop() |
| 250 | + if choice not in to_take: |
| 251 | + to_take.add(choice) |
| 252 | + highest_authors = [] |
| 253 | + while len(to_take) < 10: |
| 254 | + choice = random.randrange(0, len(row)) |
| 255 | + if choice != row_i: |
| 256 | + to_take.add(choice) |
| 257 | + for take_it in to_take: |
| 258 | + X.append(row[take_it]) |
| 259 | + y.append(0) |
| 260 | + del full_feature_vector_matrix |
| 261 | + print("Add a bunch of tricky negatives...") |
| 262 | + all_pdf_filenames = set([ |
| 263 | + f['name'] for f in |
| 264 | + gcache.sql_query( |
| 265 | + "owner = 1 AND mime_type = ? AND shortcut_target IS NULL", |
| 266 | + ('application/pdf',), |
| 267 | + ) |
| 268 | + ]) |
| 269 | + |
| 270 | + random.shuffle(website.content) |
| 271 | + |
| 272 | + @disk_memorizor.cache(cache_validation_callback=joblib.expires_after(days=14)) |
| 273 | + def find_hard_av_examples(): |
| 274 | + ret = [] |
| 275 | + for item in tqdm(website.content): |
| 276 | + if item.category != 'av' or 'pdf' in item.formats or not item.get('authors'): |
| 277 | + continue |
| 278 | + all_vecs = [extract_feature_vector_for_item_parsed_name_pair( |
| 279 | + item.title, |
| 280 | + website.normalized_author_name(item.authors[0]), |
| 281 | + split_file_name(filename), |
| 282 | + ) for filename in all_pdf_filenames] |
| 283 | + for feature_vec in heapq.nlargest( |
| 284 | + 5, |
| 285 | + all_vecs, |
| 286 | + ): |
| 287 | + ret.append(feature_vec) |
| 288 | + all_vecs.remove(feature_vec) |
| 289 | + for vec in all_vecs: |
| 290 | + if random.random() < 0.01: |
| 291 | + ret.append(vec) |
| 292 | + return ret |
| 293 | + |
| 294 | + for feature_vec in find_hard_av_examples(): |
| 295 | + X.append(feature_vec) |
| 296 | + y.append(0) |
| 297 | + |
| 298 | + print("Finding optimal model and params...") |
| 299 | + from sklearn.base import clone |
| 300 | + classifier = MLPClassifier( |
| 301 | + max_iter=300, |
| 302 | + ) |
| 303 | + param_grid = {'hidden_layer_sizes': [ |
| 304 | + (32, 16, 8, 8), |
| 305 | + (32, 16, 16, ), |
| 306 | + ]} |
| 307 | + classifier = GridSearchCV( |
| 308 | + classifier, |
| 309 | + param_grid=param_grid, |
| 310 | + cv=5, |
| 311 | + scoring='roc_auc', |
| 312 | + n_jobs=8, |
| 313 | + ).fit(X, y) |
| 314 | + |
| 315 | + print(f"Best params: {classifier.best_params_}") |
| 316 | + print(f"Best score: {classifier.best_score_}") |
| 317 | + |
| 318 | + print("Fetching additional negative examples based on first run mistakes...") |
| 319 | + for item in tqdm(website.content): |
| 320 | + if item.category != 'av' or 'pdf' in item.formats or not item.get('authors'): |
| 321 | + continue |
| 322 | + all_vecs = [extract_feature_vector_for_item_parsed_name_pair( |
| 323 | + item.title, |
| 324 | + website.normalized_author_name(item.authors[0]), |
| 325 | + split_file_name(filename), |
| 326 | + ) for filename in all_pdf_filenames] |
| 327 | + all_scores = classifier.predict_proba(all_vecs) |
| 328 | + score_vecs = [(score[1],) + vec for score, vec in zip(all_scores, all_vecs)] |
| 329 | + del all_scores |
| 330 | + del all_vecs |
| 331 | + for score_vec in heapq.nlargest(200, score_vecs): |
| 332 | + if score_vec[0] < 0.3: |
| 333 | + break |
| 334 | + to_add = tuple(list(score_vec)[1:]) |
| 335 | + if to_add not in X: |
| 336 | + X.append(to_add) |
| 337 | + y.append(0) |
| 338 | + print("Training the final classifier...") |
| 339 | + classifier = clone(classifier.best_estimator_) |
| 340 | + classifier.set_params( |
| 341 | + max_iter=1000, |
| 342 | + verbose=True, |
| 343 | + ) |
| 344 | + classifier.fit(X, y) |
| 345 | + pickle.dump(classifier, open(CLASSIFIER_FILE, 'wb')) |
| 346 | + print(f"Done training! Now testing...") |
| 347 | + del X |
| 348 | + del y |
| 349 | + website_content_with_pdfs |
| 350 | + av_content = [c for c in website.content if |
| 351 | + c.category == 'av' and |
| 352 | + 'pdf' not in c.formats and |
| 353 | + c.get('authors') |
| 354 | + ] |
| 355 | + article_scores = [] |
| 356 | + av_scores = [] |
| 357 | + print("Scoring content with PDFs...") |
| 358 | + for c in tqdm(website_content_with_pdfs): |
| 359 | + article_scores.append(max(probability_filename_matches( |
| 360 | + all_pdf_filenames, |
| 361 | + c.title, |
| 362 | + website.normalized_author_name(c.authors[0]), |
| 363 | + ))) |
| 364 | + print("Scoring AV content without PDFs...") |
| 365 | + for c in tqdm(av_content): |
| 366 | + av_scores.append(max(probability_filename_matches( |
| 367 | + all_pdf_filenames, |
| 368 | + c.title, |
| 369 | + website.normalized_author_name(c.authors[0]), |
| 370 | + ))) |
| 371 | + import numpy as np |
| 372 | + y_scores = np.concatenate([article_scores, av_scores]) |
| 373 | + y_true = np.concatenate([ |
| 374 | + np.ones(len(website_content_with_pdfs)), # we should ideally find all these |
| 375 | + np.zeros(len(av_content)), # we should ideally not find any of these |
| 376 | + ]) |
| 377 | + from sklearn.metrics import roc_curve |
| 378 | + fpr, tpr, roc_thresholds = roc_curve(y_true, y_scores) |
| 379 | + j_scores = tpr - fpr |
| 380 | + best_idx = np.argmax(j_scores) |
| 381 | + best_threshold_roc = roc_thresholds[best_idx] |
| 382 | + print(f"Optimal threshhold = {best_threshold_roc:.4f} (with a Balanced Accuracy of {(j_scores[best_idx]+1)*50:.2f}%)") |
| 383 | + |
| 384 | +else: |
| 385 | + with yaspin(text="Loading titlematch classifier..."): |
| 386 | + classifier = pickle.load(open(CLASSIFIER_FILE, 'rb')) |
0 commit comments