Using a registered task to add another
AkshitaB opened this issue · comments
Suppose I have a task registered as follows:
seqio.TaskRegistry.add(
"task_1",
source=seqio.TfdsDataSource(tfds_name="c4/en:3.0.1", splits=["train", "validation"]),
preprocessors=[
preprocess1,
seqio.preprocessors.tokenize,
seqio.CacheDatasetPlaceholder(),
preprocess2,
preprocess3,
],
output_features=...
Is it possible to add another task that starts with the cached part of task_1
, i.e., the part before seqio.CacheDatasetPlaceholder()
, and only vary preprocess2
and preprocess3
?
I'm looking for something like
seqio.TaskRegistry.add(
"task_2",
source=seqio.CachedTaskSource("task_1", ...),
preprocessors=[
preprocess2_modified,
preprocess3_modified,
],
...
Hi @AkshitaB , this is doable using seqio's _CachedDataSource (https://github.com/google/seqio/blob/main/seqio/dataset_providers.py#L683) and passing the cache_dir from task_1 as outlined in your snippet. A caveat is that this creates an implicit dependency on task_1 and any changes, re-caching, etc. would silently affect task_2.