使用Python部署机器学习模型的10个实践经验

 有时候,作为数据科学家,我们会忘记公司付钱让我们干什么。我们首先是开发人员,然后是研究人员,然后可能是数学家。我们的首要责任是快速开发无bug的解决方案。

专注于为中小企业提供网站设计制作、网站制作服务,电脑端+手机端+微信端的三站合一,更高效的管理,为中小企业嘉祥免费做网站提供优质的服务。我们立足成都,凝聚了一批互联网行业人才,有力地推动了上1000+企业的稳健成长,帮助中小企业通过网站建设实现规模扩充和转变。

我们能做模型并不意味着我们就是神。它没有给我们写垃圾代码的自由。

从一开始,我就犯了很多错误,我想和大家分享一下我所看到的ML工程中最常见的技能。在我看来,这也是目前这个行业最缺乏的技能。

我称他们为“软件文盲”,因为他们中的很多人都是非计算机科学课程学习平台(Coursera)的工程师。我自己曾经就是

如果要在一个伟大的数据科学家和一个伟大的ML工程师之间招聘,我会选择后者。让我们开始吧。

1. 学会写抽象类

一旦你开始编写抽象类,你就会知道它能给你的代码库带来多大的清晰度。它们执行相同的方法和方法名称。如果很多人都在同一个项目上工作,每个人都会开始使用不同的方法。这可能会造成无效率的混乱。

 
 
 
 
  1. import os 
  2. from abc import ABCMeta, abstractmethod 
  3.  
  4.  
  5. class DataProcessor(metaclass=ABCMeta): 
  6.     """Base processor to be used for all preparation.""" 
  7.     def __init__(self, input_directory, output_directory): 
  8.         self.input_directory = input_directory 
  9.         self.output_directory = output_directory 
  10.  
  11.     @abstractmethod 
  12.     def read(self): 
  13.         """Read raw data.""" 
  14.  
  15.     @abstractmethod 
  16.     def process(self): 
  17.         """Processes raw data. This step should create the raw dataframe with all the required features. Shouldn't implement statistical or text cleaning.""" 
  18.  
  19.     @abstractmethod 
  20.     def save(self): 
  21.         """Saves processed data.""" 
  22.  
  23.  
  24. class Trainer(metaclass=ABCMeta): 
  25.     """Base trainer to be used for all models.""" 
  26.  
  27.     def __init__(self, directory): 
  28.         self.directory = directory 
  29.         self.model_directory = os.path.join(directory, 'models') 
  30.  
  31.     @abstractmethod 
  32.     def preprocess(self): 
  33.         """This takes the preprocessed data and returns clean data. This is more about statistical or text cleaning.""" 
  34.  
  35.     @abstractmethod 
  36.     def set_model(self): 
  37.         """Define model here.""" 
  38.  
  39.     @abstractmethod 
  40.     def fit_model(self): 
  41.         """This takes the vectorised data and returns a trained model.""" 
  42.  
  43.     @abstractmethod 
  44.     def generate_metrics(self): 
  45.         """Generates metric with trained model and test data.""" 
  46.  
  47.     @abstractmethod 
  48.     def save_model(self, model_name): 
  49.         """This method saves the model in our required format.""" 
  50.  
  51.  
  52. class Predict(metaclass=ABCMeta): 
  53.     """Base predictor to be used for all models.""" 
  54.  
  55.     def __init__(self, directory): 
  56.         self.directory = directory 
  57.         self.model_directory = os.path.join(directory, 'models') 
  58.  
  59.     @abstractmethod 
  60.     def load_model(self): 
  61.         """Load model here.""" 
  62.  
  63.     @abstractmethod 
  64.     def preprocess(self): 
  65.         """This takes the raw data and returns clean data for prediction.""" 
  66.  
  67.     @abstractmethod 
  68.     def predict(self): 
  69.         """This is used for prediction.""" 
  70.  
  71.  
  72. class BaseDB(metaclass=ABCMeta): 
  73.     """ Base database class to be used for all DB connectors.""" 
  74.     @abstractmethod 
  75.     def get_connection(self): 
  76.         """This creates a new DB connection.""" 
  77.     @abstractmethod 
  78.     def close_connection(self): 
  79.         """This closes the DB connection.""" 

