-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
249 lines (197 loc) · 7.77 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
"""Deep learning and backend code for dispict.
See the README for more information. You need a Modal account to run this, and
you also need to download the dataset using the notebooks in this repository
first. This program downloads and embeds 25,000 images from the Harvard Art
Museums, then hosts recommendations at a serverless web endpoint.
"""
import json
import os
import time
from dataclasses import dataclass
from typing import Optional
import numpy as np
import modal
app = modal.App("dispict")
app.image = modal.Image.debian_slim(python_version="3.11")
def load_clip():
import clip
return clip.load("ViT-B/32")
clip_image = (
app.image.apt_install("git")
.pip_install("ftfy", "regex", "tqdm", "numpy", "torch")
.pip_install("git+https://github.com/openai/CLIP.git")
.run_function(load_clip)
.env({"INSIDE": "clip_image"})
)
web_image = app.image.pip_install("numpy", "h5py").env({"INSIDE": "web_image"})
if os.environ.get("INSIDE") == "clip_image":
import clip
import torch
model, preprocess = load_clip()
model.eval()
@app.function(
image=clip_image,
keep_warm=1,
)
def run_clip_text(texts: list[str]):
"""Run pretrained CLIP on a list of texts.
Returns a numpy array containing the concatenated 512-dimensional embedding
outputs for each provided input, evaluated as a batch.
"""
text_tokens = clip.tokenize(texts, truncate=True)
with torch.no_grad():
return model.encode_text(text_tokens).float().numpy()
@app.function(
image=clip_image,
concurrency_limit=32,
)
def run_clip_images(image_urls: list[str]):
"""Run pretrained CLIP on a list of image URLs.
Returns a numpy array containing the concatenated 512-dimensional embedding
outputs for each provided input, evaluated as a batch.
The first return value is a list of indices that had an error during fetch.
"""
from io import BytesIO
from concurrent.futures import ThreadPoolExecutor
import requests
from PIL import Image, UnidentifiedImageError
def get_with_retry(url: str) -> requests.Response:
request_num = 0
while request_num < 5:
try:
resp = requests.get(url, timeout=8)
except requests.exceptions.RequestException as exc:
print("Retrying", url, "due to", exc)
request_num += 1
time.sleep(3.0)
continue
if resp.status_code not in (200, 404):
print("Retrying", url, "due to status code", resp.status_code)
request_num += 1
time.sleep(0.1)
else:
return resp
return resp # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
responses = list(executor.map(get_with_retry, image_urls))
missing_indices: list[int] = []
original_images: list[Image.Image] = []
for i, resp in enumerate(responses):
if resp.status_code != 200:
print(f"Received status code {resp.status_code} from URL:", image_urls[i])
missing_indices.append(i)
continue
try:
original_images.append(Image.open(BytesIO(resp.content)))
except UnidentifiedImageError:
print("Failed to load image from URL:", image_urls[i])
missing_indices.append(i)
images: list[torch.Tensor] = [preprocess(img) for img in original_images] # type: ignore
image_input = torch.stack(images)
with torch.no_grad():
return missing_indices, model.encode_image(image_input).float().numpy()
@dataclass
class Artwork:
id: int
objectnumber: str
url: str
image_url: str
dimensions: str
dimheight: float
dimwidth: float
title: Optional[str] # plaintext title
description: Optional[str] # plaintext description
labeltext: Optional[str] # optional label text
people: list[str] # information about artists
dated: str # "c. 1950" or "1967-68" or "18th century"
datebegin: int # numerical year or 0
dateend: int # numerical year or 0
century: Optional[str] # alternative to "dated" column
department: str # categorical, about a dozen departments
division: Optional[str] # modern, european/american, or asian/mediterranean
culture: Optional[str] # American, Dutch, German, ...
classification: str # Photographs or Prints or ...
technique: Optional[str] # Lithograph, Etching, Gelatin silver print, ...
medium: Optional[str] # Graphite on paper, Oil on canvas, ...
accessionyear: Optional[int] # when the item was added
verificationlevel: int # How verified a work is(?)
totaluniquepageviews: int # a proxy for popularity
totalpageviews: int # a proxy for popularity
copyright: Optional[str] # copyright status
creditline: str # who donated this artwork
def read_data(filename: str) -> list[Artwork]:
with open(filename, "r") as f:
return [Artwork(**row) for row in json.load(f)]
def read_embeddings(filename: str):
import h5py
print("Loading embeddings")
with h5py.File(filename, "r") as f:
ids: h5py.Dataset = f["ids"] # type: ignore
matrix: h5py.Dataset = f["embeddings"] # type: ignore
embeddings = np.array(matrix)
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings_ids = list(ids)
print("Finished loading embeddings")
return embeddings, embeddings_ids
@dataclass
class SearchResult:
score: float
artwork: Artwork
if os.environ.get("INSIDE") == "web_image":
data = read_data("/data/catalog.json")
data_by_id: dict[int, Artwork] = {}
for row in data:
data_by_id[row.id] = row
embeddings, embeddings_ids = read_embeddings("/data/embeddings.hdf5")
if not os.environ.get("SKIP_WEB"):
@app.function(
image=web_image,
mounts=[
modal.Mount.from_local_file(
"data/artmuseums-clean.json", "/data/catalog.json"
),
modal.Mount.from_local_file(
"data/embeddings.hdf5", "/data/embeddings.hdf5"
),
],
keep_warm=1,
)
@modal.web_endpoint()
def suggestions(text: str, n: int = 50) -> list:
"""Return a list of artworks that are similar to the given text."""
features = run_clip_text.remote([text])[0, :]
features /= np.linalg.norm(features)
scores = embeddings @ features
index_array = np.argsort(scores)
return [
SearchResult(
score=50 * float(1 + scores[i]), artwork=data_by_id[embeddings_ids[i]]
)
for i in reversed(index_array[-n:])
]
@app.local_entrypoint()
def embed_images():
import h5py
data = read_data("data/artmuseums-clean.json")
chunk_size = 24
chunked_ids = []
chunked_urls = []
for idx in range(0, len(data), chunk_size):
chunked_ids.append([row.id for row in data[idx : idx + chunk_size]])
chunked_urls.append([row.image_url for row in data[idx : idx + chunk_size]])
results = list(run_clip_images.map(chunked_urls))
all_embeddings: dict[int, np.ndarray] = {}
for ids, (missing, embeddings) in zip(chunked_ids, results):
assert len(ids) == len(missing) + len(embeddings)
embeddings_idx = 0
for i, id in enumerate(ids):
if i not in missing:
all_embeddings[id] = embeddings[embeddings_idx, :]
embeddings_idx += 1
print(f"Finished embedding {len(all_embeddings)} images out of {len(data)}")
ids, embedding_matrix = zip(*all_embeddings.items())
embedding_matrix = np.vstack(embedding_matrix)
with h5py.File("data/embeddings.hdf5", "w") as f:
f.create_dataset("embeddings", data=embedding_matrix)
f.create_dataset("ids", data=ids)
print("Saved to hdf5 file")