JasonGross / guarantees-based-mechanistic-interpretability

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

`poetry run` is very slow

JasonGross opened this issue · comments

No idea why, but, e.g., poetry run python -m gbmi.exp_max_of_n.train --force load is very slow (7--8 seconds). The profile from cProfile shows literally nothing taking any significant time, Ctrl+C interruption suggests it's an issue with poetry loading pytorch?

time poetry run python -c 'import cProfile; from gbmi.exp_max_of_n import train' takes about 5--6 seconds

poetry run python -c 'import cProfile; cProfile.run("from gbmi.exp_max_of_n import train")'
         6390808 function calls (6217572 primitive calls) in 9.675 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 :1(<module>)
        1    0.000    0.000    0.000    0.000 :1(ClientTimeoutAttributes)
        1    0.000    0.000    0.000    0.000 :1(ConnectionKeyAttributes)
        1    0.000    0.000    0.000    0.000 :1(ContentDispositionAttributes)
        1    0.000    0.000    0.000    0.000 :1(ETagAttributes)
        1    0.000    0.000    0.000    0.000 :1(MimeTypeAttributes)


Ah, I guess I was sorting the profile wrong.

poetry run python -m cProfile -s cumtime -m gbmi.exp_max_of_n.train --force load gives profile.log

         6396417 function calls (6224556 primitive calls) in 14.744 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      920    0.043    0.000   57.806    0.063 __init__.py:1(<module>)
        2    0.000    0.000   15.955    7.977 train.py:1(<module>)
   7387/1    0.287    0.000   14.755   14.755 {built-in method builtins.exec}
        1    0.000    0.000   14.754   14.754 <string>:1(<module>)
        1    0.000    0.000   14.754   14.754 <frozen runpy>:201(run_module)
        1    0.000    0.000   14.752   14.752 <frozen runpy>:65(_run_code)
   5289/9    0.052    0.000   13.981    1.553 <frozen importlib._bootstrap>:1165(_find_and_load)
   5181/9    0.043    0.000   13.981    1.553 <frozen importlib._bootstrap>:1120(_find_and_load_unlocked)
   4969/9    0.037    0.000   13.977    1.553 <frozen importlib._bootstrap>:666(_load_unlocked)
   4818/9    0.021    0.000   13.977    1.553 <frozen importlib._bootstrap_external>:934(exec_module)
 11562/19    0.012    0.000   13.973    0.735 <frozen importlib._bootstrap>:233(_call_with_frames_removed)
  1593/79    0.007    0.000   11.257    0.142 {built-in method builtins.__import__}
5614/1381    0.016    0.000    9.609    0.007 <frozen importlib._bootstrap>:1207(_handle_fromlist)
       61    0.003    0.000    4.109    0.067 utils.py:1(<module>)
     4819    0.072    0.000    3.109    0.001 <frozen importlib._bootstrap_external>:1007(get_code)
        3    0.000    0.000    2.655    0.885 model.py:1(<module>)
4969/4938    0.015    0.000    2.544    0.001 <frozen importlib._bootstrap>:566(module_from_spec)
  140/104    0.000    0.000    2.514    0.024 __init__.py:108(import_module)
  145/104    0.000    0.000    2.514    0.024 <frozen importlib._bootstrap>:1192(_gcd_import)
99244/61787    0.052    0.000    2.417    0.000 {built-in method builtins.hasattr}
  129/124    0.001    0.000    2.312    0.019 <frozen importlib._bootstrap_external>:1231(create_module)
  129/124    2.299    0.018    2.311    0.019 {built-in method _imp.create_dynamic}
    36/11    0.000    0.000    1.976    0.180 import_utils.py:1366(__getattr__)
    24/12    0.000    0.000    1.976    0.165 import_utils.py:1380(_get_module)
