diff --git a/service/tasks.py b/service/tasks.py index cc0e4e6..fd316cd 100644 --- a/service/tasks.py +++ b/service/tasks.py @@ -20,6 +20,7 @@ from service.models import PlacementPoint # Запустить scheduler # celery -A postamates beat -l INFO. + @shared_task() def raschet(): conn = sqlalchemy.create_engine( @@ -27,11 +28,26 @@ def raschet(): connect_args={'options': '-csearch_path=public'}, ) query = text('select * from service_placementpoint') - pts = pd.read_sql(query, conn.connect()) - feats = ['id', 'popul_home', 'popul_job', 'other_post_cnt', 'yndxfood_sum', 'target_post_cnt'] + connection = conn.connect() + pts = pd.read_sql(query, connection) + pts.loc[pts.target_dist > 700, 'target_dist'] = 700 + pts = pts.sort_values(by='id').reset_index(drop=True) + + feats = [ + 'id', 'metro_dist', 'target_dist', 'property_price_bargains', 'property_price_offers', + 'property_mean_floor', + 'property_era', 'flats_cnt_2', 'flats_cnt', 'popul_home', 'popul_job', 'other_post_cnt', 'yndxfood_sum', + 'yndxfood_cnt', 'school_cnt', 'kindergar_cnt', 'target_post_cnt', 'public_stop_cnt', 'sport_center_cnt', + 'pharmacy_cnt', 'supermarket_cnt', 'supermarket_premium_cnt', 'clinic_cnt', 'bank_cnt', 'reca_cnt', + 'lab_cnt', 'culture_cnt', 'attraction_cnt', 'mfc_cnt', 'bc_cnt', 'tc_cnt', 'rival_pvz_cnt', + 'rival_post_cnt', + 'business_activity', 'age_day', 'target_age_nearby_mean', 'target_cnt_ao_mean', + # 'target_cnt_nearby_mean' + ] # Записи для обучения pts_trn = pts.loc[pts.sample_trn == True].reset_index(drop=True) + # pts_trn = pts_trn.loc[pts_trn.fact < 450].reset_index(drop=True) X_trn = pts_trn[feats].drop(columns=['id']) Y_trn = pts_trn[['fact']] @@ -39,24 +55,27 @@ def raschet(): pts_inf = pts.loc[(pts.status == 'Pending') | (pts.status == 'Installation') | (pts.status == 'Cancelled') | - ((pts.status == 'Pending') & (pts.sample_trn == False))].reset_index(drop=True) + ((pts.status == 'Working') & (pts.sample_trn == False))].reset_index(drop=True) + pts_inf['age_day'] = 240 X_inf = pts_inf[feats] + seeds = [39, 85, 15, 1, 59] # Обучение, инференс r2_scores = [] mapes = [] y_infers = [] - while len(r2_scores) < 5: - x_trn, x_test, y_trn, y_test = ms.train_test_split(X_trn, Y_trn, test_size=0.2) - model = catboost.CatBoostRegressor(cat_features=[]) + for i in seeds: + x_trn, x_test, y_trn, y_test = ms.train_test_split(X_trn, Y_trn, test_size=0.2, random_state=i) + model = catboost.CatBoostRegressor(cat_features=['property_era'], random_state=i) model.fit(x_trn, y_trn, verbose=False) r2_score = metrics.r2_score(y_test, model.predict(x_test)) mape = metrics.mean_absolute_percentage_error(y_test, model.predict(x_test)) - if ((r2_score > 0) & (mape < 0.5)): + if ((r2_score > 0.45) & (mape < 0.25)): r2_scores.append(r2_score) mapes.append(mape) y_infers.append(model.predict(X_inf.drop(columns=['id']))) + current_pred = sum(y_infers) / 5 # Обновление полей по результатам работы модели @@ -65,7 +84,7 @@ def raschet(): pd.concat( [ X_inf[['id']], - pd.DataFrame({'prediction_current': current_pred}), + pd.DataFrame([{'prediction_current': current_pred}]), ], axis=1, ).set_index('id'), @@ -86,8 +105,13 @@ def raschet(): update_records1.append((int(update_fields.prediction_current[i]), int(update_fields.id[i]))) sql_update_query = """Update service_placementpoint set prediction_current = %s where id = %s""" - cursor.executemany(sql_update_query, update_records1) - conn2.commit() + try: + cursor.executemany(sql_update_query, update_records1) + conn2.commit() + except Exception: + cursor.execute('ROLLBACK') + cursor.executemany(sql_update_query, update_records1) + conn2.commit() @shared_task()