Natknąłem się na dokładnie ten sam problem, a człowieku, to była królicza nora. Chciałem opublikować moje rozwiązanie tutaj, ponieważ może zaoszczędzić komuś dzień pracy:
Struktury danych TensorFlow specyficzne dla wątków
W TensorFlow istnieją dwie kluczowe struktury danych, które działają za kulisami po wywołaniu model.predict
(lub keras.models.load_model
lub keras.backend.clear_session
lub prawie jakakolwiek inna funkcja współpracująca z backendem TensorFlow):
- Wykres TensorFlow, który przedstawia strukturę Twojego modelu Keras
- Sesja TensorFlow, która jest połączeniem między bieżącym wykresem a środowiskiem wykonawczym TensorFlow
Coś, co nie jest wyraźnie jasne w dokumentach bez trochę kopania, to to, że zarówno sesja, jak i wykres są właściwościami bieżącego wątku . Zobacz dokumentację API tutaj i tutaj.
Korzystanie z modeli TensorFlow w różnych wątkach
To naturalne, że chcesz załadować swój model raz, a następnie wywołać .predict()
na nim wiele razy później:
from keras.models import load_model
MY_MODEL = load_model('path/to/model/file')
def some_worker_function(inputs):
return MY_MODEL.predict(inputs)
W kontekście serwera WWW lub puli procesów roboczych, takich jak Celery, oznacza to, że załadujesz model podczas importowania modułu zawierającego load_model
linii, wtedy inny wątek wykona some_worker_function
, uruchamiając prognozę na zmiennej globalnej zawierającej model Keras. Jednak próba uruchomienia predykcji na modelu załadowanym w innym wątku powoduje błędy typu „tensor nie jest elementem tego wykresu”. Dzięki kilku postom SO, które poruszały ten temat, np. ValueError:Tensor Tensor(...) nie jest elementem tego wykresu. Podczas korzystania z globalnego modelu keras. Aby to zadziałało, musisz trzymać się używanego wykresu TensorFlow — jak widzieliśmy wcześniej, wykres jest własnością bieżącego wątku. Zaktualizowany kod wygląda tak:
from keras.models import load_model
import tensorflow as tf
MY_MODEL = load_model('path/to/model/file')
MY_GRAPH = tf.get_default_graph()
def some_worker_function(inputs):
with MY_GRAPH.as_default():
return MY_MODEL.predict(inputs)
Nieco zaskakującym zwrotem jest tutaj:powyższy kod jest wystarczający, jeśli używasz Thread
s, ale zawiesza się na czas nieokreślony, jeśli używasz Process
es. Domyślnie Celery używa procesów do zarządzania wszystkimi swoimi pulami pracowników. W tym momencie wszystko jest wciąż nie działa na selera.
Dlaczego to działa tylko w Thread
? tak?
W Pythonie Thread
s współdzielą ten sam globalny kontekst wykonania, co proces nadrzędny. Z dokumentacji _wątku Pythona:
Ten moduł zapewnia prymitywy niskiego poziomu do pracy z wieloma wątkami (zwanymi również lekkimi procesami lub zadaniami) — wiele wątków kontroli współdzielących swoją globalną przestrzeń danych.
Ponieważ wątki nie są faktycznie oddzielnymi procesami, używają tego samego interpretera Pythona i dlatego podlegają niesławnej Global Interpeter Lock (GIL). Być może, co ważniejsze dla tego dochodzenia, udostępniają globalna przestrzeń danych z rodzicem.
W przeciwieństwie do tego, Process
es są rzeczywiste nowe procesy zrodzone przez program. Oznacza to:
- Nowa instancja interpretera Pythona (i bez GIL)
- Globalna przestrzeń adresowa jest zduplikowana
Zwróć uwagę na różnicę tutaj. Podczas gdy Thread
mają dostęp do wspólnej pojedynczej globalnej zmiennej sesji (przechowywanej wewnętrznie w tensorflow_backend
moduł Keras), Process
mają duplikaty zmiennej sesji.
Moim najlepszym zrozumieniem tego problemu jest to, że zmienna sesji ma reprezentować unikalne połączenie między klientem (procesem) a środowiskiem wykonawczym TensorFlow, ale z powodu duplikacji w procesie rozwidlenia te informacje o połączeniu nie są odpowiednio dostosowywane. Powoduje to zawieszenie TensorFlow podczas próby użycia sesji utworzonej w innym procesie. Jeśli ktoś ma większy wgląd w to, jak to działa pod maską w TensorFlow, chciałbym to usłyszeć!
Rozwiązanie/obejście
Poszedłem z dostosowaniem selera tak, aby używał Thread
s zamiast Process
es do łączenia. To podejście ma pewne wady (patrz komentarz GIL powyżej), ale pozwala nam to załadować model tylko raz. I tak tak naprawdę nie jesteśmy ograniczeni procesorem, ponieważ środowisko wykonawcze TensorFlow maksymalnie wykorzystuje wszystkie rdzenie procesora (może ominąć GIL, ponieważ nie jest napisane w Pythonie). Musisz dostarczyć Celery z osobną biblioteką do tworzenia puli opartej na wątkach; dokumentacja sugeruje dwie opcje:gevent
lub eventlet
. Następnie przekazujesz wybraną bibliotekę pracownikowi przez --pool
argument wiersza poleceń.
Ewentualnie wydaje się (jak już dowiedziałeś się @pX0r), że inne backendy Keras, takie jak Theano, nie mają tego problemu. Ma to sens, ponieważ kwestie te są ściśle powiązane ze szczegółami implementacji TensorFlow. Osobiście nie próbowałem jeszcze Theano, więc Twój przebieg może się różnić.
Wiem, że to pytanie pojawiło się jakiś czas temu, ale problem nadal istnieje, więc miejmy nadzieję, że to komuś pomoże!