222282/191425    0.117    0.000    1.913    0.000 {built-in method builtins.getattr}
10836/10575    0.220    0.000    1.899    0.000 {built-in method builtins.__build_class__}
     4820    0.029    0.000    1.865    0.000 <frozen importlib._bootstrap_external>:1127(get_data)
        1    0.000    0.000    1.725    1.725 HookedTransformer.py:1(<module>)
        1    0.000    0.000    1.704    1.704 loading_from_pretrained.py:1(<module>)
        1    0.000    0.000    1.700    1.700 modeling_bert.py:1(<module>)
        1    0.000    0.000    1.603    1.603 modeling_utils.py:1(<module>)
        5    0.000    0.000    1.581    0.316 accelerator.py:1(<module>)
     4973    1.580    0.000    1.580    0.000 {method 'read' of '_io.BufferedReader' objects}
        1    0.000    0.000    1.549    1.549 checkpointing.py:1(<module>)
        2    0.000    0.000    1.519    0.760 batch_size_finder.py:1(<module>)
        2    0.000    0.000    1.510    0.755 callback.py:1(<module>)
        9    0.001    0.000    1.505    0.167 types.py:1(<module>)
        1    0.000    0.000    1.463    1.463 fsdp_utils.py:1(<module>)
        1    0.000    0.000    1.374    1.374 state_dict_loader.py:1(<module>)
        1    0.000    0.000    1.365    1.365 default_planner.py:1(<module>)
        4    0.000    0.000    1.347    0.337 embedding_ops.py:1(<module>)
        1    0.000    0.000    1.340    1.340 op_schema.py:1(<module>)
        1    0.000    0.000    1.336    1.336 placement_types.py:1(<module>)
        1    0.000    0.000    1.334    1.334 arrow_dataset.py:1(<module>)
        1    0.000    0.000    1.333    1.333 _functional_collectives.py:1(<module>)
        2    0.000    0.000    1.180    0.590 pit.py:1(<module>)
        1    0.000    0.000    1.177    1.177 checks.py:1(<module>)
        2    0.000    0.000    1.169    0.585 metric.py:1(<module>)
        2    0.000    0.000    1.131    0.565 plot.py:1(<module>)
5242/5161    0.087    0.000    1.119    0.000 <frozen importlib._bootstrap>:1054(_find_spec)
     4818    0.024    0.000    1.012    0.000 <frozen importlib._bootstrap_external>:727(_compile_bytecode)
     4818    0.979    0.000    0.979    0.000 {built-in method marshal.loads}
       24    0.001    0.000    0.909    0.038 api.py:1(<module>)
        2    0.000    0.000    0.898    0.449 _ops.py:1(<module>)
     5234    0.009    0.000    0.846    0.000 <frozen importlib._bootstrap_external>:1496(find_spec)
5235/5234    0.029    0.000    0.837    0.000 <frozen importlib._bootstrap_external>:1464(_get_spec)
        1    0.000    0.000    0.765    0.765 _ipython_extension.py:1(<module>)
        1    0.000    0.000    0.763    0.763 model.py:240(train_or_load_model)
     6829    0.165    0.000    0.758    0.000 <frozen importlib._bootstrap_external>:1604(find_spec)
        1    0.000    0.000    0.748    0.748 normalize.py:37(wrapper)
        1    0.000    0.000    0.748    0.748 api.py:933(artifact)
        3    0.000    0.000    0.742    0.247 retry.py:94(__call__)
        3    0.000    0.000    0.742    0.247 client.py:48(execute)
        3    0.000    0.000    0.742    0.247 client.py:58(_get_result)
        3    0.000    0.000    0.741    0.247 gql_request.py:42(execute)
        1    0.049    0.049    0.739    0.739 _meta_registrations.py:1(<module>)
        3    0.000    0.000    0.736    0.245 sessions.py:626(post)
        3    0.000    0.000    0.736    0.245 sessions.py:502(request)
        3    0.000    0.000    0.722    0.241 sessions.py:673(send)
        3    0.000    0.000    0.719    0.240 adapters.py:434(send)
        3    0.000    0.000    0.715    0.238 connectionpool.py:595(urlopen)
        3    0.000    0.000    0.714    0.238 connectionpool.py:380(_make_request)
       14    0.001    0.000    0.684    0.049 config.py:1(<module>)
        1    0.000    0.000    0.621    0.621 embed.py:1(<module>)
        1    0.000    0.000    0.612    0.612 allowed_functions.py:1(<module>)
        3    0.000    0.000    0.605    0.202 internal.py:1(<module>)
        2    0.000    0.000    0.579    0.289 __init__.py:1820(__getattr__)
    20475    0.036    0.000    0.567    0.000 __init__.py:272(_compile)
    24304    0.543    0.000    0.543    0.000 {built-in method posix.stat}
        5    0.000    0.000    0.533    0.107 __init__.py:342(__init__)
        5    0.533    0.107    0.533    0.107 {built-in method _ctypes.dlopen}
        1    0.000    0.000    0.533    0.533 __init__.py:165(_load_global_deps)
