import json from decimal import Decimal from datetime import datetime import ijson import copy class InventoryManager: """库存管理系统核心类 核心特性: - 流式JSON处理(支持GB级文件) - 自动快照机制(每50条记录) - 乱序插入自动重计算 - 严格时间倒序维护 - 进度回调支持 """ def __init__(self, data_path): self.data_path = data_path self.snapshot_interval = 50 self.need_calculate = 0 # 距离上次快照的记录数 # 检查数据文件是否存在,不存在则创建 import os if not os.path.exists(self.data_path): self.initialize_data_file() def process_transaction(self, new_txn, progress_callback=None): """处理新交易记录(核心入口函数) 参数: new_txn -- 新交易记录字典 progress_callback -- 进度回调函数 callback(current, total, message) 返回: 处理结果字典 {"success": bool, "message": str, "recalculated_count": int} """ try: # 1. 验证和预处理新记录 if not self._validate_transaction(new_txn): return {"success": False, "message": "交易记录格式错误", "recalculated_count": 0} # 2. 添加daily_sequence new_txn['daily_sequence'] = self._get_next_sequence(new_txn['time']) # 3. 检查是否需要乱序插入处理 insert_position = self._find_insert_position(new_txn) if insert_position == 0: # 最新记录,直接插入 return self._insert_latest_transaction(new_txn, progress_callback) else: # 乱序插入,需要重计算 return self._insert_historical_transaction(new_txn, insert_position, progress_callback) except Exception as e: if "库存不足" in str(e): return {"success": False, "message": "库存不足", "recalculated_count": 0} return {"success": False, "message": f"处理失败: {str(e)}", "recalculated_count": 0} def _validate_transaction(self, txn): """验证交易记录格式""" required_fields = ['product', 'type', 'weight', 'time'] if txn['type'] == '入库': required_fields.append('price') for field in required_fields: if field not in txn: return False # 验证数据类型 try: Decimal(txn['weight']) if txn['type'] == '入库': Decimal(txn['price']) datetime.strptime(txn['time'], '%Y-%m-%d') except: return False return True def _find_insert_position(self, new_txn): """找到新记录应该插入的位置(流式读取)""" position = 0 new_date = new_txn['time'] new_seq = new_txn['daily_sequence'] try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) for i, txn in enumerate(data['transactions']): if txn.get('is_snapshot'): continue txn_date = txn.get('time', '') txn_seq = txn.get('daily_sequence', 0) # 比较日期和序列号 if (new_date > txn_date or (new_date == txn_date and new_seq > txn_seq)): return position position += 1 return position except: return 0 def _insert_latest_transaction(self, new_txn, progress_callback): """插入最新记录(无需重计算)""" try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) # 如果是出库,先计算价格 if new_txn['type'] == '出库': new_txn['price'] = self.calculate_outbound_price(new_txn['product'], new_txn['time']) # 更新产品库存 if not self._update_product_inventory(data['products'], new_txn): return {"success": False, "message": "库存不足", "recalculated_count": 0} # 插入记录 data['transactions'].insert(0, new_txn) # 检查是否需要创建快照 self.need_calculate += 1 if self.need_calculate >= self.snapshot_interval: snapshot = self._create_snapshot(data['products']) data['transactions'].insert(0, snapshot) self.need_calculate = 0 # 保存数据 with open(self.data_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) return {"success": True, "message": "记录添加成功", "recalculated_count": 0} except Exception as e: raise Exception(f"插入最新记录失败: {str(e)}") def _insert_historical_transaction(self, new_txn, insert_position, progress_callback): """插入历史记录并重计算""" try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) # 1. 找到基准快照 base_snapshot = self._find_base_snapshot(new_txn['time'], data['transactions']) # 2. 如果是出库,需要在重计算过程中处理价格 # 这里不预先计算价格,因为需要基于重计算后的状态 # 3. 插入新记录到正确位置 data['transactions'].insert(insert_position, new_txn) # 4. 删除插入位置之后的所有快照 self._remove_snapshots_after_position(data['transactions'], insert_position) # 5. 重计算从基准快照开始的所有记录 recalculated_count = self._recalculate_from_snapshot( data, base_snapshot, insert_position, progress_callback ) # 6. 保存数据 with open(self.data_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) return { "success": True, "message": f"历史记录插入成功,重计算了{recalculated_count}条记录", "recalculated_count": recalculated_count } except Exception as e: raise Exception(f"插入历史记录失败: {str(e)}") def _find_base_snapshot(self, target_date, transactions): """找到目标日期之前最近的快照""" for txn in reversed(transactions): if (txn.get('is_snapshot') and txn.get('timestamp', '') <= target_date): return txn # 如果没有找到快照,返回空的产品状态 return { 'is_snapshot': True, 'timestamp': '1900-01-01', 'products': {} } def _remove_snapshots_after_position(self, transactions, position): """删除指定位置之后的所有快照""" i = 0 while i < len(transactions): if i > position and transactions[i].get('is_snapshot'): transactions.pop(i) else: i += 1 def delete_transaction(self, transaction_index, progress_callback=None): """删除指定的交易记录并重计算 参数: transaction_index -- 要删除的交易记录在transactions列表中的索引 progress_callback -- 进度回调函数 callback(current, total, message) 返回: 处理结果字典 {"success": bool, "message": str, "recalculated_count": int} """ try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) # 检查索引是否有效 if transaction_index < 0 or transaction_index >= len(data['transactions']): return {"success": False, "message": "无效的交易记录索引", "recalculated_count": 0} # 获取要删除的交易记录 target_txn = data['transactions'][transaction_index] # 不能删除快照记录 if target_txn.get('is_snapshot'): return {"success": False, "message": "不能删除快照记录", "recalculated_count": 0} # 找到删除记录之前最近的快照 base_snapshot = self._find_base_snapshot(target_txn['time'], data['transactions'][:transaction_index]) # 删除目标记录 deleted_txn = data['transactions'].pop(transaction_index) # 删除该位置之后的所有快照 self._remove_snapshots_after_position(data['transactions'], transaction_index - 1) # 重计算从基准快照开始的所有记录 recalculated_count = self._recalculate_from_snapshot( data, base_snapshot, transaction_index - 1, progress_callback ) # 保存数据 with open(self.data_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) return { "success": True, "message": f"交易记录删除成功,重计算了{recalculated_count}条记录", "recalculated_count": recalculated_count } except Exception as e: return {"success": False, "message": f"删除失败: {str(e)}", "recalculated_count": 0} def find_transaction_index(self, product_name, transaction_time, transaction_type, weight): """根据交易记录的关键信息查找其在transactions列表中的索引 参数: product_name -- 产品名称 transaction_time -- 交易时间 transaction_type -- 交易类型(入库/出库) weight -- 重量 返回: 交易记录的索引,如果未找到返回-1 """ try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) for i, txn in enumerate(data['transactions']): if (not txn.get('is_snapshot') and txn.get('product') == product_name and txn.get('time') == transaction_time and txn.get('type') == transaction_type and abs(float(txn.get('weight', 0)) - float(weight)) < 0.00000001): # 浮点数比较 return i return -1 except Exception as e: print(f"查找交易记录索引失败: {str(e)}") return -1 def delete_product(self, product_name, progress_callback=None): """删除产品及其所有相关交易记录 参数: product_name -- 要删除的产品名称 progress_callback -- 进度回调函数 callback(current, total, message) 返回: 处理结果字典 {"success": bool, "message": str, "deleted_transactions": int} """ try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) # 确保产品名是字符串类型 product_name_str = str(product_name) # 检查产品是否存在 if product_name_str not in data["products"]: return {"success": False, "message": f"未找到产品: {product_name_str}", "deleted_transactions": 0} # 统计要删除的交易记录数量 deleted_count = 0 # 删除产品 del data["products"][product_name_str] # 删除相关交易记录,保留快照记录 original_count = len(data["transactions"]) data["transactions"] = [t for t in data["transactions"] if str(t.get("product")) != product_name_str or t.get("is_snapshot")] deleted_count = original_count - len(data["transactions"]) # 如果删除了交易记录,需要重新计算所有快照 if deleted_count > 0: # 删除所有快照,重新计算 data["transactions"] = [t for t in data["transactions"] if not t.get("is_snapshot")] # 重新计算所有产品的库存状态 self._recalculate_all_products(data, progress_callback) # 保存数据 with open(self.data_path, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) return { "success": True, "message": f"产品删除成功,删除了{deleted_count}条相关交易记录", "deleted_transactions": deleted_count } except Exception as e: return {"success": False, "message": f"删除产品失败: {str(e)}", "deleted_transactions": 0} def _recalculate_all_products(self, data, progress_callback): """重新计算所有产品的库存状态""" # 重置所有产品状态 for product_name in data["products"]: data["products"][product_name] = { "total_weight": "0.00000000", "avg_price": "0.00000000" } # 按时间正序排序所有交易记录 transactions = [t for t in data["transactions"] if not t.get("is_snapshot")] transactions.sort(key=lambda x: (x.get('time', ''), x.get('daily_sequence', 0))) # 重新处理所有交易记录 processed_count = 0 snapshot_count = 0 for i, txn in enumerate(transactions): if progress_callback: progress_callback(i + 1, len(transactions), f"重新计算交易记录: {txn.get('product', '')}") if txn['type'] == '入库': self._update_product_inventory(data["products"], txn) elif txn['type'] == '出库': # 基于当前重算状态计算出库价格,避免循环依赖 txn['price'] = self._calculate_outbound_price_from_products(data["products"], txn['product'], txn['time']) if not self._update_product_inventory(data["products"], txn): # 如果重计算过程中发现库存不足,则回滚并返回错误 raise Exception("重计算过程中发现库存不足") processed_count += 1 snapshot_count += 1 # 每50条记录创建快照 if snapshot_count >= self.snapshot_interval: snapshot = self._create_snapshot(data["products"]) # 插入快照到正确位置 insert_position = self._find_insert_position_for_snapshot(data['transactions'], snapshot) data['transactions'].insert(insert_position, snapshot) snapshot_count = 0 # 更新transactions列表为重新排序后的记录 data['transactions'] = [t for t in data['transactions'] if t.get('is_snapshot')] + transactions # 重新按时间倒序排序 data['transactions'].sort(key=lambda x: (x.get('time', x.get('timestamp', '')), x.get('daily_sequence', 0)), reverse=True) def _recalculate_from_snapshot(self, data, base_snapshot, start_position, progress_callback): """从快照开始重计算所有记录""" # 初始化产品状态 - 从快照复制,确保所有当前存在的产品都有正确的初始状态 products = copy.deepcopy(base_snapshot.get('products', {})) # 确保当前数据中存在但快照中不存在的产品也被正确初始化为0 for product_name in data.get('products', {}): if product_name not in products: products[product_name] = { 'total_weight': '0.00000000', 'avg_price': '0.00000000' } # 收集需要重计算的记录 - 从基准快照开始的所有记录 records_to_process = [] base_timestamp = base_snapshot.get('timestamp', '1900-01-01') for i, txn in enumerate(data['transactions']): if (not txn.get('is_snapshot') and txn.get('time', '') >= base_timestamp): records_to_process.append((i, txn)) # 按时间正序排序进行重计算,确保库存状态正确重建 records_to_process.sort(key=lambda x: (x[1].get('time', ''), x[1].get('daily_sequence', 0))) total_records = len(records_to_process) processed_count = 0 snapshot_count = 0 for original_index, txn in records_to_process: # 更新进度 if progress_callback: progress_callback(processed_count, total_records, f"重计算记录 {processed_count+1}/{total_records}") # 处理交易记录 if txn['type'] == '入库': self._update_product_inventory(products, txn) elif txn['type'] == '出库': # 基于当前重算状态计算出库价格,避免循环依赖 txn['price'] = self._calculate_outbound_price_from_products(products, txn['product'], txn['time']) if not self._update_product_inventory(products, txn): raise Exception("重计算过程中发现库存不足") processed_count += 1 snapshot_count += 1 # 每50条记录创建快照 if snapshot_count >= self.snapshot_interval: snapshot = self._create_snapshot(products) # 插入快照到正确位置,保持时间倒序 insert_position = self._find_insert_position_for_snapshot(data['transactions'], snapshot) data['transactions'].insert(insert_position, snapshot) snapshot_count = 0 # 更新最终产品状态 data['products'] = products return total_records def _calculate_outbound_price_from_products(self, products, product_name, outbound_date): """基于当前产品状态计算出库价格""" if product_name in products: return products[product_name].get('avg_price', '0.00000000') return '0.00000000' def _update_product_inventory(self, products, txn): """更新产品库存""" product_name = txn['product'] if product_name not in products: products[product_name] = { 'total_weight': '0.00000000', 'avg_price': '0.00000000' } if txn['type'] == '入库': self._update_average_price( products[product_name], txn['weight'], txn['price'] ) elif txn['type'] == '出库': current_weight = Decimal(products[product_name]['total_weight']) out_weight = Decimal(txn['weight']) if current_weight < out_weight: # 库存不足,返回错误指示 return False new_weight = current_weight - out_weight products[product_name]['total_weight'] = f"{new_weight:.8f}" return True return True def _update_average_price(self, product, weight, price): """更新产品加权平均价格""" total_weight = Decimal(product['total_weight']) avg_price = Decimal(product['avg_price']) new_weight = Decimal(weight) new_price = Decimal(price) if total_weight > 0: new_total = total_weight + new_weight new_avg = (total_weight * avg_price + new_weight * new_price) / new_total else: new_total = new_weight new_avg = new_price product['total_weight'] = f"{new_total:.8f}" product['avg_price'] = f"{new_avg:.8f}" def _create_snapshot(self, products): """创建快照""" return { "is_snapshot": True, "timestamp": datetime.now().strftime('%Y-%m-%d'), "daily_sequence": 0, "products": copy.deepcopy(products) } def _find_insert_position_for_snapshot(self, transactions, snapshot): """找到快照的正确插入位置,保持时间倒序""" snapshot_time = snapshot['timestamp'] for i, txn in enumerate(transactions): txn_time = txn.get('time', txn.get('timestamp', '')) if txn_time < snapshot_time: return i return len(transactions) def _get_next_sequence(self, timestamp): """获取当日下一个序列号""" try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) max_seq = 0 for txn in data['transactions']: if txn.get('time') == timestamp and not txn.get('is_snapshot'): max_seq = max(max_seq, txn.get('daily_sequence', 0)) return max_seq + 1 except: return 1 def calculate_outbound_price(self, product_name, outbound_date): """计算出库价格:基于最近且日期较早的快照,然后按时间顺序用后续入库记录进行加权平均""" try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) # 找到基准快照 base_snapshot = self._find_base_snapshot(outbound_date, data['transactions']) # 初始化产品状态 - 确保快照中不存在的产品重量和价格设置为0 snapshot_products = base_snapshot.get('products', {}) if product_name in snapshot_products: product_data = snapshot_products[product_name] current_weight = Decimal(product_data.get('total_weight', '0.00000000')) current_avg_price = Decimal(product_data.get('avg_price', '0.00000000')) else: # 快照中不存在该产品,设置为0 current_weight = Decimal('0.00000000') current_avg_price = Decimal('0.00000000') # 收集快照之后的所有相关交易(入库和出库) base_timestamp = base_snapshot.get('timestamp', '1900-01-01') relevant_transactions = [] for txn in data['transactions']: if (not txn.get('is_snapshot') and txn.get('product') == product_name and txn.get('time', '') >= base_timestamp and txn.get('time', '') <= outbound_date): relevant_transactions.append(txn) # 按时间正序排序,确保加权平均计算正确 relevant_transactions.sort(key=lambda x: (x.get('time', ''), x.get('daily_sequence', 0))) # 依次进行加权平均计算,考虑入库和出库 for record in relevant_transactions: if record['type'] == '入库': record_weight = Decimal(record['weight']) record_price = Decimal(record['price']) if current_weight > 0: new_total_weight = current_weight + record_weight current_avg_price = (current_weight * current_avg_price + record_weight * record_price) / new_total_weight current_weight = new_total_weight else: current_avg_price = record_price current_weight = record_weight elif record['type'] == '出库': out_weight = Decimal(record['weight']) current_weight -= out_weight # 如果出库导致库存为负,理论上不应该发生,但为了健壮性,可以将其设为0 if current_weight < 0: current_weight = Decimal('0.00000000') current_avg_price = Decimal('0.00000000') return f"{current_avg_price:.8f}" except Exception as e: return '0.00000000' def get_product_list(self): """获取产品列表(流式读取)""" try: with open(self.data_path, 'r', encoding='utf-8') as f: data = json.load(f) return data.get('products', {}) except: return {} def initialize_data_file(self): """初始化数据文件""" # 创建初始空快照 initial_snapshot = { "is_snapshot": True, "timestamp": "1900-01-01", "daily_sequence": 0, "products": {} } initial_data = { "$schema": "./inventory_schema.json", "products": {}, "transactions": [initial_snapshot] } try: with open(self.data_path, 'w', encoding='utf-8') as f: json.dump(initial_data, f, indent=2, ensure_ascii=False) except Exception as e: raise Exception(f"初始化数据文件失败: {str(e)}") if __name__ == '__main__': manager = InventoryManager('inventory_data.json') # 测试代码 test_txn = { 'product': 'TEST001', 'type': '入库', 'weight': '100.00000000', 'price': '10.00000000', 'note': '测试入库', 'time': '2025-01-15' } result = manager.process_transaction(test_txn) print(result)