google / seqio

Task-based datasets, preprocessing, and evaluation for sequence models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Using a registered task to add another

AkshitaB opened this issue · comments

Suppose I have a task registered as follows:

seqio.TaskRegistry.add(
    "task_1",
    source=seqio.TfdsDataSource(tfds_name="c4/en:3.0.1", splits=["train", "validation"]),
    preprocessors=[
        preprocess1,
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        preprocess2,
        preprocess3,
],
output_features=...

Is it possible to add another task that starts with the cached part of task_1, i.e., the part before seqio.CacheDatasetPlaceholder(), and only vary preprocess2 and preprocess3?

I'm looking for something like

seqio.TaskRegistry.add(
"task_2",
source=seqio.CachedTaskSource("task_1", ...),
preprocessors=[
    preprocess2_modified,
    preprocess3_modified,
],
...

Hi @AkshitaB , this is doable using seqio's _CachedDataSource (https://github.com/google/seqio/blob/main/seqio/dataset_providers.py#L683) and passing the cache_dir from task_1 as outlined in your snippet. A caveat is that this creates an implicit dependency on task_1 and any changes, re-caching, etc. would silently affect task_2.