NVIDIA / cuCollections

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[FEATURE]: Add host-bulk `retrieve` for hash tables

PointKernel opened this issue · comments

Is your feature request related to a problem? Please describe.

Similar to the current static_multimap::retrieve, the new cuco set/map could provide retrieve operations that return the matched value(s) for the given keys.

Describe the solution you'd like

A device retrieve for static_map could look like:

  /**
   * @brief Retrieves the matched key-value pair of a given key contained in the map 
   *
   * For `key` existing in the map, copies a pair of `key` and its matched value `cuco::pair{k, v}` to
   * unspecified  locations in `[output_begin, output_end)`. Does nothing if there are no matches.
   *
   * @tparam atomicC Type of atomic counter
   * @tparam OutputIt Device accessible output iterator whose `value_type` is
   *         constructible from `cuco::pair<Key, map_type::value_type>`
   * 
   * @param key The key to search for
   * @param num_matches Size of the output sequence
   * @param output_begin Beginning of the output sequence
   */
  template <typename atomicC, typename OutputIt>
  __device__ void retrieve(Key const& key, atomicC* num_matches, OutputIt output_begin)

variants could be useful for RAPIDS:

  • Output of a set retrieve should be something like: cuco::pair<ProbeKey, MapKey>
  • retrieve_outer returns slot sentinel in the case of non-matches (desired by outer join)
  • conditional retrieve: retrieve_if

Describe alternatives you've considered

No response

Additional context

No response

What is the intention with making atomicC a template parameter?

I think the more canonical thing here would be to make retrieve return the iterator output_begin + N for N matches.

What is the intention with making atomicC a template parameter?

It doesn't have to be a template parameter. This is mainly to remind myself that we need to atomically count the number of total matches.

I think the more canonical thing here would be to make retrieve return the iterator output_begin + N for N matches.

Good point. I plan to do so for host-bulk APIs. For device APIs, it's less straightforward since they normally accumulate or scan over shared memory buffers and don't count the number of matches by themselves. Maybe there is a way to do this, I will definitely keep it in mind when adding the APIs.

surpassed by #465