furiosa-ai / torch-fx-rs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch-fx-rs

Rust APIs to handle PyTorch graph modules and graphs

Where to use

This API can help writing a Python module in Rust using PyO3, in case the module needs to handle PyTorch graph modules or graphs.

APIs

pub struct GraphModule

#[repr(transparent)]
pub struct GraphModule(_);

A wrapper for PyTorch's GraphModule class.

The constructor method of this returns a shared reference &GraphModule instead of an owned value. The return value is GIL-bound owning reference into Python's heap.

Methods

  • pub fn new<'py>(
        py: Python<'py>,
        nn: &GraphModule,
        graph: &Graph
    ) -> PyResult<&'py Self>

    Create new instance of GraphModule PyTorch class with PyTorch native constructor but class_name is not given (so that it remains as the default value 'GraphModule').

    If new instance is created succesfully, returns Ok with a shared reference to the newly created instance in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn new_with_empty_gm<'py>(
        py: Python<'py>,
        graph: &Graph
    ) -> PyResult<&'py Self>

    Create new instane of GraphModule PyTorch class with PyTorch native constructor but class_name is not given (so that it remains as the default value 'GraphModule') and root is a newly created torch.nn.Module by torch.nn.Module().

    If new instance is created succesfully, returns Ok with a shared reference to the newly created instance in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn extract_parameters(&self)
        -> PyResult<HashMap<String, &[u8]>>

    Collect all parameters of this GraphModule.

    Make a HashMap which maps the parameter name to a slice representing the underlying storage of the parameter value.

    If this process is successful, returns Ok with the HashMap in it. Otherwise, return Err with a PyErr in it. PyErr will explain the cause of the failure.

  • pub fn extract_buffers(&self)
        -> PyResult<HashMap<String, &[u8]>>

    Collect all buffers of this GraphModule.

    Make a HashMap which maps the buffer name to a slice representing the underlying storage of the buffer value.

    If this process is successful, returns Ok with the HashMap in it. Otherwise, return Err with a PyErr in it. PyErr will explain the cause of the failure.

  • pub fn graph(&self) -> PyResult<&Graph>

    Retrieve the graph attribute of this GraphModule.

    If the retrieval is done successfully, returns Ok with a shared reference to the graph attribute (&Graph) in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn get_parameter(
        &self,
        name: &str
    ) -> PyResult<Option<&[u8]>>

    Get the underlying storage of the parameter value named as the value of name, for this GraphModule.

    If there is no parameter named as the value of name, returns Ok(None). If there exists such parameter, returns Ok(Some) with a slice representing the underlying storage of the parameter value. If this process fails, returns Err with a PyErr in it. PyErr will explain the cause of the failure.

  • pub fn count_parameters(&self) -> PyResult<usize>

    Get the number of parameters of this GraphModule.

    If a Python error occurs during this procedure, returns Err with a PyErr in it. PyErr will explain the error. Otherwise, returns Ok with the number of parameters of this GraphModule in it.

  • pub fn get_buffer(
        &self,
        name: &str
    ) -> PyResult<Option<&[u8]>>

    Get the underlying storage of the buffer value named as the value of name, for this GraphModule.

    If there is no buffer named as the value of name, returns Ok(None). If there exists such buffer, returns Ok(Some) with a slice representing the underlying storage of the buffer value. If this process fails, returns Err with a PyErr in it. PyErr will explain the cause of the failure.

  • pub fn count_buffers(&self) -> PyResult<usize>

    Get the number of buffers of this GraphModule.

    If a Python error occurs during this procedure, returns Err with a PyErr in it. PyErr will explain the error. Otherwise, returns Ok with the number of parameters of this GraphModule in it.

  • pub fn print_readable(&self) -> PyResult<String>

    Stringify this GraphModule.

    This does the same what print_readable instance method of GraphModule PyTorch class does, but print_output is given as True.

    If stringifying is done successfully, returns Ok with the resulting string in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

pub struct Graph

#[repr(transparent)]
pub struct Graph(_);

A wrapper for PyTorch's Graph class.

The constructor method of this returns a shared reference &Graph instead of an owned value. The return value is GIL-bound owning reference into Python's heap.

