juanmc2005 / diart

A python package to build AI-powered real-time audio applications

Home Page:https://diart.readthedocs.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Feature Request: Implementing Persistent Speaker Embeddings Across Conversations

DmitriyG228 opened this issue · comments

Feature Description

I propose the addition of a feature to the DIART project that allows for the persistence and reuse of speaker embeddings across multiple conversations. I am willing to contribute into this feature.

Expected Benefit

It would be particularly useful in scenarios where the identification of speakers is necessary over time accross multiple conversations

Implementation Feasibility

Given the complexity of the speaker embeddings obtained during a conversation, I seek guidance on the technical feasibility of this feature. Specifically, I'm interested in understanding whether the current architecture and design of DIART can support the persistence of speaker embeddings across conversations.

Suggested Integration Points

Could you provide insights on which parts of the DIART codebase would be most relevant for integrating this mechanism? Any pointers or suggestions on how to approach this enhancement would be greatly appreciated.

Additional Context

I have reviewed the paper implemented by DIART and believe that, although challenging, this feature could be a feasible and valuable addition.

I am eager to contribute to this aspect of the project and align it with DIART's overall goals and design.

Thank you for considering this feature request and for any guidance you can provide.

My own solution is the following:

patch OnlineSpeakerClustering with:


    def get_speaker_id_to_centroid_mapping(self) -> Dict[int, np.ndarray]:
        """Returns the mapping of speaker IDs to their centroids."""
        if self.centers is None:
            return {}

        speaker_id_to_centroid = {}
        for g_spk in self.active_centers:
            speaker_id_to_centroid[g_spk] = self.centers[g_spk].tolist()
        return speaker_id_to_centroid

SpeakerDiarization:

def __call__()
    ### existing code
    for wav, seg, emb in zip(waveforms, segmentations, embeddings):
        ### existing code
        speaker_id_to_centroid_mapping = self.clustering.get_speaker_id_to_centroid_mapping()
        outputs.append((agg_prediction, agg_waveform,speaker_id_to_centroid_mapping)) #####    
    

class RedisWriter(Observer):
    def __init__(self, uri: Text, redis_client, patch_collar: float = 0.05):
        super().__init__()
        self.uri = uri
        self.redis_client = redis_client
        self.conversation_id = uri  # Assuming URI as a unique identifier for the conversation
        self.patch_collar = patch_collar

    def on_next(self, value: Union[Tuple, Annotation]):
        if isinstance(value, tuple):
            prediction, _, centroids = value    
            # Process each segment in the prediction
            for segment, _, label in prediction.itertracks(yield_label=True):
                # Update last centroids for each speaker

                # Write data to Redis queues
                diarization_data = {
                    'start': segment.start,
                    'end': segment.end,
                    'speaker_id': label,
                    'centroids': centroids
                }
                self.redis_client.rpush(f'diarization_{self.conversation_id}', json.dumps(diarization_data))

        else:
            prediction = value

    def on_error(self, error: Exception):
        # Handle error (optional)
        pass

    def on_completed(self):
        # Handle completion (optional)
        pass

I run this the following way:

from diart import SpeakerDiarization
from diart.sources import FileAudioSource  # Import the class for file audio source
from diart.inference import StreamingInference
from diart.sinks import RTTMWriter,FileWriter, RedisWriter

# Initialize the speaker diarization pipeline
pipeline = SpeakerDiarization()

sample_rate = 16000 
file_source = FileAudioSource(audio_file_path,sample_rate)  # Use FileAudioSource

# Create a StreamingInference instance with the file source
inference = StreamingInference(pipeline, file_source, do_plot=False)
inference.attach_observers(RedisWriter(file_source.uri, redis_client)) # instead of RTTMWriter

# Run the inference
prediction = inference()

the above is writing output to redis queue with global speaker embeddings

note that it's probably suboptimal to save centroids with every iteration, as they quickly converge to equal

I would appreciate your feedback!

Hey @DmitriyG228! Thanks for this feature request, your implementation looks very ncie! I would only change some minor things. For example, I would prefer not to have a speaker id mapping mechanism in the clustering block. The speaker ids are already numbered according to their centroid if I'm not mistaken (e.g. speaker_0 == centroid 0). However, if we decide to include a mapping structure (I'm willing to be persuaded on this cause I see some advantages), I'd prefer to put it in SpeakerDiarization as part of the pipeline state.

Apart from that, I also really like the idea of a RedisWriter! Could you open a PR with your code so we can discuss the details there?

I would prefer not to add unnecessary dependencies to diart. I would implement the RedisWriter to throw an error if redis isn't installed. For an example of this you can check the imports in models.py.

Thank you!

Hey @juanmc2005, thanks for your feedback, please find the PR