File size: 5,295 Bytes
361ee5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import { pipeline, AutoProcessor, ClapAudioModelWithProjection } from '@xenova/transformers';

class CLAPProcessor {
  constructor() {
    this.model = null;
    this.processor = null;
    this.defaultLabels = [
      'speech', 'music', 'singing', 'guitar', 'piano', 'drums', 'violin',
      'trumpet', 'saxophone', 'flute', 'classical music', 'rock music',
      'pop music', 'jazz', 'electronic music', 'ambient', 'nature sounds',
      'rain', 'wind', 'ocean waves', 'birds chirping', 'dog barking',
      'cat meowing', 'car engine', 'traffic', 'footsteps', 'door closing',
      'applause', 'laughter', 'crying', 'coughing', 'sneezing',
      'telephone ringing', 'alarm clock', 'typing', 'water running',
      'fire crackling', 'thunder', 'helicopter', 'airplane', 'train',
      'motorcycle', 'bell ringing', 'whistle', 'horn', 'siren',
      'explosion', 'gunshot', 'silence', 'noise', 'distortion'
    ];
  }

  async initialize() {
    if (this.model && this.processor) return;

    try {
      // Load the CLAP model and processor
      this.processor = await AutoProcessor.from_pretrained('Xenova/clap-htsat-unfused');
      this.model = await ClapAudioModelWithProjection.from_pretrained('Xenova/clap-htsat-unfused');
      
      console.log('CLAP model loaded successfully');
    } catch (error) {
      console.error('Failed to load CLAP model:', error);
      throw error;
    }
  }

  async processAudio(audioBuffer) {
    if (!this.model || !this.processor) {
      await this.initialize();
    }

    try {
      // Convert audio to the format expected by CLAP
      const audio = await this.preprocessAudio(audioBuffer);
      
      // Process audio through the model
      const audioInputs = await this.processor(audio);
      const audioFeatures = await this.model.get_audio_features(audioInputs);
      
      // Process text labels
      const textInputs = await this.processor.text(this.defaultLabels);
      const textFeatures = await this.model.get_text_features(textInputs);
      
      // Calculate similarities
      const similarities = await this.calculateSimilarities(audioFeatures, textFeatures);
      
      // Return top tags with confidence scores
      return this.getTopTags(similarities, 5);
    } catch (error) {
      console.error('Error processing audio:', error);
      throw error;
    }
  }

  async preprocessAudio(audioBuffer) {
    // Convert to mono if stereo
    let audioData;
    if (audioBuffer.numberOfChannels > 1) {
      audioData = new Float32Array(audioBuffer.length);
      for (let i = 0; i < audioBuffer.length; i++) {
        let sum = 0;
        for (let channel = 0; channel < audioBuffer.numberOfChannels; channel++) {
          sum += audioBuffer.getChannelData(channel)[i];
        }
        audioData[i] = sum / audioBuffer.numberOfChannels;
      }
    } else {
      audioData = audioBuffer.getChannelData(0);
    }

    // Resample to 48kHz if needed (CLAP expects 48kHz)
    const targetSampleRate = 48000;
    if (audioBuffer.sampleRate !== targetSampleRate) {
      audioData = await this.resampleAudio(audioData, audioBuffer.sampleRate, targetSampleRate);
    }

    return audioData;
  }

  async resampleAudio(audioData, originalRate, targetRate) {
    // Simple linear interpolation resampling
    const ratio = originalRate / targetRate;
    const newLength = Math.round(audioData.length / ratio);
    const resampled = new Float32Array(newLength);
    
    for (let i = 0; i < newLength; i++) {
      const originalIndex = i * ratio;
      const indexFloor = Math.floor(originalIndex);
      const indexCeil = Math.min(indexFloor + 1, audioData.length - 1);
      const fraction = originalIndex - indexFloor;
      
      resampled[i] = audioData[indexFloor] * (1 - fraction) + audioData[indexCeil] * fraction;
    }
    
    return resampled;
  }

  async calculateSimilarities(audioFeatures, textFeatures) {
    // Calculate cosine similarity between audio and text features
    const audioVector = audioFeatures.data;
    const similarities = [];

    for (let i = 0; i < this.defaultLabels.length; i++) {
      const textVector = textFeatures.data.slice(
        i * audioVector.length, 
        (i + 1) * audioVector.length
      );
      
      const similarity = this.cosineSimilarity(audioVector, textVector);
      similarities.push(similarity);
    }

    return similarities;
  }

  cosineSimilarity(vecA, vecB) {
    let dotProduct = 0;
    let normA = 0;
    let normB = 0;

    for (let i = 0; i < vecA.length; i++) {
      dotProduct += vecA[i] * vecB[i];
      normA += vecA[i] * vecA[i];
      normB += vecB[i] * vecB[i];
    }

    return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
  }

  getTopTags(similarities, topK = 5) {
    const tagged = this.defaultLabels.map((label, index) => ({
      label,
      confidence: Math.max(0, similarities[index]) // Ensure non-negative
    }));

    return tagged
      .sort((a, b) => b.confidence - a.confidence)
      .slice(0, topK);
  }

  // Convert file to AudioBuffer
  async fileToAudioBuffer(file) {
    const arrayBuffer = await file.arrayBuffer();
    const audioContext = new (window.AudioContext || window.webkitAudioContext)();
    return await audioContext.decodeAudioData(arrayBuffer);
  }
}

export default CLAPProcessor;