mosaicml / streaming

A Data Streaming Library for Efficient Neural Network Training

Home Page:https://streaming.docs.mosaicml.com

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Is it possible to store shard metadata with the shard itself?

universome opened this issue · comments

For tar files, it's very easy to untar them and inspect the contents. For mds files, I first need to pool the index, locate a given shard inside the index and only then inspect. Is it possible to build a command similar to tar xf ..., which would only take an mds file as an input and be able to unpack it? I need this to easily debug/visualize my data.

Depending on what you mean by 'unpack', we already replicate shard metadata to two places: index.json and inside the shard itself. Although I don't believe this is exposed programmatically at the moment, it is easy enough to access in a few lines of python.

What is your use case, out of curiosity? Normally, MDS files are not intended to be used "bare"

Hi @knighton , thank you for your reply! My use-case is just to explore the contents of the file for debugging purposes only. I am trying to visualize some data inside a particular shard, and now I need to download the entire index file for that, find the shard metadata inside the index file, and only then "unpack" the shard. I would like to write a simple script which would be runnable via unpack_mds /path/to/shard.mds -o /path/to/outupt/dir which would dump the content of the shard inside it.

@knighton pinging to keep the thread alive.

# reconstruct_index.py

from glob import glob
import json
import numpy as np
import os
from tqdm import tqdm

from streaming.base.hashing import get_hash


d = 'path/to/dir/'
s = os.path.join(d, '*.mds')
ff = sorted(glob(s))
xx = []
for f in tqdm(ff, leave=False):
    b = open(f, 'rb').read()
    n, z = map(int, np.frombuffer(b[:8], np.uint32))
    a = 4 + n * 4 + 4
    s = b[a:z].decode('utf-8')
    x = json.loads(s)
    hashes = {}
    x['hashes'] = hashes
    #for h in x['hashes']:
    #    hashes[h] = get_hash(h, b)  # slow
    x['raw_data'] = {
        'basename': os.path.basename(f),
        'bytes': len(b),
        'hashes': hashes,
    }
    x['samples'] = n
    x['zip_data'] = None
    xx.append(x)
x = {
    'version': 2,
    'shards': xx,
}
f = os.path.join(d, 'index.json')
json.dump(x, open(f, 'w'))

Specially the bit about x = json.loads(b[a:z].decode('utf-8')), that's your shard metadata

Thank you, your solution seems to be working for my usecase. I am closing the issue then. Just for reference, here is my final script:

import os
import shutil
import argparse
from typing import Any
import json
import numpy as np
from streaming.base.format import reader_from_json
from tqdm import tqdm

# from streaming.base.hashing import get_hash


def unpack_mds(shard_path: str, output_path: str, overwrite: bool=False):
    if overwrite and os.path.exists(output_path):
        assert os.path.isdir(output_path), f'Output path {output_path} is not a directory.'
        shutil.rmtree(output_path)
    os.makedirs(output_path, exist_ok=False)

    with open(shard_path, 'rb') as f:
        b = f.read()
        n, z = map(int, np.frombuffer(b[:8], np.uint32))
        a = 4 + n * 4 + 4
        s = b[a:z].decode('utf-8')
        shard_metadata = json.loads(s)

        # for h in shard_metadata['hashes']:
        #     shard_metadata['hashes'][h] = get_hash(h, f.read())  # slow

        shard_metadata['raw_data'] = {
            'basename': os.path.basename(shard_path),
            'bytes': os.path.getsize(shard_path),
            'hashes': shard_metadata['hashes'],
        }
        shard_metadata['samples'] = n
        shard_metadata['zip_data'] = None

        reader = reader_from_json(os.path.dirname(shard_path), split=None, obj=shard_metadata)

        print('num samples', len(reader))

        for shard_sample_idx in tqdm(range(len(reader)), desc=f'Processing shard {shard_path}'):
            sample = reader.get_item(shard_sample_idx)

            if 'video' in sample:
                main_data = sample['video']
                main_data_path = os.path.join(output_path, f'{sample["filename"]}.mp4')
            elif 'image' in sample:
                main_data = sample['image']
                main_data_path = os.path.join(output_path, f'{sample["filename"]}.jpg')
            else:
                raise ValueError('Unknown main data type.', list(sample.keys()))

            with open(main_data_path, 'wb') as f:
                f.write(main_data)

            with open(os.path.join(output_path, f'{sample["filename"]}.json'), 'w') as f:
                metadata = drop_numpy_types({k: v for k, v in sample.items() if not k in ['video', 'image']})
                json.dump(metadata, f)

#----------------------------------------------------------------------------

def drop_numpy_types(obj: Any) -> Any:
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, dict):
        return {k: drop_numpy_types(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [drop_numpy_types(v) for v in obj]
    if hasattr(obj, 'dtype') and isinstance(obj.dtype, np.dtype):
        return obj.tolist()
    return obj

#----------------------------------------------------------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('shard_path', type=str, help='Path to the shard.')
    parser.add_argument('output_path', type=str, help='Output path to the directory.')
    parser.add_argument('--overwrite', action='store_true', help='Overwrite the output directory if it exists.')
    args = parser.parse_args()

    unpack_mds(
        shard_path=args.shard_path,
        output_path=args.output_path,
        overwrite=args.overwrite,
    )

#----------------------------------------------------------------------------