From 4897bd096d23930b0e84a664c16fee3b2347309c Mon Sep 17 00:00:00 2001 From: Timofey Malinin Date: Thu, 5 Oct 2023 12:57:56 +0000 Subject: [PATCH] Update tasks.py --- service/tasks.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/service/tasks.py b/service/tasks.py index 8fb58da..729eec9 100644 --- a/service/tasks.py +++ b/service/tasks.py @@ -4,6 +4,7 @@ import catboost import geopandas as gpd import numpy as np import pandas as pd +import shap import psycopg2 import sqlalchemy from celery import shared_task @@ -189,11 +190,20 @@ def raschet(table_name='service_placementpoint', need_time=True): status.status = 'Обучение inference 100%' current_pred = sum(y_infers) / 5 + # расчет шапов + explainer = shap.TreeExplainer(model) + shap_values = explainer(X_inf.drop(columns=['id'])) + shap_fields = pd.DataFrame(shap_values.values) + shap_fields.columns = X_inf.drop(columns=['id']).columns + '_shap' + shap_fields = shap_fields.drop(columns = ['age_day_shap']) + shap_fields['sum'] = abs(shap_fields).sum(axis=1) + shap_fields = round(shap_fields.iloc[:,:32].div(shap_fields['sum'], axis=0)*100, 2) + # Обновление полей по результатам работы модели update_fields = pts_inf[ [ 'id', 'age_day_init', 'status', 'fact', 'delta_current', 'delta_first', 'plan_current', 'plan_first', - 'prediction_first', + 'prediction_first', 'target_post_cnt', 'target_dist' ] ] update_fields = update_fields.join( @@ -252,8 +262,52 @@ def raschet(table_name='service_placementpoint', need_time=True): conn2 = None log_to_telegram('Не удалось подключиться к базе данных') - # prediction_current if conn2 is not None: + # апдейт шапов + update_fields_shap = pd.concat([shap_fields, update_fields[['id']]], axis=1) + update_records0 = [] + for i in range(0, len(update_fields_shap)): + update_records1 = [] + for n in list(update_fields_shap): + update_records1.append(int(update_fields_shap[n][i])) + update_records0.append(tuple(update_records1)) + shap_fields_name = str(list(shap_fields))[1:-1].replace("'", "").replace(',', '=%s,') + sql_update_query = f"""Update {table_name} set {shap_fields_name} = %s where id = %s""" + try: + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records0) + conn2.commit() + except: + cursor.execute("ROLLBACK") + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records0) + conn2.commit() + + # target_post_cnt + update_records1 = [] + for i in range(0, len(update_fields)): + update_records1.append((int(update_fields.target_post_cnt[i]), int(update_fields.id[i]))) + sql_update_query = f"""Update {table_name} set target_post_cnt = %s where id = %s""" + try: + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records1) + conn2.commit() + except: + cursor.execute("ROLLBACK") + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records1) + conn2.commit() + + # target_dist + update_records1 = [] + for i in range(0, len(update_fields)): + update_records1.append((int(update_fields.target_dist[i]), int(update_fields.id[i]))) + sql_update_query = f"""Update {table_name} set target_dist = %s where id = %s""" + try: + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records1) + conn2.commit() + except: + cursor.execute("ROLLBACK") + psycopg2.extras.execute_batch(cursor, sql_update_query, update_records1) + conn2.commit() + + # prediction_current update_records1 = [] for i in range(0, len(update_fields)): update_records1.append((int(update_fields.prediction_current[i]), int(update_fields.id[i])))