2. 在最前面设置你的随机数种子

实验的可重复性是非常重要的,而种子是我们的敌人。抓住它,否则会导致不同的训练/测试数据分割和不同的权值初始化神经网络。这导致了不一致的结果。

 
 
 
 
  1. def set_seed(args): 
  2.     random.seed(args.seed) 
  3.     np.random.seed(args.seed) 
  4.     torch.manual_seed(args.seed) 
  5.     if args.n_gpu > 0: 
  6.         torch.cuda.manual_seed_all(args.seed) 

3. 从几行数据开始

如果你的数据太大,而你的工作是代码的后面的部分,如清理数据或建模,那么可以使用nrows来避免每次加载巨大的数据。当你只想测试代码而不实际运行整个代码时,请使用此方法。

当你的本地PC配置无法加载所有的数据的时候,但你又喜欢在本地开发时,这是非常适用的,

 
 
 
 
  1. df_train = pd.read_csv(‘train.csv’, nrows=1000) 

4. 预见失败(成熟开发人员的标志)

一定要检查数据中的NA,因为这些会给你以后带来问题。即使你当前的数据没有,这并不意味着它不会在未来的再训练循环中发生。所以无论如何继续检查。

 
 
 
 
  1. print(len(df)) 
  2. df.isna().sum() 
  3. df.dropna() 
  4. print(len(df)) 

5. 显示处理进度

当你在处理大数据时,知道它将花费多少时间以及我们在整个处理过程中的位置肯定会让你感觉很好。

选项 1 — tqdm

 
 
 
 
  1. from tqdm import tqdm 
  2. import time 
  3.  
  4. tqdm.pandas() 
  5.  
  6. df['col'] = df['col'].progress_apply(lambda x: x**2) 
  7.  
  8. text = "" 
  9. for char in tqdm(["a", "b", "c", "d"]): 
  10.     time.sleep(0.25) 
  11.     text = text + char 

选项 2 — fastprogress

 
 
 
 
  1. from fastprogress.fastprogress import master_bar, progress_bar 
  2. from time import sleep 
  3. mb = master_bar(range(10)) 
  4. for i in mb: 
  5.     for j in progress_bar(range(100), parent=mb): 
  6.         sleep(0.01) 
  7.         mb.child.comment = f'second bar stat' 
  8.     mb.first_bar.comment = f'first bar stat' 
  9.     mb.write(f'Finished loop {i}.') 

6. Pandas很慢

如果你使用过pandas,你就会知道有时它有多慢 —— 尤其是groupby。不用打破头寻找“伟大的”解决方案加速,只需使用modin改变一行代码就可以了。

 
 
 
 
  1. import modin.pandas as pd 

7. 统计函数的时间

不是所有的函数都是生而平等的

即使整个代码都能工作,也不意味着你写的代码很棒。一些软件bug实际上会使你的代码变慢,所以有必要找到它们。使用这个装饰器来记录函数的时间。

 
 
 
 
  1. import time 
  2.  
  3.  
  4. def timing(f): 
  5.     """Decorator for timing functions 
  6.     Usage: 
  7.     @timing 
  8.     def function(a): 
  9.         pass 
  10.     """ 
  11.  
  12.     @wraps(f) 
  13.     def wrapper(*args, **kwargs): 
  14.         start = time.time() 
  15.         result = f(*args, **kwargs) 
  16.         end = time.time() 
  17.         print('function:%r took: %2.2f sec' % (f.__name__,  end - start)) 
  18.         return result 
  19.     return wrapper 

8. 不要在云上烧钱

没有人喜欢浪费云资源的工程师。

