huggingface / datatrove

Freeing data processing from scripting madness by providing a set of platform-agnostic customizable pipeline processing blocks.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Minhash deduplication rate different from other implementation

jordane95 opened this issue · comments

Hi,

I'm using datatrove to run minhash dedup on one dataset. However, I find that the deduplication rate of datatrove is a little lower than other implementations, such as this one based on spark.

More precisely, I'm using 13 bands each with 60 hash codes, the previous spark code removes 88% of the raw content, while datatrove only removes 60%. Based on my understanding, the union set and GraphFrame implementation are equivalent in principe. So there shouldn't be so much difference...

Can you share the datatrove script you used?

from datatrove.executor.base import PipelineExecutor
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.pipeline.dedup import MinhashDedupSignature
from datatrove.pipeline.dedup.minhash import (
    MinhashConfig,
    MinhashDedupBuckets,
    MinhashDedupCluster,
    MinhashDedupFilter,
)
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.tokens import TokensCounter
from datatrove.pipeline.writers.jsonl import JsonlWriter
from datatrove.data import Document

import os
import argparse

def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--stage", type=int, default=1, help="stage to run")

    parser.add_argument("--logging_dir", type=str, default=None, help="logging dir")
    
    parser.add_argument("--input_path", type=str, default=None, help="input path")
    parser.add_argument("--minhash_path", type=str, default=None, help="minhash path")

    parser.add_argument("--workers", type=int, default=100, help="number of workers per node")

    parser.add_argument("--tasks", type=int, default=100, help="total parallelism")

    parser.add_argument("--b", type=int, default=13, help="number of hash per buckt")
    parser.add_argument("--r", type=int, default=60, help="number of buckets")

    args = parser.parse_args()

    return args

def document_uf_adapter(data: dict, path: str, id_in_file: int | str):
    item = data.pop("data")[0]
    text = item.pop("text", "")
    meta = item.pop("meta", {})
    return {
        "text": text,
        "id": data.pop("id", f"{path}/{id_in_file}"),
        "media": data.pop("media", []),
        "metadata": {'meta': meta, **data},  # remaining data goes into metadata
    }


def write_adapter(document: Document):
    meta = document.metadata.pop("meta", {})
    return {
        "id": document.id,
        "data": [{"meta": meta, "text": document.text}],
        **document.metadata,
    }

