Fix passing of device to setup method in /reload, /regenerate API

- Use local variable to pass device to asymmetric.setup method via /reload, /regenerate API
- Set default argument to torch.device('cpu') instead of 'cpu' to be more formal
This commit is contained in:
Debanjum Singh Solanky
2022-06-30 01:32:56 +04:00
parent eda4b65ddb
commit 7677465f23
2 changed files with 10 additions and 5 deletions

View File

@@ -25,7 +25,6 @@ processor_config = ProcessorConfigModel()
config_file = ""
verbose = 0
app = FastAPI()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
app.mount("/views", StaticFiles(directory="views"), name="views")
templates = Jinja2Templates(directory="views/")
@@ -53,6 +52,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
print(f'No query param (q) passed in API call to initiate search')
return {}
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
user_query = q
results_count = n
@@ -95,6 +95,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
@app.get('/reload')
def regenerate(t: Optional[SearchType] = None):
global model
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = initialize_search(config, regenerate=False, t=t, device=device)
return {'status': 'ok', 'message': 'reload completed'}
@@ -102,6 +103,7 @@ def regenerate(t: Optional[SearchType] = None):
@app.get('/regenerate')
def regenerate(t: Optional[SearchType] = None):
global model
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = initialize_search(config, regenerate=True, t=t, device=device)
return {'status': 'ok', 'message': 'regeneration completed'}
@@ -147,7 +149,7 @@ def chat(q: str):
return {'status': 'ok', 'response': gpt_response}
def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None):
def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None, device=torch.device("cpu")):
# Initialize Org Notes Search
if (t == SearchType.Notes or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
@@ -241,8 +243,11 @@ if __name__ == '__main__':
# Store the raw config data.
config = args.config
# Set device to GPU if available
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
# Initialize the search model from Config
model = initialize_search(args.config, args.regenerate)
model = initialize_search(args.config, args.regenerate, device=device)
# Initialize Processor from Config
processor_config = initialize_processor(args.config)