我们的一些实验可以持续几个小时。很难跟踪它并在它完成时关闭云实例。我自己也犯过错误,也见过有人把实例开了好几天。

这种情况发生在星期五,离开后,周一才意识到

只要在执行结束时调用这个函数,你的屁股就再也不会着火了!!

但是将主代码包装在try中,此方法也包装在except中 —— 这样如果发生错误,服务器就不会继续运行。是的,我也处理过这些情况

让我们更负责任一点,不要产生二氧化碳。

 
 
 
 
  1. import os 
  2.  
  3. def run_command(cmd): 
  4.     return os.system(cmd) 
  5.      
  6. def shutdown(seconds=0, os='linux'): 
  7.     """Shutdown system after seconds given. Useful for shutting EC2 to save costs.""" 
  8.     if os == 'linux': 
  9.         run_command('sudo shutdown -h -t sec %s' % seconds) 
  10.     elif os == 'windows': 
  11.         run_command('shutdown -s -t %s' % seconds) 

9. 创建和保存报告

在建模的某个特定点之后,所有伟大的见解都只来自错误和度量分析。确保为自己和你的管理层创建和保存格式良好的报告。

管理层喜欢报告,对吗?

 
 
 
 
  1. import json 
  2. import os 
  3.  
  4. from sklearn.metrics import (accuracy_score, classification_report, 
  5.                              confusion_matrix, f1_score, fbeta_score) 
  6.  
  7. def get_metrics(y, y_pred, beta=2, average_method='macro', y_encoder=None): 
  8.     if y_encoder: 
  9.         y = y_encoder.inverse_transform(y) 
  10.         y_pred = y_encoder.inverse_transform(y_pred) 
  11.     return { 
  12.         'accuracy': round(accuracy_score(y, y_pred), 4), 
  13.         'f1_score_macro': round(f1_score(y, y_pred, average=average_method), 4), 
  14.         'fbeta_score_macro': round(fbeta_score(y, y_pred, beta, average=average_method), 4), 
  15.         'report': classification_report(y, y_pred, output_dict=True), 
  16.         'report_csv': classification_report(y, y_pred, output_dict=False).replace('\n','\r\n') 
  17.     } 
  18.  
  19.  
  20. def save_metrics(metrics: dict, model_directory, file_name): 
  21.     path = os.path.join(model_directory, file_name + '_report.txt') 
  22.     classification_report_to_csv(metrics['report_csv'], path) 
  23.     metrics.pop('report_csv') 
  24.     path = os.path.join(model_directory, file_name + '_metrics.json') 
  25.     json.dump(metrics, open(path, 'w'), indent=4) 

10. 写好APIs

结果不好就是不好。

你可以进行很好的数据清理和建模,但最终仍可能造成巨大的混乱。我与人打交道的经验告诉我,许多人不清楚如何编写好的api、文档和服务器设置。我很快会写另一篇关于这个的文章,但是让我开始吧。

下面是在不太高的负载下(比如1000/min)部署经典的ML和DL的好方法。

fasbut + uvicorn

  • Fastest — 使用fastapi编写API,因为它很快。
  • Documentation — 用fastapi写API让我们不用操心文档。
  • Workers — 使用uvicorn部署API

使用4个worker运行这些命令进行部署。通过负载测试优化workers的数量。

 
 
 
 
  1. pip install fastapi uvicorn 
  2. uvicorn main:app --workers 4 --host 0.0.0.0 --port 8000 

当前名称:使用Python部署机器学习模型的10个实践经验
网页地址:http://www.shufengxianlan.com/qtweb/news31/126631.html

网站建设、网络推广公司-创新互联,是专注品牌与效果的网站制作,网络营销seo公司;服务项目有等

广告

声明:本网站发布的内容(图片、视频和文字)以用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服。电话:028-86922220;邮箱:631063699@qq.com。内容未经允许不得转载,或转载时需注明来源: 创新互联