年始からお手軽ベクトル検索を作る

あけましておめでとうございます。

今年は、もうちょっとblogを書こうかなということで、三が日のうちに1つ出してみようと思います。

さて、2023年はLLMの利用と同時にベクトル検索が急に利用されるようになった年でした。 Retrieval-Augmented Generation(RAG)をみんな使い出したのと、OpenAI の embedding APIの性能が思った以上に良かったことが主な理由だと思います。

ベクトル検索は faiss でもchromaでも、qdrant でも何を使ってもよいと思いますが、numpy を使えば数行で実装できるし性能も悪くないことがわかったので書き残しておきます。

import numpy as np
class SimpleVecSearch():
    def add(self, ndarray):
        self._ndarray = ndarray

    def search(self, vec, topk=10):
        scores = np.matmul(self._ndarray, vec)
        idx = np.argsort(scores)[-1:-(topk+1):-1]
        return [scores[i].item() for i in idx], list(idx)

if __name__ == '__main__':
    def gen_vecs(n, dtype='float32'):
        tmp = np.random.rand(n, 1536).astype(dtype)
        return tmp / np.array([[i] for i in np.linalg.norm(tmp, axis=1)])

    raw_index = gen_vecs(10**5)
    simple_index = SimpleVecSearch()
    simple_index.add(raw_index)

    vec = np.random.rand(1536)
    d, i = mlx_index.search(vec, 1)

if 文以下はテスト用コードなので、実装自体は8行です。faiss に同じ入力をして結果が一致することも確認しています。インターフェースも faiss にそろえています。

性能比較

faiss vs numpy
クエリに一番近いベクトルを見つけるのに必要だった時間を、データ量毎に計測しました。100回の検索を行った平均値で単位は秒です。 総当り探索なので検索対象のデータ量(横軸のデータ数)が増えると線形に計算時間が増えます。n=1000までは numpy の方が速く、その後 faiss に抜かれますが実用の範囲だと思います。どのみち、query 用のベクトルを取得するのに OpenAI の API を呼び出すのに数10msecから数100msecかかるはずなので。 まあそれに誤差もデカいのであまり差が無いという以上のことが言える差では無さそうです。

テストコードの全体はこちらにあります。

背景

OpenAI の embedding API は、文章を 1536次元のベクトルに変換して返してくれます。 大雑把に言うと、このベクトルを比較して文章が似ているかどうかを判定し一番似ている文章を返すというのがベクトル検索のキモです。

似ている度合いの計算方法は色々ありますが、OpenAI は コサイン類似度の利用を推奨しています。 コサイン類似度の定義は以下のとおりです。

 \cos \theta =  \frac{\langle x,y\rangle}{||x|| ~||y||}

x, y はベクトルで、 \langle x,y\rangle は xとy の内積 ||x|| はベクトル x の大きさです。 そして、Open AI の embedding API はベクトルを大きさ1に正規化して返すので、コサイン類似度の分母は常に1になります。 ということは、ベクトルの内積を計算するだけで類似度が計算できます。

コサイン類似度はベクトルとベクトルの内積によって計算されますが、ベクトル検索で行いたいのは多数のベクトル(本稿ではindexと呼ぶ)から、あるベクトル(本稿ではqueryと呼ぶ)に一番近いベクトルを見つけることです。なので、index の方を行列として保持しておいて、行列とベクトルの掛け算で一気に類似度を計算してしまおうというのが最初のコードの内容です。

高速化

上記の通りindex にある全てのベクトルとqueyrのベクトル比較する総当りアルゴリズムなので、正確ですが速くはないはずです。 が、faiss の IndexFlatIP インデックスも同じ手法なので速度的にはあまり変わりません。

nが大きいとfaiss のほうがちょっと速いんだし、faiss でええやんという話ではあるので高速化を試みます。 ちょうど昨年末に appleMLX という apple silicon 向けの numpy 互換ライブラリを出しました。apple silicon の GPU を使って線形代数計算を高速に計算するものですので、これを使ってみます。 numpy を mlx.core に書き換えるだけなので、コードはほぼ変わりません。

import mlx.core as mx
class MLXVecSearch():
    def __init__(self, stream=mx.gpu):
        self._stream = stream

    def add(self, ndarray):
        self._ndarray = mx.array(ndarray)

    def search(self, _vec, topk=10):
        vec = mx.array(_vec)
        scores = mx.matmul(self._ndarray, vec, stream=self._stream)
        idx = mx.argsort(scores, stream=self._stream)[-1:-(topk+1):-1]
        return [scores[i].item() for i in idx], [i.item() for i in idx]

コンストラクタで stream を指定できるようにしていて、mx.gpu と mx.cpu を切り替えると GPU計算 CPU 計算を切り替えられます。 で、肝心の性能はこんな感じ

faiss vs numpy vs mlx

n=105 以上ではfaissを抜いて最速です。(ただし精度的な問題があるので後述します。) では、n=107 以上では?と思うところですが、メモリ不足のため私の環境では実験出来ませんでした。 これは faiss でも同じで、そこそこ巨大な行列をメモリ上に展開するので、メモリが確保できなくて落ちます。 OpenAI の返すベクトルの次元を大雑把に 2*103 次元と考えると、1つの次元を表すのに、float32 = 32bit = 4 Byte が必要ですから、ベクトル全体では 4 Byte * 2 * 103 で 8 kB です。 そのベクトルが 107 個有ると 8kB * 10 ^7 = 80 GB となるので、24GB しか積んでない私の MacBook Air M2 では確保できないのは道理ですね。

107 を超えるような文書数があるならば、もっと高尚なアルゴリズムを使うべきで、この資料の案内が頼りになります。

speakerdeck.com

実用的にはRAG に入れたい文書数が1000万文書もあるというのは相当稀なことで、あって数千というところなのではと想像するところです。 その程度であれば最初に掲載した numpy 実装でも十分なはずで、使う機会も有るかなと思います。

MLX の謎挙動

手軽に numpy を高速化してくれるので、MLX いいね!と言いたいところなのですが、微妙に計算結果が numpy と異なることがあります。 これは stream を GPU にした時だけに起こるもので、CPU を指定したときには起きません。追いかけると楽しいかもですが、現時点ではよく分からず。 精度の違いも疑ったのですが、numpy も MLX も float32 を利用しており精度の問題では無さそう。アーキテクチャが違う計算機なんだからこのぐらい誤差も有るでしょという話かもしれません。 RAG で使う分には問題にはならない気もしますが、numpy や faiss と一致しないことは覚えて置いたほうが良いでしょう。

計算が一致しない例の再現コード

matrix = [[0.58551717, 0.13656957],
                 [0.12464499, 0.09199308]]

np_matrix = np.array(matrix, dtype='float32')
mx_matrix = mx.array(matrix, dtype=mx.float32)

a = np.matmul(np_matrix[0],np_matrix[1]).item()
b = mx.matmul(mx_matrix[0], mx_matrix[1],stream=mx.gpu).item()
print(a) #  => 0.08554524183273315
print(b) #  => 0.08554523438215256

本稿の詰めが甘い部分

numpy が利用している BLAS ライブラリがもっと高速なら、numpy でも fiass に勝てるのでは? とか、ARM64 の SIMD 命令を直接呼べばもっと高速なのでは? など追いかけると面白そうなところは色々あるのですが、今回はお手軽に実装してみるのがテーマなので深掘りはやめておきます。

まとめ

  • 数行のコードで実用的なベクトル検索が実装出来ることを示しました
  • MLX は速いけど謎挙動があります。まだ0.0.6だしね。
  • ことしはもうちょい blog 書くぞ