Methods

  • pub fn new(py: Python<'_>) -> PyResult<&Self>

    Create new instance of Graph PyTorch class with PyTorch native constructor.

    If new instance is created successfully, returns Ok with a shared reference to the newly created instance in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn nodes_iterator(&self) -> PyResult<&PyIterator>

    Retrieve all the Nodes of this Graph as a Python iterator.

    If the retrieval is done successfully, returns Ok with a shared reference to a Python iterator for it in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn eliminate_dead_code(&self) -> PyResult<()>

    An interface for eliminate_dead_code instance method of Graph PyTorch class.

    If the method call is done successfully, returns Ok(()). Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn lint(&self) -> PyResult<()>

    An interface for lint instance method of Graph PyTorch class.

    If the method call is done successfully, returns Ok(()). Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn create_node<S: AsRef<str>>(
        &self,
        op: Op,
        target: Target,
        args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>,
        kwargs: impl IntoIterator<Item = (String, Argument)>,
        name: S,
        meta: Option<HashMap<String, PyObject>>,
    ) -> PyResult<&Node>

    An interface for create_node instance method of Graph PyTorch class, but type_expr is not given (None). Also, if meta is given, the newly created Node will have an attribute meta, whose value will be the given argument meta.

    If the method call is done successfully, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn placeholder<S: AsRef<str>>(
        &self,
        name: S
    ) -> PyResult<&Node>

    Create and insert a placeholder Node into this Graph. A placeholder represents a function input. name is the name for the input value.

    This does the same what placeholder instance method of Graph PyTorch class does, but type_expr is None and default_value is inspect.Signature.empty.

    If the creation and insertion of the Node is done successfully, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn output(
        &self,
        args: Argument
    ) -> PyResult<&Node>

    Create and insert an output Node into this Graph. args is the value that should be returned by this output node. args has to be Argument::NodeTuple.

    This does the same what output instance method of Graph PyTorch class does, but type_expr is None and the newly created Node has a name 'output'.

    If the creation and insertion of the Node is done successfully, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn call_custom_function<S: AsRef<str>>(
        &self,
        name: S,
        custom_fn: CustomFn,
        args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>,
        kwargs: impl IntoIterator<Item = (String, Argument)>,
    ) -> PyResult<&Node>

    Create and insert a call_function Node into this Graph. call_function Node represents a call to a Python callable, specified by custom_fn.

    This does the same what call_function instance method of Graph PyTorch class does, but the name of the_function parameter is changed into custom_fn, type_expr is not given (None), and the name for the name of this node is given.

    custom_fn must be a CustomFn, a python callable which calls a Rust function actually.

    If the creation and insertion of the Node is done successfully, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn call_python_function<S: AsRef<str>>(
        &self,
        name: S,
        the_function: Py<PyAny>,
        args: impl IntoIterator<IntoIter = impl ExactSizeIterator<Item = Argument>>,
        kwargs: impl IntoIterator<Item = (String, Argument)>,
    ) -> PyResult<&Node>

    Create and insert a call_function Node into this Graph. call_function Node represents a call to a Python callable, specified by the_function.

    This does the same what call_function instance method of Graph PyTorch class does, but type_expr is not given (None) and the name for the name of this node is given.

    If the creation and insertion of the Node is done successfully, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn node_copy(
        &self,
        node: &Node,
        mapper: Option<&HashMap<String, String>>,
    ) -> PyResult<&Node>

    Copy a Node from another Graph into this Graph(self). node is the node to copy into self. mapper needs to transform arguments from the graph of node to the graph of self.

    This does the same what node_copy instance method of Graph PyTorch class does.

    If the copying and insertion of the Node is done successfuly, returns Ok with a shared reference to the newly created Node in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn flatten_node_args<S: AsRef<str>>(
        &self,
        node_name: S
    ) -> PyResult<Option<Vec<String>>>

    Retrieve the names of argument Nodes of the Node named as the value of node_name in this Graph.

    If this graph doesn't have a Node named as the value of node_name, returns Ok(None). If this graph have a Node named as the value of node_name, returns Ok(Some) with a Vec of names of argument Nodes of the Node, in the Some. If something fails while looking into this Graph, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn users<S: AsRef<str>>(
        &self,
        node_name: S
    ) -> PyResult<Option<Vec<String>>>

    Retrieve the names of user Nodes of the Node named as the value of node_name in this Graph.

    If this graph doesn't have a Node named as the value of node_name, returns Ok(None). If this graph have a Node named as the value of node_name, returns Ok(Some) with a Vec of names of user Nodes of the Node, in the Some. If something fails while looking into this Graph, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn graph_to_string(
        &self,
        py: Python<'_>
    ) -> PyResult<String>

    Stringify this Graph.

    This does the same what __str__ instance method of Graph PyTorch class.

    If stringifying is done successfully, returns Ok with the resulting string in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn extract_named_nodes(&self)
        -> PyResult<IndexMap<String, &Node>>

    Collect all named Nodes of this Graph.

    Make an IndexMap which maps each Node's name to a shared reference of the Node itself, for every Node in self.

    If this process is successful, returns Ok with the IndexMap in it. Otherwise, return Err with a PyErr in it. PyErr will explain the cause of the failure.

  • pub fn lookup_node<S: AsRef<str>>(
        &self,
        name: S
    ) -> PyResult<Option<&Node>>

    Lookup a Node by its name(name) in this Graph.

    If there is no Node with a name named as the value of name, Ok(None) is returned. If there exists such Node in this Graph, Ok(Some) with a shared reference to the Node is returned. If this process fails, returns Err with a PyErr in it. PyErr will explain the cause of the failure.

pub struct Node

#[repr(transparent)]
pub struct Node(_);

A wrapper for PyTorch's Node class.

This appears as a shared reference &Node into Python's heap instead of an owned value.

