NVIDIA / earth2mip

Earth-2 Model Intercomparison Project (MIP) is a python framework that enables climate researchers and scientists to inter-compare AI models for weather and climate.

Home Page:https://nvidia.github.io/earth2mip/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Design TimeStepper abstraction

nbren12 opened this issue · comments

to allow custom time looping in an extensible manner we need a new abstraction. #42 for example requires touching both model code and inference_ensemble---these routines are too tightly coupled. Exposing an iterator interface as TimeLoop does is not enough for use-cases like diagnostic models or time dependent forcing which works for various models.

To control the timestep we should define a TimeStepper abstraction. Does this name seem okay?

Here is a proposed interface:

State = Tuple[Tensor, datetime]
Outputs = Tensor

class TimeStepper:
	"""Callable time stepper

	Updates the state and returns some outputs.

	The state contains time and a tensor of data representing the domain.

	The outputs are diagnostics outputs not required for time-stepping, but useful.

	Attributes:
			state_metadata: information about the state object, dimensions of the tensor, channels, grid, etc
			output_metadata: information about the outputs, dimensions etc.


	Examples:

		stepper : TimeStepper = ...
		state = get_initial_condition(time)
		channel_to_save = "tcwv"
		for i in range(n):
			state, output = stepper(state)
			idx = stepper.output_metadata.index(channel_to_save)
			tcwv = state.data[idx]
			plt.pcolormesh(tcwv)
			plt.savefig(f"{i}.png")

	"""

	def __call__(self, state: State) -> Tuple[State, Outputs]:

thoughts @NickGeneva?

The time looper is a step in the right direction I think, however I think it only addresses part of the problem.
imo, the key point of entanglement with the present structure is that the ensemble loop and time loop are controlled by two different objects. Although having a separate object take owner ship of the time loop gives some greater customization / control, this also means that presently this is something that is initialized / set prior to the inference ensemble driver.

This has the following issues (many of these only showing up now):

  1. increases call stack which in turn creates confusion when tracing the code
  2. key components (such as the model, perturb, IO, etc) are now segmented, which not only leads to confusion / more complicated construction but also customization
  3. the model controls the loop instead of a single step which creates hard coupling between the two... making universal customization hard.
  4. While the model gets to control the loop, it rarely does and instead is still forced to get multiple wrapper layers

While somewhat minor these issues cause more headaches than one would think.

Thus I'm considering the following steps:

  • Create a Model interface that is centered around a single step that guides a single model wrapper that drives the forward pass / maintains an internal (model) state if needed. This allows models like RNNs / Pangu to do their fancy stuff but still operate like a single step model.

  • Create a super lightweight inference/looper base class that looks like like time looper but rather focuses on 1) consuming functional components on construction and 2) controls both ensembles and time (deterministic is special case of ensemble).
    (If we want to be super flexible, we can support different constructors or something... some details there to discuss)

  • Interface all the things, and go functional. All major components are callables.

The key is to off load the construction of different parts to things like inference_ensemble while this looper is actually super simple. This better opens up the option to have multiple inference/loopers for edge cases.

Basically something like (like this is literally it):

class Inferencer(Protocol):
def run () -> None:
    # runs inference

class InferenceDefault(Inferencer):
def __init__(
     ds: DataSource, 
     nn: BaseModel, 
     IO: IOHandler, 
     D: DerivativeModel = None, 
     P: Peturbation = None, 
     R: Reduction = None
     n_ensemble:int = 1,
     n_timestep:int = 1,
):
  ....
  
