Skip to content

Commit

Permalink
Try to verify data by checksum before building vocab
Browse files Browse the repository at this point in the history
  • Loading branch information
yethee committed Jul 17, 2024
1 parent 4317a32 commit 94a4ac1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
40 changes: 35 additions & 5 deletions src/Vocab/Loader/DefaultVocabLoader.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@
use function file_exists;
use function fopen;
use function hash_equals;
use function hash_file;
use function hash_final;
use function hash_init;
use function hash_update_file;
use function hash_update_stream;
use function is_dir;
use function is_resource;
use function is_writable;
use function mkdir;
use function rewind;
use function sha1;
use function sprintf;
use function stream_copy_to_stream;
use function stream_get_meta_data;

use const DIRECTORY_SEPARATOR;

Expand Down Expand Up @@ -59,6 +65,17 @@ public function load(string $uri, string|null $checksum = null): Vocab
}

try {
if ($checksum !== null && $this->isRewindable($stream)) {
if (! $this->checkHash($stream, $checksum)) {
throw new RuntimeException(sprintf(
'Checksum failed. Could not load vocab from URI: %s',
$uri,
));
}

rewind($stream);
}

if ($cacheFile !== null) {
$cacheStream = fopen($cacheFile, 'w+');

Expand All @@ -81,18 +98,31 @@ public function load(string $uri, string|null $checksum = null): Vocab
}
}

private function checkHash(string $filename, string|null $expectedHash): bool
/** @param string|resource $resource */
private function checkHash($resource, string|null $expectedHash): bool
{
if ($expectedHash === null) {
return true;
}

$hash = hash_file('sha256', $filename);
$ctx = hash_init('sha256');

if ($hash === false) {
return false;
if (is_resource($resource)) {
hash_update_stream($ctx, $resource);
} else {
hash_update_file($ctx, $resource);
}

$hash = hash_final($ctx);

return hash_equals($hash, $expectedHash);
}

/** @param resource $stream */
private function isRewindable($stream): bool
{
$meta = stream_get_meta_data($stream);

return $meta['seekable'];
}
}
12 changes: 12 additions & 0 deletions tests/Vocab/Loader/DefaultVocabLoaderTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use org\bovigo\vfs\vfsStream;
use org\bovigo\vfs\vfsStreamDirectory;
use PHPUnit\Framework\TestCase;
use RuntimeException;
use Yethee\Tiktoken\Vocab\Loader\DefaultVocabLoader;

use function dirname;
Expand Down Expand Up @@ -50,6 +51,17 @@ public function testInvalidateCacheWhenChecksumMismatch(): void
self::assertFileEquals($vocabUrl, $cacheFile->url());
}

public function testChecksumWhenNoCache(): void
{
$loader = new DefaultVocabLoader();
$vocabUrl = dirname(__DIR__, 2) . '/Fixtures/p50k_base.tiktoken';

$this->expectException(RuntimeException::class);
$this->expectExceptionMessageMatches('/Checksum failed/');

$loader->load($vocabUrl, hash('sha256', 'expected hash'));
}

protected function setUp(): void
{
$this->cacheDir = vfsStream::setup('cache');
Expand Down

0 comments on commit 94a4ac1

Please sign in to comment.