Is it possible to store shard metadata with the shard itself?
universome opened this issue · comments
For tar files, it's very easy to untar them and inspect the contents. For mds files, I first need to pool the index, locate a given shard inside the index and only then inspect. Is it possible to build a command similar to tar xf ...
, which would only take an mds file as an input and be able to unpack it? I need this to easily debug/visualize my data.
Depending on what you mean by 'unpack', we already replicate shard metadata to two places: index.json and inside the shard itself. Although I don't believe this is exposed programmatically at the moment, it is easy enough to access in a few lines of python.
What is your use case, out of curiosity? Normally, MDS files are not intended to be used "bare"
Hi @knighton , thank you for your reply! My use-case is just to explore the contents of the file for debugging purposes only. I am trying to visualize some data inside a particular shard, and now I need to download the entire index file for that, find the shard metadata inside the index file, and only then "unpack" the shard. I would like to write a simple script which would be runnable via unpack_mds /path/to/shard.mds -o /path/to/outupt/dir
which would dump the content of the shard inside it.
@knighton pinging to keep the thread alive.
# reconstruct_index.py
from glob import glob
import json
import numpy as np
import os
from tqdm import tqdm
from streaming.base.hashing import get_hash
d = 'path/to/dir/'
s = os.path.join(d, '*.mds')
ff = sorted(glob(s))
xx = []
for f in tqdm(ff, leave=False):
b = open(f, 'rb').read()
n, z = map(int, np.frombuffer(b[:8], np.uint32))
a = 4 + n * 4 + 4
s = b[a:z].decode('utf-8')
x = json.loads(s)
hashes = {}
x['hashes'] = hashes
#for h in x['hashes']:
# hashes[h] = get_hash(h, b) # slow
x['raw_data'] = {
'basename': os.path.basename(f),
'bytes': len(b),
'hashes': hashes,
}
x['samples'] = n
x['zip_data'] = None
xx.append(x)
x = {
'version': 2,
'shards': xx,
}
f = os.path.join(d, 'index.json')
json.dump(x, open(f, 'w'))
Specially the bit about x = json.loads(b[a:z].decode('utf-8'))
, that's your shard metadata
Thank you, your solution seems to be working for my usecase. I am closing the issue then. Just for reference, here is my final script:
import os
import shutil
import argparse
from typing import Any
import json
import numpy as np
from streaming.base.format import reader_from_json
from tqdm import tqdm
# from streaming.base.hashing import get_hash
def unpack_mds(shard_path: str, output_path: str, overwrite: bool=False):
if overwrite and os.path.exists(output_path):
assert os.path.isdir(output_path), f'Output path {output_path} is not a directory.'
shutil.rmtree(output_path)
os.makedirs(output_path, exist_ok=False)
with open(shard_path, 'rb') as f:
b = f.read()
n, z = map(int, np.frombuffer(b[:8], np.uint32))
a = 4 + n * 4 + 4
s = b[a:z].decode('utf-8')
shard_metadata = json.loads(s)
# for h in shard_metadata['hashes']:
# shard_metadata['hashes'][h] = get_hash(h, f.read()) # slow
shard_metadata['raw_data'] = {
'basename': os.path.basename(shard_path),
'bytes': os.path.getsize(shard_path),
'hashes': shard_metadata['hashes'],
}
shard_metadata['samples'] = n
shard_metadata['zip_data'] = None
reader = reader_from_json(os.path.dirname(shard_path), split=None, obj=shard_metadata)
print('num samples', len(reader))
for shard_sample_idx in tqdm(range(len(reader)), desc=f'Processing shard {shard_path}'):
sample = reader.get_item(shard_sample_idx)
if 'video' in sample:
main_data = sample['video']
main_data_path = os.path.join(output_path, f'{sample["filename"]}.mp4')
elif 'image' in sample:
main_data = sample['image']
main_data_path = os.path.join(output_path, f'{sample["filename"]}.jpg')
else:
raise ValueError('Unknown main data type.', list(sample.keys()))
with open(main_data_path, 'wb') as f:
f.write(main_data)
with open(os.path.join(output_path, f'{sample["filename"]}.json'), 'w') as f:
metadata = drop_numpy_types({k: v for k, v in sample.items() if not k in ['video', 'image']})
json.dump(metadata, f)
#----------------------------------------------------------------------------
def drop_numpy_types(obj: Any) -> Any:
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, dict):
return {k: drop_numpy_types(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [drop_numpy_types(v) for v in obj]
if hasattr(obj, 'dtype') and isinstance(obj.dtype, np.dtype):
return obj.tolist()
return obj
#----------------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('shard_path', type=str, help='Path to the shard.')
parser.add_argument('output_path', type=str, help='Output path to the directory.')
parser.add_argument('--overwrite', action='store_true', help='Overwrite the output directory if it exists.')
args = parser.parse_args()
unpack_mds(
shard_path=args.shard_path,
output_path=args.output_path,
overwrite=args.overwrite,
)
#----------------------------------------------------------------------------