def run():
# Runs inference (everything in physical units)
for n_ensemble:
   x0 = DS()
   NN.reset()
   for n_timestep:
    x = (x0, t)        # State is tensor + time (index maybe?)
    x = P(x)           # Perturbation called every time step, up to it to track if initial step
    x' = NN(x)       # Forward pass
    y' = D(x')         # Derivative model
    IO(y')              # Intermediate IO (Note that IO may be a inmemory array or something, does need to be file system)
    x0 = x(0)        # Auto-regress
 
IO.finalize()        # Finalize IO
R()                     # Call reducer (should also handle any IO, this is basically a post processor)

The other beauty of focusing on functions, is that now its super easy for users to mix and match stuff and be sure things will run (so long the output is of correct form). the only edits really required is adding it to the config options, but we can create generalized utils that can scale.

Essentially everything but the two ints are callables with a set interface. What the construction / what happens behind the scenes doesnt matter, so long as it can fit into this adapter, each one of these becomes a sub folder.

Want to save to IO ever two steps, IO handler figures that out, inferencer does not care. Only perturb first time-step? Pertubation functions figures that out, inferencer doesnt care. So long as the state contains sufficient info (tensor, time, timestep?) complexity is off loaded.

The other cool thing here is that this completely decouples anything about configuration from the actual inference loop... mean inference ensemble only needs to care about parsing the config, and building these callable objects. Once created, it can initialize the inference and run...

For the inference ensemble this doesnt actually need to be a class, but I think its worth it for enabling python api users... Namely it could allow for some how swapping different components easily:

DS = Datasoure(CDS...)
IO = IONetCDF('netcdf1.nc')
infernencer = Inference(DS, NN, ...., IO)
inferencer.run()

DS = Datasource(GFS)
IO = IONetCDF('netcdf2.nc')
inferencer.set(DS, IO) # Update callables with some slick update function
inferencer.run()

Will perhaps different, I dont think this would actually be too much effort. Additionally, since it doesn't touch configuration, the CLI / configs would be largely the same. However this would create opportunity to better modularize configuration into these distinct components instead of always adding additional parameters to the base config. But thats a completely separate effort!

I've been thinking about this for a while and right now I see the hardest model fitting into this being DLWP, with its multiple time-steps... however as long as the data source is set up to provide two time-steps, perturbation is agnotistic to that dimension, and the model can handle the change between 1 and 2 timestep inputs its possible. the complexity is largely placed on the model which is what we want. Worse case we make another inferencer and roll with it instead of squeezing things into the same box.

Looking for thoughts on this, holes to get spotted, and edge cases get exposed. I think the timelooper idea is in the same thought stream as this but this is taking it even further.
@yairchn please give opinion too.

Thanks. I think we mostly agree, but it would help to use the language of the concepts already existing in the code. The task I see is to define a concrete interface for the idea of "Module" in this list: https://nvidia.github.io/earth2mip/concepts.html#id2.

Create a Model interface that is centered around a single step that guides a single model wrapper that drives the forward pass / maintains an internal (model) state if needed.

This is what I intended by with the TimeStepper proposal in my comment above. Let's use the term "TimeStepper" unless you have another idea.

Create a super lightweight inference/looper base class

This basically what earth2mip.networks.Inference does now wraps a torch module into a TimeLoop. We can implement a similar version of it without the legacy, and can potentially add features like diagnostics models, etc. Also see this section of the docs: https://nvidia.github.io/earth2mip/concepts.html#translating-between-model-wrappers.

Finally, TimeLoop already supports I/O and reductions by exposing an iterator interface. The consuming CLIs/inference functions do different kinds of reductions and I/Os with this. I don't think our model wrapping layers should know about these concerns.

So basically, the work I see here is

  1. Define a TimeStepper "functional" interface. This cannot be a function directly since it also needs metadata like .grid to be useful.
  2. write a class analogous to earth2mip.networks.Inference which wraps a TimeStepper and implements the earth2mip.time_loop.TimeLoop interface.

Agreed?

This is what I intended by with the TimeStepper proposal in my comment above. Let's use the term "TimeStepper" unless you have another idea.

Yes I think so, I think I got a little mixed up with TimeStepper/TimeLooper :) . Generally I think the TimeStepper should be the single wrapper that encompasses a base model (NN from defs above) and is what is the foundational object returned on a model.load. This can absorb all meta data the module currently needs to have.

I think prioritizing a shallow call stack will greatly improve contribution / debugging. Thus I want to minimize wrappers.

This basically what earth2mip.networks.Inference does now wraps a torch module into a TimeLoop.

Inference still operates with the iterator approach and is still coupled to Model.load. I would want this to be detached from the model and have the loops explicitly running inside. This object would be where all functional components get combined.

If needed (hopefully not) these new inferencer classes, like the sample above, would then be extended to specific tasks if needed. For example, if you want parallelize / get fancy with optimization you could but the default would encapsulate ensembling, forecasting and scoring.

Presently each task (inference ensemble, inference basic, medium range) all have their own specific loops scattered around. This would at the least allow us to centralize structure and move functionality to a common location. This would also eliminate these hard coded functions for say medium range with specific error metrics fixed to that file. Rather move those into a derivative function and give that to the manager to the time loop.

Finally, TimeLoop already supports I/O and reductions by exposing an iterator interface.

This is true, but I think the iterator approach lowers traceability and increases the coupling between the actual inference run with the loading of the model. Right now I dont see an advantage of this approach that outweighs the draw backs

1.Define a TimeStepper "functional" interface. This cannot be a function directly since it also needs metadata like .grid to be useful.

Yes to be the Model wrapper, give me an input state and logic out an output state.

2.write a class analogous to earth2mip.networks.Inference which wraps a TimeStepper and implements the earth2mip.time_loop.TimeLoop interface.

Not wrapping, I think the TimeStepper should be a "functional" (callable object) input to the Inferencer with all other goodies you're wanting to do. Then the inferencer drives.

I think we agree there is a need for TimeStepper. Can we stick to this terminology please. I think step vs loop is as useful distinction. a loop is many steps.

It's definitely a tradeoff. If one is to write a new scoring/inference function, then it is often nice to use the higher level TimeLoop interface since one typically just needs an iterator. This then puts some burden on the user to write their own time looping logic or use the builtin earth2mip.networks.Inference class.

inference_ensemble is an outlier in the amount of stuff it needs to know about the model, let's not over optimize or change existing abstractions for this one use case.

Happy to discuss more.

What different types of model "State" are there? graphcast, pangu, etc have some differences.

What different types of model "State" are there? graphcast, pangu, etc have some differences.

I imagine the model's state would be pretty diverse if its not a basic auto-regressive model. For example, a RNN may be the most complex or pangu with its time-steps that interweaves models. I dont think State needs to be explicitly defined, its just what the model stores between timesteps.

I think by minimum there just needs to be some reset state / init state callback that the default looper can hit. I think this should enable the overwhelming majority of TimeStepper use cases even with these complicated models.

I think by minimum there just needs to be some reset state / init state callback

Agreed, though I am feeling increasingly skeptical about this since it will require adding some more abstractions that the model developer will need to incorporate.

What is state vs not may actually change depending on the configuration (e.g. for a prescribed SST run).

I'm starting to think we should just provide the lowest level interface possible. Provide your ML model as is and label what the inputs/outputs are. The framework can figure out what is state vs not from that.

See TimeStepper in #134. Worked pretty nicely IMO.

There hasn't been any more discussion so closing. earth2mip.time_loop.TimeStepper seems good enough for now. We could consider refactoring some of the existing models to use that interface.