Methods

  • pub fn flatten_node_args(&self) -> PyResult<Vec<String>>

    Retrieve the names of argument Nodes of this Node. Although a Node can have multiple arguments and an argument can have one or more Nodes, the result will contain all the argument Nodes' names in a 1-dimensional vector. (This is why this method is named flatten_node_args.)

    If the retrieval is done successfully, returns Ok with a Vec of names of argument nodes. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn args(&self) -> PyResult<Vec<Argument>>

    Retrieve the arguments of this Node.

    If the retrieval is done successfully, returns Ok with a Vec<Argument> containing the arguments. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn name(&self) -> PyResult<String>

    Retrieve the name of this Node.

    If the retrieval is done successfully, returns Ok with the name in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn op(&self) -> PyResult<Op>

    Retrieve the opcode of this Node.

    If the retrieval is done successfully, returns Ok with the opcode in Op in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn target(&self) -> PyResult<Target>

    Retrieve the target this Node should call.

    If the retrieval is done successfully, returns Ok with the target in Target in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn kwargs(&self)
        -> PyResult<HashMap<String, Argument>>

    Retrieve the kwargs to be passed to the target of this Node.

    If the retrieval is done successfully, returns Ok with the kwargs in HashMap<String, Argument> in it. Otherwise, returns Err with a PyErr in it. The PyErr will explain the cause of the failure.

  • pub fn meta(&self)
        -> PyResult<HashMap<String, PyObject>>

    Retrieve the meta of this Node.

    If this Node has an attribute meta, returns Ok with the meta in HashMap<String, PyObject> in it. Otherwise, returns Ok(Default::default()). This never returns Err.

pub type FunctionWrapper

Wrapper for a Rust function. This wraps a function to execute it in Python. Therefore, the function needs to receive 2 arguments, args as &PyTuple and kwargs as Option<&PyDict>, and return PyResult<PyObject>.

pub struct CustomFn

#[pyclass]
#[derive(Clone)]
pub struct CustomFn {
    pub func_name: String,
    /* private fields */
}

An interface for Python callable object which actually executes a Rust function.

Fields

  • pub func_name: String
    • Name of the custom function

Methods

  • pub fn new<S: AsRef<str>>(
        func_name: S,
        func: FunctionWrapper
    ) -> Self

    Create a new Python callable object which is named as the value of func_name and actually executes a Rust function wrapped in func.

pub struct TensorMeta

#[derive(Debug, Clone, FromPyObject)]
pub struct TensorMeta {
    pub shape: Vec<usize>,
    pub dtype: Dtype,
    pub requires_grad: bool,
    pub stride: Vec<usize>,
    pub memory_format: Option<MemoryFormat>,
    pub is_quantized: bool,
    pub qparams: HashMap<String, PyObject>,
}

A structure containing pertinent information about a tensor within a PyTorch program.

(reference)

pub enum Op

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Op {
    Placeholder,
    CallFunction,
    CallMethod,
    CallModule,
    GetAttr,
    Output,
}

A representation of opcodes for Nodes.

pub enum Target

#[derive(Debug, Clone)]
pub enum Target {
    Str(String),
    TorchOp(String, PyObject),
    BuiltinFn(String, PyObject),
    Callable(PyObject),
    CustomFn(CustomFn),
}

A representation of targets for Nodes.

pub enum Argument

#[derive(Debug, Clone)]
pub enum Argument {
    Node(String),
    NodeList(Vec<String>),
    NodeTuple(Vec<String>),
    OptionalNodeList(Vec<Option<String>>),
    OptionalNodeTuple(Vec<Option<String>>),
    NoneList(usize),
    NoneTuple(usize),
    Bool(bool),
    Int(i64),
    Float(f64),
    VecBool(Vec<bool>),
    VecInt(Vec<i64>),
    VecFloat(Vec<f64>),
    Dtype(Dtype),
    Layout(Layout),
    Device(Device),
    MemoryFormat(MemoryFormat),
    Value(PyObject),
    EmptyList,
    None,
}

A representation of arguments for Nodes.

pub enum Dtype

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Dtype {
    Float32,
    Float64,
    Complex64,
    Complex128,
    Float16,
    Bfloat16,
    Uint8,
    Int8,
    Int16,
    Int32,
    Int64,
    Bool,
}

An enum which represents the data type of a torch.Tensor.

(reference)

pub enum MemoryFormat

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum MemoryFormat {
    ContiguousFormat,
    ChannelsLast,
    ChannelsLast3d,
    PreserveFormat,
}

An enum which represents the memory format on which a torch.Tensor is or will be allocated.

(reference)

pub enum Device

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Device {
    Cpu(Option<usize>),
    Cuda(Option<usize>),
    Mps(Option<usize>),
}

An enum which represents the device on which a torch.Tensor is or will be allocated.

(reference)

Documentation

By executing following, the documentation, by cargo-docs, for this crate will open.

cargo doc --open

More detailed documentation for torch.fx may be needed.

About


Languages

Language:Rust 100.0%