29196/27393    0.047    0.000    0.519    0.000 typing.py:352(inner)
        2    0.000    0.000    0.517    0.258 retry.py:210(wrapped_fn)
        2    0.000    0.000    0.517    0.258 api.py:66(execute)
      677    0.006    0.000    0.510    0.001 _compiler.py:738(compile)
        1    0.000    0.000    0.504    0.504 wandb_init.py:1(<module>)
        1    0.000    0.000    0.494    0.494 artifact.py:1(<module>)
        2    0.001    0.000    0.492    0.246 interactiveshell.py:1(<module>)
      105    0.002    0.000    0.489    0.005 artist.py:159(_update_set_signature_and_docstring)
     7435    0.004    0.000    0.488    0.000 __init__.py:225(compile)
      104    0.001    0.000    0.478    0.005 artist.py:126(__init_subclass__)
        1    0.000    0.000    0.475    0.475 dependency_versions_check.py:1(<module>)
      545    0.001    0.000    0.447    0.001 dataclasses.py:1219(wrap)
      545    0.034    0.000    0.445    0.001 dataclasses.py:884(_process_class)
        3    0.000    0.000    0.442    0.147 _base.py:1(<module>)
        2    0.000    0.000    0.406    0.203 backend.py:1(<module>)
        1    0.000    0.000    0.393    0.393 api.py:615(_parse_artifact_path)
        1    0.000    0.000    0.393    0.393 api.py:491(default_entity)
      136    0.006    0.000    0.383    0.003 {method 'readline' of '_io.BufferedReader' objects}
      117    0.001    0.000    0.383    0.003 artist.py:1837(kwdoc)
        3    0.000    0.000    0.380    0.127 connection.py:435(getresponse)
        6    0.001    0.000    0.379    0.063 functional.py:1(<module>)
        3    0.000    0.000    0.379    0.126 client.py:1334(getresponse)
        3    0.000    0.000    0.378    0.126 client.py:311(begin)
        3    0.000    0.000    0.377    0.126 client.py:278(_read_status)
        3    0.000    0.000    0.377    0.126 socket.py:692(readinto)
        3    0.000    0.000    0.377    0.126 ssl.py:1300(recv_into)
        3    0.000    0.000    0.377    0.126 ssl.py:1158(read)
        3    0.377    0.126    0.377    0.126 {method 'read' of '_ssl._SSLSocket' objects}
    68/64    0.000    0.000    0.375    0.006 __init__.py:365(__getattr__)
        5    0.000    0.000    0.372    0.074 decorators.py:1(<module>)
        1    0.000    0.000    0.370    0.370 offsetbox.py:1(<module>)
        1    0.000    0.000    0.368    0.368 decorators.py:101(inner)
       10    0.000    0.000    0.368    0.037 allowed_functions.py:82(remove)
    28/27    0.001    0.000    0.368    0.014 allowed_functions.py:63(__call__)
        1    0.001    0.001    0.367    0.367 allowed_functions.py:149(_allowed_function_ids)
    629/2    0.076    0.000    0.365    0.182 allowed_functions.py:188(_find_torch_objects)