def main():

    args = get_args()
    global_rank = int(os.environ.get("RANK", 0))
    global_world_size = int(os.environ.get("WORLD_SIZE", 1))

    stage = args.stage

    # asssign tasks to each node, taking into account that args.tasks may not be divisible by global_world_size
    local_tasks = args.tasks // global_world_size
    if global_rank < args.tasks % global_world_size:
        local_tasks += 1
    local_rank_offset = local_tasks * global_rank

    # you can also change ngrams or the number of buckets and their size here
    minhash_config = MinhashConfig(use_64bit_hashes=False, num_buckets=args.r, hashes_per_bucket=args.b)  # better precision -> fewer false positives (collisions)

    MINHASH_BASE_PATH = args.minhash_path

    LOGS_FOLDER = args.logging_dir

    TOTAL_TASKS = args.tasks

    # this is the original data that we want to deduplicate
    INPUT_READER = JsonlReader(
        data_folder=args.input_path,
        adapter=document_uf_adapter,
    )

    if stage == 1:
        # stage 1 computes minhash signatures for each task (each task gets a set of files)
        stage1 = LocalPipelineExecutor(
            pipeline=[
                INPUT_READER,
                MinhashDedupSignature(output_folder=f"{MINHASH_BASE_PATH}/signatures", config=minhash_config),
            ],
            tasks=TOTAL_TASKS,
            logging_dir=f"{LOGS_FOLDER}/signatures",
            workers=args.workers,
            local_rank_offset=local_rank_offset,
            local_tasks=local_tasks,
        )
        stage1.run()

    elif stage == 2:
        assert args.tasks % minhash_config.num_buckets == 0, "number of tasks must be divisible by number of buckets"
        # stage 2 finds matches between signatures in each bucket
        stage2 = LocalPipelineExecutor(
            pipeline=[
                MinhashDedupBuckets(
                    input_folder=f"{MINHASH_BASE_PATH}/signatures",
                    output_folder=f"{MINHASH_BASE_PATH}/buckets",
                    config=minhash_config,
                ),
            ],
            tasks=args.tasks,
            logging_dir=f"{LOGS_FOLDER}/buckets",
            local_tasks=local_tasks,
            local_rank_offset=local_rank_offset,
        )
        stage2.run()

    elif stage == 3:
        # stage 3 creates clusters of duplicates using the results from all buckets
        stage3 = LocalPipelineExecutor(
            pipeline=[
                MinhashDedupCluster(
                    input_folder=f"{MINHASH_BASE_PATH}/buckets",
                    output_folder=f"{MINHASH_BASE_PATH}/remove_ids",
                    config=minhash_config,
                ),
            ],
            tasks=1,
            logging_dir=f"{LOGS_FOLDER}/clusters",
        )
        stage3.run()
    
    elif stage == 4:

        # stage 4 reads the original input data and removes all but 1 sample per duplicate cluster
        # the data must match exactly stage 1, so number of tasks and the input source must be the same
        stage4 = LocalPipelineExecutor(
            pipeline=[
                INPUT_READER,
                TokensCounter(),  # nice way to see how many tokens we had before and after deduplication
                MinhashDedupFilter(
                    input_folder=f"{MINHASH_BASE_PATH}/remove_ids",
                    exclusion_writer=JsonlWriter(f"{MINHASH_BASE_PATH}/removed", adapter=write_adapter, compression=None),
                ),
                JsonlWriter(output_folder=f"{MINHASH_BASE_PATH}/deduplicated_output", adapter=write_adapter, compression=None),
            ],
            tasks=TOTAL_TASKS,
            logging_dir=f"{LOGS_FOLDER}/filter",
            workers=args.workers,
            local_rank_offset=local_rank_offset,
            local_tasks=local_tasks,
        )
        stage4.run()



if __name__ == "__main__":
    main()
  • are you sure you are using the same n for n-grams, and the same bucket and bucket size in the two scripts?
  • are you sure you really used 13 bands each with size 60? Those numbers would make it almost impossible to find any matches, for instance the probability of finding two documents whose true minhash overlap is 0.8 would be 1-(1-0.8^60)^13=0.002%
  • how did you measure your percentages? number of tokens (with which tokenizer on spark)? disk size (are the documents in the same format? our jsonl is gzip compressed)?

Yes, I used n=5.
Actually b=13 is the band size, r=60 is number of band, so you should permutate the two in the equation.
I measure them in terms of disk size, which is comparable since the data is saved in the same format. I haven't use compression.

Maybe take a look at the implementation of union set? @guipenedo I suspect that the current implementation might remove less documents since then are not correctly clustered?

Union set is a really simple algorithm and the implementation has been tested. Going back to the format, datatrove adds some metadata fields such as the filepath for example, can you open some of the output files and really compare the format?

You are bound to have some small differences just by the size of which document is chosen as the one to be kept in each cluster but they shouldn't be so extreme as in your case. The easiest way to really be sure of the difference would be to tokenize your spark output (maybe even with datatrove if you would like: TokensCounter)

The computation of deduplication rate is not the problem. For datatrove, I use deduplicated_output / (deduplicated_output + removed), for spark, I calculated the output / input.

are you able to look at the clusters identified by the spark code? would be interested to know what they look like if you apply the spark code to the final datatrove deduplicated data

Actually, Spark cannot find more duplicates from the deduplicated_output produced by datatrove...

Interesting. Then I would recommend actually counting the number of documents on each one (instead of file size) to have a clear idea of whether there really is a significant difference

I think the implementation of datatrove is correct. It seems that the GraphFrame used in the spark code has some dark bugs...