BM25 batch search with multi threads error: java.lang.OutOfMemoryError: Java heap space
zhiyuanpeng opened this issue · comments
I get this error java.lang.OutOfMemoryError: Java heap space
when I do bm 25 batch search:
bm25rm3 = BM25(dataset=args.dataset, doc_dir=args.doc_dir, build_index=args.build_index)
results = bm25rm3.batch_search(queries, qids, k=#_of_indexed_documents_which_is_very_large, threads=50)
due to my requirements, I need to return all the matched documents, so k
is set to #_of_indexed_documents_which_is_very_large
and I set os.environ['JAVA_OPTS'] = '-Xmx400g'
before I call from pyserini.search import SimpleSearcher
to allocate 1/2 of the total memory to the JVM in case the retrieved documents are huge. But I still get the error quickly after I run the program even thought the memory usage is about 40G which is far less than 400G
The BM25 class I implemented based on pyserini:
import os
os.environ['JAVA_OPTS'] = '-Xmx400g'
from os.path import join
import random
import os
import random
from os.path import join, exists
from typing import Optional, List, Dict, Union
from pyserini.search import SimpleSearcher
import argparse
import sys
import multiprocessing
cwd = os.getcwd()
random.seed(12345)
if cwd not in sys.path:
sys.path.append(cwd)
if join(cwd, "baselines") not in sys.path:
sys.path.append(join(cwd, "baselines"))
data_dir = join(cwd, "data")
bm25rm3_dir = join(cwd, "baselines", "BM25RM3")
def hit_template(hits):
results = {}
for qid, hit in hits.items():
results[qid] = {}
for i in range(0, len(hit)):
results[qid][hit[i].docid] = hit[i].score
return results
class BM25():
def __init__(self, dataset, doc_dir, build_index:bool=False, bm25_k1:float=0.9, bm25_b:float=0.4):
self.doc_dir = join(data_dir, doc_dir)
self.doc_json = join(self.doc_dir, "doc.json")
assert exists(self.doc_json), "doc json file does not exist!"
self.dataset_dir = join(bm25rm3_dir, dataset)
os.makedirs(self.dataset_dir, exist_ok=True)
self.index_dir = join(self.dataset_dir, "index")
os.makedirs(self.index_dir, exist_ok=True)
self.log_dir = join(self.dataset_dir, "logs")
os.makedirs(self.log_dir, exist_ok=True)
if build_index:
self.build_index()
self.searcher = SimpleSearcher(self.index_dir)
self.searcher.set_bm25(k1=bm25_k1, b=bm25_b)
def build_index(self):
threads = multiprocessing.cpu_count()
print(f"Start building index with {threads} threads...")
command = f"python -m pyserini.index -collection JsonCollection \
-generator DefaultLuceneDocumentGenerator -threads {threads} \
-input {self.doc_dir} -index {self.index_dir} -storeRaw \
-storePositions -storeDocvectors"
os.system(command)
print("Index built!")
def search(self, q: str, k: Optional[int] = 1000, fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}):
hits = self.searcher.search(q, k=k, fields=fields)
results = []
for i in range(0, len(hits)):
results.append({'docid': hits[i].docid, 'score': hits[i].score})
return results
def batch_search(self, queries: List[str], qids: List[str], k: Optional[int] = 1000, threads: Optional[int] = 8, fields: Optional[Dict[str, float]] = {"contents": 1.0, "title": 1.0}):
hits = self.searcher.batch_search(queries=queries, qids=qids, k=k, threads=threads, fields=fields)
return hit_template(hits)
Any suggestions? Thanks.
I have tested:
len(queries)=2000, k=10000000, time cost: 26.962780952453613s
len(queries)=2000, k=100000000, java.lang.OutOfMemoryError: Java heap space, the memory usage is about 40G << 400G I set