castorini / pyserini

Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations.

Home Page:http://pyserini.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

BM25 batch search with multi threads error: java.lang.OutOfMemoryError: Java heap space

zhiyuanpeng opened this issue · comments

I get this error java.lang.OutOfMemoryError: Java heap space when I do bm 25 batch search:

bm25rm3 = BM25(dataset=args.dataset, doc_dir=args.doc_dir, build_index=args.build_index)
results = bm25rm3.batch_search(queries, qids, k=#_of_indexed_documents_which_is_very_large, threads=50)

due to my requirements, I need to return all the matched documents, so k is set to #_of_indexed_documents_which_is_very_large and I set os.environ['JAVA_OPTS'] = '-Xmx400g' before I call from pyserini.search import SimpleSearcher to allocate 1/2 of the total memory to the JVM in case the retrieved documents are huge. But I still get the error quickly after I run the program even thought the memory usage is about 40G which is far less than 400G

The BM25 class I implemented based on pyserini:

import os
os.environ['JAVA_OPTS'] = '-Xmx400g'
from os.path import join
import random
import os
import random
from os.path import join, exists
from typing import Optional, List, Dict, Union
from pyserini.search import SimpleSearcher
import argparse
import sys
import multiprocessing

cwd = os.getcwd()
random.seed(12345)
if cwd not in sys.path:
    sys.path.append(cwd)
if join(cwd, "baselines") not in sys.path:
    sys.path.append(join(cwd, "baselines"))
data_dir = join(cwd, "data")
bm25rm3_dir = join(cwd, "baselines", "BM25RM3")

def hit_template(hits):
    results = {}
    for qid, hit in hits.items():
        results[qid] = {}
        for i in range(0, len(hit)):
            results[qid][hit[i].docid] = hit[i].score
    return results

class BM25():

    def __init__(self, dataset, doc_dir, build_index:bool=False, bm25_k1:float=0.9, bm25_b:float=0.4):
        self.doc_dir = join(data_dir, doc_dir)
        self.doc_json = join(self.doc_dir, "doc.json")
        assert exists(self.doc_json), "doc json file does not exist!"
        self.dataset_dir = join(bm25rm3_dir, dataset)
        os.makedirs(self.dataset_dir, exist_ok=True)
        self.index_dir = join(self.dataset_dir, "index")
        os.makedirs(self.index_dir, exist_ok=True)
        self.log_dir = join(self.dataset_dir, "logs")
        os.makedirs(self.log_dir, exist_ok=True)
        if build_index:
            self.build_index()
        self.searcher = SimpleSearcher(self.index_dir)
        self.searcher.set_bm25(k1=bm25_k1, b=bm25_b) 

    def build_index(self):    
        threads = multiprocessing.cpu_count()   
        print(f"Start building index with {threads} threads...")
        command = f"python -m pyserini.index -collection JsonCollection \
        -generator DefaultLuceneDocumentGenerator -threads {threads} \
        -input {self.doc_dir} -index {self.index_dir} -storeRaw \
        -storePositions -storeDocvectors"
        
        os.system(command)
        print("Index built!")
    
    def search(self, q: str, k: Optional[int] = 1000, fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}): 
        hits = self.searcher.search(q, k=k, fields=fields)
        results = []
        for i in range(0, len(hits)):
            results.append({'docid': hits[i].docid, 'score': hits[i].score})
        return results

    def batch_search(self, queries: List[str], qids: List[str], k: Optional[int] = 1000, threads: Optional[int] = 8, fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}):
        hits = self.searcher.batch_search(queries=queries, qids=qids, k=k, threads=threads, fields=fields)
        return hit_template(hits)

Any suggestions? Thanks.

I have tested:

len(queries)=2000, k=10000000, time cost: 26.962780952453613s
len(queries)=2000, k=100000000, java.lang.OutOfMemoryError: Java heap space, the memory usage is about 40G << 400G I set